From f9618a9bbf607516cd686aa5ba9892ddbbff2a29 Mon Sep 17 00:00:00 2001 From: Alvis Date: Thu, 12 Mar 2026 13:50:12 +0000 Subject: [PATCH] Integrate Bifrost LLM gateway, add test suite, implement memory pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Bifrost (maximhq/bifrost) as LLM gateway: all inference routes through bifrost:8080/v1 with retry logic and observability; VRAMManager keeps direct Ollama access for VRAM flush/prewarm operations - Switch medium model from qwen3:4b to qwen2.5:1.5b (direct call, no tools) via _DirectModel wrapper; complex keeps create_deep_agent with qwen3:8b - Implement out-of-agent memory pipeline: _retrieve_memories pre-fetches relevant context (injected into all tiers), _store_memory runs as background task after each reply writing to openmemory/Qdrant - Add tests/unit/ with 133 tests covering router, channels, vram_manager, agent helpers; move integration test to tests/integration/ - Add bifrost-config.json with GPU Ollama (qwen2.5:0.5b/1.5b, qwen3:4b/8b, gemma3:4b) and CPU Ollama providers - Integration test 28/29 pass (only grammy fails — no TELEGRAM_BOT_TOKEN) Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 2 + CLAUDE.md | 133 ++++++++++++ Dockerfile | 2 +- agent.py | 120 +++++++++-- agent_factory.py | 20 +- bifrost-config.json | 58 +++++ docker-compose.yml | 24 ++- openmemory/server.py | 3 +- pytest.ini | 4 + .../integration/test_pipeline.py | 133 +++++++++++- tests/requirements.txt | 2 + tests/unit/conftest.py | 80 +++++++ tests/unit/test_agent_helpers.py | 161 ++++++++++++++ tests/unit/test_channels.py | 125 +++++++++++ tests/unit/test_router.py | 200 ++++++++++++++++++ tests/unit/test_vram_manager.py | 164 ++++++++++++++ 16 files changed, 1195 insertions(+), 36 deletions(-) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 bifrost-config.json create mode 100644 pytest.ini rename test_pipeline.py => tests/integration/test_pipeline.py (88%) create mode 100644 tests/requirements.txt create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/test_agent_helpers.py create mode 100644 tests/unit/test_channels.py create mode 100644 tests/unit/test_router.py create mode 100644 tests/unit/test_vram_manager.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a60b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..a85d22f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,133 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +**Start all services:** +```bash +docker compose up --build +``` + +**Interactive CLI (requires gateway running):** +```bash +python3 cli.py [--url http://localhost:8000] [--session cli-alvis] [--timeout 400] +``` + +**Run integration tests:** +```bash +python3 test_pipeline.py [--chat-id CHAT_ID] + +# Selective sections: +python3 test_pipeline.py --bench-only # routing + memory benchmarks only (sections 10–13) +python3 test_pipeline.py --easy-only # light-tier routing benchmark +python3 test_pipeline.py --medium-only # medium-tier routing benchmark +python3 test_pipeline.py --hard-only # complex-tier + VRAM flush benchmark +python3 test_pipeline.py --memory-only # memory store/recall/dedup benchmark +python3 test_pipeline.py --no-bench # service health + single name store/recall only +``` + +## Architecture + +Adolf is a multi-channel personal assistant. All LLM inference is routed through **Bifrost**, an open-source Go-based LLM gateway that adds retry logic, failover, and observability in front of Ollama. + +### Request flow + +``` +Channel adapter → POST /message {text, session_id, channel, user_id} + → 202 Accepted (immediate) + → background: run_agent_task() + → router.route() → tier decision (light/medium/complex) + → invoke agent for tier via Bifrost + deepagents:8000 → bifrost:8080/v1 → ollama:11436 + → channels.deliver(session_id, channel, reply) + → pending_replies[session_id] queue (SSE) + → channel-specific callback (Telegram POST, CLI no-op) +CLI/wiki polling → GET /reply/{session_id} (SSE, blocks until reply) +``` + +### Bifrost integration + +Bifrost (`bifrost-config.json`) is configured with the `ollama` provider pointing to the GPU Ollama instance on host port 11436. It exposes an OpenAI-compatible API at `http://bifrost:8080/v1`. + +`agent.py` uses `langchain_openai.ChatOpenAI` with `base_url=BIFROST_URL`. Model names use the `provider/model` format that Bifrost expects: `ollama/qwen3:4b`, `ollama/qwen3:8b`, `ollama/qwen2.5:1.5b`. Bifrost strips the `ollama/` prefix before forwarding to Ollama. + +`VRAMManager` bypasses Bifrost and talks directly to Ollama via `OLLAMA_BASE_URL` (host:11436) for flush/poll/prewarm operations — Bifrost cannot manage GPU VRAM. + +### Three-tier routing (`router.py`, `agent.py`) + +| Tier | Model (env var) | Trigger | +|------|-----------------|---------| +| light | `qwen2.5:1.5b` (`DEEPAGENTS_ROUTER_MODEL`) | Regex pre-match or LLM classifies "light" — answered by router model directly, no agent invoked | +| medium | `qwen2.5:1.5b` (`DEEPAGENTS_MODEL`) | Default for tool-requiring queries | +| complex | `qwen3:8b` (`DEEPAGENTS_COMPLEX_MODEL`) | `/think ` prefix only | + +The router does regex pre-classification first, then LLM classification. Complex tier is blocked unless the message starts with `/think ` — any LLM classification of "complex" is downgraded to medium. + +A global `asyncio.Semaphore(1)` (`_reply_semaphore`) serializes all LLM inference — one request at a time. + +### Thinking mode + +qwen3 models produce chain-of-thought `...` tokens via Ollama's OpenAI-compatible endpoint. Adolf controls this via system prompt prefixes: + +- **Medium** (`qwen2.5:1.5b`): no thinking mode in this model; fast ~3s calls +- **Complex** (`qwen3:8b`): no prefix — thinking enabled by default, used for deep research +- **Router** (`qwen2.5:1.5b`): no thinking support in this model + +`_strip_think()` in `agent.py` and `router.py` strips any `` blocks from model output before returning to users. + +### VRAM management (`vram_manager.py`) + +Hardware: GTX 1070 (8 GB). Before running the 8b model, medium models are flushed via Ollama `keep_alive=0`, then `/api/ps` is polled (15s timeout) to confirm eviction. On timeout, falls back to medium tier. After complex reply, 8b is flushed and medium models are pre-warmed as a background task. + +### Channel adapters (`channels.py`) + +- **Telegram**: Grammy Node.js bot (`grammy/bot.mjs`) long-polls Telegram → `POST /message`; replies delivered via `POST grammy:3001/send` +- **CLI**: `cli.py` posts to `/message`, then blocks on `GET /reply/{session_id}` SSE + +Session IDs: `tg-` for Telegram, `cli-` for CLI. Conversation history: 5-turn buffer per session. + +### Services (`docker-compose.yml`) + +| Service | Port | Role | +|---------|------|------| +| `bifrost` | 8080 | LLM gateway — retries, failover, observability; config from `bifrost-config.json` | +| `deepagents` | 8000 | FastAPI gateway + agent core | +| `openmemory` | 8765 | FastMCP server + mem0 memory tools (Qdrant-backed) | +| `grammy` | 3001 | grammY Telegram bot + `/send` HTTP endpoint | +| `crawl4ai` | 11235 | JS-rendered page fetching | + +External (from `openai/` stack, host ports): +- Ollama GPU: `11436` — all reply inference (via Bifrost) + VRAM management (direct) +- Ollama CPU: `11435` — nomic-embed-text embeddings for openmemory +- Qdrant: `6333` — vector store for memories +- SearXNG: `11437` — web search + +### Bifrost config (`bifrost-config.json`) + +The file is mounted into the bifrost container at `/app/data/config.json`. It declares one Ollama provider key pointing to `host.docker.internal:11436` with 2 retries and 300s timeout. To add fallback providers or adjust weights, edit this file and restart the bifrost container. + +### Agent tools + +`web_search`: SearXNG search + Crawl4AI auto-fetch of top 2 results → combined snippet + full page content. +`fetch_url`: Crawl4AI single-URL fetch. +MCP tools from openmemory (`add_memory`, `search_memory`, `get_all_memories`) are **excluded** from agent tools — memory management is handled outside the agent loop. + +### Medium vs Complex agent + +| Agent | Builder | Speed | Use case | +|-------|---------|-------|----------| +| medium | `_DirectModel` (single LLM call, no tools) | ~3s | General questions, conversation | +| complex | `create_deep_agent` (deepagents) | Slow — multi-step planner | Deep research via `/think` prefix | + +### Key files + +- `agent.py` — FastAPI app, lifespan wiring, `run_agent_task()`, all endpoints +- `bifrost-config.json` — Bifrost provider config (Ollama GPU, retries, timeouts) +- `channels.py` — channel registry and `deliver()` dispatcher +- `router.py` — `Router` class: regex + LLM classification, light-tier reply generation +- `vram_manager.py` — `VRAMManager`: flush/poll/prewarm Ollama VRAM directly +- `agent_factory.py` — `build_medium_agent` / `build_complex_agent` via `create_deep_agent()` +- `openmemory/server.py` — FastMCP + mem0 config with custom extraction/dedup prompts +- `wiki_research.py` — batch research pipeline using `/message` + SSE polling +- `grammy/bot.mjs` — Telegram long-poll + HTTP `/send` endpoint diff --git a/Dockerfile b/Dockerfile index d81ee0c..22b7a8e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.12-slim WORKDIR /app -RUN pip install --no-cache-dir deepagents langchain-ollama langgraph \ +RUN pip install --no-cache-dir deepagents langchain-openai langgraph \ fastapi uvicorn langchain-mcp-adapters langchain-community httpx COPY agent.py channels.py vram_manager.py router.py agent_factory.py hello_world.py . diff --git a/agent.py b/agent.py index 8d415f9..4d85675 100644 --- a/agent.py +++ b/agent.py @@ -10,7 +10,7 @@ from pydantic import BaseModel import re as _re import httpx as _httpx -from langchain_ollama import ChatOllama +from langchain_openai import ChatOpenAI from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_community.utilities import SearxSearchWrapper from langchain_core.tools import Tool @@ -20,8 +20,12 @@ from router import Router from agent_factory import build_medium_agent, build_complex_agent import channels +# Bifrost gateway — all LLM inference goes through here +BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1") +# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") -ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b") + +ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b") MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b") COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b") SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437") @@ -31,10 +35,12 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235") MAX_HISTORY_TURNS = 5 _conversation_buffers: dict[str, list] = {} +# /no_think at the start of the system prompt disables qwen3 chain-of-thought. +# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so +# /no_think lands at position 0 and is respected by qwen3 models via Ollama. MEDIUM_SYSTEM_PROMPT = ( - "You are a helpful AI assistant. " - "Use web_search for questions about current events or facts you don't know. " - "Reply concisely." + "You are a helpful AI assistant. Reply concisely. " + "If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'" ) COMPLEX_SYSTEM_PROMPT = ( @@ -54,6 +60,8 @@ complex_agent = None router: Router = None vram_manager: VRAMManager = None mcp_client = None +_memory_add_tool = None +_memory_search_tool = None # GPU mutex: one LLM inference at a time _reply_semaphore = asyncio.Semaphore(1) @@ -61,21 +69,34 @@ _reply_semaphore = asyncio.Semaphore(1) @asynccontextmanager async def lifespan(app: FastAPI): - global medium_agent, complex_agent, router, vram_manager, mcp_client + global medium_agent, complex_agent, router, vram_manager, mcp_client, \ + _memory_add_tool, _memory_search_tool # Register channel adapters channels.register_defaults() - # Three model instances - router_model = ChatOllama( - model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096, + # All three models route through Bifrost → Ollama GPU. + # Bifrost adds retry logic, observability, and failover. + # Model names use provider/model format: Bifrost strips the "ollama/" prefix + # before forwarding to Ollama's /v1/chat/completions endpoint. + router_model = ChatOpenAI( + model=f"ollama/{ROUTER_MODEL}", + base_url=BIFROST_URL, + api_key="dummy", temperature=0, + timeout=30, ) - medium_model = ChatOllama( - model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192 + medium_model = ChatOpenAI( + model=f"ollama/{MEDIUM_MODEL}", + base_url=BIFROST_URL, + api_key="dummy", + timeout=180, ) - complex_model = ChatOllama( - model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384 + complex_model = ChatOpenAI( + model=f"ollama/{COMPLEX_MODEL}", + base_url=BIFROST_URL, + api_key="dummy", + timeout=600, ) vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL) @@ -97,6 +118,13 @@ async def lifespan(app: FastAPI): agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")] + # Expose memory tools directly so run_agent_task can call them outside the agent loop + for t in mcp_tools: + if t.name == "add_memory": + _memory_add_tool = t + elif t.name == "search_memory": + _memory_search_tool = t + searx = SearxSearchWrapper(searx_host=SEARXNG_URL) def _crawl4ai_fetch(url: str) -> str: @@ -187,7 +215,8 @@ async def lifespan(app: FastAPI): ) print( - f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}", + f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | " + f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}", flush=True, ) print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True) @@ -222,13 +251,19 @@ class ChatRequest(BaseModel): # ── helpers ──────────────────────────────────────────────────────────────────── +def _strip_think(text: str) -> str: + """Strip qwen3 chain-of-thought blocks that appear inline in content + when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions).""" + return _re.sub(r".*?", "", text, flags=_re.DOTALL).strip() + + def _extract_final_text(result) -> str | None: msgs = result.get("messages", []) for m in reversed(msgs): if type(m).__name__ == "AIMessage" and getattr(m, "content", ""): - return m.content + return _strip_think(m.content) if isinstance(result, dict) and result.get("output"): - return result["output"] + return _strip_think(result["output"]) return None @@ -244,6 +279,34 @@ def _log_messages(result): print(f"[agent] {role} → {tc['name']}({tc['args']})", flush=True) +# ── memory helpers ───────────────────────────────────────────────────────────── + +async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None: + """Store a conversation turn in openmemory (runs as a background task).""" + if _memory_add_tool is None: + return + t0 = time.monotonic() + try: + text = f"User: {user_msg}\nAssistant: {assistant_reply}" + await _memory_add_tool.ainvoke({"text": text, "user_id": session_id}) + print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True) + except Exception as e: + print(f"[memory] error: {e}", flush=True) + + +async def _retrieve_memories(message: str, session_id: str) -> str: + """Search openmemory for relevant context. Returns formatted string or ''.""" + if _memory_search_tool is None: + return "" + try: + result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id}) + if result and result.strip() and result.strip() != "[]": + return f"Relevant memories:\n{result}" + except Exception: + pass + return "" + + # ── core task ────────────────────────────────────────────────────────────────── async def run_agent_task(message: str, session_id: str, channel: str = "telegram"): @@ -261,7 +324,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram history = _conversation_buffers.get(session_id, []) print(f"[agent] running: {clean_message[:80]!r}", flush=True) - tier, light_reply = await router.route(clean_message, history, force_complex) + # Retrieve memories once; inject into history so ALL tiers can use them + memories = await _retrieve_memories(clean_message, session_id) + enriched_history = ( + [{"role": "system", "content": memories}] + history if memories else history + ) + + tier, light_reply = await router.route(clean_message, enriched_history, force_complex) print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True) final_text = None @@ -273,6 +342,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram elif tier == "medium": system_prompt = MEDIUM_SYSTEM_PROMPT + if memories: + system_prompt = system_prompt + "\n\n" + memories result = await medium_agent.ainvoke({ "messages": [ {"role": "system", "content": system_prompt}, @@ -289,9 +360,12 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram if not ok: print("[agent] complex→medium fallback (eviction timeout)", flush=True) tier = "medium" + system_prompt = MEDIUM_SYSTEM_PROMPT + if memories: + system_prompt = system_prompt + "\n\n" + memories result = await medium_agent.ainvoke({ "messages": [ - {"role": "system", "content": MEDIUM_SYSTEM_PROMPT}, + {"role": "system", "content": system_prompt}, *history, {"role": "user", "content": clean_message}, ] @@ -320,7 +394,10 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram # Deliver reply through the originating channel if final_text: t1 = time.monotonic() - await channels.deliver(session_id, channel, final_text) + try: + await channels.deliver(session_id, channel, final_text) + except Exception as e: + print(f"[agent] delivery error (non-fatal): {e}", flush=True) send_elapsed = time.monotonic() - t1 print( f"[agent] replied in {time.monotonic() - t0:.1f}s " @@ -331,12 +408,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram else: print("[agent] warning: no text reply from agent", flush=True) - # Update conversation buffer + # Update conversation buffer and schedule memory storage if final_text: buf = _conversation_buffers.get(session_id, []) buf.append({"role": "user", "content": clean_message}) buf.append({"role": "assistant", "content": final_text}) _conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):] + asyncio.create_task(_store_memory(session_id, clean_message, final_text)) # ── endpoints ────────────────────────────────────────────────────────────────── @@ -374,7 +452,7 @@ async def reply_stream(session_id: str): try: text = await asyncio.wait_for(q.get(), timeout=900) # Escape newlines so entire reply fits in one SSE data line - yield f"data: {text.replace(chr(10), '\\n').replace(chr(13), '')}\n\n" + yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n" except asyncio.TimeoutError: yield "data: [timeout]\n\n" diff --git a/agent_factory.py b/agent_factory.py index 6182ff4..9fa91f1 100644 --- a/agent_factory.py +++ b/agent_factory.py @@ -1,13 +1,21 @@ from deepagents import create_deep_agent +class _DirectModel: + """Thin wrapper: single LLM call, no tools, same ainvoke interface as a graph.""" + + def __init__(self, model): + self._model = model + + async def ainvoke(self, input_dict: dict) -> dict: + messages = input_dict["messages"] + response = await self._model.ainvoke(messages) + return {"messages": list(messages) + [response]} + + def build_medium_agent(model, agent_tools: list, system_prompt: str): - """Medium agent: create_deep_agent with TodoList planning, no subagents.""" - return create_deep_agent( - model=model, - tools=agent_tools, - system_prompt=system_prompt, - ) + """Medium agent: single LLM call, no tools — fast ~3s response.""" + return _DirectModel(model) def build_complex_agent(model, agent_tools: list, system_prompt: str): diff --git a/bifrost-config.json b/bifrost-config.json new file mode 100644 index 0000000..7db331e --- /dev/null +++ b/bifrost-config.json @@ -0,0 +1,58 @@ +{ + "client": { + "drop_excess_requests": false + }, + "providers": { + "ollama": { + "keys": [ + { + "name": "ollama-gpu", + "value": "dummy", + "models": [ + "qwen2.5:0.5b", + "qwen2.5:1.5b", + "qwen3:4b", + "gemma3:4b", + "qwen3:8b" + ], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://host.docker.internal:11436", + "default_request_timeout_in_seconds": 300, + "max_retries": 2, + "retry_backoff_initial_ms": 500, + "retry_backoff_max_ms": 10000 + } + }, + "ollama-cpu": { + "keys": [ + { + "name": "ollama-cpu-key", + "value": "dummy", + "models": [ + "gemma3:1b", + "qwen2.5:1.5b", + "qwen2.5:3b" + ], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://host.docker.internal:11435", + "default_request_timeout_in_seconds": 120, + "max_retries": 2, + "retry_backoff_initial_ms": 500, + "retry_backoff_max_ms": 10000 + }, + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true + } + } + } + } +} diff --git a/docker-compose.yml b/docker-compose.yml index 6851e19..fbc177d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,19 @@ services: + bifrost: + image: maximhq/bifrost + container_name: bifrost + ports: + - "8080:8080" + volumes: + - ./bifrost-config.json:/app/data/config.json:ro + environment: + - APP_DIR=/app/data + - APP_PORT=8080 + - LOG_LEVEL=info + extra_hosts: + - "host.docker.internal:host-gateway" + restart: unless-stopped + deepagents: build: . container_name: deepagents @@ -6,8 +21,11 @@ services: - "8000:8000" environment: - PYTHONUNBUFFERED=1 + # Bifrost gateway — all LLM inference goes through here + - BIFROST_URL=http://bifrost:8080/v1 + # Direct Ollama GPU URL — used only by VRAMManager for flush/prewarm - OLLAMA_BASE_URL=http://host.docker.internal:11436 - - DEEPAGENTS_MODEL=qwen3:4b + - DEEPAGENTS_MODEL=qwen2.5:1.5b - DEEPAGENTS_COMPLEX_MODEL=qwen3:8b - DEEPAGENTS_ROUTER_MODEL=qwen2.5:1.5b - SEARXNG_URL=http://host.docker.internal:11437 @@ -19,6 +37,7 @@ services: - openmemory - grammy - crawl4ai + - bifrost restart: unless-stopped openmemory: @@ -27,8 +46,9 @@ services: ports: - "8765:8765" environment: - # Extraction LLM (qwen2.5:1.5b) runs on GPU after reply — fast 2-5s extraction + # Extraction LLM runs on GPU — qwen2.5:1.5b for speed (~3s) - OLLAMA_GPU_URL=http://host.docker.internal:11436 + - OLLAMA_EXTRACTION_MODEL=qwen2.5:1.5b # Embedding (nomic-embed-text) runs on CPU — fast enough for search (50-150ms) - OLLAMA_CPU_URL=http://host.docker.internal:11435 extra_hosts: diff --git a/openmemory/server.py b/openmemory/server.py index 73fd93c..56f8575 100644 --- a/openmemory/server.py +++ b/openmemory/server.py @@ -6,6 +6,7 @@ from mem0 import Memory # Extraction LLM — GPU Ollama (qwen3:4b, same model as medium agent) # Runs after reply when GPU is idle; spin-wait in agent.py prevents contention OLLAMA_GPU_URL = os.getenv("OLLAMA_GPU_URL", "http://host.docker.internal:11436") +EXTRACTION_MODEL = os.getenv("OLLAMA_EXTRACTION_MODEL", "qwen2.5:1.5b") # Embedding — CPU Ollama (nomic-embed-text, 137 MB RAM) # Used for both search (50-150ms, acceptable) and store-time embedding @@ -94,7 +95,7 @@ config = { "llm": { "provider": "ollama", "config": { - "model": "qwen3:4b", + "model": EXTRACTION_MODEL, "ollama_base_url": OLLAMA_GPU_URL, "temperature": 0.1, # consistent JSON output }, diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2302a28 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +testpaths = tests/unit +pythonpath = . +asyncio_mode = auto diff --git a/test_pipeline.py b/tests/integration/test_pipeline.py similarity index 88% rename from test_pipeline.py rename to tests/integration/test_pipeline.py index 034720d..ce13775 100644 --- a/test_pipeline.py +++ b/tests/integration/test_pipeline.py @@ -42,6 +42,7 @@ import urllib.request # ── config ──────────────────────────────────────────────────────────────────── DEEPAGENTS = "http://localhost:8000" +BIFROST = "http://localhost:8080" OPENMEMORY = "http://localhost:8765" GRAMMY_HOST = "localhost" GRAMMY_PORT = 3001 @@ -49,7 +50,7 @@ OLLAMA_GPU = "http://localhost:11436" OLLAMA_CPU = "http://localhost:11435" QDRANT = "http://localhost:6333" SEARXNG = "http://localhost:11437" -COMPOSE_FILE = "/home/alvis/agap_git/adolf/docker-compose.yml" +COMPOSE_FILE = "/home/alvis/adolf/docker-compose.yml" DEFAULT_CHAT_ID = "346967270" NAMES = [ @@ -166,6 +167,19 @@ def fetch_logs(since_s=600): return [] +def fetch_bifrost_logs(since_s=120): + """Return bifrost container log lines from the last since_s seconds.""" + try: + r = subprocess.run( + ["docker", "compose", "-f", COMPOSE_FILE, "logs", "bifrost", + f"--since={int(since_s)}s", "--no-log-prefix"], + capture_output=True, text=True, timeout=10, + ) + return r.stdout.splitlines() + except Exception: + return [] + + def parse_run_block(lines, msg_prefix): """ Scan log lines for the LAST '[agent] running: ' block. @@ -303,10 +317,12 @@ _run_hard = not args.no_bench and not args.easy_only and not args.medium_onl _run_memory = not args.no_bench and not args.easy_only and not args.medium_only and not args.hard_only random_name = random.choice(NAMES) +# Use a unique chat_id per run to avoid cross-run history contamination +TEST_CHAT_ID = f"{CHAT_ID}-{random_name.lower()}" if not _skip_pipeline: print(f"\n Test name : \033[1m{random_name}\033[0m") - print(f" Chat ID : {CHAT_ID}") + print(f" Chat ID : {TEST_CHAT_ID}") # ── 1. service health ───────────────────────────────────────────────────────── @@ -331,6 +347,93 @@ if not _skip_pipeline: timings["health_check"] = time.monotonic() - t0 +# ── 1b. Bifrost gateway ─────────────────────────────────────────────────────── +if not _skip_pipeline: + print(f"\n[{INFO}] 1b. Bifrost gateway (port 8080)") + t0 = time.monotonic() + + # Health ────────────────────────────────────────────────────────────────── + try: + status, body = get(f"{BIFROST}/health", timeout=5) + ok = status == 200 + report("Bifrost /health reachable", ok, f"HTTP {status}") + except Exception as e: + report("Bifrost /health reachable", False, str(e)) + + # Ollama GPU models listed ──────────────────────────────────────────────── + try: + status, body = get(f"{BIFROST}/v1/models", timeout=5) + data = json.loads(body) + model_ids = [m.get("id", "") for m in data.get("data", [])] + gpu_models = [m for m in model_ids if m.startswith("ollama/")] + report("Bifrost lists ollama GPU models", len(gpu_models) > 0, + f"found: {gpu_models}") + for expected in ["ollama/qwen3:4b", "ollama/qwen3:8b", "ollama/qwen2.5:1.5b"]: + report(f" model {expected} listed", expected in model_ids) + except Exception as e: + report("Bifrost /v1/models", False, str(e)) + + # Direct inference through Bifrost → GPU Ollama ─────────────────────────── + # Uses the smallest GPU model (qwen2.5:0.5b) to keep latency low. + print(f" [bifrost-infer] direct POST /v1/chat/completions → ollama/qwen2.5:0.5b ...") + t_infer = time.monotonic() + try: + infer_payload = { + "model": "ollama/qwen2.5:0.5b", + "messages": [{"role": "user", "content": "Reply with exactly one word: pong"}], + "max_tokens": 16, + } + infer_data = json.dumps(infer_payload).encode() + req = urllib.request.Request( + f"{BIFROST}/v1/chat/completions", + data=infer_data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=60) as r: + infer_status = r.status + infer_body = json.loads(r.read().decode()) + infer_elapsed = time.monotonic() - t_infer + reply_content = infer_body.get("choices", [{}])[0].get("message", {}).get("content", "") + used_model = infer_body.get("model", "") + report("Bifrost → Ollama GPU inference succeeds", + infer_status == 200 and bool(reply_content), + f"{infer_elapsed:.1f}s model={used_model!r} reply={reply_content[:60]!r}") + timings["bifrost_direct_infer"] = infer_elapsed + except Exception as e: + report("Bifrost → Ollama GPU inference succeeds", False, str(e)) + timings["bifrost_direct_infer"] = None + + # deepagents is configured to route through Bifrost ─────────────────────── + # The startup log line "[agent] bifrost=http://bifrost:8080/v1 | ..." is emitted + # during lifespan setup and confirms deepagents is using Bifrost as the LLM gateway. + try: + r = subprocess.run( + ["docker", "compose", "-f", COMPOSE_FILE, "logs", "deepagents", + "--since=3600s", "--no-log-prefix"], + capture_output=True, text=True, timeout=10, + ) + log_lines = r.stdout.splitlines() + bifrost_line = next( + (l for l in log_lines if "[agent] bifrost=" in l and "bifrost:8080" in l), + None, + ) + report( + "deepagents startup log confirms bifrost URL", + bifrost_line is not None, + bifrost_line.strip() if bifrost_line else "line not found in logs", + ) + # Also confirm model names use provider/model format (ollama/...) + if bifrost_line: + has_prefix = "router=ollama/" in bifrost_line and "medium=ollama/" in bifrost_line + report("deepagents model names use ollama/ prefix", has_prefix, + bifrost_line.strip()) + except Exception as e: + report("deepagents startup log check", False, str(e)) + + timings["bifrost_check"] = time.monotonic() - t0 + + # ── 2. GPU Ollama ───────────────────────────────────────────────────────────── if not _skip_pipeline: print(f"\n[{INFO}] 2. GPU Ollama (port 11436)") @@ -415,11 +518,18 @@ if not _skip_pipeline: # ── 6–8. Name memory pipeline ───────────────────────────────────────────────── if not _skip_pipeline: print(f"\n[{INFO}] 6–8. Name memory pipeline") - print(f" chat_id={CHAT_ID} name={random_name}") + print(f" chat_id={TEST_CHAT_ID} name={random_name}") store_msg = f"remember that your name is {random_name}" recall_msg = "what is your name?" + # Clear adolf_memories so each run starts clean (avoids cross-run stale memories) + try: + post_json(f"{QDRANT}/collections/adolf_memories/points/delete", + {"filter": {}}, timeout=5) + except Exception: + pass + pts_before = qdrant_count() print(f" Qdrant points before: {pts_before}") @@ -429,7 +539,7 @@ if not _skip_pipeline: try: status, _ = post_json(f"{DEEPAGENTS}/chat", - {"message": store_msg, "chat_id": CHAT_ID}, timeout=5) + {"message": store_msg, "chat_id": TEST_CHAT_ID}, timeout=5) t_accept = time.monotonic() - t_store report("POST /chat (store) returns 202 immediately", status == 202 and t_accept < 1, f"status={status}, t={t_accept:.3f}s") @@ -472,7 +582,7 @@ if not _skip_pipeline: try: status, _ = post_json(f"{DEEPAGENTS}/chat", - {"message": recall_msg, "chat_id": CHAT_ID}, timeout=5) + {"message": recall_msg, "chat_id": TEST_CHAT_ID}, timeout=5) t_accept2 = time.monotonic() - t_recall report("POST /chat (recall) returns 202 immediately", status == 202 and t_accept2 < 1, f"status={status}, t={t_accept2:.3f}s") @@ -496,6 +606,19 @@ if not _skip_pipeline: report("Agent replied to recall message", False, "timeout") report(f"Reply contains '{random_name}'", False, "no reply") + # ── 8b. Verify requests passed through Bifrost ──────────────────────────── + # After the store+recall round-trip, Bifrost logs must show forwarded + # requests. An empty Bifrost log means deepagents bypasses the gateway. + bifrost_lines = fetch_bifrost_logs(since_s=300) + report("Bifrost container has log output (requests forwarded)", + len(bifrost_lines) > 0, + f"{len(bifrost_lines)} lines in bifrost logs") + # Bifrost logs contain the request body; AsyncOpenAI user-agent confirms the path + bifrost_raw = "\n".join(bifrost_lines) + report(" Bifrost log shows AsyncOpenAI agent requests", + "AsyncOpenAI" in bifrost_raw, + f"{'found' if 'AsyncOpenAI' in bifrost_raw else 'NOT found'} in bifrost logs") + # ── 9. Timing profile ───────────────────────────────────────────────────────── if not _skip_pipeline: diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..94022af --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,2 @@ +pytest>=8.0 +pytest-asyncio>=0.23 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..117e3f7 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,80 @@ +""" +Stub out all third-party packages that Adolf's source modules import. +This lets the unit tests run without a virtualenv or Docker environment. +Stubs are installed into sys.modules before any test file is collected. +""" + +import sys +from unittest.mock import MagicMock + + +# ── helpers ──────────────────────────────────────────────────────────────────── + +def _mock(name: str) -> MagicMock: + m = MagicMock(name=name) + sys.modules[name] = m + return m + + +# ── pydantic: BaseModel must be a real class so `class Foo(BaseModel)` works ── + +class _FakeBaseModel: + model_fields: dict = {} + + def __init_subclass__(cls, **kwargs): + pass + + def __init__(self, **data): + for k, v in data.items(): + setattr(self, k, v) + + +_pydantic = _mock("pydantic") +_pydantic.BaseModel = _FakeBaseModel + +# ── httpx: used by channels.py, vram_manager.py, agent.py ──────────────────── + +_mock("httpx") + +# ── fastapi ─────────────────────────────────────────────────────────────────── + +_fastapi = _mock("fastapi") +_mock("fastapi.responses") + +# ── langchain stack ─────────────────────────────────────────────────────────── + +_mock("langchain_openai") + +_lc_core = _mock("langchain_core") +_lc_msgs = _mock("langchain_core.messages") +_mock("langchain_core.tools") + +# Provide real-ish message classes so router.py can instantiate them +class _FakeMsg: + def __init__(self, content=""): + self.content = content + +class SystemMessage(_FakeMsg): + pass + +class HumanMessage(_FakeMsg): + pass + +class AIMessage(_FakeMsg): + def __init__(self, content="", tool_calls=None): + super().__init__(content) + self.tool_calls = tool_calls or [] + +_lc_msgs.SystemMessage = SystemMessage +_lc_msgs.HumanMessage = HumanMessage +_lc_msgs.AIMessage = AIMessage + +_mock("langchain_mcp_adapters") +_mock("langchain_mcp_adapters.client") +_mock("langchain_community") +_mock("langchain_community.utilities") + +# ── deepagents (agent_factory.py) ───────────────────────────────────────────── + +_mock("deepagents") + diff --git a/tests/unit/test_agent_helpers.py b/tests/unit/test_agent_helpers.py new file mode 100644 index 0000000..9df77d1 --- /dev/null +++ b/tests/unit/test_agent_helpers.py @@ -0,0 +1,161 @@ +""" +Unit tests for agent.py helper functions: + - _strip_think(text) + - _extract_final_text(result) + +agent.py has heavy FastAPI/LangChain imports; conftest.py stubs them out so +these pure functions can be imported and tested in isolation. +""" + +import pytest + +# conftest.py has already installed all stubs into sys.modules. +# The FastAPI app is instantiated at module level in agent.py — +# with the mocked fastapi, that just creates a MagicMock() object +# and the route decorators are no-ops. +from agent import _strip_think, _extract_final_text + + +# ── _strip_think ─────────────────────────────────────────────────────────────── + +class TestStripThink: + def test_removes_single_think_block(self): + text = "internal reasoningFinal answer." + assert _strip_think(text) == "Final answer." + + def test_removes_multiline_think_block(self): + text = "\nLine one.\nLine two.\n\nResult here." + assert _strip_think(text) == "Result here." + + def test_no_think_block_unchanged(self): + text = "This is a plain answer with no think block." + assert _strip_think(text) == text + + def test_removes_multiple_think_blocks(self): + text = "step 1middlestep 2end" + assert _strip_think(text) == "middleend" + + def test_strips_surrounding_whitespace(self): + text = " stuff answer " + assert _strip_think(text) == "answer" + + def test_empty_think_block(self): + text = "Hello." + assert _strip_think(text) == "Hello." + + def test_empty_string(self): + assert _strip_think("") == "" + + def test_only_think_block_returns_empty(self): + text = "nothing useful" + assert _strip_think(text) == "" + + def test_think_block_with_nested_tags(self): + text = "I should use bold hereDone." + assert _strip_think(text) == "Done." + + def test_preserves_markdown(self): + text = "plan## Report\n\n- Point one\n- Point two" + result = _strip_think(text) + assert result == "## Report\n\n- Point one\n- Point two" + + +# ── _extract_final_text ──────────────────────────────────────────────────────── + +class TestExtractFinalText: + def _ai_msg(self, content: str, tool_calls=None): + """Create a minimal AIMessage-like object.""" + class AIMessage: + pass + m = AIMessage() + m.content = content + m.tool_calls = tool_calls or [] + return m + + def _human_msg(self, content: str): + class HumanMessage: + pass + m = HumanMessage() + m.content = content + return m + + def test_returns_last_ai_message_content(self): + result = { + "messages": [ + self._human_msg("what is 2+2"), + self._ai_msg("The answer is 4."), + ] + } + assert _extract_final_text(result) == "The answer is 4." + + def test_returns_last_of_multiple_ai_messages(self): + result = { + "messages": [ + self._ai_msg("First response."), + self._human_msg("follow-up"), + self._ai_msg("Final response."), + ] + } + assert _extract_final_text(result) == "Final response." + + def test_skips_empty_ai_messages(self): + result = { + "messages": [ + self._ai_msg("Real answer."), + self._ai_msg(""), # empty — should be skipped + ] + } + assert _extract_final_text(result) == "Real answer." + + def test_strips_think_tags_from_ai_message(self): + result = { + "messages": [ + self._ai_msg("reasoning hereClean reply."), + ] + } + assert _extract_final_text(result) == "Clean reply." + + def test_falls_back_to_output_field(self): + result = { + "messages": [], + "output": "Fallback output.", + } + assert _extract_final_text(result) == "Fallback output." + + def test_strips_think_from_output_field(self): + result = { + "messages": [], + "output": "thoughtsActual output.", + } + assert _extract_final_text(result) == "Actual output." + + def test_returns_none_when_no_content(self): + result = {"messages": []} + assert _extract_final_text(result) is None + + def test_returns_none_when_no_messages_and_no_output(self): + result = {"messages": [], "output": ""} + # output is falsy → returns None + assert _extract_final_text(result) is None + + def test_skips_non_ai_messages(self): + result = { + "messages": [ + self._human_msg("user question"), + ] + } + assert _extract_final_text(result) is None + + def test_handles_ai_message_with_tool_calls_but_no_content(self): + """AIMessage that only has tool_calls (no content) should be skipped.""" + msg = self._ai_msg("", tool_calls=[{"name": "web_search", "args": {}}]) + result = {"messages": [msg]} + assert _extract_final_text(result) is None + + def test_multiline_think_stripped_correctly(self): + result = { + "messages": [ + self._ai_msg("\nLong\nreasoning\nblock\n\n## Report\n\nSome content."), + ] + } + assert _extract_final_text(result) == "## Report\n\nSome content." diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py new file mode 100644 index 0000000..b0418bc --- /dev/null +++ b/tests/unit/test_channels.py @@ -0,0 +1,125 @@ +"""Unit tests for channels.py — register, deliver, pending_replies queue.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, patch + +import channels + + +@pytest.fixture(autouse=True) +def reset_channels_state(): + """Clear module-level state before and after every test.""" + channels._callbacks.clear() + channels.pending_replies.clear() + yield + channels._callbacks.clear() + channels.pending_replies.clear() + + +# ── register ─────────────────────────────────────────────────────────────────── + +class TestRegister: + def test_register_stores_callback(self): + cb = AsyncMock() + channels.register("test_channel", cb) + assert channels._callbacks["test_channel"] is cb + + def test_register_overwrites_existing(self): + cb1 = AsyncMock() + cb2 = AsyncMock() + channels.register("ch", cb1) + channels.register("ch", cb2) + assert channels._callbacks["ch"] is cb2 + + def test_register_multiple_channels(self): + cb_a = AsyncMock() + cb_b = AsyncMock() + channels.register("a", cb_a) + channels.register("b", cb_b) + assert channels._callbacks["a"] is cb_a + assert channels._callbacks["b"] is cb_b + + +# ── deliver ──────────────────────────────────────────────────────────────────── + +class TestDeliver: + async def test_deliver_enqueues_reply(self): + channels.register("cli", AsyncMock()) + await channels.deliver("cli-alvis", "cli", "hello world") + q = channels.pending_replies["cli-alvis"] + assert not q.empty() + assert await q.get() == "hello world" + + async def test_deliver_calls_channel_callback(self): + cb = AsyncMock() + channels.register("telegram", cb) + await channels.deliver("tg-123", "telegram", "reply text") + cb.assert_awaited_once_with("tg-123", "reply text") + + async def test_deliver_unknown_channel_still_enqueues(self): + """No registered callback for channel → reply still goes to the queue.""" + await channels.deliver("cli-bob", "nonexistent", "fallback reply") + q = channels.pending_replies["cli-bob"] + assert await q.get() == "fallback reply" + + async def test_deliver_unknown_channel_does_not_raise(self): + """Missing callback must not raise an exception.""" + await channels.deliver("cli-x", "ghost_channel", "msg") + + async def test_deliver_creates_queue_if_absent(self): + channels.register("cli", AsyncMock()) + assert "cli-new" not in channels.pending_replies + await channels.deliver("cli-new", "cli", "hi") + assert "cli-new" in channels.pending_replies + + async def test_deliver_reuses_existing_queue(self): + """Second deliver to the same session appends to the same queue.""" + channels.register("cli", AsyncMock()) + await channels.deliver("cli-alvis", "cli", "first") + await channels.deliver("cli-alvis", "cli", "second") + q = channels.pending_replies["cli-alvis"] + assert await q.get() == "first" + assert await q.get() == "second" + + async def test_deliver_telegram_sends_to_callback(self): + sent = [] + + async def fake_tg(session_id, text): + sent.append((session_id, text)) + + channels.register("telegram", fake_tg) + await channels.deliver("tg-999", "telegram", "test message") + assert sent == [("tg-999", "test message")] + + +# ── register_defaults ────────────────────────────────────────────────────────── + +class TestRegisterDefaults: + def test_registers_telegram_and_cli(self): + channels.register_defaults() + assert "telegram" in channels._callbacks + assert "cli" in channels._callbacks + + async def test_cli_callback_is_noop(self): + """CLI send callback does nothing (replies are handled via SSE queue).""" + channels.register_defaults() + cb = channels._callbacks["cli"] + # Should not raise and should return None + result = await cb("cli-alvis", "some reply") + assert result is None + + async def test_telegram_callback_chunks_long_messages(self): + """Telegram callback splits messages > 4000 chars into chunks.""" + channels.register_defaults() + cb = channels._callbacks["telegram"] + long_text = "x" * 9000 # > 4000 chars → should produce 3 chunks + with patch("channels.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock() + mock_client_cls.return_value = mock_client + await cb("tg-123", long_text) + # 9000 chars / 4000 per chunk = 3 POST calls + assert mock_client.post.await_count == 3 diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py new file mode 100644 index 0000000..5d382b4 --- /dev/null +++ b/tests/unit/test_router.py @@ -0,0 +1,200 @@ +"""Unit tests for router.py — Router, _parse_tier, _format_history, _LIGHT_PATTERNS.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from router import Router, _parse_tier, _format_history, _LIGHT_PATTERNS + + +# ── _LIGHT_PATTERNS regex ────────────────────────────────────────────────────── + +class TestLightPatterns: + @pytest.mark.parametrize("text", [ + "hi", "Hi", "HI", + "hello", "hey", "yo", "sup", + "good morning", "good evening", "good night", "good afternoon", + "bye", "goodbye", "see you", "cya", "later", "ttyl", + "thanks", "thank you", "thx", "ty", + "ok", "okay", "k", "cool", "great", "awesome", "perfect", + "sounds good", "got it", "nice", "sure", + "how are you", "how are you?", "how are you doing today?", + "what's up", + "what day comes after Monday?", + "what day follows Friday?", + "what comes after summer?", + "what does NASA stand for?", + "what does AI stand for?", + # with trailing punctuation + "hi!", "hello.", "thanks!", + ]) + def test_matches(self, text): + assert _LIGHT_PATTERNS.match(text.strip()), f"Expected light match for: {text!r}" + + @pytest.mark.parametrize("text", [ + "what is the capital of France", + "tell me about bitcoin", + "what is 2+2", + "write me a poem", + "search for news about the election", + "what did we talk about last time", + "what is my name", + "/think compare these frameworks", + "how do I install Python", + "explain machine learning", + "", # empty string doesn't match the pattern + ]) + def test_no_match(self, text): + assert not _LIGHT_PATTERNS.match(text.strip()), f"Expected NO light match for: {text!r}" + + +# ── _parse_tier ──────────────────────────────────────────────────────────────── + +class TestParseTier: + @pytest.mark.parametrize("raw,expected", [ + ("light", "light"), + ("Light", "light"), + ("LIGHT\n", "light"), + ("medium", "medium"), + ("Medium.", "medium"), + ("complex", "complex"), + ("Complex!", "complex"), + # descriptive words → light + ("simplefact", "light"), + ("trivial question", "light"), + ("basic", "light"), + ("easy answer", "light"), + ("general knowledge", "light"), + # unknown → medium + ("unknown_category", "medium"), + ("", "medium"), + ("I don't know", "medium"), + # complex only if 'complex' appears in first 60 chars + ("this is a complex query requiring search", "complex"), + # _parse_tier checks "complex" before "medium", so complex wins even if medium appears first + ("medium complexity, not complex", "complex"), + ]) + def test_parse_tier(self, raw, expected): + assert _parse_tier(raw) == expected + + +# ── _format_history ──────────────────────────────────────────────────────────── + +class TestFormatHistory: + def test_empty(self): + assert _format_history([]) == "(none)" + + def test_single_user_message(self): + history = [{"role": "user", "content": "hello there"}] + result = _format_history(history) + assert "user: hello there" in result + + def test_multiple_turns(self): + history = [ + {"role": "user", "content": "What is Python?"}, + {"role": "assistant", "content": "Python is a programming language."}, + ] + result = _format_history(history) + assert "user: What is Python?" in result + assert "assistant: Python is a programming language." in result + + def test_truncates_long_content(self): + long_content = "x" * 300 + history = [{"role": "user", "content": long_content}] + result = _format_history(history) + # content is truncated to 200 chars in _format_history + assert len(result) < 250 + + def test_missing_keys_handled(self): + # Should not raise — uses .get() with defaults + history = [{"role": "user"}] # no content key + result = _format_history(history) + assert "user:" in result + + +# ── Router.route() ───────────────────────────────────────────────────────────── + +class TestRouterRoute: + def _make_router(self, classify_response: str, reply_response: str = "Sure!") -> Router: + """Return a Router with a mock model that returns given classification and reply.""" + model = MagicMock() + classify_msg = MagicMock() + classify_msg.content = classify_response + reply_msg = MagicMock() + reply_msg.content = reply_response + # First ainvoke call → classification; second → reply + model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg]) + return Router(model=model) + + async def test_force_complex_bypasses_classification(self): + router = self._make_router("medium") + tier, reply = await router.route("some question", [], force_complex=True) + assert tier == "complex" + assert reply is None + # Model should NOT have been called + router.model.ainvoke.assert_not_called() + + async def test_regex_light_skips_llm_classification(self): + # Regex match bypasses classification entirely; the only ainvoke call is the reply. + model = MagicMock() + reply_msg = MagicMock() + reply_msg.content = "I'm doing great!" + model.ainvoke = AsyncMock(return_value=reply_msg) + router = Router(model=model) + tier, reply = await router.route("how are you", [], force_complex=False) + assert tier == "light" + assert reply == "I'm doing great!" + # Exactly one model call — no classification step + assert router.model.ainvoke.call_count == 1 + + async def test_llm_classifies_medium(self): + router = self._make_router("medium") + tier, reply = await router.route("what is the bitcoin price?", [], force_complex=False) + assert tier == "medium" + assert reply is None + + async def test_llm_classifies_light_generates_reply(self): + router = self._make_router("light", "Paris is the capital of France.") + tier, reply = await router.route("what is the capital of France?", [], force_complex=False) + assert tier == "light" + assert reply == "Paris is the capital of France." + + async def test_llm_classifies_complex_downgraded_to_medium(self): + # Without /think prefix, complex classification → downgraded to medium + router = self._make_router("complex") + tier, reply = await router.route("compare React and Vue", [], force_complex=False) + assert tier == "medium" + assert reply is None + + async def test_llm_error_falls_back_to_medium(self): + model = MagicMock() + model.ainvoke = AsyncMock(side_effect=Exception("connection error")) + router = Router(model=model) + tier, reply = await router.route("some question", [], force_complex=False) + assert tier == "medium" + assert reply is None + + async def test_light_reply_empty_falls_back_to_medium(self): + """If the light reply comes back empty, router returns medium instead.""" + router = self._make_router("light", "") # empty reply + tier, reply = await router.route("what is 2+2", [], force_complex=False) + assert tier == "medium" + assert reply is None + + async def test_strips_think_tags_from_classification(self): + """Router strips ... from model output before parsing tier.""" + model = MagicMock() + classify_msg = MagicMock() + classify_msg.content = "Hmm let me think...medium" + reply_msg = MagicMock() + reply_msg.content = "I'm fine!" + model.ainvoke = AsyncMock(side_effect=[classify_msg, reply_msg]) + router = Router(model=model) + tier, _ = await router.route("what is the news?", [], force_complex=False) + assert tier == "medium" + + async def test_think_prefix_forces_complex(self): + """/think prefix is already stripped by agent.py; force_complex=True is passed.""" + router = self._make_router("medium") + tier, reply = await router.route("analyse this", [], force_complex=True) + assert tier == "complex" + assert reply is None diff --git a/tests/unit/test_vram_manager.py b/tests/unit/test_vram_manager.py new file mode 100644 index 0000000..ae6df35 --- /dev/null +++ b/tests/unit/test_vram_manager.py @@ -0,0 +1,164 @@ +"""Unit tests for vram_manager.py — VRAMManager flush/poll/prewarm logic.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from vram_manager import VRAMManager + + +BASE_URL = "http://localhost:11434" + + +def _make_manager() -> VRAMManager: + return VRAMManager(base_url=BASE_URL) + + +def _mock_client(get_response=None, post_response=None): + """Return a context-manager mock for httpx.AsyncClient.""" + client = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + if get_response is not None: + client.get = AsyncMock(return_value=get_response) + if post_response is not None: + client.post = AsyncMock(return_value=post_response) + return client + + +# ── _flush ───────────────────────────────────────────────────────────────────── + +class TestFlush: + async def test_sends_keep_alive_zero(self): + client = _mock_client(post_response=MagicMock()) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + await mgr._flush("qwen3:4b") + client.post.assert_awaited_once() + _, kwargs = client.post.await_args + body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1] + assert body["model"] == "qwen3:4b" + assert body["keep_alive"] == 0 + + async def test_posts_to_correct_endpoint(self): + client = _mock_client(post_response=MagicMock()) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + await mgr._flush("qwen3:8b") + url = client.post.call_args[0][0] + assert url == f"{BASE_URL}/api/generate" + + async def test_ignores_exceptions_silently(self): + client = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + client.post = AsyncMock(side_effect=Exception("connection refused")) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + # Should not raise + await mgr._flush("qwen3:4b") + + +# ── _prewarm ─────────────────────────────────────────────────────────────────── + +class TestPrewarm: + async def test_sends_keep_alive_300(self): + client = _mock_client(post_response=MagicMock()) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + await mgr._prewarm("qwen3:4b") + _, kwargs = client.post.await_args + body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1] + assert body["keep_alive"] == 300 + assert body["model"] == "qwen3:4b" + + async def test_ignores_exceptions_silently(self): + client = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + client.post = AsyncMock(side_effect=Exception("timeout")) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + await mgr._prewarm("qwen3:4b") + + +# ── _poll_evicted ────────────────────────────────────────────────────────────── + +class TestPollEvicted: + async def test_returns_true_when_models_absent(self): + resp = MagicMock() + resp.json.return_value = {"models": [{"name": "some_other_model"}]} + client = _mock_client(get_response=resp) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + result = await mgr._poll_evicted(["qwen3:4b", "qwen2.5:1.5b"], timeout=5) + assert result is True + + async def test_returns_false_on_timeout_when_model_still_loaded(self): + resp = MagicMock() + resp.json.return_value = {"models": [{"name": "qwen3:4b"}]} + client = _mock_client(get_response=resp) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.1) + assert result is False + + async def test_returns_true_immediately_if_already_empty(self): + resp = MagicMock() + resp.json.return_value = {"models": []} + client = _mock_client(get_response=resp) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + result = await mgr._poll_evicted(["qwen3:4b"], timeout=5) + assert result is True + + async def test_handles_poll_error_and_continues(self): + """If /api/ps errors, polling continues until timeout.""" + client = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + client.get = AsyncMock(side_effect=Exception("network error")) + with patch("vram_manager.httpx.AsyncClient", return_value=client): + mgr = _make_manager() + result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.2) + assert result is False + + +# ── enter_complex_mode / exit_complex_mode ───────────────────────────────────── + +class TestComplexMode: + async def test_enter_complex_mode_returns_true_on_success(self): + mgr = _make_manager() + mgr._flush = AsyncMock() + mgr._poll_evicted = AsyncMock(return_value=True) + result = await mgr.enter_complex_mode() + assert result is True + + async def test_enter_complex_mode_flushes_medium_models(self): + mgr = _make_manager() + mgr._flush = AsyncMock() + mgr._poll_evicted = AsyncMock(return_value=True) + await mgr.enter_complex_mode() + flushed = {call.args[0] for call in mgr._flush.call_args_list} + assert "qwen3:4b" in flushed + assert "qwen2.5:1.5b" in flushed + + async def test_enter_complex_mode_returns_false_on_eviction_timeout(self): + mgr = _make_manager() + mgr._flush = AsyncMock() + mgr._poll_evicted = AsyncMock(return_value=False) + result = await mgr.enter_complex_mode() + assert result is False + + async def test_exit_complex_mode_flushes_complex_and_prewarms_medium(self): + mgr = _make_manager() + mgr._flush = AsyncMock() + mgr._prewarm = AsyncMock() + await mgr.exit_complex_mode() + # Must flush 8b + flushed = {call.args[0] for call in mgr._flush.call_args_list} + assert "qwen3:8b" in flushed + # Must prewarm medium models + prewarmed = {call.args[0] for call in mgr._prewarm.call_args_list} + assert "qwen3:4b" in prewarmed + assert "qwen2.5:1.5b" in prewarmed