import asyncio import os import httpx OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") class VRAMManager: MEDIUM_MODELS = ["qwen3:4b", "qwen2.5:1.5b"] COMPLEX_MODEL = "qwen3:8b" def __init__(self, base_url: str = OLLAMA_BASE_URL): self.base_url = base_url async def enter_complex_mode(self) -> bool: """Flush medium models before loading 8b. Returns False if eviction timed out.""" print("[vram] enter_complex_mode: flushing medium models", flush=True) await asyncio.gather(*[self._flush(m) for m in self.MEDIUM_MODELS]) ok = await self._poll_evicted(self.MEDIUM_MODELS, timeout=15) if ok: print("[vram] enter_complex_mode: eviction confirmed, loading qwen3:8b", flush=True) else: print("[vram] enter_complex_mode: eviction timeout — falling back to medium", flush=True) return ok async def exit_complex_mode(self): """Flush 8b and pre-warm medium models. Run as background task after complex reply.""" print("[vram] exit_complex_mode: flushing qwen3:8b", flush=True) await self._flush(self.COMPLEX_MODEL) print("[vram] exit_complex_mode: pre-warming medium models", flush=True) await asyncio.gather(*[self._prewarm(m) for m in self.MEDIUM_MODELS]) print("[vram] exit_complex_mode: done", flush=True) async def _flush(self, model: str): """Send keep_alive=0 to force immediate unload from VRAM.""" try: async with httpx.AsyncClient(timeout=10.0) as client: await client.post( f"{self.base_url}/api/generate", json={"model": model, "prompt": "", "keep_alive": 0}, ) except Exception as e: print(f"[vram] flush {model} error: {e}", flush=True) async def _poll_evicted(self, models: list[str], timeout: float) -> bool: """Poll /api/ps until none of the given models appear (or timeout).""" deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: try: async with httpx.AsyncClient(timeout=5.0) as client: resp = await client.get(f"{self.base_url}/api/ps") data = resp.json() loaded = {m.get("name", "") for m in data.get("models", [])} if not any(m in loaded for m in models): return True except Exception as e: print(f"[vram] poll_evicted error: {e}", flush=True) await asyncio.sleep(0.5) return False async def _prewarm(self, model: str): """Load model into VRAM with keep_alive=300 (5 min).""" try: async with httpx.AsyncClient(timeout=60.0) as client: await client.post( f"{self.base_url}/api/generate", json={"model": model, "prompt": "", "keep_alive": 300}, ) print(f"[vram] pre-warmed {model}", flush=True) except Exception as e: print(f"[vram] prewarm {model} error: {e}", flush=True)