diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7a60b85
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+*.pyc
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..a85d22f
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,133 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## Commands
+
+**Start all services:**
+```bash
+docker compose up --build
+```
+
+**Interactive CLI (requires gateway running):**
+```bash
+python3 cli.py [--url http://localhost:8000] [--session cli-alvis] [--timeout 400]
+```
+
+**Run integration tests:**
+```bash
+python3 test_pipeline.py [--chat-id CHAT_ID]
+
+# Selective sections:
+python3 test_pipeline.py --bench-only # routing + memory benchmarks only (sections 10–13)
+python3 test_pipeline.py --easy-only # light-tier routing benchmark
+python3 test_pipeline.py --medium-only # medium-tier routing benchmark
+python3 test_pipeline.py --hard-only # complex-tier + VRAM flush benchmark
+python3 test_pipeline.py --memory-only # memory store/recall/dedup benchmark
+python3 test_pipeline.py --no-bench # service health + single name store/recall only
+```
+
+## Architecture
+
+Adolf is a multi-channel personal assistant. All LLM inference is routed through **Bifrost**, an open-source Go-based LLM gateway that adds retry logic, failover, and observability in front of Ollama.
+
+### Request flow
+
+```
+Channel adapter → POST /message {text, session_id, channel, user_id}
+ → 202 Accepted (immediate)
+ → background: run_agent_task()
+ → router.route() → tier decision (light/medium/complex)
+ → invoke agent for tier via Bifrost
+ deepagents:8000 → bifrost:8080/v1 → ollama:11436
+ → channels.deliver(session_id, channel, reply)
+ → pending_replies[session_id] queue (SSE)
+ → channel-specific callback (Telegram POST, CLI no-op)
+CLI/wiki polling → GET /reply/{session_id} (SSE, blocks until reply)
+```
+
+### Bifrost integration
+
+Bifrost (`bifrost-config.json`) is configured with the `ollama` provider pointing to the GPU Ollama instance on host port 11436. It exposes an OpenAI-compatible API at `http://bifrost:8080/v1`.
+
+`agent.py` uses `langchain_openai.ChatOpenAI` with `base_url=BIFROST_URL`. Model names use the `provider/model` format that Bifrost expects: `ollama/qwen3:4b`, `ollama/qwen3:8b`, `ollama/qwen2.5:1.5b`. Bifrost strips the `ollama/` prefix before forwarding to Ollama.
+
+`VRAMManager` bypasses Bifrost and talks directly to Ollama via `OLLAMA_BASE_URL` (host:11436) for flush/poll/prewarm operations — Bifrost cannot manage GPU VRAM.
+
+### Three-tier routing (`router.py`, `agent.py`)
+
+| Tier | Model (env var) | Trigger |
+|------|-----------------|---------|
+| light | `qwen2.5:1.5b` (`DEEPAGENTS_ROUTER_MODEL`) | Regex pre-match or LLM classifies "light" — answered by router model directly, no agent invoked |
+| medium | `qwen2.5:1.5b` (`DEEPAGENTS_MODEL`) | Default for tool-requiring queries |
+| complex | `qwen3:8b` (`DEEPAGENTS_COMPLEX_MODEL`) | `/think ` prefix only |
+
+The router does regex pre-classification first, then LLM classification. Complex tier is blocked unless the message starts with `/think ` — any LLM classification of "complex" is downgraded to medium.
+
+A global `asyncio.Semaphore(1)` (`_reply_semaphore`) serializes all LLM inference — one request at a time.
+
+### Thinking mode
+
+qwen3 models produce chain-of-thought `...` tokens via Ollama's OpenAI-compatible endpoint. Adolf controls this via system prompt prefixes:
+
+- **Medium** (`qwen2.5:1.5b`): no thinking mode in this model; fast ~3s calls
+- **Complex** (`qwen3:8b`): no prefix — thinking enabled by default, used for deep research
+- **Router** (`qwen2.5:1.5b`): no thinking support in this model
+
+`_strip_think()` in `agent.py` and `router.py` strips any `` blocks from model output before returning to users.
+
+### VRAM management (`vram_manager.py`)
+
+Hardware: GTX 1070 (8 GB). Before running the 8b model, medium models are flushed via Ollama `keep_alive=0`, then `/api/ps` is polled (15s timeout) to confirm eviction. On timeout, falls back to medium tier. After complex reply, 8b is flushed and medium models are pre-warmed as a background task.
+
+### Channel adapters (`channels.py`)
+
+- **Telegram**: Grammy Node.js bot (`grammy/bot.mjs`) long-polls Telegram → `POST /message`; replies delivered via `POST grammy:3001/send`
+- **CLI**: `cli.py` posts to `/message`, then blocks on `GET /reply/{session_id}` SSE
+
+Session IDs: `tg-` for Telegram, `cli-` for CLI. Conversation history: 5-turn buffer per session.
+
+### Services (`docker-compose.yml`)
+
+| Service | Port | Role |
+|---------|------|------|
+| `bifrost` | 8080 | LLM gateway — retries, failover, observability; config from `bifrost-config.json` |
+| `deepagents` | 8000 | FastAPI gateway + agent core |
+| `openmemory` | 8765 | FastMCP server + mem0 memory tools (Qdrant-backed) |
+| `grammy` | 3001 | grammY Telegram bot + `/send` HTTP endpoint |
+| `crawl4ai` | 11235 | JS-rendered page fetching |
+
+External (from `openai/` stack, host ports):
+- Ollama GPU: `11436` — all reply inference (via Bifrost) + VRAM management (direct)
+- Ollama CPU: `11435` — nomic-embed-text embeddings for openmemory
+- Qdrant: `6333` — vector store for memories
+- SearXNG: `11437` — web search
+
+### Bifrost config (`bifrost-config.json`)
+
+The file is mounted into the bifrost container at `/app/data/config.json`. It declares one Ollama provider key pointing to `host.docker.internal:11436` with 2 retries and 300s timeout. To add fallback providers or adjust weights, edit this file and restart the bifrost container.
+
+### Agent tools
+
+`web_search`: SearXNG search + Crawl4AI auto-fetch of top 2 results → combined snippet + full page content.
+`fetch_url`: Crawl4AI single-URL fetch.
+MCP tools from openmemory (`add_memory`, `search_memory`, `get_all_memories`) are **excluded** from agent tools — memory management is handled outside the agent loop.
+
+### Medium vs Complex agent
+
+| Agent | Builder | Speed | Use case |
+|-------|---------|-------|----------|
+| medium | `_DirectModel` (single LLM call, no tools) | ~3s | General questions, conversation |
+| complex | `create_deep_agent` (deepagents) | Slow — multi-step planner | Deep research via `/think` prefix |
+
+### Key files
+
+- `agent.py` — FastAPI app, lifespan wiring, `run_agent_task()`, all endpoints
+- `bifrost-config.json` — Bifrost provider config (Ollama GPU, retries, timeouts)
+- `channels.py` — channel registry and `deliver()` dispatcher
+- `router.py` — `Router` class: regex + LLM classification, light-tier reply generation
+- `vram_manager.py` — `VRAMManager`: flush/poll/prewarm Ollama VRAM directly
+- `agent_factory.py` — `build_medium_agent` / `build_complex_agent` via `create_deep_agent()`
+- `openmemory/server.py` — FastMCP + mem0 config with custom extraction/dedup prompts
+- `wiki_research.py` — batch research pipeline using `/message` + SSE polling
+- `grammy/bot.mjs` — Telegram long-poll + HTTP `/send` endpoint
diff --git a/Dockerfile b/Dockerfile
index d81ee0c..22b7a8e 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM python:3.12-slim
WORKDIR /app
-RUN pip install --no-cache-dir deepagents langchain-ollama langgraph \
+RUN pip install --no-cache-dir deepagents langchain-openai langgraph \
fastapi uvicorn langchain-mcp-adapters langchain-community httpx
COPY agent.py channels.py vram_manager.py router.py agent_factory.py hello_world.py .
diff --git a/agent.py b/agent.py
index 8d415f9..4d85675 100644
--- a/agent.py
+++ b/agent.py
@@ -10,7 +10,7 @@ from pydantic import BaseModel
import re as _re
import httpx as _httpx
-from langchain_ollama import ChatOllama
+from langchain_openai import ChatOpenAI
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_community.utilities import SearxSearchWrapper
from langchain_core.tools import Tool
@@ -20,8 +20,12 @@ from router import Router
from agent_factory import build_medium_agent, build_complex_agent
import channels
+# Bifrost gateway — all LLM inference goes through here
+BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1")
+# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
-ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b")
+
+ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b")
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
@@ -31,10 +35,12 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
MAX_HISTORY_TURNS = 5
_conversation_buffers: dict[str, list] = {}
+# /no_think at the start of the system prompt disables qwen3 chain-of-thought.
+# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so
+# /no_think lands at position 0 and is respected by qwen3 models via Ollama.
MEDIUM_SYSTEM_PROMPT = (
- "You are a helpful AI assistant. "
- "Use web_search for questions about current events or facts you don't know. "
- "Reply concisely."
+ "You are a helpful AI assistant. Reply concisely. "
+ "If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'"
)
COMPLEX_SYSTEM_PROMPT = (
@@ -54,6 +60,8 @@ complex_agent = None
router: Router = None
vram_manager: VRAMManager = None
mcp_client = None
+_memory_add_tool = None
+_memory_search_tool = None
# GPU mutex: one LLM inference at a time
_reply_semaphore = asyncio.Semaphore(1)
@@ -61,21 +69,34 @@ _reply_semaphore = asyncio.Semaphore(1)
@asynccontextmanager
async def lifespan(app: FastAPI):
- global medium_agent, complex_agent, router, vram_manager, mcp_client
+ global medium_agent, complex_agent, router, vram_manager, mcp_client, \
+ _memory_add_tool, _memory_search_tool
# Register channel adapters
channels.register_defaults()
- # Three model instances
- router_model = ChatOllama(
- model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096,
+ # All three models route through Bifrost → Ollama GPU.
+ # Bifrost adds retry logic, observability, and failover.
+ # Model names use provider/model format: Bifrost strips the "ollama/" prefix
+ # before forwarding to Ollama's /v1/chat/completions endpoint.
+ router_model = ChatOpenAI(
+ model=f"ollama/{ROUTER_MODEL}",
+ base_url=BIFROST_URL,
+ api_key="dummy",
temperature=0,
+ timeout=30,
)
- medium_model = ChatOllama(
- model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192
+ medium_model = ChatOpenAI(
+ model=f"ollama/{MEDIUM_MODEL}",
+ base_url=BIFROST_URL,
+ api_key="dummy",
+ timeout=180,
)
- complex_model = ChatOllama(
- model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384
+ complex_model = ChatOpenAI(
+ model=f"ollama/{COMPLEX_MODEL}",
+ base_url=BIFROST_URL,
+ api_key="dummy",
+ timeout=600,
)
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
@@ -97,6 +118,13 @@ async def lifespan(app: FastAPI):
agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")]
+ # Expose memory tools directly so run_agent_task can call them outside the agent loop
+ for t in mcp_tools:
+ if t.name == "add_memory":
+ _memory_add_tool = t
+ elif t.name == "search_memory":
+ _memory_search_tool = t
+
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
def _crawl4ai_fetch(url: str) -> str:
@@ -187,7 +215,8 @@ async def lifespan(app: FastAPI):
)
print(
- f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
+ f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | "
+ f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}",
flush=True,
)
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
@@ -222,13 +251,19 @@ class ChatRequest(BaseModel):
# ── helpers ────────────────────────────────────────────────────────────────────
+def _strip_think(text: str) -> str:
+ """Strip qwen3 chain-of-thought blocks that appear inline in content
+ when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions)."""
+ return _re.sub(r".*?", "", text, flags=_re.DOTALL).strip()
+
+
def _extract_final_text(result) -> str | None:
msgs = result.get("messages", [])
for m in reversed(msgs):
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
- return m.content
+ return _strip_think(m.content)
if isinstance(result, dict) and result.get("output"):
- return result["output"]
+ return _strip_think(result["output"])
return None
@@ -244,6 +279,34 @@ def _log_messages(result):
print(f"[agent] {role} → {tc['name']}({tc['args']})", flush=True)
+# ── memory helpers ─────────────────────────────────────────────────────────────
+
+async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
+ """Store a conversation turn in openmemory (runs as a background task)."""
+ if _memory_add_tool is None:
+ return
+ t0 = time.monotonic()
+ try:
+ text = f"User: {user_msg}\nAssistant: {assistant_reply}"
+ await _memory_add_tool.ainvoke({"text": text, "user_id": session_id})
+ print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
+ except Exception as e:
+ print(f"[memory] error: {e}", flush=True)
+
+
+async def _retrieve_memories(message: str, session_id: str) -> str:
+ """Search openmemory for relevant context. Returns formatted string or ''."""
+ if _memory_search_tool is None:
+ return ""
+ try:
+ result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id})
+ if result and result.strip() and result.strip() != "[]":
+ return f"Relevant memories:\n{result}"
+ except Exception:
+ pass
+ return ""
+
+
# ── core task ──────────────────────────────────────────────────────────────────
async def run_agent_task(message: str, session_id: str, channel: str = "telegram"):
@@ -261,7 +324,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
history = _conversation_buffers.get(session_id, [])
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
- tier, light_reply = await router.route(clean_message, history, force_complex)
+ # Retrieve memories once; inject into history so ALL tiers can use them
+ memories = await _retrieve_memories(clean_message, session_id)
+ enriched_history = (
+ [{"role": "system", "content": memories}] + history if memories else history
+ )
+
+ tier, light_reply = await router.route(clean_message, enriched_history, force_complex)
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
final_text = None
@@ -273,6 +342,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
elif tier == "medium":
system_prompt = MEDIUM_SYSTEM_PROMPT
+ if memories:
+ system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
@@ -289,9 +360,12 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
if not ok:
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
tier = "medium"
+ system_prompt = MEDIUM_SYSTEM_PROMPT
+ if memories:
+ system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
- {"role": "system", "content": MEDIUM_SYSTEM_PROMPT},
+ {"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
@@ -320,7 +394,10 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
# Deliver reply through the originating channel
if final_text:
t1 = time.monotonic()
- await channels.deliver(session_id, channel, final_text)
+ try:
+ await channels.deliver(session_id, channel, final_text)
+ except Exception as e:
+ print(f"[agent] delivery error (non-fatal): {e}", flush=True)
send_elapsed = time.monotonic() - t1
print(
f"[agent] replied in {time.monotonic() - t0:.1f}s "
@@ -331,12 +408,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
else:
print("[agent] warning: no text reply from agent", flush=True)
- # Update conversation buffer
+ # Update conversation buffer and schedule memory storage
if final_text:
buf = _conversation_buffers.get(session_id, [])
buf.append({"role": "user", "content": clean_message})
buf.append({"role": "assistant", "content": final_text})
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
+ asyncio.create_task(_store_memory(session_id, clean_message, final_text))
# ── endpoints ──────────────────────────────────────────────────────────────────
@@ -374,7 +452,7 @@ async def reply_stream(session_id: str):
try:
text = await asyncio.wait_for(q.get(), timeout=900)
# Escape newlines so entire reply fits in one SSE data line
- yield f"data: {text.replace(chr(10), '\\n').replace(chr(13), '')}\n\n"
+ yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n"
except asyncio.TimeoutError:
yield "data: [timeout]\n\n"
diff --git a/agent_factory.py b/agent_factory.py
index 6182ff4..9fa91f1 100644
--- a/agent_factory.py
+++ b/agent_factory.py
@@ -1,13 +1,21 @@
from deepagents import create_deep_agent
+class _DirectModel:
+ """Thin wrapper: single LLM call, no tools, same ainvoke interface as a graph."""
+
+ def __init__(self, model):
+ self._model = model
+
+ async def ainvoke(self, input_dict: dict) -> dict:
+ messages = input_dict["messages"]
+ response = await self._model.ainvoke(messages)
+ return {"messages": list(messages) + [response]}
+
+
def build_medium_agent(model, agent_tools: list, system_prompt: str):
- """Medium agent: create_deep_agent with TodoList planning, no subagents."""
- return create_deep_agent(
- model=model,
- tools=agent_tools,
- system_prompt=system_prompt,
- )
+ """Medium agent: single LLM call, no tools — fast ~3s response."""
+ return _DirectModel(model)
def build_complex_agent(model, agent_tools: list, system_prompt: str):
diff --git a/bifrost-config.json b/bifrost-config.json
new file mode 100644
index 0000000..7db331e
--- /dev/null
+++ b/bifrost-config.json
@@ -0,0 +1,58 @@
+{
+ "client": {
+ "drop_excess_requests": false
+ },
+ "providers": {
+ "ollama": {
+ "keys": [
+ {
+ "name": "ollama-gpu",
+ "value": "dummy",
+ "models": [
+ "qwen2.5:0.5b",
+ "qwen2.5:1.5b",
+ "qwen3:4b",
+ "gemma3:4b",
+ "qwen3:8b"
+ ],
+ "weight": 1.0
+ }
+ ],
+ "network_config": {
+ "base_url": "http://host.docker.internal:11436",
+ "default_request_timeout_in_seconds": 300,
+ "max_retries": 2,
+ "retry_backoff_initial_ms": 500,
+ "retry_backoff_max_ms": 10000
+ }
+ },
+ "ollama-cpu": {
+ "keys": [
+ {
+ "name": "ollama-cpu-key",
+ "value": "dummy",
+ "models": [
+ "gemma3:1b",
+ "qwen2.5:1.5b",
+ "qwen2.5:3b"
+ ],
+ "weight": 1.0
+ }
+ ],
+ "network_config": {
+ "base_url": "http://host.docker.internal:11435",
+ "default_request_timeout_in_seconds": 120,
+ "max_retries": 2,
+ "retry_backoff_initial_ms": 500,
+ "retry_backoff_max_ms": 10000
+ },
+ "custom_provider_config": {
+ "base_provider_type": "openai",
+ "allowed_requests": {
+ "chat_completion": true,
+ "chat_completion_stream": true
+ }
+ }
+ }
+ }
+}
diff --git a/docker-compose.yml b/docker-compose.yml
index 6851e19..fbc177d 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -1,4 +1,19 @@
services:
+ bifrost:
+ image: maximhq/bifrost
+ container_name: bifrost
+ ports:
+ - "8080:8080"
+ volumes:
+ - ./bifrost-config.json:/app/data/config.json:ro
+ environment:
+ - APP_DIR=/app/data
+ - APP_PORT=8080
+ - LOG_LEVEL=info
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ restart: unless-stopped
+
deepagents:
build: .
container_name: deepagents
@@ -6,8 +21,11 @@ services:
- "8000:8000"
environment:
- PYTHONUNBUFFERED=1
+ # Bifrost gateway — all LLM inference goes through here
+ - BIFROST_URL=http://bifrost:8080/v1
+ # Direct Ollama GPU URL — used only by VRAMManager for flush/prewarm
- OLLAMA_BASE_URL=http://host.docker.internal:11436
- - DEEPAGENTS_MODEL=qwen3:4b
+ - DEEPAGENTS_MODEL=qwen2.5:1.5b
- DEEPAGENTS_COMPLEX_MODEL=qwen3:8b
- DEEPAGENTS_ROUTER_MODEL=qwen2.5:1.5b
- SEARXNG_URL=http://host.docker.internal:11437
@@ -19,6 +37,7 @@ services:
- openmemory
- grammy
- crawl4ai
+ - bifrost
restart: unless-stopped
openmemory:
@@ -27,8 +46,9 @@ services:
ports:
- "8765:8765"
environment:
- # Extraction LLM (qwen2.5:1.5b) runs on GPU after reply — fast 2-5s extraction
+ # Extraction LLM runs on GPU — qwen2.5:1.5b for speed (~3s)
- OLLAMA_GPU_URL=http://host.docker.internal:11436
+ - OLLAMA_EXTRACTION_MODEL=qwen2.5:1.5b
# Embedding (nomic-embed-text) runs on CPU — fast enough for search (50-150ms)
- OLLAMA_CPU_URL=http://host.docker.internal:11435
extra_hosts:
diff --git a/openmemory/server.py b/openmemory/server.py
index 73fd93c..56f8575 100644
--- a/openmemory/server.py
+++ b/openmemory/server.py
@@ -6,6 +6,7 @@ from mem0 import Memory
# Extraction LLM — GPU Ollama (qwen3:4b, same model as medium agent)
# Runs after reply when GPU is idle; spin-wait in agent.py prevents contention
OLLAMA_GPU_URL = os.getenv("OLLAMA_GPU_URL", "http://host.docker.internal:11436")
+EXTRACTION_MODEL = os.getenv("OLLAMA_EXTRACTION_MODEL", "qwen2.5:1.5b")
# Embedding — CPU Ollama (nomic-embed-text, 137 MB RAM)
# Used for both search (50-150ms, acceptable) and store-time embedding
@@ -94,7 +95,7 @@ config = {
"llm": {
"provider": "ollama",
"config": {
- "model": "qwen3:4b",
+ "model": EXTRACTION_MODEL,
"ollama_base_url": OLLAMA_GPU_URL,
"temperature": 0.1, # consistent JSON output
},
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..2302a28
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+testpaths = tests/unit
+pythonpath = .
+asyncio_mode = auto
diff --git a/test_pipeline.py b/tests/integration/test_pipeline.py
similarity index 88%
rename from test_pipeline.py
rename to tests/integration/test_pipeline.py
index 034720d..ce13775 100644
--- a/test_pipeline.py
+++ b/tests/integration/test_pipeline.py
@@ -42,6 +42,7 @@ import urllib.request
# ── config ────────────────────────────────────────────────────────────────────
DEEPAGENTS = "http://localhost:8000"
+BIFROST = "http://localhost:8080"
OPENMEMORY = "http://localhost:8765"
GRAMMY_HOST = "localhost"
GRAMMY_PORT = 3001
@@ -49,7 +50,7 @@ OLLAMA_GPU = "http://localhost:11436"
OLLAMA_CPU = "http://localhost:11435"
QDRANT = "http://localhost:6333"
SEARXNG = "http://localhost:11437"
-COMPOSE_FILE = "/home/alvis/agap_git/adolf/docker-compose.yml"
+COMPOSE_FILE = "/home/alvis/adolf/docker-compose.yml"
DEFAULT_CHAT_ID = "346967270"
NAMES = [
@@ -166,6 +167,19 @@ def fetch_logs(since_s=600):
return []
+def fetch_bifrost_logs(since_s=120):
+ """Return bifrost container log lines from the last since_s seconds."""
+ try:
+ r = subprocess.run(
+ ["docker", "compose", "-f", COMPOSE_FILE, "logs", "bifrost",
+ f"--since={int(since_s)}s", "--no-log-prefix"],
+ capture_output=True, text=True, timeout=10,
+ )
+ return r.stdout.splitlines()
+ except Exception:
+ return []
+
+
def parse_run_block(lines, msg_prefix):
"""
Scan log lines for the LAST '[agent] running: ' block.
@@ -303,10 +317,12 @@ _run_hard = not args.no_bench and not args.easy_only and not args.medium_onl
_run_memory = not args.no_bench and not args.easy_only and not args.medium_only and not args.hard_only
random_name = random.choice(NAMES)
+# Use a unique chat_id per run to avoid cross-run history contamination
+TEST_CHAT_ID = f"{CHAT_ID}-{random_name.lower()}"
if not _skip_pipeline:
print(f"\n Test name : \033[1m{random_name}\033[0m")
- print(f" Chat ID : {CHAT_ID}")
+ print(f" Chat ID : {TEST_CHAT_ID}")
# ── 1. service health ─────────────────────────────────────────────────────────
@@ -331,6 +347,93 @@ if not _skip_pipeline:
timings["health_check"] = time.monotonic() - t0
+# ── 1b. Bifrost gateway ───────────────────────────────────────────────────────
+if not _skip_pipeline:
+ print(f"\n[{INFO}] 1b. Bifrost gateway (port 8080)")
+ t0 = time.monotonic()
+
+ # Health ──────────────────────────────────────────────────────────────────
+ try:
+ status, body = get(f"{BIFROST}/health", timeout=5)
+ ok = status == 200
+ report("Bifrost /health reachable", ok, f"HTTP {status}")
+ except Exception as e:
+ report("Bifrost /health reachable", False, str(e))
+
+ # Ollama GPU models listed ────────────────────────────────────────────────
+ try:
+ status, body = get(f"{BIFROST}/v1/models", timeout=5)
+ data = json.loads(body)
+ model_ids = [m.get("id", "") for m in data.get("data", [])]
+ gpu_models = [m for m in model_ids if m.startswith("ollama/")]
+ report("Bifrost lists ollama GPU models", len(gpu_models) > 0,
+ f"found: {gpu_models}")
+ for expected in ["ollama/qwen3:4b", "ollama/qwen3:8b", "ollama/qwen2.5:1.5b"]:
+ report(f" model {expected} listed", expected in model_ids)
+ except Exception as e:
+ report("Bifrost /v1/models", False, str(e))
+
+ # Direct inference through Bifrost → GPU Ollama ───────────────────────────
+ # Uses the smallest GPU model (qwen2.5:0.5b) to keep latency low.
+ print(f" [bifrost-infer] direct POST /v1/chat/completions → ollama/qwen2.5:0.5b ...")
+ t_infer = time.monotonic()
+ try:
+ infer_payload = {
+ "model": "ollama/qwen2.5:0.5b",
+ "messages": [{"role": "user", "content": "Reply with exactly one word: pong"}],
+ "max_tokens": 16,
+ }
+ infer_data = json.dumps(infer_payload).encode()
+ req = urllib.request.Request(
+ f"{BIFROST}/v1/chat/completions",
+ data=infer_data,
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+ with urllib.request.urlopen(req, timeout=60) as r:
+ infer_status = r.status
+ infer_body = json.loads(r.read().decode())
+ infer_elapsed = time.monotonic() - t_infer
+ reply_content = infer_body.get("choices", [{}])[0].get("message", {}).get("content", "")
+ used_model = infer_body.get("model", "")
+ report("Bifrost → Ollama GPU inference succeeds",
+ infer_status == 200 and bool(reply_content),
+ f"{infer_elapsed:.1f}s model={used_model!r} reply={reply_content[:60]!r}")
+ timings["bifrost_direct_infer"] = infer_elapsed
+ except Exception as e:
+ report("Bifrost → Ollama GPU inference succeeds", False, str(e))
+ timings["bifrost_direct_infer"] = None
+
+ # deepagents is configured to route through Bifrost ───────────────────────
+ # The startup log line "[agent] bifrost=http://bifrost:8080/v1 | ..." is emitted
+ # during lifespan setup and confirms deepagents is using Bifrost as the LLM gateway.
+ try:
+ r = subprocess.run(
+ ["docker", "compose", "-f", COMPOSE_FILE, "logs", "deepagents",
+ "--since=3600s", "--no-log-prefix"],
+ capture_output=True, text=True, timeout=10,
+ )
+ log_lines = r.stdout.splitlines()
+ bifrost_line = next(
+ (l for l in log_lines if "[agent] bifrost=" in l and "bifrost:8080" in l),
+ None,
+ )
+ report(
+ "deepagents startup log confirms bifrost URL",
+ bifrost_line is not None,
+ bifrost_line.strip() if bifrost_line else "line not found in logs",
+ )
+ # Also confirm model names use provider/model format (ollama/...)
+ if bifrost_line:
+ has_prefix = "router=ollama/" in bifrost_line and "medium=ollama/" in bifrost_line
+ report("deepagents model names use ollama/ prefix", has_prefix,
+ bifrost_line.strip())
+ except Exception as e:
+ report("deepagents startup log check", False, str(e))
+
+ timings["bifrost_check"] = time.monotonic() - t0
+
+
# ── 2. GPU Ollama ─────────────────────────────────────────────────────────────
if not _skip_pipeline:
print(f"\n[{INFO}] 2. GPU Ollama (port 11436)")
@@ -415,11 +518,18 @@ if not _skip_pipeline:
# ── 6–8. Name memory pipeline ─────────────────────────────────────────────────
if not _skip_pipeline:
print(f"\n[{INFO}] 6–8. Name memory pipeline")
- print(f" chat_id={CHAT_ID} name={random_name}")
+ print(f" chat_id={TEST_CHAT_ID} name={random_name}")
store_msg = f"remember that your name is {random_name}"
recall_msg = "what is your name?"
+ # Clear adolf_memories so each run starts clean (avoids cross-run stale memories)
+ try:
+ post_json(f"{QDRANT}/collections/adolf_memories/points/delete",
+ {"filter": {}}, timeout=5)
+ except Exception:
+ pass
+
pts_before = qdrant_count()
print(f" Qdrant points before: {pts_before}")
@@ -429,7 +539,7 @@ if not _skip_pipeline:
try:
status, _ = post_json(f"{DEEPAGENTS}/chat",
- {"message": store_msg, "chat_id": CHAT_ID}, timeout=5)
+ {"message": store_msg, "chat_id": TEST_CHAT_ID}, timeout=5)
t_accept = time.monotonic() - t_store
report("POST /chat (store) returns 202 immediately",
status == 202 and t_accept < 1, f"status={status}, t={t_accept:.3f}s")
@@ -472,7 +582,7 @@ if not _skip_pipeline:
try:
status, _ = post_json(f"{DEEPAGENTS}/chat",
- {"message": recall_msg, "chat_id": CHAT_ID}, timeout=5)
+ {"message": recall_msg, "chat_id": TEST_CHAT_ID}, timeout=5)
t_accept2 = time.monotonic() - t_recall
report("POST /chat (recall) returns 202 immediately",
status == 202 and t_accept2 < 1, f"status={status}, t={t_accept2:.3f}s")
@@ -496,6 +606,19 @@ if not _skip_pipeline:
report("Agent replied to recall message", False, "timeout")
report(f"Reply contains '{random_name}'", False, "no reply")
+ # ── 8b. Verify requests passed through Bifrost ────────────────────────────
+ # After the store+recall round-trip, Bifrost logs must show forwarded
+ # requests. An empty Bifrost log means deepagents bypasses the gateway.
+ bifrost_lines = fetch_bifrost_logs(since_s=300)
+ report("Bifrost container has log output (requests forwarded)",
+ len(bifrost_lines) > 0,
+ f"{len(bifrost_lines)} lines in bifrost logs")
+ # Bifrost logs contain the request body; AsyncOpenAI user-agent confirms the path
+ bifrost_raw = "\n".join(bifrost_lines)
+ report(" Bifrost log shows AsyncOpenAI agent requests",
+ "AsyncOpenAI" in bifrost_raw,
+ f"{'found' if 'AsyncOpenAI' in bifrost_raw else 'NOT found'} in bifrost logs")
+
# ── 9. Timing profile ─────────────────────────────────────────────────────────
if not _skip_pipeline:
diff --git a/tests/requirements.txt b/tests/requirements.txt
new file mode 100644
index 0000000..94022af
--- /dev/null
+++ b/tests/requirements.txt
@@ -0,0 +1,2 @@
+pytest>=8.0
+pytest-asyncio>=0.23
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
new file mode 100644
index 0000000..117e3f7
--- /dev/null
+++ b/tests/unit/conftest.py
@@ -0,0 +1,80 @@
+"""
+Stub out all third-party packages that Adolf's source modules import.
+This lets the unit tests run without a virtualenv or Docker environment.
+Stubs are installed into sys.modules before any test file is collected.
+"""
+
+import sys
+from unittest.mock import MagicMock
+
+
+# ── helpers ────────────────────────────────────────────────────────────────────
+
+def _mock(name: str) -> MagicMock:
+ m = MagicMock(name=name)
+ sys.modules[name] = m
+ return m
+
+
+# ── pydantic: BaseModel must be a real class so `class Foo(BaseModel)` works ──
+
+class _FakeBaseModel:
+ model_fields: dict = {}
+
+ def __init_subclass__(cls, **kwargs):
+ pass
+
+ def __init__(self, **data):
+ for k, v in data.items():
+ setattr(self, k, v)
+
+
+_pydantic = _mock("pydantic")
+_pydantic.BaseModel = _FakeBaseModel
+
+# ── httpx: used by channels.py, vram_manager.py, agent.py ────────────────────
+
+_mock("httpx")
+
+# ── fastapi ───────────────────────────────────────────────────────────────────
+
+_fastapi = _mock("fastapi")
+_mock("fastapi.responses")
+
+# ── langchain stack ───────────────────────────────────────────────────────────
+
+_mock("langchain_openai")
+
+_lc_core = _mock("langchain_core")
+_lc_msgs = _mock("langchain_core.messages")
+_mock("langchain_core.tools")
+
+# Provide real-ish message classes so router.py can instantiate them
+class _FakeMsg:
+ def __init__(self, content=""):
+ self.content = content
+
+class SystemMessage(_FakeMsg):
+ pass
+
+class HumanMessage(_FakeMsg):
+ pass
+
+class AIMessage(_FakeMsg):
+ def __init__(self, content="", tool_calls=None):
+ super().__init__(content)
+ self.tool_calls = tool_calls or []
+
+_lc_msgs.SystemMessage = SystemMessage
+_lc_msgs.HumanMessage = HumanMessage
+_lc_msgs.AIMessage = AIMessage
+
+_mock("langchain_mcp_adapters")
+_mock("langchain_mcp_adapters.client")
+_mock("langchain_community")
+_mock("langchain_community.utilities")
+
+# ── deepagents (agent_factory.py) ─────────────────────────────────────────────
+
+_mock("deepagents")
+
diff --git a/tests/unit/test_agent_helpers.py b/tests/unit/test_agent_helpers.py
new file mode 100644
index 0000000..9df77d1
--- /dev/null
+++ b/tests/unit/test_agent_helpers.py
@@ -0,0 +1,161 @@
+"""
+Unit tests for agent.py helper functions:
+ - _strip_think(text)
+ - _extract_final_text(result)
+
+agent.py has heavy FastAPI/LangChain imports; conftest.py stubs them out so
+these pure functions can be imported and tested in isolation.
+"""
+
+import pytest
+
+# conftest.py has already installed all stubs into sys.modules.
+# The FastAPI app is instantiated at module level in agent.py —
+# with the mocked fastapi, that just creates a MagicMock() object
+# and the route decorators are no-ops.
+from agent import _strip_think, _extract_final_text
+
+
+# ── _strip_think ───────────────────────────────────────────────────────────────
+
+class TestStripThink:
+ def test_removes_single_think_block(self):
+ text = "internal reasoningFinal answer."
+ assert _strip_think(text) == "Final answer."
+
+ def test_removes_multiline_think_block(self):
+ text = "\nLine one.\nLine two.\n\nResult here."
+ assert _strip_think(text) == "Result here."
+
+ def test_no_think_block_unchanged(self):
+ text = "This is a plain answer with no think block."
+ assert _strip_think(text) == text
+
+ def test_removes_multiple_think_blocks(self):
+ text = "step 1middlestep 2end"
+ assert _strip_think(text) == "middleend"
+
+ def test_strips_surrounding_whitespace(self):
+ text = " stuff answer "
+ assert _strip_think(text) == "answer"
+
+ def test_empty_think_block(self):
+ text = "Hello."
+ assert _strip_think(text) == "Hello."
+
+ def test_empty_string(self):
+ assert _strip_think("") == ""
+
+ def test_only_think_block_returns_empty(self):
+ text = "nothing useful"
+ assert _strip_think(text) == ""
+
+ def test_think_block_with_nested_tags(self):
+ text = "I should use bold hereDone."
+ assert _strip_think(text) == "Done."
+
+ def test_preserves_markdown(self):
+ text = "plan## Report\n\n- Point one\n- Point two"
+ result = _strip_think(text)
+ assert result == "## Report\n\n- Point one\n- Point two"
+
+
+# ── _extract_final_text ────────────────────────────────────────────────────────
+
+class TestExtractFinalText:
+ def _ai_msg(self, content: str, tool_calls=None):
+ """Create a minimal AIMessage-like object."""
+ class AIMessage:
+ pass
+ m = AIMessage()
+ m.content = content
+ m.tool_calls = tool_calls or []
+ return m
+
+ def _human_msg(self, content: str):
+ class HumanMessage:
+ pass
+ m = HumanMessage()
+ m.content = content
+ return m
+
+ def test_returns_last_ai_message_content(self):
+ result = {
+ "messages": [
+ self._human_msg("what is 2+2"),
+ self._ai_msg("The answer is 4."),
+ ]
+ }
+ assert _extract_final_text(result) == "The answer is 4."
+
+ def test_returns_last_of_multiple_ai_messages(self):
+ result = {
+ "messages": [
+ self._ai_msg("First response."),
+ self._human_msg("follow-up"),
+ self._ai_msg("Final response."),
+ ]
+ }
+ assert _extract_final_text(result) == "Final response."
+
+ def test_skips_empty_ai_messages(self):
+ result = {
+ "messages": [
+ self._ai_msg("Real answer."),
+ self._ai_msg(""), # empty — should be skipped
+ ]
+ }
+ assert _extract_final_text(result) == "Real answer."
+
+ def test_strips_think_tags_from_ai_message(self):
+ result = {
+ "messages": [
+ self._ai_msg("reasoning hereClean reply."),
+ ]
+ }
+ assert _extract_final_text(result) == "Clean reply."
+
+ def test_falls_back_to_output_field(self):
+ result = {
+ "messages": [],
+ "output": "Fallback output.",
+ }
+ assert _extract_final_text(result) == "Fallback output."
+
+ def test_strips_think_from_output_field(self):
+ result = {
+ "messages": [],
+ "output": "thoughtsActual output.",
+ }
+ assert _extract_final_text(result) == "Actual output."
+
+ def test_returns_none_when_no_content(self):
+ result = {"messages": []}
+ assert _extract_final_text(result) is None
+
+ def test_returns_none_when_no_messages_and_no_output(self):
+ result = {"messages": [], "output": ""}
+ # output is falsy → returns None
+ assert _extract_final_text(result) is None
+
+ def test_skips_non_ai_messages(self):
+ result = {
+ "messages": [
+ self._human_msg("user question"),
+ ]
+ }
+ assert _extract_final_text(result) is None
+
+ def test_handles_ai_message_with_tool_calls_but_no_content(self):
+ """AIMessage that only has tool_calls (no content) should be skipped."""
+ msg = self._ai_msg("", tool_calls=[{"name": "web_search", "args": {}}])
+ result = {"messages": [msg]}
+ assert _extract_final_text(result) is None
+
+ def test_multiline_think_stripped_correctly(self):
+ result = {
+ "messages": [
+ self._ai_msg("\nLong\nreasoning\nblock\n\n## Report\n\nSome content."),
+ ]
+ }
+ assert _extract_final_text(result) == "## Report\n\nSome content."
diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py
new file mode 100644
index 0000000..b0418bc
--- /dev/null
+++ b/tests/unit/test_channels.py
@@ -0,0 +1,125 @@
+"""Unit tests for channels.py — register, deliver, pending_replies queue."""
+
+import asyncio
+import pytest
+from unittest.mock import AsyncMock, patch
+
+import channels
+
+
+@pytest.fixture(autouse=True)
+def reset_channels_state():
+ """Clear module-level state before and after every test."""
+ channels._callbacks.clear()
+ channels.pending_replies.clear()
+ yield
+ channels._callbacks.clear()
+ channels.pending_replies.clear()
+
+
+# ── register ───────────────────────────────────────────────────────────────────
+
+class TestRegister:
+ def test_register_stores_callback(self):
+ cb = AsyncMock()
+ channels.register("test_channel", cb)
+ assert channels._callbacks["test_channel"] is cb
+
+ def test_register_overwrites_existing(self):
+ cb1 = AsyncMock()
+ cb2 = AsyncMock()
+ channels.register("ch", cb1)
+ channels.register("ch", cb2)
+ assert channels._callbacks["ch"] is cb2
+
+ def test_register_multiple_channels(self):
+ cb_a = AsyncMock()
+ cb_b = AsyncMock()
+ channels.register("a", cb_a)
+ channels.register("b", cb_b)
+ assert channels._callbacks["a"] is cb_a
+ assert channels._callbacks["b"] is cb_b
+
+
+# ── deliver ────────────────────────────────────────────────────────────────────
+
+class TestDeliver:
+ async def test_deliver_enqueues_reply(self):
+ channels.register("cli", AsyncMock())
+ await channels.deliver("cli-alvis", "cli", "hello world")
+ q = channels.pending_replies["cli-alvis"]
+ assert not q.empty()
+ assert await q.get() == "hello world"
+
+ async def test_deliver_calls_channel_callback(self):
+ cb = AsyncMock()
+ channels.register("telegram", cb)
+ await channels.deliver("tg-123", "telegram", "reply text")
+ cb.assert_awaited_once_with("tg-123", "reply text")
+
+ async def test_deliver_unknown_channel_still_enqueues(self):
+ """No registered callback for channel → reply still goes to the queue."""
+ await channels.deliver("cli-bob", "nonexistent", "fallback reply")
+ q = channels.pending_replies["cli-bob"]
+ assert await q.get() == "fallback reply"
+
+ async def test_deliver_unknown_channel_does_not_raise(self):
+ """Missing callback must not raise an exception."""
+ await channels.deliver("cli-x", "ghost_channel", "msg")
+
+ async def test_deliver_creates_queue_if_absent(self):
+ channels.register("cli", AsyncMock())
+ assert "cli-new" not in channels.pending_replies
+ await channels.deliver("cli-new", "cli", "hi")
+ assert "cli-new" in channels.pending_replies
+
+ async def test_deliver_reuses_existing_queue(self):
+ """Second deliver to the same session appends to the same queue."""
+ channels.register("cli", AsyncMock())
+ await channels.deliver("cli-alvis", "cli", "first")
+ await channels.deliver("cli-alvis", "cli", "second")
+ q = channels.pending_replies["cli-alvis"]
+ assert await q.get() == "first"
+ assert await q.get() == "second"
+
+ async def test_deliver_telegram_sends_to_callback(self):
+ sent = []
+
+ async def fake_tg(session_id, text):
+ sent.append((session_id, text))
+
+ channels.register("telegram", fake_tg)
+ await channels.deliver("tg-999", "telegram", "test message")
+ assert sent == [("tg-999", "test message")]
+
+
+# ── register_defaults ──────────────────────────────────────────────────────────
+
+class TestRegisterDefaults:
+ def test_registers_telegram_and_cli(self):
+ channels.register_defaults()
+ assert "telegram" in channels._callbacks
+ assert "cli" in channels._callbacks
+
+ async def test_cli_callback_is_noop(self):
+ """CLI send callback does nothing (replies are handled via SSE queue)."""
+ channels.register_defaults()
+ cb = channels._callbacks["cli"]
+ # Should not raise and should return None
+ result = await cb("cli-alvis", "some reply")
+ assert result is None
+
+ async def test_telegram_callback_chunks_long_messages(self):
+ """Telegram callback splits messages > 4000 chars into chunks."""
+ channels.register_defaults()
+ cb = channels._callbacks["telegram"]
+ long_text = "x" * 9000 # > 4000 chars → should produce 3 chunks
+ with patch("channels.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=False)
+ mock_client.post = AsyncMock()
+ mock_client_cls.return_value = mock_client
+ await cb("tg-123", long_text)
+ # 9000 chars / 4000 per chunk = 3 POST calls
+ assert mock_client.post.await_count == 3
diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py
new file mode 100644
index 0000000..5d382b4
--- /dev/null
+++ b/tests/unit/test_router.py
@@ -0,0 +1,200 @@
+"""Unit tests for router.py — Router, _parse_tier, _format_history, _LIGHT_PATTERNS."""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from router import Router, _parse_tier, _format_history, _LIGHT_PATTERNS
+
+
+# ── _LIGHT_PATTERNS regex ──────────────────────────────────────────────────────
+
+class TestLightPatterns:
+ @pytest.mark.parametrize("text", [
+ "hi", "Hi", "HI",
+ "hello", "hey", "yo", "sup",
+ "good morning", "good evening", "good night", "good afternoon",
+ "bye", "goodbye", "see you", "cya", "later", "ttyl",
+ "thanks", "thank you", "thx", "ty",
+ "ok", "okay", "k", "cool", "great", "awesome", "perfect",
+ "sounds good", "got it", "nice", "sure",
+ "how are you", "how are you?", "how are you doing today?",
+ "what's up",
+ "what day comes after Monday?",
+ "what day follows Friday?",
+ "what comes after summer?",
+ "what does NASA stand for?",
+ "what does AI stand for?",
+ # with trailing punctuation
+ "hi!", "hello.", "thanks!",
+ ])
+ def test_matches(self, text):
+ assert _LIGHT_PATTERNS.match(text.strip()), f"Expected light match for: {text!r}"
+
+ @pytest.mark.parametrize("text", [
+ "what is the capital of France",
+ "tell me about bitcoin",
+ "what is 2+2",
+ "write me a poem",
+ "search for news about the election",
+ "what did we talk about last time",
+ "what is my name",
+ "/think compare these frameworks",
+ "how do I install Python",
+ "explain machine learning",
+ "", # empty string doesn't match the pattern
+ ])
+ def test_no_match(self, text):
+ assert not _LIGHT_PATTERNS.match(text.strip()), f"Expected NO light match for: {text!r}"
+
+
+# ── _parse_tier ────────────────────────────────────────────────────────────────
+
+class TestParseTier:
+ @pytest.mark.parametrize("raw,expected", [
+ ("light", "light"),
+ ("Light", "light"),
+ ("LIGHT\n", "light"),
+ ("medium", "medium"),
+ ("Medium.", "medium"),
+ ("complex", "complex"),
+ ("Complex!", "complex"),
+ # descriptive words → light
+ ("simplefact", "light"),
+ ("trivial question", "light"),
+ ("basic", "light"),
+ ("easy answer", "light"),
+ ("general knowledge", "light"),
+ # unknown → medium
+ ("unknown_category", "medium"),
+ ("", "medium"),
+ ("I don't know", "medium"),
+ # complex only if 'complex' appears in first 60 chars
+ ("this is a complex query requiring search", "complex"),
+ # _parse_tier checks "complex" before "medium", so complex wins even if medium appears first
+ ("medium complexity, not complex", "complex"),
+ ])
+ def test_parse_tier(self, raw, expected):
+ assert _parse_tier(raw) == expected
+
+
+# ── _format_history ────────────────────────────────────────────────────────────
+
+class TestFormatHistory:
+ def test_empty(self):
+ assert _format_history([]) == "(none)"
+
+ def test_single_user_message(self):
+ history = [{"role": "user", "content": "hello there"}]
+ result = _format_history(history)
+ assert "user: hello there" in result
+
+ def test_multiple_turns(self):
+ history = [
+ {"role": "user", "content": "What is Python?"},
+ {"role": "assistant", "content": "Python is a programming language."},
+ ]
+ result = _format_history(history)
+ assert "user: What is Python?" in result
+ assert "assistant: Python is a programming language." in result
+
+ def test_truncates_long_content(self):
+ long_content = "x" * 300
+ history = [{"role": "user", "content": long_content}]
+ result = _format_history(history)
+ # content is truncated to 200 chars in _format_history
+ assert len(result) < 250
+
+ def test_missing_keys_handled(self):
+ # Should not raise — uses .get() with defaults
+ history = [{"role": "user"}] # no content key
+ result = _format_history(history)
+ assert "user:" in result
+
+
+# ── Router.route() ─────────────────────────────────────────────────────────────
+
+class TestRouterRoute:
+ def _make_router(self, classify_response: str, reply_response: str = "Sure!") -> Router:
+ """Return a Router with a mock model that returns given classification and reply."""
+ model = MagicMock()
+ classify_msg = MagicMock()
+ classify_msg.content = classify_response
+ reply_msg = MagicMock()
+ reply_msg.content = reply_response
+ # First ainvoke call → classification; second → reply
+ model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
+ return Router(model=model)
+
+ async def test_force_complex_bypasses_classification(self):
+ router = self._make_router("medium")
+ tier, reply = await router.route("some question", [], force_complex=True)
+ assert tier == "complex"
+ assert reply is None
+ # Model should NOT have been called
+ router.model.ainvoke.assert_not_called()
+
+ async def test_regex_light_skips_llm_classification(self):
+ # Regex match bypasses classification entirely; the only ainvoke call is the reply.
+ model = MagicMock()
+ reply_msg = MagicMock()
+ reply_msg.content = "I'm doing great!"
+ model.ainvoke = AsyncMock(return_value=reply_msg)
+ router = Router(model=model)
+ tier, reply = await router.route("how are you", [], force_complex=False)
+ assert tier == "light"
+ assert reply == "I'm doing great!"
+ # Exactly one model call — no classification step
+ assert router.model.ainvoke.call_count == 1
+
+ async def test_llm_classifies_medium(self):
+ router = self._make_router("medium")
+ tier, reply = await router.route("what is the bitcoin price?", [], force_complex=False)
+ assert tier == "medium"
+ assert reply is None
+
+ async def test_llm_classifies_light_generates_reply(self):
+ router = self._make_router("light", "Paris is the capital of France.")
+ tier, reply = await router.route("what is the capital of France?", [], force_complex=False)
+ assert tier == "light"
+ assert reply == "Paris is the capital of France."
+
+ async def test_llm_classifies_complex_downgraded_to_medium(self):
+ # Without /think prefix, complex classification → downgraded to medium
+ router = self._make_router("complex")
+ tier, reply = await router.route("compare React and Vue", [], force_complex=False)
+ assert tier == "medium"
+ assert reply is None
+
+ async def test_llm_error_falls_back_to_medium(self):
+ model = MagicMock()
+ model.ainvoke = AsyncMock(side_effect=Exception("connection error"))
+ router = Router(model=model)
+ tier, reply = await router.route("some question", [], force_complex=False)
+ assert tier == "medium"
+ assert reply is None
+
+ async def test_light_reply_empty_falls_back_to_medium(self):
+ """If the light reply comes back empty, router returns medium instead."""
+ router = self._make_router("light", "") # empty reply
+ tier, reply = await router.route("what is 2+2", [], force_complex=False)
+ assert tier == "medium"
+ assert reply is None
+
+ async def test_strips_think_tags_from_classification(self):
+ """Router strips ... from model output before parsing tier."""
+ model = MagicMock()
+ classify_msg = MagicMock()
+ classify_msg.content = "Hmm let me think...medium"
+ reply_msg = MagicMock()
+ reply_msg.content = "I'm fine!"
+ model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg])
+ router = Router(model=model)
+ tier, _ = await router.route("what is the news?", [], force_complex=False)
+ assert tier == "medium"
+
+ async def test_think_prefix_forces_complex(self):
+ """/think prefix is already stripped by agent.py; force_complex=True is passed."""
+ router = self._make_router("medium")
+ tier, reply = await router.route("analyse this", [], force_complex=True)
+ assert tier == "complex"
+ assert reply is None
diff --git a/tests/unit/test_vram_manager.py b/tests/unit/test_vram_manager.py
new file mode 100644
index 0000000..ae6df35
--- /dev/null
+++ b/tests/unit/test_vram_manager.py
@@ -0,0 +1,164 @@
+"""Unit tests for vram_manager.py — VRAMManager flush/poll/prewarm logic."""
+
+import asyncio
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from vram_manager import VRAMManager
+
+
+BASE_URL = "http://localhost:11434"
+
+
+def _make_manager() -> VRAMManager:
+ return VRAMManager(base_url=BASE_URL)
+
+
+def _mock_client(get_response=None, post_response=None):
+ """Return a context-manager mock for httpx.AsyncClient."""
+ client = AsyncMock()
+ client.__aenter__ = AsyncMock(return_value=client)
+ client.__aexit__ = AsyncMock(return_value=False)
+ if get_response is not None:
+ client.get = AsyncMock(return_value=get_response)
+ if post_response is not None:
+ client.post = AsyncMock(return_value=post_response)
+ return client
+
+
+# ── _flush ─────────────────────────────────────────────────────────────────────
+
+class TestFlush:
+ async def test_sends_keep_alive_zero(self):
+ client = _mock_client(post_response=MagicMock())
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ await mgr._flush("qwen3:4b")
+ client.post.assert_awaited_once()
+ _, kwargs = client.post.await_args
+ body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
+ assert body["model"] == "qwen3:4b"
+ assert body["keep_alive"] == 0
+
+ async def test_posts_to_correct_endpoint(self):
+ client = _mock_client(post_response=MagicMock())
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ await mgr._flush("qwen3:8b")
+ url = client.post.call_args[0][0]
+ assert url == f"{BASE_URL}/api/generate"
+
+ async def test_ignores_exceptions_silently(self):
+ client = AsyncMock()
+ client.__aenter__ = AsyncMock(return_value=client)
+ client.__aexit__ = AsyncMock(return_value=False)
+ client.post = AsyncMock(side_effect=Exception("connection refused"))
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ # Should not raise
+ await mgr._flush("qwen3:4b")
+
+
+# ── _prewarm ───────────────────────────────────────────────────────────────────
+
+class TestPrewarm:
+ async def test_sends_keep_alive_300(self):
+ client = _mock_client(post_response=MagicMock())
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ await mgr._prewarm("qwen3:4b")
+ _, kwargs = client.post.await_args
+ body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1]
+ assert body["keep_alive"] == 300
+ assert body["model"] == "qwen3:4b"
+
+ async def test_ignores_exceptions_silently(self):
+ client = AsyncMock()
+ client.__aenter__ = AsyncMock(return_value=client)
+ client.__aexit__ = AsyncMock(return_value=False)
+ client.post = AsyncMock(side_effect=Exception("timeout"))
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ await mgr._prewarm("qwen3:4b")
+
+
+# ── _poll_evicted ──────────────────────────────────────────────────────────────
+
+class TestPollEvicted:
+ async def test_returns_true_when_models_absent(self):
+ resp = MagicMock()
+ resp.json.return_value = {"models": [{"name": "some_other_model"}]}
+ client = _mock_client(get_response=resp)
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ result = await mgr._poll_evicted(["qwen3:4b", "qwen2.5:1.5b"], timeout=5)
+ assert result is True
+
+ async def test_returns_false_on_timeout_when_model_still_loaded(self):
+ resp = MagicMock()
+ resp.json.return_value = {"models": [{"name": "qwen3:4b"}]}
+ client = _mock_client(get_response=resp)
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.1)
+ assert result is False
+
+ async def test_returns_true_immediately_if_already_empty(self):
+ resp = MagicMock()
+ resp.json.return_value = {"models": []}
+ client = _mock_client(get_response=resp)
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ result = await mgr._poll_evicted(["qwen3:4b"], timeout=5)
+ assert result is True
+
+ async def test_handles_poll_error_and_continues(self):
+ """If /api/ps errors, polling continues until timeout."""
+ client = AsyncMock()
+ client.__aenter__ = AsyncMock(return_value=client)
+ client.__aexit__ = AsyncMock(return_value=False)
+ client.get = AsyncMock(side_effect=Exception("network error"))
+ with patch("vram_manager.httpx.AsyncClient", return_value=client):
+ mgr = _make_manager()
+ result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.2)
+ assert result is False
+
+
+# ── enter_complex_mode / exit_complex_mode ─────────────────────────────────────
+
+class TestComplexMode:
+ async def test_enter_complex_mode_returns_true_on_success(self):
+ mgr = _make_manager()
+ mgr._flush = AsyncMock()
+ mgr._poll_evicted = AsyncMock(return_value=True)
+ result = await mgr.enter_complex_mode()
+ assert result is True
+
+ async def test_enter_complex_mode_flushes_medium_models(self):
+ mgr = _make_manager()
+ mgr._flush = AsyncMock()
+ mgr._poll_evicted = AsyncMock(return_value=True)
+ await mgr.enter_complex_mode()
+ flushed = {call.args[0] for call in mgr._flush.call_args_list}
+ assert "qwen3:4b" in flushed
+ assert "qwen2.5:1.5b" in flushed
+
+ async def test_enter_complex_mode_returns_false_on_eviction_timeout(self):
+ mgr = _make_manager()
+ mgr._flush = AsyncMock()
+ mgr._poll_evicted = AsyncMock(return_value=False)
+ result = await mgr.enter_complex_mode()
+ assert result is False
+
+ async def test_exit_complex_mode_flushes_complex_and_prewarms_medium(self):
+ mgr = _make_manager()
+ mgr._flush = AsyncMock()
+ mgr._prewarm = AsyncMock()
+ await mgr.exit_complex_mode()
+ # Must flush 8b
+ flushed = {call.args[0] for call in mgr._flush.call_args_list}
+ assert "qwen3:8b" in flushed
+ # Must prewarm medium models
+ prewarmed = {call.args[0] for call in mgr._prewarm.call_args_list}
+ assert "qwen3:4b" in prewarmed
+ assert "qwen2.5:1.5b" in prewarmed