Integrate Bifrost LLM gateway, add test suite, implement memory pipeline
- Add Bifrost (maximhq/bifrost) as LLM gateway: all inference routes through bifrost:8080/v1 with retry logic and observability; VRAMManager keeps direct Ollama access for VRAM flush/prewarm operations - Switch medium model from qwen3:4b to qwen2.5:1.5b (direct call, no tools) via _DirectModel wrapper; complex keeps create_deep_agent with qwen3:8b - Implement out-of-agent memory pipeline: _retrieve_memories pre-fetches relevant context (injected into all tiers), _store_memory runs as background task after each reply writing to openmemory/Qdrant - Add tests/unit/ with 133 tests covering router, channels, vram_manager, agent helpers; move integration test to tests/integration/ - Add bifrost-config.json with GPU Ollama (qwen2.5:0.5b/1.5b, qwen3:4b/8b, gemma3:4b) and CPU Ollama providers - Integration test 28/29 pass (only grammy fails — no TELEGRAM_BOT_TOKEN) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
80
tests/unit/conftest.py
Normal file
80
tests/unit/conftest.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Stub out all third-party packages that Adolf's source modules import.
|
||||
This lets the unit tests run without a virtualenv or Docker environment.
|
||||
Stubs are installed into sys.modules before any test file is collected.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _mock(name: str) -> MagicMock:
|
||||
m = MagicMock(name=name)
|
||||
sys.modules[name] = m
|
||||
return m
|
||||
|
||||
|
||||
# ── pydantic: BaseModel must be a real class so `class Foo(BaseModel)` works ──
|
||||
|
||||
class _FakeBaseModel:
|
||||
model_fields: dict = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, **data):
|
||||
for k, v in data.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
_pydantic = _mock("pydantic")
|
||||
_pydantic.BaseModel = _FakeBaseModel
|
||||
|
||||
# ── httpx: used by channels.py, vram_manager.py, agent.py ────────────────────
|
||||
|
||||
_mock("httpx")
|
||||
|
||||
# ── fastapi ───────────────────────────────────────────────────────────────────
|
||||
|
||||
_fastapi = _mock("fastapi")
|
||||
_mock("fastapi.responses")
|
||||
|
||||
# ── langchain stack ───────────────────────────────────────────────────────────
|
||||
|
||||
_mock("langchain_openai")
|
||||
|
||||
_lc_core = _mock("langchain_core")
|
||||
_lc_msgs = _mock("langchain_core.messages")
|
||||
_mock("langchain_core.tools")
|
||||
|
||||
# Provide real-ish message classes so router.py can instantiate them
|
||||
class _FakeMsg:
|
||||
def __init__(self, content=""):
|
||||
self.content = content
|
||||
|
||||
class SystemMessage(_FakeMsg):
|
||||
pass
|
||||
|
||||
class HumanMessage(_FakeMsg):
|
||||
pass
|
||||
|
||||
class AIMessage(_FakeMsg):
|
||||
def __init__(self, content="", tool_calls=None):
|
||||
super().__init__(content)
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
_lc_msgs.SystemMessage = SystemMessage
|
||||
_lc_msgs.HumanMessage = HumanMessage
|
||||
_lc_msgs.AIMessage = AIMessage
|
||||
|
||||
_mock("langchain_mcp_adapters")
|
||||
_mock("langchain_mcp_adapters.client")
|
||||
_mock("langchain_community")
|
||||
_mock("langchain_community.utilities")
|
||||
|
||||
# ── deepagents (agent_factory.py) ─────────────────────────────────────────────
|
||||
|
||||
_mock("deepagents")
|
||||
|
||||
161
tests/unit/test_agent_helpers.py
Normal file
161
tests/unit/test_agent_helpers.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Unit tests for agent.py helper functions:
|
||||
- _strip_think(text)
|
||||
- _extract_final_text(result)
|
||||
|
||||
agent.py has heavy FastAPI/LangChain imports; conftest.py stubs them out so
|
||||
these pure functions can be imported and tested in isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
# conftest.py has already installed all stubs into sys.modules.
|
||||
# The FastAPI app is instantiated at module level in agent.py —
|
||||
# with the mocked fastapi, that just creates a MagicMock() object
|
||||
# and the route decorators are no-ops.
|
||||
from agent import _strip_think, _extract_final_text
|
||||
|
||||
|
||||
# ── _strip_think ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestStripThink:
|
||||
def test_removes_single_think_block(self):
|
||||
text = "<think>internal reasoning</think>Final answer."
|
||||
assert _strip_think(text) == "Final answer."
|
||||
|
||||
def test_removes_multiline_think_block(self):
|
||||
text = "<think>\nLine one.\nLine two.\n</think>\nResult here."
|
||||
assert _strip_think(text) == "Result here."
|
||||
|
||||
def test_no_think_block_unchanged(self):
|
||||
text = "This is a plain answer with no think block."
|
||||
assert _strip_think(text) == text
|
||||
|
||||
def test_removes_multiple_think_blocks(self):
|
||||
text = "<think>step 1</think>middle<think>step 2</think>end"
|
||||
assert _strip_think(text) == "middleend"
|
||||
|
||||
def test_strips_surrounding_whitespace(self):
|
||||
text = " <think>stuff</think> answer "
|
||||
assert _strip_think(text) == "answer"
|
||||
|
||||
def test_empty_think_block(self):
|
||||
text = "<think></think>Hello."
|
||||
assert _strip_think(text) == "Hello."
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _strip_think("") == ""
|
||||
|
||||
def test_only_think_block_returns_empty(self):
|
||||
text = "<think>nothing useful</think>"
|
||||
assert _strip_think(text) == ""
|
||||
|
||||
def test_think_block_with_nested_tags(self):
|
||||
text = "<think>I should use <b>bold</b> here</think>Done."
|
||||
assert _strip_think(text) == "Done."
|
||||
|
||||
def test_preserves_markdown(self):
|
||||
text = "<think>plan</think>## Report\n\n- Point one\n- Point two"
|
||||
result = _strip_think(text)
|
||||
assert result == "## Report\n\n- Point one\n- Point two"
|
||||
|
||||
|
||||
# ── _extract_final_text ────────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFinalText:
|
||||
def _ai_msg(self, content: str, tool_calls=None):
|
||||
"""Create a minimal AIMessage-like object."""
|
||||
class AIMessage:
|
||||
pass
|
||||
m = AIMessage()
|
||||
m.content = content
|
||||
m.tool_calls = tool_calls or []
|
||||
return m
|
||||
|
||||
def _human_msg(self, content: str):
|
||||
class HumanMessage:
|
||||
pass
|
||||
m = HumanMessage()
|
||||
m.content = content
|
||||
return m
|
||||
|
||||
def test_returns_last_ai_message_content(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._human_msg("what is 2+2"),
|
||||
self._ai_msg("The answer is 4."),
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) == "The answer is 4."
|
||||
|
||||
def test_returns_last_of_multiple_ai_messages(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._ai_msg("First response."),
|
||||
self._human_msg("follow-up"),
|
||||
self._ai_msg("Final response."),
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) == "Final response."
|
||||
|
||||
def test_skips_empty_ai_messages(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._ai_msg("Real answer."),
|
||||
self._ai_msg(""), # empty — should be skipped
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) == "Real answer."
|
||||
|
||||
def test_strips_think_tags_from_ai_message(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._ai_msg("<think>reasoning here</think>Clean reply."),
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) == "Clean reply."
|
||||
|
||||
def test_falls_back_to_output_field(self):
|
||||
result = {
|
||||
"messages": [],
|
||||
"output": "Fallback output.",
|
||||
}
|
||||
assert _extract_final_text(result) == "Fallback output."
|
||||
|
||||
def test_strips_think_from_output_field(self):
|
||||
result = {
|
||||
"messages": [],
|
||||
"output": "<think>thoughts</think>Actual output.",
|
||||
}
|
||||
assert _extract_final_text(result) == "Actual output."
|
||||
|
||||
def test_returns_none_when_no_content(self):
|
||||
result = {"messages": []}
|
||||
assert _extract_final_text(result) is None
|
||||
|
||||
def test_returns_none_when_no_messages_and_no_output(self):
|
||||
result = {"messages": [], "output": ""}
|
||||
# output is falsy → returns None
|
||||
assert _extract_final_text(result) is None
|
||||
|
||||
def test_skips_non_ai_messages(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._human_msg("user question"),
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) is None
|
||||
|
||||
def test_handles_ai_message_with_tool_calls_but_no_content(self):
|
||||
"""AIMessage that only has tool_calls (no content) should be skipped."""
|
||||
msg = self._ai_msg("", tool_calls=[{"name": "web_search", "args": {}}])
|
||||
result = {"messages": [msg]}
|
||||
assert _extract_final_text(result) is None
|
||||
|
||||
def test_multiline_think_stripped_correctly(self):
|
||||
result = {
|
||||
"messages": [
|
||||
self._ai_msg("<think>\nLong\nreasoning\nblock\n</think>\n## Report\n\nSome content."),
|
||||
]
|
||||
}
|
||||
assert _extract_final_text(result) == "## Report\n\nSome content."
|
||||
125
tests/unit/test_channels.py
Normal file
125
tests/unit/test_channels.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Unit tests for channels.py — register, deliver, pending_replies queue."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import channels
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_channels_state():
|
||||
"""Clear module-level state before and after every test."""
|
||||
channels._callbacks.clear()
|
||||
channels.pending_replies.clear()
|
||||
yield
|
||||
channels._callbacks.clear()
|
||||
channels.pending_replies.clear()
|
||||
|
||||
|
||||
# ── register ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRegister:
|
||||
def test_register_stores_callback(self):
|
||||
cb = AsyncMock()
|
||||
channels.register("test_channel", cb)
|
||||
assert channels._callbacks["test_channel"] is cb
|
||||
|
||||
def test_register_overwrites_existing(self):
|
||||
cb1 = AsyncMock()
|
||||
cb2 = AsyncMock()
|
||||
channels.register("ch", cb1)
|
||||
channels.register("ch", cb2)
|
||||
assert channels._callbacks["ch"] is cb2
|
||||
|
||||
def test_register_multiple_channels(self):
|
||||
cb_a = AsyncMock()
|
||||
cb_b = AsyncMock()
|
||||
channels.register("a", cb_a)
|
||||
channels.register("b", cb_b)
|
||||
assert channels._callbacks["a"] is cb_a
|
||||
assert channels._callbacks["b"] is cb_b
|
||||
|
||||
|
||||
# ── deliver ────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestDeliver:
|
||||
async def test_deliver_enqueues_reply(self):
|
||||
channels.register("cli", AsyncMock())
|
||||
await channels.deliver("cli-alvis", "cli", "hello world")
|
||||
q = channels.pending_replies["cli-alvis"]
|
||||
assert not q.empty()
|
||||
assert await q.get() == "hello world"
|
||||
|
||||
async def test_deliver_calls_channel_callback(self):
|
||||
cb = AsyncMock()
|
||||
channels.register("telegram", cb)
|
||||
await channels.deliver("tg-123", "telegram", "reply text")
|
||||
cb.assert_awaited_once_with("tg-123", "reply text")
|
||||
|
||||
async def test_deliver_unknown_channel_still_enqueues(self):
|
||||
"""No registered callback for channel → reply still goes to the queue."""
|
||||
await channels.deliver("cli-bob", "nonexistent", "fallback reply")
|
||||
q = channels.pending_replies["cli-bob"]
|
||||
assert await q.get() == "fallback reply"
|
||||
|
||||
async def test_deliver_unknown_channel_does_not_raise(self):
|
||||
"""Missing callback must not raise an exception."""
|
||||
await channels.deliver("cli-x", "ghost_channel", "msg")
|
||||
|
||||
async def test_deliver_creates_queue_if_absent(self):
|
||||
channels.register("cli", AsyncMock())
|
||||
assert "cli-new" not in channels.pending_replies
|
||||
await channels.deliver("cli-new", "cli", "hi")
|
||||
assert "cli-new" in channels.pending_replies
|
||||
|
||||
async def test_deliver_reuses_existing_queue(self):
|
||||
"""Second deliver to the same session appends to the same queue."""
|
||||
channels.register("cli", AsyncMock())
|
||||
await channels.deliver("cli-alvis", "cli", "first")
|
||||
await channels.deliver("cli-alvis", "cli", "second")
|
||||
q = channels.pending_replies["cli-alvis"]
|
||||
assert await q.get() == "first"
|
||||
assert await q.get() == "second"
|
||||
|
||||
async def test_deliver_telegram_sends_to_callback(self):
|
||||
sent = []
|
||||
|
||||
async def fake_tg(session_id, text):
|
||||
sent.append((session_id, text))
|
||||
|
||||
channels.register("telegram", fake_tg)
|
||||
await channels.deliver("tg-999", "telegram", "test message")
|
||||
assert sent == [("tg-999", "test message")]
|
||||
|
||||
|
||||
# ── register_defaults ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestRegisterDefaults:
|
||||
def test_registers_telegram_and_cli(self):
|
||||
channels.register_defaults()
|
||||
assert "telegram" in channels._callbacks
|
||||
assert "cli" in channels._callbacks
|
||||
|
||||
async def test_cli_callback_is_noop(self):
|
||||
"""CLI send callback does nothing (replies are handled via SSE queue)."""
|
||||
channels.register_defaults()
|
||||
cb = channels._callbacks["cli"]
|
||||
# Should not raise and should return None
|
||||
result = await cb("cli-alvis", "some reply")
|
||||
assert result is None
|
||||
|
||||
async def test_telegram_callback_chunks_long_messages(self):
|
||||
"""Telegram callback splits messages > 4000 chars into chunks."""
|
||||
channels.register_defaults()
|
||||
cb = channels._callbacks["telegram"]
|
||||
long_text = "x" * 9000 # > 4000 chars → should produce 3 chunks
|
||||
with patch("channels.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
await cb("tg-123", long_text)
|
||||
# 9000 chars / 4000 per chunk = 3 POST calls
|
||||
assert mock_client.post.await_count == 3
|
||||
200
tests/unit/test_router.py
Normal file
200
tests/unit/test_router.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Unit tests for router.py — Router, _parse_tier, _format_history, _LIGHT_PATTERNS."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from router import Router, _parse_tier, _format_history, _LIGHT_PATTERNS
|
||||
|
||||
|
||||
# ── _LIGHT_PATTERNS regex ──────────────────────────────────────────────────────
|
||||
|
||||
class TestLightPatterns:
|
||||
@pytest.mark.parametrize("text", [
|
||||
"hi", "Hi", "HI",
|
||||
"hello", "hey", "yo", "sup",
|
||||
"good morning", "good evening", "good night", "good afternoon",
|
||||
"bye", "goodbye", "see you", "cya", "later", "ttyl",
|
||||
"thanks", "thank you", "thx", "ty",
|
||||
"ok", "okay", "k", "cool", "great", "awesome", "perfect",
|
||||
"sounds good", "got it", "nice", "sure",
|
||||
"how are you", "how are you?", "how are you doing today?",
|
||||
"what's up",
|
||||
"what day comes after Monday?",
|
||||
"what day follows Friday?",
|
||||
"what comes after summer?",
|
||||
"what does NASA stand for?",
|
||||
"what does AI stand for?",
|
||||
# with trailing punctuation
|
||||
"hi!", "hello.", "thanks!",
|
||||
])
|
||||
def test_matches(self, text):
|
||||
assert _LIGHT_PATTERNS.match(text.strip()), f"Expected light match for: {text!r}"
|
||||
|
||||
@pytest.mark.parametrize("text", [
|
||||
"what is the capital of France",
|
||||
"tell me about bitcoin",
|
||||
"what is 2+2",
|
||||
"write me a poem",
|
||||
"search for news about the election",
|
||||
"what did we talk about last time",
|
||||
"what is my name",
|
||||
"/think compare these frameworks",
|
||||
"how do I install Python",
|
||||
"explain machine learning",
|
||||
"", # empty string doesn't match the pattern
|
||||
])
|
||||
def test_no_match(self, text):
|
||||
assert not _LIGHT_PATTERNS.match(text.strip()), f"Expected NO light match for: {text!r}"
|
||||
|
||||
|
||||
# ── _parse_tier ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestParseTier:
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("light", "light"),
|
||||
("Light", "light"),
|
||||
("LIGHT\n", "light"),
|
||||
("medium", "medium"),
|
||||
("Medium.", "medium"),
|
||||
("complex", "complex"),
|
||||
("Complex!", "complex"),
|
||||
# descriptive words → light
|
||||
("simplefact", "light"),
|
||||
("trivial question", "light"),
|
||||
("basic", "light"),
|
||||
("easy answer", "light"),
|
||||
("general knowledge", "light"),
|
||||
# unknown → medium
|
||||
("unknown_category", "medium"),
|
||||
("", "medium"),
|
||||
("I don't know", "medium"),
|
||||
# complex only if 'complex' appears in first 60 chars
|
||||
("this is a complex query requiring search", "complex"),
|
||||
# _parse_tier checks "complex" before "medium", so complex wins even if medium appears first
|
||||
("medium complexity, not complex", "complex"),
|
||||
])
|
||||
def test_parse_tier(self, raw, expected):
|
||||
assert _parse_tier(raw) == expected
|
||||
|
||||
|
||||
# ── _format_history ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestFormatHistory:
|
||||
def test_empty(self):
|
||||
assert _format_history([]) == "(none)"
|
||||
|
||||
def test_single_user_message(self):
|
||||
history = [{"role": "user", "content": "hello there"}]
|
||||
result = _format_history(history)
|
||||
assert "user: hello there" in result
|
||||
|
||||
def test_multiple_turns(self):
|
||||
history = [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."},
|
||||
]
|
||||
result = _format_history(history)
|
||||
assert "user: What is Python?" in result
|
||||
assert "assistant: Python is a programming language." in result
|
||||
|
||||
def test_truncates_long_content(self):
|
||||
long_content = "x" * 300
|
||||
history = [{"role": "user", "content": long_content}]
|
||||
result = _format_history(history)
|
||||
# content is truncated to 200 chars in _format_history
|
||||
assert len(result) < 250
|
||||
|
||||
def test_missing_keys_handled(self):
|
||||
# Should not raise — uses .get() with defaults
|
||||
history = [{"role": "user"}] # no content key
|
||||
result = _format_history(history)
|
||||
assert "user:" in result
|
||||
|
||||
|
||||
# ── Router.route() ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRouterRoute:
|
||||
def _make_router(self, classify_response: str, reply_response: str = "Sure!") -> Router:
|
||||
"""Return a Router with a mock model that returns given classification and reply."""
|
||||
model = MagicMock()
|
||||
classify_msg = MagicMock()
|
||||
classify_msg.content = classify_response
|
||||
reply_msg = MagicMock()
|
||||
reply_msg.content = reply_response
|
||||
# First ainvoke call → classification; second → reply
|
||||
model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
|
||||
return Router(model=model)
|
||||
|
||||
async def test_force_complex_bypasses_classification(self):
|
||||
router = self._make_router("medium")
|
||||
tier, reply = await router.route("some question", [], force_complex=True)
|
||||
assert tier == "complex"
|
||||
assert reply is None
|
||||
# Model should NOT have been called
|
||||
router.model.ainvoke.assert_not_called()
|
||||
|
||||
async def test_regex_light_skips_llm_classification(self):
|
||||
# Regex match bypasses classification entirely; the only ainvoke call is the reply.
|
||||
model = MagicMock()
|
||||
reply_msg = MagicMock()
|
||||
reply_msg.content = "I'm doing great!"
|
||||
model.ainvoke = AsyncMock(return_value=reply_msg)
|
||||
router = Router(model=model)
|
||||
tier, reply = await router.route("how are you", [], force_complex=False)
|
||||
assert tier == "light"
|
||||
assert reply == "I'm doing great!"
|
||||
# Exactly one model call — no classification step
|
||||
assert router.model.ainvoke.call_count == 1
|
||||
|
||||
async def test_llm_classifies_medium(self):
|
||||
router = self._make_router("medium")
|
||||
tier, reply = await router.route("what is the bitcoin price?", [], force_complex=False)
|
||||
assert tier == "medium"
|
||||
assert reply is None
|
||||
|
||||
async def test_llm_classifies_light_generates_reply(self):
|
||||
router = self._make_router("light", "Paris is the capital of France.")
|
||||
tier, reply = await router.route("what is the capital of France?", [], force_complex=False)
|
||||
assert tier == "light"
|
||||
assert reply == "Paris is the capital of France."
|
||||
|
||||
async def test_llm_classifies_complex_downgraded_to_medium(self):
|
||||
# Without /think prefix, complex classification → downgraded to medium
|
||||
router = self._make_router("complex")
|
||||
tier, reply = await router.route("compare React and Vue", [], force_complex=False)
|
||||
assert tier == "medium"
|
||||
assert reply is None
|
||||
|
||||
async def test_llm_error_falls_back_to_medium(self):
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=Exception("connection error"))
|
||||
router = Router(model=model)
|
||||
tier, reply = await router.route("some question", [], force_complex=False)
|
||||
assert tier == "medium"
|
||||
assert reply is None
|
||||
|
||||
async def test_light_reply_empty_falls_back_to_medium(self):
|
||||
"""If the light reply comes back empty, router returns medium instead."""
|
||||
router = self._make_router("light", "") # empty reply
|
||||
tier, reply = await router.route("what is 2+2", [], force_complex=False)
|
||||
assert tier == "medium"
|
||||
assert reply is None
|
||||
|
||||
async def test_strips_think_tags_from_classification(self):
|
||||
"""Router strips <think>...</think> from model output before parsing tier."""
|
||||
model = MagicMock()
|
||||
classify_msg = MagicMock()
|
||||
classify_msg.content = "<think>Hmm let me think...</think>medium"
|
||||
reply_msg = MagicMock()
|
||||
reply_msg.content = "I'm fine!"
|
||||
model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
|
||||
router = Router(model=model)
|
||||
tier, _ = await router.route("what is the news?", [], force_complex=False)
|
||||
assert tier == "medium"
|
||||
|
||||
async def test_think_prefix_forces_complex(self):
|
||||
"""/think prefix is already stripped by agent.py; force_complex=True is passed."""
|
||||
router = self._make_router("medium")
|
||||
tier, reply = await router.route("analyse this", [], force_complex=True)
|
||||
assert tier == "complex"
|
||||
assert reply is None
|
||||
164
tests/unit/test_vram_manager.py
Normal file
164
tests/unit/test_vram_manager.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Unit tests for vram_manager.py — VRAMManager flush/poll/prewarm logic."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from vram_manager import VRAMManager
|
||||
|
||||
|
||||
BASE_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
def _make_manager() -> VRAMManager:
|
||||
return VRAMManager(base_url=BASE_URL)
|
||||
|
||||
|
||||
def _mock_client(get_response=None, post_response=None):
|
||||
"""Return a context-manager mock for httpx.AsyncClient."""
|
||||
client = AsyncMock()
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=False)
|
||||
if get_response is not None:
|
||||
client.get = AsyncMock(return_value=get_response)
|
||||
if post_response is not None:
|
||||
client.post = AsyncMock(return_value=post_response)
|
||||
return client
|
||||
|
||||
|
||||
# ── _flush ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestFlush:
|
||||
async def test_sends_keep_alive_zero(self):
|
||||
client = _mock_client(post_response=MagicMock())
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
await mgr._flush("qwen3:4b")
|
||||
client.post.assert_awaited_once()
|
||||
_, kwargs = client.post.await_args
|
||||
body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
|
||||
assert body["model"] == "qwen3:4b"
|
||||
assert body["keep_alive"] == 0
|
||||
|
||||
async def test_posts_to_correct_endpoint(self):
|
||||
client = _mock_client(post_response=MagicMock())
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
await mgr._flush("qwen3:8b")
|
||||
url = client.post.call_args[0][0]
|
||||
assert url == f"{BASE_URL}/api/generate"
|
||||
|
||||
async def test_ignores_exceptions_silently(self):
|
||||
client = AsyncMock()
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=False)
|
||||
client.post = AsyncMock(side_effect=Exception("connection refused"))
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
# Should not raise
|
||||
await mgr._flush("qwen3:4b")
|
||||
|
||||
|
||||
# ── _prewarm ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPrewarm:
|
||||
async def test_sends_keep_alive_300(self):
|
||||
client = _mock_client(post_response=MagicMock())
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
await mgr._prewarm("qwen3:4b")
|
||||
_, kwargs = client.post.await_args
|
||||
body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
|
||||
assert body["keep_alive"] == 300
|
||||
assert body["model"] == "qwen3:4b"
|
||||
|
||||
async def test_ignores_exceptions_silently(self):
|
||||
client = AsyncMock()
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=False)
|
||||
client.post = AsyncMock(side_effect=Exception("timeout"))
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
await mgr._prewarm("qwen3:4b")
|
||||
|
||||
|
||||
# ── _poll_evicted ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPollEvicted:
|
||||
async def test_returns_true_when_models_absent(self):
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"models": [{"name": "some_other_model"}]}
|
||||
client = _mock_client(get_response=resp)
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
result = await mgr._poll_evicted(["qwen3:4b", "qwen2.5:1.5b"], timeout=5)
|
||||
assert result is True
|
||||
|
||||
async def test_returns_false_on_timeout_when_model_still_loaded(self):
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"models": [{"name": "qwen3:4b"}]}
|
||||
client = _mock_client(get_response=resp)
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.1)
|
||||
assert result is False
|
||||
|
||||
async def test_returns_true_immediately_if_already_empty(self):
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"models": []}
|
||||
client = _mock_client(get_response=resp)
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
result = await mgr._poll_evicted(["qwen3:4b"], timeout=5)
|
||||
assert result is True
|
||||
|
||||
async def test_handles_poll_error_and_continues(self):
|
||||
"""If /api/ps errors, polling continues until timeout."""
|
||||
client = AsyncMock()
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=False)
|
||||
client.get = AsyncMock(side_effect=Exception("network error"))
|
||||
with patch("vram_manager.httpx.AsyncClient", return_value=client):
|
||||
mgr = _make_manager()
|
||||
result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.2)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ── enter_complex_mode / exit_complex_mode ─────────────────────────────────────
|
||||
|
||||
class TestComplexMode:
|
||||
async def test_enter_complex_mode_returns_true_on_success(self):
|
||||
mgr = _make_manager()
|
||||
mgr._flush = AsyncMock()
|
||||
mgr._poll_evicted = AsyncMock(return_value=True)
|
||||
result = await mgr.enter_complex_mode()
|
||||
assert result is True
|
||||
|
||||
async def test_enter_complex_mode_flushes_medium_models(self):
|
||||
mgr = _make_manager()
|
||||
mgr._flush = AsyncMock()
|
||||
mgr._poll_evicted = AsyncMock(return_value=True)
|
||||
await mgr.enter_complex_mode()
|
||||
flushed = {call.args[0] for call in mgr._flush.call_args_list}
|
||||
assert "qwen3:4b" in flushed
|
||||
assert "qwen2.5:1.5b" in flushed
|
||||
|
||||
async def test_enter_complex_mode_returns_false_on_eviction_timeout(self):
|
||||
mgr = _make_manager()
|
||||
mgr._flush = AsyncMock()
|
||||
mgr._poll_evicted = AsyncMock(return_value=False)
|
||||
result = await mgr.enter_complex_mode()
|
||||
assert result is False
|
||||
|
||||
async def test_exit_complex_mode_flushes_complex_and_prewarms_medium(self):
|
||||
mgr = _make_manager()
|
||||
mgr._flush = AsyncMock()
|
||||
mgr._prewarm = AsyncMock()
|
||||
await mgr.exit_complex_mode()
|
||||
# Must flush 8b
|
||||
flushed = {call.args[0] for call in mgr._flush.call_args_list}
|
||||
assert "qwen3:8b" in flushed
|
||||
# Must prewarm medium models
|
||||
prewarmed = {call.args[0] for call in mgr._prewarm.call_args_list}
|
||||
assert "qwen3:4b" in prewarmed
|
||||
assert "qwen2.5:1.5b" in prewarmed
|
||||
Reference in New Issue
Block a user