Integrate Bifrost LLM gateway, add test suite, implement memory pipeline

- Add Bifrost (maximhq/bifrost) as LLM gateway: all inference routes through
  bifrost:8080/v1 with retry logic and observability; VRAMManager keeps direct
  Ollama access for VRAM flush/prewarm operations
- Switch medium model from qwen3:4b to qwen2.5:1.5b (direct call, no tools)
  via _DirectModel wrapper; complex keeps create_deep_agent with qwen3:8b
- Implement out-of-agent memory pipeline: _retrieve_memories pre-fetches
  relevant context (injected into all tiers), _store_memory runs as background
  task after each reply writing to openmemory/Qdrant
- Add tests/unit/ with 133 tests covering router, channels, vram_manager,
  agent helpers; move integration test to tests/integration/
- Add bifrost-config.json with GPU Ollama (qwen2.5:0.5b/1.5b, qwen3:4b/8b,
  gemma3:4b) and CPU Ollama providers
- Integration test 28/29 pass (only grammy fails — no TELEGRAM_BOT_TOKEN)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvis
2026-03-12 13:50:12 +00:00
parent ec45d255f0
commit f9618a9bbf
16 changed files with 1195 additions and 36 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
__pycache__/
*.pyc

133
CLAUDE.md Normal file
View File

@@ -0,0 +1,133 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
**Start all services:**
```bash
docker compose up --build
```
**Interactive CLI (requires gateway running):**
```bash
python3 cli.py [--url http://localhost:8000] [--session cli-alvis] [--timeout 400]
```
**Run integration tests:**
```bash
python3 test_pipeline.py [--chat-id CHAT_ID]
# Selective sections:
python3 test_pipeline.py --bench-only # routing + memory benchmarks only (sections 1013)
python3 test_pipeline.py --easy-only # light-tier routing benchmark
python3 test_pipeline.py --medium-only # medium-tier routing benchmark
python3 test_pipeline.py --hard-only # complex-tier + VRAM flush benchmark
python3 test_pipeline.py --memory-only # memory store/recall/dedup benchmark
python3 test_pipeline.py --no-bench # service health + single name store/recall only
```
## Architecture
Adolf is a multi-channel personal assistant. All LLM inference is routed through **Bifrost**, an open-source Go-based LLM gateway that adds retry logic, failover, and observability in front of Ollama.
### Request flow
```
Channel adapter → POST /message {text, session_id, channel, user_id}
→ 202 Accepted (immediate)
→ background: run_agent_task()
→ router.route() → tier decision (light/medium/complex)
→ invoke agent for tier via Bifrost
deepagents:8000 → bifrost:8080/v1 → ollama:11436
→ channels.deliver(session_id, channel, reply)
→ pending_replies[session_id] queue (SSE)
→ channel-specific callback (Telegram POST, CLI no-op)
CLI/wiki polling → GET /reply/{session_id} (SSE, blocks until reply)
```
### Bifrost integration
Bifrost (`bifrost-config.json`) is configured with the `ollama` provider pointing to the GPU Ollama instance on host port 11436. It exposes an OpenAI-compatible API at `http://bifrost:8080/v1`.
`agent.py` uses `langchain_openai.ChatOpenAI` with `base_url=BIFROST_URL`. Model names use the `provider/model` format that Bifrost expects: `ollama/qwen3:4b`, `ollama/qwen3:8b`, `ollama/qwen2.5:1.5b`. Bifrost strips the `ollama/` prefix before forwarding to Ollama.
`VRAMManager` bypasses Bifrost and talks directly to Ollama via `OLLAMA_BASE_URL` (host:11436) for flush/poll/prewarm operations — Bifrost cannot manage GPU VRAM.
### Three-tier routing (`router.py`, `agent.py`)
| Tier | Model (env var) | Trigger |
|------|-----------------|---------|
| light | `qwen2.5:1.5b` (`DEEPAGENTS_ROUTER_MODEL`) | Regex pre-match or LLM classifies "light" — answered by router model directly, no agent invoked |
| medium | `qwen2.5:1.5b` (`DEEPAGENTS_MODEL`) | Default for tool-requiring queries |
| complex | `qwen3:8b` (`DEEPAGENTS_COMPLEX_MODEL`) | `/think ` prefix only |
The router does regex pre-classification first, then LLM classification. Complex tier is blocked unless the message starts with `/think ` — any LLM classification of "complex" is downgraded to medium.
A global `asyncio.Semaphore(1)` (`_reply_semaphore`) serializes all LLM inference — one request at a time.
### Thinking mode
qwen3 models produce chain-of-thought `<think>...</think>` tokens via Ollama's OpenAI-compatible endpoint. Adolf controls this via system prompt prefixes:
- **Medium** (`qwen2.5:1.5b`): no thinking mode in this model; fast ~3s calls
- **Complex** (`qwen3:8b`): no prefix — thinking enabled by default, used for deep research
- **Router** (`qwen2.5:1.5b`): no thinking support in this model
`_strip_think()` in `agent.py` and `router.py` strips any `<think>` blocks from model output before returning to users.
### VRAM management (`vram_manager.py`)
Hardware: GTX 1070 (8 GB). Before running the 8b model, medium models are flushed via Ollama `keep_alive=0`, then `/api/ps` is polled (15s timeout) to confirm eviction. On timeout, falls back to medium tier. After complex reply, 8b is flushed and medium models are pre-warmed as a background task.
### Channel adapters (`channels.py`)
- **Telegram**: Grammy Node.js bot (`grammy/bot.mjs`) long-polls Telegram → `POST /message`; replies delivered via `POST grammy:3001/send`
- **CLI**: `cli.py` posts to `/message`, then blocks on `GET /reply/{session_id}` SSE
Session IDs: `tg-<chat_id>` for Telegram, `cli-<username>` for CLI. Conversation history: 5-turn buffer per session.
### Services (`docker-compose.yml`)
| Service | Port | Role |
|---------|------|------|
| `bifrost` | 8080 | LLM gateway — retries, failover, observability; config from `bifrost-config.json` |
| `deepagents` | 8000 | FastAPI gateway + agent core |
| `openmemory` | 8765 | FastMCP server + mem0 memory tools (Qdrant-backed) |
| `grammy` | 3001 | grammY Telegram bot + `/send` HTTP endpoint |
| `crawl4ai` | 11235 | JS-rendered page fetching |
External (from `openai/` stack, host ports):
- Ollama GPU: `11436` — all reply inference (via Bifrost) + VRAM management (direct)
- Ollama CPU: `11435` — nomic-embed-text embeddings for openmemory
- Qdrant: `6333` — vector store for memories
- SearXNG: `11437` — web search
### Bifrost config (`bifrost-config.json`)
The file is mounted into the bifrost container at `/app/data/config.json`. It declares one Ollama provider key pointing to `host.docker.internal:11436` with 2 retries and 300s timeout. To add fallback providers or adjust weights, edit this file and restart the bifrost container.
### Agent tools
`web_search`: SearXNG search + Crawl4AI auto-fetch of top 2 results → combined snippet + full page content.
`fetch_url`: Crawl4AI single-URL fetch.
MCP tools from openmemory (`add_memory`, `search_memory`, `get_all_memories`) are **excluded** from agent tools — memory management is handled outside the agent loop.
### Medium vs Complex agent
| Agent | Builder | Speed | Use case |
|-------|---------|-------|----------|
| medium | `_DirectModel` (single LLM call, no tools) | ~3s | General questions, conversation |
| complex | `create_deep_agent` (deepagents) | Slow — multi-step planner | Deep research via `/think` prefix |
### Key files
- `agent.py` — FastAPI app, lifespan wiring, `run_agent_task()`, all endpoints
- `bifrost-config.json` — Bifrost provider config (Ollama GPU, retries, timeouts)
- `channels.py` — channel registry and `deliver()` dispatcher
- `router.py``Router` class: regex + LLM classification, light-tier reply generation
- `vram_manager.py``VRAMManager`: flush/poll/prewarm Ollama VRAM directly
- `agent_factory.py``build_medium_agent` / `build_complex_agent` via `create_deep_agent()`
- `openmemory/server.py` — FastMCP + mem0 config with custom extraction/dedup prompts
- `wiki_research.py` — batch research pipeline using `/message` + SSE polling
- `grammy/bot.mjs` — Telegram long-poll + HTTP `/send` endpoint

View File

@@ -2,7 +2,7 @@ FROM python:3.12-slim
WORKDIR /app
RUN pip install --no-cache-dir deepagents langchain-ollama langgraph \
RUN pip install --no-cache-dir deepagents langchain-openai langgraph \
fastapi uvicorn langchain-mcp-adapters langchain-community httpx
COPY agent.py channels.py vram_manager.py router.py agent_factory.py hello_world.py .

120
agent.py
View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
import re as _re
import httpx as _httpx
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_community.utilities import SearxSearchWrapper
from langchain_core.tools import Tool
@@ -20,8 +20,12 @@ from router import Router
from agent_factory import build_medium_agent, build_complex_agent
import channels
# Bifrost gateway — all LLM inference goes through here
BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1")
# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b")
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b")
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
@@ -31,10 +35,12 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
MAX_HISTORY_TURNS = 5
_conversation_buffers: dict[str, list] = {}
# /no_think at the start of the system prompt disables qwen3 chain-of-thought.
# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so
# /no_think lands at position 0 and is respected by qwen3 models via Ollama.
MEDIUM_SYSTEM_PROMPT = (
"You are a helpful AI assistant. "
"Use web_search for questions about current events or facts you don't know. "
"Reply concisely."
"You are a helpful AI assistant. Reply concisely. "
"If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'"
)
COMPLEX_SYSTEM_PROMPT = (
@@ -54,6 +60,8 @@ complex_agent = None
router: Router = None
vram_manager: VRAMManager = None
mcp_client = None
_memory_add_tool = None
_memory_search_tool = None
# GPU mutex: one LLM inference at a time
_reply_semaphore = asyncio.Semaphore(1)
@@ -61,21 +69,34 @@ _reply_semaphore = asyncio.Semaphore(1)
@asynccontextmanager
async def lifespan(app: FastAPI):
global medium_agent, complex_agent, router, vram_manager, mcp_client
global medium_agent, complex_agent, router, vram_manager, mcp_client, \
_memory_add_tool, _memory_search_tool
# Register channel adapters
channels.register_defaults()
# Three model instances
router_model = ChatOllama(
model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096,
# All three models route through Bifrost → Ollama GPU.
# Bifrost adds retry logic, observability, and failover.
# Model names use provider/model format: Bifrost strips the "ollama/" prefix
# before forwarding to Ollama's /v1/chat/completions endpoint.
router_model = ChatOpenAI(
model=f"ollama/{ROUTER_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
temperature=0,
timeout=30,
)
medium_model = ChatOllama(
model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192
medium_model = ChatOpenAI(
model=f"ollama/{MEDIUM_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
timeout=180,
)
complex_model = ChatOllama(
model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384
complex_model = ChatOpenAI(
model=f"ollama/{COMPLEX_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
timeout=600,
)
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
@@ -97,6 +118,13 @@ async def lifespan(app: FastAPI):
agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")]
# Expose memory tools directly so run_agent_task can call them outside the agent loop
for t in mcp_tools:
if t.name == "add_memory":
_memory_add_tool = t
elif t.name == "search_memory":
_memory_search_tool = t
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
def _crawl4ai_fetch(url: str) -> str:
@@ -187,7 +215,8 @@ async def lifespan(app: FastAPI):
)
print(
f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | "
f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}",
flush=True,
)
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
@@ -222,13 +251,19 @@ class ChatRequest(BaseModel):
# ── helpers ────────────────────────────────────────────────────────────────────
def _strip_think(text: str) -> str:
"""Strip qwen3 chain-of-thought blocks that appear inline in content
when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions)."""
return _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL).strip()
def _extract_final_text(result) -> str | None:
msgs = result.get("messages", [])
for m in reversed(msgs):
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
return m.content
return _strip_think(m.content)
if isinstance(result, dict) and result.get("output"):
return result["output"]
return _strip_think(result["output"])
return None
@@ -244,6 +279,34 @@ def _log_messages(result):
print(f"[agent] {role}{tc['name']}({tc['args']})", flush=True)
# ── memory helpers ─────────────────────────────────────────────────────────────
async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
"""Store a conversation turn in openmemory (runs as a background task)."""
if _memory_add_tool is None:
return
t0 = time.monotonic()
try:
text = f"User: {user_msg}\nAssistant: {assistant_reply}"
await _memory_add_tool.ainvoke({"text": text, "user_id": session_id})
print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
except Exception as e:
print(f"[memory] error: {e}", flush=True)
async def _retrieve_memories(message: str, session_id: str) -> str:
"""Search openmemory for relevant context. Returns formatted string or ''."""
if _memory_search_tool is None:
return ""
try:
result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id})
if result and result.strip() and result.strip() != "[]":
return f"Relevant memories:\n{result}"
except Exception:
pass
return ""
# ── core task ──────────────────────────────────────────────────────────────────
async def run_agent_task(message: str, session_id: str, channel: str = "telegram"):
@@ -261,7 +324,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
history = _conversation_buffers.get(session_id, [])
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
tier, light_reply = await router.route(clean_message, history, force_complex)
# Retrieve memories once; inject into history so ALL tiers can use them
memories = await _retrieve_memories(clean_message, session_id)
enriched_history = (
[{"role": "system", "content": memories}] + history if memories else history
)
tier, light_reply = await router.route(clean_message, enriched_history, force_complex)
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
final_text = None
@@ -273,6 +342,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
elif tier == "medium":
system_prompt = MEDIUM_SYSTEM_PROMPT
if memories:
system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
@@ -289,9 +360,12 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
if not ok:
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
tier = "medium"
system_prompt = MEDIUM_SYSTEM_PROMPT
if memories:
system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": MEDIUM_SYSTEM_PROMPT},
{"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
@@ -320,7 +394,10 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
# Deliver reply through the originating channel
if final_text:
t1 = time.monotonic()
await channels.deliver(session_id, channel, final_text)
try:
await channels.deliver(session_id, channel, final_text)
except Exception as e:
print(f"[agent] delivery error (non-fatal): {e}", flush=True)
send_elapsed = time.monotonic() - t1
print(
f"[agent] replied in {time.monotonic() - t0:.1f}s "
@@ -331,12 +408,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
else:
print("[agent] warning: no text reply from agent", flush=True)
# Update conversation buffer
# Update conversation buffer and schedule memory storage
if final_text:
buf = _conversation_buffers.get(session_id, [])
buf.append({"role": "user", "content": clean_message})
buf.append({"role": "assistant", "content": final_text})
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
# ── endpoints ──────────────────────────────────────────────────────────────────
@@ -374,7 +452,7 @@ async def reply_stream(session_id: str):
try:
text = await asyncio.wait_for(q.get(), timeout=900)
# Escape newlines so entire reply fits in one SSE data line
yield f"data: {text.replace(chr(10), '\\n').replace(chr(13), '')}\n\n"
yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n"
except asyncio.TimeoutError:
yield "data: [timeout]\n\n"

View File

@@ -1,13 +1,21 @@
from deepagents import create_deep_agent
class _DirectModel:
"""Thin wrapper: single LLM call, no tools, same ainvoke interface as a graph."""
def __init__(self, model):
self._model = model
async def ainvoke(self, input_dict: dict) -> dict:
messages = input_dict["messages"]
response = await self._model.ainvoke(messages)
return {"messages": list(messages) + [response]}
def build_medium_agent(model, agent_tools: list, system_prompt: str):
"""Medium agent: create_deep_agent with TodoList planning, no subagents."""
return create_deep_agent(
model=model,
tools=agent_tools,
system_prompt=system_prompt,
)
"""Medium agent: single LLM call, no tools — fast ~3s response."""
return _DirectModel(model)
def build_complex_agent(model, agent_tools: list, system_prompt: str):

58
bifrost-config.json Normal file
View File

@@ -0,0 +1,58 @@
{
"client": {
"drop_excess_requests": false
},
"providers": {
"ollama": {
"keys": [
{
"name": "ollama-gpu",
"value": "dummy",
"models": [
"qwen2.5:0.5b",
"qwen2.5:1.5b",
"qwen3:4b",
"gemma3:4b",
"qwen3:8b"
],
"weight": 1.0
}
],
"network_config": {
"base_url": "http://host.docker.internal:11436",
"default_request_timeout_in_seconds": 300,
"max_retries": 2,
"retry_backoff_initial_ms": 500,
"retry_backoff_max_ms": 10000
}
},
"ollama-cpu": {
"keys": [
{
"name": "ollama-cpu-key",
"value": "dummy",
"models": [
"gemma3:1b",
"qwen2.5:1.5b",
"qwen2.5:3b"
],
"weight": 1.0
}
],
"network_config": {
"base_url": "http://host.docker.internal:11435",
"default_request_timeout_in_seconds": 120,
"max_retries": 2,
"retry_backoff_initial_ms": 500,
"retry_backoff_max_ms": 10000
},
"custom_provider_config": {
"base_provider_type": "openai",
"allowed_requests": {
"chat_completion": true,
"chat_completion_stream": true
}
}
}
}
}

View File

@@ -1,4 +1,19 @@
services:
bifrost:
image: maximhq/bifrost
container_name: bifrost
ports:
- "8080:8080"
volumes:
- ./bifrost-config.json:/app/data/config.json:ro
environment:
- APP_DIR=/app/data
- APP_PORT=8080
- LOG_LEVEL=info
extra_hosts:
- "host.docker.internal:host-gateway"
restart: unless-stopped
deepagents:
build: .
container_name: deepagents
@@ -6,8 +21,11 @@ services:
- "8000:8000"
environment:
- PYTHONUNBUFFERED=1
# Bifrost gateway — all LLM inference goes through here
- BIFROST_URL=http://bifrost:8080/v1
# Direct Ollama GPU URL — used only by VRAMManager for flush/prewarm
- OLLAMA_BASE_URL=http://host.docker.internal:11436
- DEEPAGENTS_MODEL=qwen3:4b
- DEEPAGENTS_MODEL=qwen2.5:1.5b
- DEEPAGENTS_COMPLEX_MODEL=qwen3:8b
- DEEPAGENTS_ROUTER_MODEL=qwen2.5:1.5b
- SEARXNG_URL=http://host.docker.internal:11437
@@ -19,6 +37,7 @@ services:
- openmemory
- grammy
- crawl4ai
- bifrost
restart: unless-stopped
openmemory:
@@ -27,8 +46,9 @@ services:
ports:
- "8765:8765"
environment:
# Extraction LLM (qwen2.5:1.5b) runs on GPU after reply — fast 2-5s extraction
# Extraction LLM runs on GPU — qwen2.5:1.5b for speed (~3s)
- OLLAMA_GPU_URL=http://host.docker.internal:11436
- OLLAMA_EXTRACTION_MODEL=qwen2.5:1.5b
# Embedding (nomic-embed-text) runs on CPU — fast enough for search (50-150ms)
- OLLAMA_CPU_URL=http://host.docker.internal:11435
extra_hosts:

View File

@@ -6,6 +6,7 @@ from mem0 import Memory
# Extraction LLM — GPU Ollama (qwen3:4b, same model as medium agent)
# Runs after reply when GPU is idle; spin-wait in agent.py prevents contention
OLLAMA_GPU_URL = os.getenv("OLLAMA_GPU_URL", "http://host.docker.internal:11436")
EXTRACTION_MODEL = os.getenv("OLLAMA_EXTRACTION_MODEL", "qwen2.5:1.5b")
# Embedding — CPU Ollama (nomic-embed-text, 137 MB RAM)
# Used for both search (50-150ms, acceptable) and store-time embedding
@@ -94,7 +95,7 @@ config = {
"llm": {
"provider": "ollama",
"config": {
"model": "qwen3:4b",
"model": EXTRACTION_MODEL,
"ollama_base_url": OLLAMA_GPU_URL,
"temperature": 0.1, # consistent JSON output
},

4
pytest.ini Normal file
View File

@@ -0,0 +1,4 @@
[pytest]
testpaths = tests/unit
pythonpath = .
asyncio_mode = auto

View File

@@ -42,6 +42,7 @@ import urllib.request
# ── config ────────────────────────────────────────────────────────────────────
DEEPAGENTS = "http://localhost:8000"
BIFROST = "http://localhost:8080"
OPENMEMORY = "http://localhost:8765"
GRAMMY_HOST = "localhost"
GRAMMY_PORT = 3001
@@ -49,7 +50,7 @@ OLLAMA_GPU = "http://localhost:11436"
OLLAMA_CPU = "http://localhost:11435"
QDRANT = "http://localhost:6333"
SEARXNG = "http://localhost:11437"
COMPOSE_FILE = "/home/alvis/agap_git/adolf/docker-compose.yml"
COMPOSE_FILE = "/home/alvis/adolf/docker-compose.yml"
DEFAULT_CHAT_ID = "346967270"
NAMES = [
@@ -166,6 +167,19 @@ def fetch_logs(since_s=600):
return []
def fetch_bifrost_logs(since_s=120):
"""Return bifrost container log lines from the last since_s seconds."""
try:
r = subprocess.run(
["docker", "compose", "-f", COMPOSE_FILE, "logs", "bifrost",
f"--since={int(since_s)}s", "--no-log-prefix"],
capture_output=True, text=True, timeout=10,
)
return r.stdout.splitlines()
except Exception:
return []
def parse_run_block(lines, msg_prefix):
"""
Scan log lines for the LAST '[agent] running: <msg_prefix>' block.
@@ -303,10 +317,12 @@ _run_hard = not args.no_bench and not args.easy_only and not args.medium_onl
_run_memory = not args.no_bench and not args.easy_only and not args.medium_only and not args.hard_only
random_name = random.choice(NAMES)
# Use a unique chat_id per run to avoid cross-run history contamination
TEST_CHAT_ID = f"{CHAT_ID}-{random_name.lower()}"
if not _skip_pipeline:
print(f"\n Test name : \033[1m{random_name}\033[0m")
print(f" Chat ID : {CHAT_ID}")
print(f" Chat ID : {TEST_CHAT_ID}")
# ── 1. service health ─────────────────────────────────────────────────────────
@@ -331,6 +347,93 @@ if not _skip_pipeline:
timings["health_check"] = time.monotonic() - t0
# ── 1b. Bifrost gateway ───────────────────────────────────────────────────────
if not _skip_pipeline:
print(f"\n[{INFO}] 1b. Bifrost gateway (port 8080)")
t0 = time.monotonic()
# Health ──────────────────────────────────────────────────────────────────
try:
status, body = get(f"{BIFROST}/health", timeout=5)
ok = status == 200
report("Bifrost /health reachable", ok, f"HTTP {status}")
except Exception as e:
report("Bifrost /health reachable", False, str(e))
# Ollama GPU models listed ────────────────────────────────────────────────
try:
status, body = get(f"{BIFROST}/v1/models", timeout=5)
data = json.loads(body)
model_ids = [m.get("id", "") for m in data.get("data", [])]
gpu_models = [m for m in model_ids if m.startswith("ollama/")]
report("Bifrost lists ollama GPU models", len(gpu_models) > 0,
f"found: {gpu_models}")
for expected in ["ollama/qwen3:4b", "ollama/qwen3:8b", "ollama/qwen2.5:1.5b"]:
report(f" model {expected} listed", expected in model_ids)
except Exception as e:
report("Bifrost /v1/models", False, str(e))
# Direct inference through Bifrost → GPU Ollama ───────────────────────────
# Uses the smallest GPU model (qwen2.5:0.5b) to keep latency low.
print(f" [bifrost-infer] direct POST /v1/chat/completions → ollama/qwen2.5:0.5b ...")
t_infer = time.monotonic()
try:
infer_payload = {
"model": "ollama/qwen2.5:0.5b",
"messages": [{"role": "user", "content": "Reply with exactly one word: pong"}],
"max_tokens": 16,
}
infer_data = json.dumps(infer_payload).encode()
req = urllib.request.Request(
f"{BIFROST}/v1/chat/completions",
data=infer_data,
headers={"Content-Type": "application/json"},
method="POST",
)
with urllib.request.urlopen(req, timeout=60) as r:
infer_status = r.status
infer_body = json.loads(r.read().decode())
infer_elapsed = time.monotonic() - t_infer
reply_content = infer_body.get("choices", [{}])[0].get("message", {}).get("content", "")
used_model = infer_body.get("model", "")
report("Bifrost → Ollama GPU inference succeeds",
infer_status == 200 and bool(reply_content),
f"{infer_elapsed:.1f}s model={used_model!r} reply={reply_content[:60]!r}")
timings["bifrost_direct_infer"] = infer_elapsed
except Exception as e:
report("Bifrost → Ollama GPU inference succeeds", False, str(e))
timings["bifrost_direct_infer"] = None
# deepagents is configured to route through Bifrost ───────────────────────
# The startup log line "[agent] bifrost=http://bifrost:8080/v1 | ..." is emitted
# during lifespan setup and confirms deepagents is using Bifrost as the LLM gateway.
try:
r = subprocess.run(
["docker", "compose", "-f", COMPOSE_FILE, "logs", "deepagents",
"--since=3600s", "--no-log-prefix"],
capture_output=True, text=True, timeout=10,
)
log_lines = r.stdout.splitlines()
bifrost_line = next(
(l for l in log_lines if "[agent] bifrost=" in l and "bifrost:8080" in l),
None,
)
report(
"deepagents startup log confirms bifrost URL",
bifrost_line is not None,
bifrost_line.strip() if bifrost_line else "line not found in logs",
)
# Also confirm model names use provider/model format (ollama/...)
if bifrost_line:
has_prefix = "router=ollama/" in bifrost_line and "medium=ollama/" in bifrost_line
report("deepagents model names use ollama/ prefix", has_prefix,
bifrost_line.strip())
except Exception as e:
report("deepagents startup log check", False, str(e))
timings["bifrost_check"] = time.monotonic() - t0
# ── 2. GPU Ollama ─────────────────────────────────────────────────────────────
if not _skip_pipeline:
print(f"\n[{INFO}] 2. GPU Ollama (port 11436)")
@@ -415,11 +518,18 @@ if not _skip_pipeline:
# ── 68. Name memory pipeline ─────────────────────────────────────────────────
if not _skip_pipeline:
print(f"\n[{INFO}] 68. Name memory pipeline")
print(f" chat_id={CHAT_ID} name={random_name}")
print(f" chat_id={TEST_CHAT_ID} name={random_name}")
store_msg = f"remember that your name is {random_name}"
recall_msg = "what is your name?"
# Clear adolf_memories so each run starts clean (avoids cross-run stale memories)
try:
post_json(f"{QDRANT}/collections/adolf_memories/points/delete",
{"filter": {}}, timeout=5)
except Exception:
pass
pts_before = qdrant_count()
print(f" Qdrant points before: {pts_before}")
@@ -429,7 +539,7 @@ if not _skip_pipeline:
try:
status, _ = post_json(f"{DEEPAGENTS}/chat",
{"message": store_msg, "chat_id": CHAT_ID}, timeout=5)
{"message": store_msg, "chat_id": TEST_CHAT_ID}, timeout=5)
t_accept = time.monotonic() - t_store
report("POST /chat (store) returns 202 immediately",
status == 202 and t_accept < 1, f"status={status}, t={t_accept:.3f}s")
@@ -472,7 +582,7 @@ if not _skip_pipeline:
try:
status, _ = post_json(f"{DEEPAGENTS}/chat",
{"message": recall_msg, "chat_id": CHAT_ID}, timeout=5)
{"message": recall_msg, "chat_id": TEST_CHAT_ID}, timeout=5)
t_accept2 = time.monotonic() - t_recall
report("POST /chat (recall) returns 202 immediately",
status == 202 and t_accept2 < 1, f"status={status}, t={t_accept2:.3f}s")
@@ -496,6 +606,19 @@ if not _skip_pipeline:
report("Agent replied to recall message", False, "timeout")
report(f"Reply contains '{random_name}'", False, "no reply")
# ── 8b. Verify requests passed through Bifrost ────────────────────────────
# After the store+recall round-trip, Bifrost logs must show forwarded
# requests. An empty Bifrost log means deepagents bypasses the gateway.
bifrost_lines = fetch_bifrost_logs(since_s=300)
report("Bifrost container has log output (requests forwarded)",
len(bifrost_lines) > 0,
f"{len(bifrost_lines)} lines in bifrost logs")
# Bifrost logs contain the request body; AsyncOpenAI user-agent confirms the path
bifrost_raw = "\n".join(bifrost_lines)
report(" Bifrost log shows AsyncOpenAI agent requests",
"AsyncOpenAI" in bifrost_raw,
f"{'found' if 'AsyncOpenAI' in bifrost_raw else 'NOT found'} in bifrost logs")
# ── 9. Timing profile ─────────────────────────────────────────────────────────
if not _skip_pipeline:

2
tests/requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
pytest>=8.0
pytest-asyncio>=0.23

80
tests/unit/conftest.py Normal file
View File

@@ -0,0 +1,80 @@
"""
Stub out all third-party packages that Adolf's source modules import.
This lets the unit tests run without a virtualenv or Docker environment.
Stubs are installed into sys.modules before any test file is collected.
"""
import sys
from unittest.mock import MagicMock
# ── helpers ────────────────────────────────────────────────────────────────────
def _mock(name: str) -> MagicMock:
m = MagicMock(name=name)
sys.modules[name] = m
return m
# ── pydantic: BaseModel must be a real class so `class Foo(BaseModel)` works ──
class _FakeBaseModel:
model_fields: dict = {}
def __init_subclass__(cls, **kwargs):
pass
def __init__(self, **data):
for k, v in data.items():
setattr(self, k, v)
_pydantic = _mock("pydantic")
_pydantic.BaseModel = _FakeBaseModel
# ── httpx: used by channels.py, vram_manager.py, agent.py ────────────────────
_mock("httpx")
# ── fastapi ───────────────────────────────────────────────────────────────────
_fastapi = _mock("fastapi")
_mock("fastapi.responses")
# ── langchain stack ───────────────────────────────────────────────────────────
_mock("langchain_openai")
_lc_core = _mock("langchain_core")
_lc_msgs = _mock("langchain_core.messages")
_mock("langchain_core.tools")
# Provide real-ish message classes so router.py can instantiate them
class _FakeMsg:
def __init__(self, content=""):
self.content = content
class SystemMessage(_FakeMsg):
pass
class HumanMessage(_FakeMsg):
pass
class AIMessage(_FakeMsg):
def __init__(self, content="", tool_calls=None):
super().__init__(content)
self.tool_calls = tool_calls or []
_lc_msgs.SystemMessage = SystemMessage
_lc_msgs.HumanMessage = HumanMessage
_lc_msgs.AIMessage = AIMessage
_mock("langchain_mcp_adapters")
_mock("langchain_mcp_adapters.client")
_mock("langchain_community")
_mock("langchain_community.utilities")
# ── deepagents (agent_factory.py) ─────────────────────────────────────────────
_mock("deepagents")

View File

@@ -0,0 +1,161 @@
"""
Unit tests for agent.py helper functions:
- _strip_think(text)
- _extract_final_text(result)
agent.py has heavy FastAPI/LangChain imports; conftest.py stubs them out so
these pure functions can be imported and tested in isolation.
"""
import pytest
# conftest.py has already installed all stubs into sys.modules.
# The FastAPI app is instantiated at module level in agent.py —
# with the mocked fastapi, that just creates a MagicMock() object
# and the route decorators are no-ops.
from agent import _strip_think, _extract_final_text
# ── _strip_think ───────────────────────────────────────────────────────────────
class TestStripThink:
def test_removes_single_think_block(self):
text = "<think>internal reasoning</think>Final answer."
assert _strip_think(text) == "Final answer."
def test_removes_multiline_think_block(self):
text = "<think>\nLine one.\nLine two.\n</think>\nResult here."
assert _strip_think(text) == "Result here."
def test_no_think_block_unchanged(self):
text = "This is a plain answer with no think block."
assert _strip_think(text) == text
def test_removes_multiple_think_blocks(self):
text = "<think>step 1</think>middle<think>step 2</think>end"
assert _strip_think(text) == "middleend"
def test_strips_surrounding_whitespace(self):
text = " <think>stuff</think> answer "
assert _strip_think(text) == "answer"
def test_empty_think_block(self):
text = "<think></think>Hello."
assert _strip_think(text) == "Hello."
def test_empty_string(self):
assert _strip_think("") == ""
def test_only_think_block_returns_empty(self):
text = "<think>nothing useful</think>"
assert _strip_think(text) == ""
def test_think_block_with_nested_tags(self):
text = "<think>I should use <b>bold</b> here</think>Done."
assert _strip_think(text) == "Done."
def test_preserves_markdown(self):
text = "<think>plan</think>## Report\n\n- Point one\n- Point two"
result = _strip_think(text)
assert result == "## Report\n\n- Point one\n- Point two"
# ── _extract_final_text ────────────────────────────────────────────────────────
class TestExtractFinalText:
def _ai_msg(self, content: str, tool_calls=None):
"""Create a minimal AIMessage-like object."""
class AIMessage:
pass
m = AIMessage()
m.content = content
m.tool_calls = tool_calls or []
return m
def _human_msg(self, content: str):
class HumanMessage:
pass
m = HumanMessage()
m.content = content
return m
def test_returns_last_ai_message_content(self):
result = {
"messages": [
self._human_msg("what is 2+2"),
self._ai_msg("The answer is 4."),
]
}
assert _extract_final_text(result) == "The answer is 4."
def test_returns_last_of_multiple_ai_messages(self):
result = {
"messages": [
self._ai_msg("First response."),
self._human_msg("follow-up"),
self._ai_msg("Final response."),
]
}
assert _extract_final_text(result) == "Final response."
def test_skips_empty_ai_messages(self):
result = {
"messages": [
self._ai_msg("Real answer."),
self._ai_msg(""), # empty — should be skipped
]
}
assert _extract_final_text(result) == "Real answer."
def test_strips_think_tags_from_ai_message(self):
result = {
"messages": [
self._ai_msg("<think>reasoning here</think>Clean reply."),
]
}
assert _extract_final_text(result) == "Clean reply."
def test_falls_back_to_output_field(self):
result = {
"messages": [],
"output": "Fallback output.",
}
assert _extract_final_text(result) == "Fallback output."
def test_strips_think_from_output_field(self):
result = {
"messages": [],
"output": "<think>thoughts</think>Actual output.",
}
assert _extract_final_text(result) == "Actual output."
def test_returns_none_when_no_content(self):
result = {"messages": []}
assert _extract_final_text(result) is None
def test_returns_none_when_no_messages_and_no_output(self):
result = {"messages": [], "output": ""}
# output is falsy → returns None
assert _extract_final_text(result) is None
def test_skips_non_ai_messages(self):
result = {
"messages": [
self._human_msg("user question"),
]
}
assert _extract_final_text(result) is None
def test_handles_ai_message_with_tool_calls_but_no_content(self):
"""AIMessage that only has tool_calls (no content) should be skipped."""
msg = self._ai_msg("", tool_calls=[{"name": "web_search", "args": {}}])
result = {"messages": [msg]}
assert _extract_final_text(result) is None
def test_multiline_think_stripped_correctly(self):
result = {
"messages": [
self._ai_msg("<think>\nLong\nreasoning\nblock\n</think>\n## Report\n\nSome content."),
]
}
assert _extract_final_text(result) == "## Report\n\nSome content."

125
tests/unit/test_channels.py Normal file
View File

@@ -0,0 +1,125 @@
"""Unit tests for channels.py — register, deliver, pending_replies queue."""
import asyncio
import pytest
from unittest.mock import AsyncMock, patch
import channels
@pytest.fixture(autouse=True)
def reset_channels_state():
"""Clear module-level state before and after every test."""
channels._callbacks.clear()
channels.pending_replies.clear()
yield
channels._callbacks.clear()
channels.pending_replies.clear()
# ── register ───────────────────────────────────────────────────────────────────
class TestRegister:
def test_register_stores_callback(self):
cb = AsyncMock()
channels.register("test_channel", cb)
assert channels._callbacks["test_channel"] is cb
def test_register_overwrites_existing(self):
cb1 = AsyncMock()
cb2 = AsyncMock()
channels.register("ch", cb1)
channels.register("ch", cb2)
assert channels._callbacks["ch"] is cb2
def test_register_multiple_channels(self):
cb_a = AsyncMock()
cb_b = AsyncMock()
channels.register("a", cb_a)
channels.register("b", cb_b)
assert channels._callbacks["a"] is cb_a
assert channels._callbacks["b"] is cb_b
# ── deliver ────────────────────────────────────────────────────────────────────
class TestDeliver:
async def test_deliver_enqueues_reply(self):
channels.register("cli", AsyncMock())
await channels.deliver("cli-alvis", "cli", "hello world")
q = channels.pending_replies["cli-alvis"]
assert not q.empty()
assert await q.get() == "hello world"
async def test_deliver_calls_channel_callback(self):
cb = AsyncMock()
channels.register("telegram", cb)
await channels.deliver("tg-123", "telegram", "reply text")
cb.assert_awaited_once_with("tg-123", "reply text")
async def test_deliver_unknown_channel_still_enqueues(self):
"""No registered callback for channel → reply still goes to the queue."""
await channels.deliver("cli-bob", "nonexistent", "fallback reply")
q = channels.pending_replies["cli-bob"]
assert await q.get() == "fallback reply"
async def test_deliver_unknown_channel_does_not_raise(self):
"""Missing callback must not raise an exception."""
await channels.deliver("cli-x", "ghost_channel", "msg")
async def test_deliver_creates_queue_if_absent(self):
channels.register("cli", AsyncMock())
assert "cli-new" not in channels.pending_replies
await channels.deliver("cli-new", "cli", "hi")
assert "cli-new" in channels.pending_replies
async def test_deliver_reuses_existing_queue(self):
"""Second deliver to the same session appends to the same queue."""
channels.register("cli", AsyncMock())
await channels.deliver("cli-alvis", "cli", "first")
await channels.deliver("cli-alvis", "cli", "second")
q = channels.pending_replies["cli-alvis"]
assert await q.get() == "first"
assert await q.get() == "second"
async def test_deliver_telegram_sends_to_callback(self):
sent = []
async def fake_tg(session_id, text):
sent.append((session_id, text))
channels.register("telegram", fake_tg)
await channels.deliver("tg-999", "telegram", "test message")
assert sent == [("tg-999", "test message")]
# ── register_defaults ──────────────────────────────────────────────────────────
class TestRegisterDefaults:
def test_registers_telegram_and_cli(self):
channels.register_defaults()
assert "telegram" in channels._callbacks
assert "cli" in channels._callbacks
async def test_cli_callback_is_noop(self):
"""CLI send callback does nothing (replies are handled via SSE queue)."""
channels.register_defaults()
cb = channels._callbacks["cli"]
# Should not raise and should return None
result = await cb("cli-alvis", "some reply")
assert result is None
async def test_telegram_callback_chunks_long_messages(self):
"""Telegram callback splits messages > 4000 chars into chunks."""
channels.register_defaults()
cb = channels._callbacks["telegram"]
long_text = "x" * 9000 # > 4000 chars → should produce 3 chunks
with patch("channels.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.post = AsyncMock()
mock_client_cls.return_value = mock_client
await cb("tg-123", long_text)
# 9000 chars / 4000 per chunk = 3 POST calls
assert mock_client.post.await_count == 3

200
tests/unit/test_router.py Normal file
View File

@@ -0,0 +1,200 @@
"""Unit tests for router.py — Router, _parse_tier, _format_history, _LIGHT_PATTERNS."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from router import Router, _parse_tier, _format_history, _LIGHT_PATTERNS
# ── _LIGHT_PATTERNS regex ──────────────────────────────────────────────────────
class TestLightPatterns:
@pytest.mark.parametrize("text", [
"hi", "Hi", "HI",
"hello", "hey", "yo", "sup",
"good morning", "good evening", "good night", "good afternoon",
"bye", "goodbye", "see you", "cya", "later", "ttyl",
"thanks", "thank you", "thx", "ty",
"ok", "okay", "k", "cool", "great", "awesome", "perfect",
"sounds good", "got it", "nice", "sure",
"how are you", "how are you?", "how are you doing today?",
"what's up",
"what day comes after Monday?",
"what day follows Friday?",
"what comes after summer?",
"what does NASA stand for?",
"what does AI stand for?",
# with trailing punctuation
"hi!", "hello.", "thanks!",
])
def test_matches(self, text):
assert _LIGHT_PATTERNS.match(text.strip()), f"Expected light match for: {text!r}"
@pytest.mark.parametrize("text", [
"what is the capital of France",
"tell me about bitcoin",
"what is 2+2",
"write me a poem",
"search for news about the election",
"what did we talk about last time",
"what is my name",
"/think compare these frameworks",
"how do I install Python",
"explain machine learning",
"", # empty string doesn't match the pattern
])
def test_no_match(self, text):
assert not _LIGHT_PATTERNS.match(text.strip()), f"Expected NO light match for: {text!r}"
# ── _parse_tier ────────────────────────────────────────────────────────────────
class TestParseTier:
@pytest.mark.parametrize("raw,expected", [
("light", "light"),
("Light", "light"),
("LIGHT\n", "light"),
("medium", "medium"),
("Medium.", "medium"),
("complex", "complex"),
("Complex!", "complex"),
# descriptive words → light
("simplefact", "light"),
("trivial question", "light"),
("basic", "light"),
("easy answer", "light"),
("general knowledge", "light"),
# unknown → medium
("unknown_category", "medium"),
("", "medium"),
("I don't know", "medium"),
# complex only if 'complex' appears in first 60 chars
("this is a complex query requiring search", "complex"),
# _parse_tier checks "complex" before "medium", so complex wins even if medium appears first
("medium complexity, not complex", "complex"),
])
def test_parse_tier(self, raw, expected):
assert _parse_tier(raw) == expected
# ── _format_history ────────────────────────────────────────────────────────────
class TestFormatHistory:
def test_empty(self):
assert _format_history([]) == "(none)"
def test_single_user_message(self):
history = [{"role": "user", "content": "hello there"}]
result = _format_history(history)
assert "user: hello there" in result
def test_multiple_turns(self):
history = [
{"role": "user", "content": "What is Python?"},
{"role": "assistant", "content": "Python is a programming language."},
]
result = _format_history(history)
assert "user: What is Python?" in result
assert "assistant: Python is a programming language." in result
def test_truncates_long_content(self):
long_content = "x" * 300
history = [{"role": "user", "content": long_content}]
result = _format_history(history)
# content is truncated to 200 chars in _format_history
assert len(result) < 250
def test_missing_keys_handled(self):
# Should not raise — uses .get() with defaults
history = [{"role": "user"}] # no content key
result = _format_history(history)
assert "user:" in result
# ── Router.route() ─────────────────────────────────────────────────────────────
class TestRouterRoute:
def _make_router(self, classify_response: str, reply_response: str = "Sure!") -> Router:
"""Return a Router with a mock model that returns given classification and reply."""
model = MagicMock()
classify_msg = MagicMock()
classify_msg.content = classify_response
reply_msg = MagicMock()
reply_msg.content = reply_response
# First ainvoke call → classification; second → reply
model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
return Router(model=model)
async def test_force_complex_bypasses_classification(self):
router = self._make_router("medium")
tier, reply = await router.route("some question", [], force_complex=True)
assert tier == "complex"
assert reply is None
# Model should NOT have been called
router.model.ainvoke.assert_not_called()
async def test_regex_light_skips_llm_classification(self):
# Regex match bypasses classification entirely; the only ainvoke call is the reply.
model = MagicMock()
reply_msg = MagicMock()
reply_msg.content = "I'm doing great!"
model.ainvoke = AsyncMock(return_value=reply_msg)
router = Router(model=model)
tier, reply = await router.route("how are you", [], force_complex=False)
assert tier == "light"
assert reply == "I'm doing great!"
# Exactly one model call — no classification step
assert router.model.ainvoke.call_count == 1
async def test_llm_classifies_medium(self):
router = self._make_router("medium")
tier, reply = await router.route("what is the bitcoin price?", [], force_complex=False)
assert tier == "medium"
assert reply is None
async def test_llm_classifies_light_generates_reply(self):
router = self._make_router("light", "Paris is the capital of France.")
tier, reply = await router.route("what is the capital of France?", [], force_complex=False)
assert tier == "light"
assert reply == "Paris is the capital of France."
async def test_llm_classifies_complex_downgraded_to_medium(self):
# Without /think prefix, complex classification → downgraded to medium
router = self._make_router("complex")
tier, reply = await router.route("compare React and Vue", [], force_complex=False)
assert tier == "medium"
assert reply is None
async def test_llm_error_falls_back_to_medium(self):
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=Exception("connection error"))
router = Router(model=model)
tier, reply = await router.route("some question", [], force_complex=False)
assert tier == "medium"
assert reply is None
async def test_light_reply_empty_falls_back_to_medium(self):
"""If the light reply comes back empty, router returns medium instead."""
router = self._make_router("light", "") # empty reply
tier, reply = await router.route("what is 2+2", [], force_complex=False)
assert tier == "medium"
assert reply is None
async def test_strips_think_tags_from_classification(self):
"""Router strips <think>...</think> from model output before parsing tier."""
model = MagicMock()
classify_msg = MagicMock()
classify_msg.content = "<think>Hmm let me think...</think>medium"
reply_msg = MagicMock()
reply_msg.content = "I'm fine!"
model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
router = Router(model=model)
tier, _ = await router.route("what is the news?", [], force_complex=False)
assert tier == "medium"
async def test_think_prefix_forces_complex(self):
"""/think prefix is already stripped by agent.py; force_complex=True is passed."""
router = self._make_router("medium")
tier, reply = await router.route("analyse this", [], force_complex=True)
assert tier == "complex"
assert reply is None

View File

@@ -0,0 +1,164 @@
"""Unit tests for vram_manager.py — VRAMManager flush/poll/prewarm logic."""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from vram_manager import VRAMManager
BASE_URL = "http://localhost:11434"
def _make_manager() -> VRAMManager:
return VRAMManager(base_url=BASE_URL)
def _mock_client(get_response=None, post_response=None):
"""Return a context-manager mock for httpx.AsyncClient."""
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=False)
if get_response is not None:
client.get = AsyncMock(return_value=get_response)
if post_response is not None:
client.post = AsyncMock(return_value=post_response)
return client
# ── _flush ─────────────────────────────────────────────────────────────────────
class TestFlush:
async def test_sends_keep_alive_zero(self):
client = _mock_client(post_response=MagicMock())
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
await mgr._flush("qwen3:4b")
client.post.assert_awaited_once()
_, kwargs = client.post.await_args
body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
assert body["model"] == "qwen3:4b"
assert body["keep_alive"] == 0
async def test_posts_to_correct_endpoint(self):
client = _mock_client(post_response=MagicMock())
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
await mgr._flush("qwen3:8b")
url = client.post.call_args[0][0]
assert url == f"{BASE_URL}/api/generate"
async def test_ignores_exceptions_silently(self):
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=False)
client.post = AsyncMock(side_effect=Exception("connection refused"))
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
# Should not raise
await mgr._flush("qwen3:4b")
# ── _prewarm ───────────────────────────────────────────────────────────────────
class TestPrewarm:
async def test_sends_keep_alive_300(self):
client = _mock_client(post_response=MagicMock())
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
await mgr._prewarm("qwen3:4b")
_, kwargs = client.post.await_args
body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
assert body["keep_alive"] == 300
assert body["model"] == "qwen3:4b"
async def test_ignores_exceptions_silently(self):
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=False)
client.post = AsyncMock(side_effect=Exception("timeout"))
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
await mgr._prewarm("qwen3:4b")
# ── _poll_evicted ──────────────────────────────────────────────────────────────
class TestPollEvicted:
async def test_returns_true_when_models_absent(self):
resp = MagicMock()
resp.json.return_value = {"models": [{"name": "some_other_model"}]}
client = _mock_client(get_response=resp)
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
result = await mgr._poll_evicted(["qwen3:4b", "qwen2.5:1.5b"], timeout=5)
assert result is True
async def test_returns_false_on_timeout_when_model_still_loaded(self):
resp = MagicMock()
resp.json.return_value = {"models": [{"name": "qwen3:4b"}]}
client = _mock_client(get_response=resp)
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.1)
assert result is False
async def test_returns_true_immediately_if_already_empty(self):
resp = MagicMock()
resp.json.return_value = {"models": []}
client = _mock_client(get_response=resp)
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
result = await mgr._poll_evicted(["qwen3:4b"], timeout=5)
assert result is True
async def test_handles_poll_error_and_continues(self):
"""If /api/ps errors, polling continues until timeout."""
client = AsyncMock()
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=False)
client.get = AsyncMock(side_effect=Exception("network error"))
with patch("vram_manager.httpx.AsyncClient", return_value=client):
mgr = _make_manager()
result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.2)
assert result is False
# ── enter_complex_mode / exit_complex_mode ─────────────────────────────────────
class TestComplexMode:
async def test_enter_complex_mode_returns_true_on_success(self):
mgr = _make_manager()
mgr._flush = AsyncMock()
mgr._poll_evicted = AsyncMock(return_value=True)
result = await mgr.enter_complex_mode()
assert result is True
async def test_enter_complex_mode_flushes_medium_models(self):
mgr = _make_manager()
mgr._flush = AsyncMock()
mgr._poll_evicted = AsyncMock(return_value=True)
await mgr.enter_complex_mode()
flushed = {call.args[0] for call in mgr._flush.call_args_list}
assert "qwen3:4b" in flushed
assert "qwen2.5:1.5b" in flushed
async def test_enter_complex_mode_returns_false_on_eviction_timeout(self):
mgr = _make_manager()
mgr._flush = AsyncMock()
mgr._poll_evicted = AsyncMock(return_value=False)
result = await mgr.enter_complex_mode()
assert result is False
async def test_exit_complex_mode_flushes_complex_and_prewarms_medium(self):
mgr = _make_manager()
mgr._flush = AsyncMock()
mgr._prewarm = AsyncMock()
await mgr.exit_complex_mode()
# Must flush 8b
flushed = {call.args[0] for call in mgr._flush.call_args_list}
assert "qwen3:8b" in flushed
# Must prewarm medium models
prewarmed = {call.args[0] for call in mgr._prewarm.call_args_list}
assert "qwen3:4b" in prewarmed
assert "qwen2.5:1.5b" in prewarmed