Add three-tier model routing with VRAM management and benchmark suite

- Three-tier routing: light (router answers directly ~3s), medium (qwen3:4b
  + tools ~60s), complex (/think prefix → qwen3:8b + subagents ~140s)
- Router: qwen2.5:1.5b, temp=0, regex pre-classifier + raw-text LLM classify
- VRAMManager: explicit flush/poll/prewarm to prevent Ollama CPU-spill bug
- agent_factory: build_medium_agent and build_complex_agent using deepagents
  (TodoListMiddleware + SubAgentMiddleware with research/memory subagents)
- Fix: split Telegram replies >4000 chars into multiple messages
- Benchmark: 30 questions (easy/medium/hard) — 10/10/10 verified passing
  easy→light, medium→medium, hard→complex with VRAM flush confirmed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvis
2026-02-28 17:54:51 +00:00
parent ff20f8942d
commit 09a93c661e
8 changed files with 1400 additions and 308 deletions

View File

@@ -11,15 +11,23 @@ from langchain_ollama import ChatOllama
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_community.utilities import SearxSearchWrapper
from langchain_core.tools import Tool
from langgraph.prebuilt import create_react_agent
from vram_manager import VRAMManager
from router import Router
from agent_factory import build_medium_agent, build_complex_agent
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:8b")
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b")
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
OPENMEMORY_URL = os.getenv("OPENMEMORY_URL", "http://openmemory:8765")
GRAMMY_URL = os.getenv("GRAMMY_URL", "http://grammy:3001")
SYSTEM_PROMPT_TEMPLATE = (
MAX_HISTORY_TURNS = 5
_conversation_buffers: dict[str, list] = {}
MEDIUM_SYSTEM_PROMPT = (
"You are a helpful AI assistant talking to a user via Telegram. "
"The user's ID is {user_id}. "
"IMPORTANT: When calling any memory tool (search_memory, get_all_memories), "
@@ -28,33 +36,62 @@ SYSTEM_PROMPT_TEMPLATE = (
"you do NOT need to explicitly store anything. "
"NEVER tell the user you cannot remember or store information. "
"If the user asks you to remember something, acknowledge it and confirm it will be remembered. "
"Always call search_memory before answering to recall relevant past context. "
"Use web_search for questions about current events. "
"Use search_memory when context from past conversations may be relevant. "
"Use web_search for questions about current events or facts you don't know. "
"Reply concisely."
)
agent = None
COMPLEX_SYSTEM_PROMPT = (
"You are a capable AI assistant tackling a complex, multi-step task for a Telegram user. "
"The user's ID is {user_id}. "
"IMPORTANT: When calling any memory tool (search_memory, get_all_memories), "
"always use user_id=\"{user_id}\". "
"Plan your work using write_todos before diving in. "
"Delegate: use the 'research' subagent for thorough web research across multiple queries, "
"and the 'memory' subagent to gather comprehensive context from past conversations. "
"Every conversation is automatically saved to memory — you do NOT need to store anything. "
"NEVER tell the user you cannot remember or store information. "
"Produce a thorough, well-structured reply."
)
medium_agent = None
complex_agent = None
router: Router = None
vram_manager: VRAMManager = None
mcp_client = None
send_tool = None
add_memory_tool = None
# GPU semaphore: one LLM inference at a time
# GPU mutex: one LLM inference at a time
_reply_semaphore = asyncio.Semaphore(1)
# CPU semaphore: one memory store at a time (runs on CPU Ollama, no GPU contention)
# Memory semaphore: one async extraction at a time
_memory_semaphore = asyncio.Semaphore(1)
@asynccontextmanager
async def lifespan(app: FastAPI):
global agent, mcp_client, send_tool, add_memory_tool
global medium_agent, complex_agent, router, vram_manager
global mcp_client, send_tool, add_memory_tool
model = ChatOllama(model=MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192)
# Three model instances
router_model = ChatOllama(
model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096,
temperature=0, # deterministic classification
)
medium_model = ChatOllama(
model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192
)
complex_model = ChatOllama(
model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384
)
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
router = Router(model=router_model)
mcp_connections = {
"openmemory": {"transport": "sse", "url": f"{OPENMEMORY_URL}/sse"},
"grammy": {"transport": "sse", "url": f"{GRAMMY_URL}/sse"},
}
mcp_client = MultiServerMCPClient(mcp_connections)
for attempt in range(12):
try:
@@ -66,10 +103,8 @@ async def lifespan(app: FastAPI):
print(f"[agent] MCP not ready (attempt {attempt + 1}/12): {e}. Retrying in 5s...")
await asyncio.sleep(5)
# Split tools: send is called by us, add_memory runs async after reply
send_tool = next((t for t in mcp_tools if t.name == "send_telegram_message"), None)
add_memory_tool = next((t for t in mcp_tools if t.name == "add_memory"), None)
# Agent only gets read/search tools — no add_memory (would block reply)
agent_tools = [t for t in mcp_tools if t.name not in ("send_telegram_message", "add_memory")]
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
@@ -79,13 +114,30 @@ async def lifespan(app: FastAPI):
description="Search the web for current information",
))
agent = create_react_agent(model, agent_tools)
print(f"[agent] ready — agent tools: {[t.name for t in agent_tools]}", flush=True)
print(f"[agent] async memory: add_memory via CPU Ollama (qwen2.5:1.5b + nomic-embed-text)", flush=True)
# Build agents (system_prompt filled per-request with user_id)
medium_agent = build_medium_agent(
model=medium_model,
agent_tools=agent_tools,
system_prompt=MEDIUM_SYSTEM_PROMPT.format(user_id="{user_id}"),
)
complex_agent = build_complex_agent(
model=complex_model,
agent_tools=agent_tools,
system_prompt=COMPLEX_SYSTEM_PROMPT.format(user_id="{user_id}"),
)
print(
f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
flush=True,
)
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
yield
agent = None
medium_agent = None
complex_agent = None
router = None
vram_manager = None
mcp_client = None
send_tool = None
add_memory_tool = None
@@ -100,7 +152,13 @@ class ChatRequest(BaseModel):
async def store_memory_async(conversation: str, user_id: str):
"""Fire-and-forget: extract and store memories using CPU Ollama. Never blocks replies."""
"""Fire-and-forget: extract and store memories after GPU is free."""
t_wait = time.monotonic()
while _reply_semaphore.locked():
if time.monotonic() - t_wait > 60:
print(f"[memory] spin-wait timeout 60s, proceeding for user {user_id}", flush=True)
break
await asyncio.sleep(0.5)
async with _memory_semaphore:
t0 = time.monotonic()
try:
@@ -110,60 +168,137 @@ async def store_memory_async(conversation: str, user_id: str):
print(f"[memory] error after {time.monotonic() - t0:.1f}s: {e}", flush=True)
def _extract_final_text(result) -> str | None:
"""Extract last AIMessage content from agent result."""
msgs = result.get("messages", [])
for m in reversed(msgs):
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
return m.content
# deepagents may return output differently
if isinstance(result, dict) and result.get("output"):
return result["output"]
return None
def _log_messages(result):
msgs = result.get("messages", [])
for m in msgs:
role = type(m).__name__
content = getattr(m, "content", "")
tool_calls = getattr(m, "tool_calls", [])
if content:
print(f"[agent] {role}: {str(content)[:150]}", flush=True)
for tc in tool_calls:
print(f"[agent] {role}{tc['name']}({tc['args']})", flush=True)
async def run_agent_task(message: str, chat_id: str):
print(f"[agent] queued: {message[:80]!r} chat={chat_id}", flush=True)
# Pre-check: /think prefix forces complex tier
force_complex = False
clean_message = message
if message.startswith("/think "):
force_complex = True
clean_message = message[len("/think "):]
print("[agent] /think prefix → force_complex=True", flush=True)
async with _reply_semaphore:
t0 = time.monotonic()
print(f"[agent] running: {message[:80]!r}", flush=True)
history = _conversation_buffers.get(chat_id, [])
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
# Route the message
tier, light_reply = await router.route(clean_message, history, force_complex)
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
final_text = None
try:
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(user_id=chat_id)
result = await agent.ainvoke(
{"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
]}
)
llm_elapsed = time.monotonic() - t0
if tier == "light":
final_text = light_reply
llm_elapsed = time.monotonic() - t0
print(f"[agent] light path: answered by router", flush=True)
# Log trace
msgs = result.get("messages", [])
for m in msgs:
role = type(m).__name__
content = getattr(m, "content", "")
tool_calls = getattr(m, "tool_calls", [])
if content:
print(f"[agent] {role}: {str(content)[:150]}", flush=True)
for tc in tool_calls:
print(f"[agent] {role}{tc['name']}({tc['args']})", flush=True)
elif tier == "medium":
system_prompt = MEDIUM_SYSTEM_PROMPT.format(user_id=chat_id)
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
})
llm_elapsed = time.monotonic() - t0
_log_messages(result)
final_text = _extract_final_text(result)
# Send reply immediately
final_text = None
for m in reversed(msgs):
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
final_text = m.content
break
else: # complex
ok = await vram_manager.enter_complex_mode()
if not ok:
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
tier = "medium"
system_prompt = MEDIUM_SYSTEM_PROMPT.format(user_id=chat_id)
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
})
else:
system_prompt = COMPLEX_SYSTEM_PROMPT.format(user_id=chat_id)
result = await complex_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
})
asyncio.create_task(vram_manager.exit_complex_mode())
if final_text and send_tool:
t1 = time.monotonic()
await send_tool.ainvoke({"chat_id": chat_id, "text": final_text})
print(f"[agent] replied in {time.monotonic() - t0:.1f}s (llm={llm_elapsed:.1f}s, send={time.monotonic()-t1:.1f}s)", flush=True)
elif not final_text:
print(f"[agent] warning: no text reply from agent", flush=True)
# Async memoization: runs on CPU Ollama, does not block next reply
if add_memory_tool and final_text:
conversation = f"User: {message}\nAssistant: {final_text}"
asyncio.create_task(store_memory_async(conversation, chat_id))
llm_elapsed = time.monotonic() - t0
_log_messages(result)
final_text = _extract_final_text(result)
except Exception as e:
import traceback
print(f"[agent] error after {time.monotonic()-t0:.1f}s for chat {chat_id}: {e}", flush=True)
llm_elapsed = time.monotonic() - t0
print(f"[agent] error after {llm_elapsed:.1f}s for chat {chat_id}: {e}", flush=True)
traceback.print_exc()
# Send reply via grammy MCP (split if > Telegram's 4096-char limit)
if final_text and send_tool:
t1 = time.monotonic()
MAX_TG = 4000 # leave headroom below the 4096 hard limit
chunks = [final_text[i:i + MAX_TG] for i in range(0, len(final_text), MAX_TG)]
for chunk in chunks:
await send_tool.ainvoke({"chat_id": chat_id, "text": chunk})
send_elapsed = time.monotonic() - t1
# Log in format compatible with test_pipeline.py parser
print(
f"[agent] replied in {time.monotonic() - t0:.1f}s "
f"(llm={llm_elapsed:.1f}s, send={send_elapsed:.1f}s) tier={tier}",
flush=True,
)
elif not final_text:
print("[agent] warning: no text reply from agent", flush=True)
# Update conversation buffer
if final_text:
buf = _conversation_buffers.get(chat_id, [])
buf.append({"role": "user", "content": clean_message})
buf.append({"role": "assistant", "content": final_text})
_conversation_buffers[chat_id] = buf[-(MAX_HISTORY_TURNS * 2):]
# Async memory storage (fire-and-forget)
if add_memory_tool and final_text:
conversation = f"User: {clean_message}\nAssistant: {final_text}"
asyncio.create_task(store_memory_async(conversation, chat_id))
@app.post("/chat")
async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
if agent is None:
if medium_agent is None:
return JSONResponse(status_code=503, content={"error": "Agent not ready"})
background_tasks.add_task(run_agent_task, request.message, request.chat_id)
return JSONResponse(status_code=202, content={"status": "accepted"})
@@ -171,4 +306,4 @@ async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
@app.get("/health")
async def health():
return {"status": "ok", "agent_ready": agent is not None}
return {"status": "ok", "agent_ready": medium_agent is not None}