Add three-tier model routing with VRAM management and benchmark suite
- Three-tier routing: light (router answers directly ~3s), medium (qwen3:4b + tools ~60s), complex (/think prefix → qwen3:8b + subagents ~140s) - Router: qwen2.5:1.5b, temp=0, regex pre-classifier + raw-text LLM classify - VRAMManager: explicit flush/poll/prewarm to prevent Ollama CPU-spill bug - agent_factory: build_medium_agent and build_complex_agent using deepagents (TodoListMiddleware + SubAgentMiddleware with research/memory subagents) - Fix: split Telegram replies >4000 chars into multiple messages - Benchmark: 30 questions (easy/medium/hard) — 10/10/10 verified passing easy→light, medium→medium, hard→complex with VRAM flush confirmed Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
71
vram_manager.py
Normal file
71
vram_manager.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import os
|
||||
import httpx
|
||||
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
|
||||
class VRAMManager:
|
||||
MEDIUM_MODELS = ["qwen3:4b", "qwen2.5:1.5b"]
|
||||
COMPLEX_MODEL = "qwen3:8b"
|
||||
|
||||
def __init__(self, base_url: str = OLLAMA_BASE_URL):
|
||||
self.base_url = base_url
|
||||
|
||||
async def enter_complex_mode(self) -> bool:
|
||||
"""Flush medium models before loading 8b. Returns False if eviction timed out."""
|
||||
print("[vram] enter_complex_mode: flushing medium models", flush=True)
|
||||
await asyncio.gather(*[self._flush(m) for m in self.MEDIUM_MODELS])
|
||||
ok = await self._poll_evicted(self.MEDIUM_MODELS, timeout=15)
|
||||
if ok:
|
||||
print("[vram] enter_complex_mode: eviction confirmed, loading qwen3:8b", flush=True)
|
||||
else:
|
||||
print("[vram] enter_complex_mode: eviction timeout — falling back to medium", flush=True)
|
||||
return ok
|
||||
|
||||
async def exit_complex_mode(self):
|
||||
"""Flush 8b and pre-warm medium models. Run as background task after complex reply."""
|
||||
print("[vram] exit_complex_mode: flushing qwen3:8b", flush=True)
|
||||
await self._flush(self.COMPLEX_MODEL)
|
||||
print("[vram] exit_complex_mode: pre-warming medium models", flush=True)
|
||||
await asyncio.gather(*[self._prewarm(m) for m in self.MEDIUM_MODELS])
|
||||
print("[vram] exit_complex_mode: done", flush=True)
|
||||
|
||||
async def _flush(self, model: str):
|
||||
"""Send keep_alive=0 to force immediate unload from VRAM."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"model": model, "prompt": "", "keep_alive": 0},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[vram] flush {model} error: {e}", flush=True)
|
||||
|
||||
async def _poll_evicted(self, models: list[str], timeout: float) -> bool:
|
||||
"""Poll /api/ps until none of the given models appear (or timeout)."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/api/ps")
|
||||
data = resp.json()
|
||||
loaded = {m.get("name", "") for m in data.get("models", [])}
|
||||
if not any(m in loaded for m in models):
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[vram] poll_evicted error: {e}", flush=True)
|
||||
await asyncio.sleep(0.5)
|
||||
return False
|
||||
|
||||
async def _prewarm(self, model: str):
|
||||
"""Load model into VRAM with keep_alive=300 (5 min)."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"model": model, "prompt": "", "keep_alive": 300},
|
||||
)
|
||||
print(f"[vram] pre-warmed {model}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[vram] prewarm {model} error: {e}", flush=True)
|
||||
Reference in New Issue
Block a user