Integrate Bifrost LLM gateway, add test suite, implement memory pipeline

- Add Bifrost (maximhq/bifrost) as LLM gateway: all inference routes through
  bifrost:8080/v1 with retry logic and observability; VRAMManager keeps direct
  Ollama access for VRAM flush/prewarm operations
- Switch medium model from qwen3:4b to qwen2.5:1.5b (direct call, no tools)
  via _DirectModel wrapper; complex keeps create_deep_agent with qwen3:8b
- Implement out-of-agent memory pipeline: _retrieve_memories pre-fetches
  relevant context (injected into all tiers), _store_memory runs as background
  task after each reply writing to openmemory/Qdrant
- Add tests/unit/ with 133 tests covering router, channels, vram_manager,
  agent helpers; move integration test to tests/integration/
- Add bifrost-config.json with GPU Ollama (qwen2.5:0.5b/1.5b, qwen3:4b/8b,
  gemma3:4b) and CPU Ollama providers
- Integration test 28/29 pass (only grammy fails — no TELEGRAM_BOT_TOKEN)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alvis
2026-03-12 13:50:12 +00:00
parent ec45d255f0
commit f9618a9bbf
16 changed files with 1195 additions and 36 deletions

120
agent.py
View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
import re as _re
import httpx as _httpx
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_community.utilities import SearxSearchWrapper
from langchain_core.tools import Tool
@@ -20,8 +20,12 @@ from router import Router
from agent_factory import build_medium_agent, build_complex_agent
import channels
# Bifrost gateway — all LLM inference goes through here
BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1")
# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b")
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b")
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
@@ -31,10 +35,12 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
MAX_HISTORY_TURNS = 5
_conversation_buffers: dict[str, list] = {}
# /no_think at the start of the system prompt disables qwen3 chain-of-thought.
# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so
# /no_think lands at position 0 and is respected by qwen3 models via Ollama.
MEDIUM_SYSTEM_PROMPT = (
"You are a helpful AI assistant. "
"Use web_search for questions about current events or facts you don't know. "
"Reply concisely."
"You are a helpful AI assistant. Reply concisely. "
"If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'"
)
COMPLEX_SYSTEM_PROMPT = (
@@ -54,6 +60,8 @@ complex_agent = None
router: Router = None
vram_manager: VRAMManager = None
mcp_client = None
_memory_add_tool = None
_memory_search_tool = None
# GPU mutex: one LLM inference at a time
_reply_semaphore = asyncio.Semaphore(1)
@@ -61,21 +69,34 @@ _reply_semaphore = asyncio.Semaphore(1)
@asynccontextmanager
async def lifespan(app: FastAPI):
global medium_agent, complex_agent, router, vram_manager, mcp_client
global medium_agent, complex_agent, router, vram_manager, mcp_client, \
_memory_add_tool, _memory_search_tool
# Register channel adapters
channels.register_defaults()
# Three model instances
router_model = ChatOllama(
model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096,
# All three models route through Bifrost → Ollama GPU.
# Bifrost adds retry logic, observability, and failover.
# Model names use provider/model format: Bifrost strips the "ollama/" prefix
# before forwarding to Ollama's /v1/chat/completions endpoint.
router_model = ChatOpenAI(
model=f"ollama/{ROUTER_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
temperature=0,
timeout=30,
)
medium_model = ChatOllama(
model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192
medium_model = ChatOpenAI(
model=f"ollama/{MEDIUM_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
timeout=180,
)
complex_model = ChatOllama(
model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384
complex_model = ChatOpenAI(
model=f"ollama/{COMPLEX_MODEL}",
base_url=BIFROST_URL,
api_key="dummy",
timeout=600,
)
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
@@ -97,6 +118,13 @@ async def lifespan(app: FastAPI):
agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")]
# Expose memory tools directly so run_agent_task can call them outside the agent loop
for t in mcp_tools:
if t.name == "add_memory":
_memory_add_tool = t
elif t.name == "search_memory":
_memory_search_tool = t
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
def _crawl4ai_fetch(url: str) -> str:
@@ -187,7 +215,8 @@ async def lifespan(app: FastAPI):
)
print(
f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | "
f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}",
flush=True,
)
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
@@ -222,13 +251,19 @@ class ChatRequest(BaseModel):
# ── helpers ────────────────────────────────────────────────────────────────────
def _strip_think(text: str) -> str:
"""Strip qwen3 chain-of-thought blocks that appear inline in content
when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions)."""
return _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL).strip()
def _extract_final_text(result) -> str | None:
msgs = result.get("messages", [])
for m in reversed(msgs):
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
return m.content
return _strip_think(m.content)
if isinstance(result, dict) and result.get("output"):
return result["output"]
return _strip_think(result["output"])
return None
@@ -244,6 +279,34 @@ def _log_messages(result):
print(f"[agent] {role}{tc['name']}({tc['args']})", flush=True)
# ── memory helpers ─────────────────────────────────────────────────────────────
async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
"""Store a conversation turn in openmemory (runs as a background task)."""
if _memory_add_tool is None:
return
t0 = time.monotonic()
try:
text = f"User: {user_msg}\nAssistant: {assistant_reply}"
await _memory_add_tool.ainvoke({"text": text, "user_id": session_id})
print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
except Exception as e:
print(f"[memory] error: {e}", flush=True)
async def _retrieve_memories(message: str, session_id: str) -> str:
"""Search openmemory for relevant context. Returns formatted string or ''."""
if _memory_search_tool is None:
return ""
try:
result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id})
if result and result.strip() and result.strip() != "[]":
return f"Relevant memories:\n{result}"
except Exception:
pass
return ""
# ── core task ──────────────────────────────────────────────────────────────────
async def run_agent_task(message: str, session_id: str, channel: str = "telegram"):
@@ -261,7 +324,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
history = _conversation_buffers.get(session_id, [])
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
tier, light_reply = await router.route(clean_message, history, force_complex)
# Retrieve memories once; inject into history so ALL tiers can use them
memories = await _retrieve_memories(clean_message, session_id)
enriched_history = (
[{"role": "system", "content": memories}] + history if memories else history
)
tier, light_reply = await router.route(clean_message, enriched_history, force_complex)
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
final_text = None
@@ -273,6 +342,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
elif tier == "medium":
system_prompt = MEDIUM_SYSTEM_PROMPT
if memories:
system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": system_prompt},
@@ -289,9 +360,12 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
if not ok:
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
tier = "medium"
system_prompt = MEDIUM_SYSTEM_PROMPT
if memories:
system_prompt = system_prompt + "\n\n" + memories
result = await medium_agent.ainvoke({
"messages": [
{"role": "system", "content": MEDIUM_SYSTEM_PROMPT},
{"role": "system", "content": system_prompt},
*history,
{"role": "user", "content": clean_message},
]
@@ -320,7 +394,10 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
# Deliver reply through the originating channel
if final_text:
t1 = time.monotonic()
await channels.deliver(session_id, channel, final_text)
try:
await channels.deliver(session_id, channel, final_text)
except Exception as e:
print(f"[agent] delivery error (non-fatal): {e}", flush=True)
send_elapsed = time.monotonic() - t1
print(
f"[agent] replied in {time.monotonic() - t0:.1f}s "
@@ -331,12 +408,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
else:
print("[agent] warning: no text reply from agent", flush=True)
# Update conversation buffer
# Update conversation buffer and schedule memory storage
if final_text:
buf = _conversation_buffers.get(session_id, [])
buf.append({"role": "user", "content": clean_message})
buf.append({"role": "assistant", "content": final_text})
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
# ── endpoints ──────────────────────────────────────────────────────────────────
@@ -374,7 +452,7 @@ async def reply_stream(session_id: str):
try:
text = await asyncio.wait_for(q.get(), timeout=900)
# Escape newlines so entire reply fits in one SSE data line
yield f"data: {text.replace(chr(10), '\\n').replace(chr(13), '')}\n\n"
yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n"
except asyncio.TimeoutError:
yield "data: [timeout]\n\n"