Integrate Bifrost LLM gateway, add test suite, implement memory pipeline
- Add Bifrost (maximhq/bifrost) as LLM gateway: all inference routes through bifrost:8080/v1 with retry logic and observability; VRAMManager keeps direct Ollama access for VRAM flush/prewarm operations - Switch medium model from qwen3:4b to qwen2.5:1.5b (direct call, no tools) via _DirectModel wrapper; complex keeps create_deep_agent with qwen3:8b - Implement out-of-agent memory pipeline: _retrieve_memories pre-fetches relevant context (injected into all tiers), _store_memory runs as background task after each reply writing to openmemory/Qdrant - Add tests/unit/ with 133 tests covering router, channels, vram_manager, agent helpers; move integration test to tests/integration/ - Add bifrost-config.json with GPU Ollama (qwen2.5:0.5b/1.5b, qwen3:4b/8b, gemma3:4b) and CPU Ollama providers - Integration test 28/29 pass (only grammy fails — no TELEGRAM_BOT_TOKEN) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
120
agent.py
120
agent.py
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
import re as _re
|
||||
import httpx as _httpx
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
from langchain_core.tools import Tool
|
||||
@@ -20,8 +20,12 @@ from router import Router
|
||||
from agent_factory import build_medium_agent, build_complex_agent
|
||||
import channels
|
||||
|
||||
# Bifrost gateway — all LLM inference goes through here
|
||||
BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1")
|
||||
# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:0.5b")
|
||||
|
||||
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b")
|
||||
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
|
||||
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
|
||||
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
|
||||
@@ -31,10 +35,12 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
|
||||
MAX_HISTORY_TURNS = 5
|
||||
_conversation_buffers: dict[str, list] = {}
|
||||
|
||||
# /no_think at the start of the system prompt disables qwen3 chain-of-thought.
|
||||
# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so
|
||||
# /no_think lands at position 0 and is respected by qwen3 models via Ollama.
|
||||
MEDIUM_SYSTEM_PROMPT = (
|
||||
"You are a helpful AI assistant. "
|
||||
"Use web_search for questions about current events or facts you don't know. "
|
||||
"Reply concisely."
|
||||
"You are a helpful AI assistant. Reply concisely. "
|
||||
"If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'"
|
||||
)
|
||||
|
||||
COMPLEX_SYSTEM_PROMPT = (
|
||||
@@ -54,6 +60,8 @@ complex_agent = None
|
||||
router: Router = None
|
||||
vram_manager: VRAMManager = None
|
||||
mcp_client = None
|
||||
_memory_add_tool = None
|
||||
_memory_search_tool = None
|
||||
|
||||
# GPU mutex: one LLM inference at a time
|
||||
_reply_semaphore = asyncio.Semaphore(1)
|
||||
@@ -61,21 +69,34 @@ _reply_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global medium_agent, complex_agent, router, vram_manager, mcp_client
|
||||
global medium_agent, complex_agent, router, vram_manager, mcp_client, \
|
||||
_memory_add_tool, _memory_search_tool
|
||||
|
||||
# Register channel adapters
|
||||
channels.register_defaults()
|
||||
|
||||
# Three model instances
|
||||
router_model = ChatOllama(
|
||||
model=ROUTER_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=4096,
|
||||
# All three models route through Bifrost → Ollama GPU.
|
||||
# Bifrost adds retry logic, observability, and failover.
|
||||
# Model names use provider/model format: Bifrost strips the "ollama/" prefix
|
||||
# before forwarding to Ollama's /v1/chat/completions endpoint.
|
||||
router_model = ChatOpenAI(
|
||||
model=f"ollama/{ROUTER_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
temperature=0,
|
||||
timeout=30,
|
||||
)
|
||||
medium_model = ChatOllama(
|
||||
model=MEDIUM_MODEL, base_url=OLLAMA_BASE_URL, think=False, num_ctx=8192
|
||||
medium_model = ChatOpenAI(
|
||||
model=f"ollama/{MEDIUM_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
timeout=180,
|
||||
)
|
||||
complex_model = ChatOllama(
|
||||
model=COMPLEX_MODEL, base_url=OLLAMA_BASE_URL, think=True, num_ctx=16384
|
||||
complex_model = ChatOpenAI(
|
||||
model=f"ollama/{COMPLEX_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
|
||||
@@ -97,6 +118,13 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")]
|
||||
|
||||
# Expose memory tools directly so run_agent_task can call them outside the agent loop
|
||||
for t in mcp_tools:
|
||||
if t.name == "add_memory":
|
||||
_memory_add_tool = t
|
||||
elif t.name == "search_memory":
|
||||
_memory_search_tool = t
|
||||
|
||||
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
|
||||
|
||||
def _crawl4ai_fetch(url: str) -> str:
|
||||
@@ -187,7 +215,8 @@ async def lifespan(app: FastAPI):
|
||||
)
|
||||
|
||||
print(
|
||||
f"[agent] three-tier: router={ROUTER_MODEL} | medium={MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
|
||||
f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | "
|
||||
f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
|
||||
@@ -222,13 +251,19 @@ class ChatRequest(BaseModel):
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _strip_think(text: str) -> str:
|
||||
"""Strip qwen3 chain-of-thought blocks that appear inline in content
|
||||
when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions)."""
|
||||
return _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL).strip()
|
||||
|
||||
|
||||
def _extract_final_text(result) -> str | None:
|
||||
msgs = result.get("messages", [])
|
||||
for m in reversed(msgs):
|
||||
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
|
||||
return m.content
|
||||
return _strip_think(m.content)
|
||||
if isinstance(result, dict) and result.get("output"):
|
||||
return result["output"]
|
||||
return _strip_think(result["output"])
|
||||
return None
|
||||
|
||||
|
||||
@@ -244,6 +279,34 @@ def _log_messages(result):
|
||||
print(f"[agent] {role} → {tc['name']}({tc['args']})", flush=True)
|
||||
|
||||
|
||||
# ── memory helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
|
||||
"""Store a conversation turn in openmemory (runs as a background task)."""
|
||||
if _memory_add_tool is None:
|
||||
return
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
text = f"User: {user_msg}\nAssistant: {assistant_reply}"
|
||||
await _memory_add_tool.ainvoke({"text": text, "user_id": session_id})
|
||||
print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[memory] error: {e}", flush=True)
|
||||
|
||||
|
||||
async def _retrieve_memories(message: str, session_id: str) -> str:
|
||||
"""Search openmemory for relevant context. Returns formatted string or ''."""
|
||||
if _memory_search_tool is None:
|
||||
return ""
|
||||
try:
|
||||
result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id})
|
||||
if result and result.strip() and result.strip() != "[]":
|
||||
return f"Relevant memories:\n{result}"
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
# ── core task ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run_agent_task(message: str, session_id: str, channel: str = "telegram"):
|
||||
@@ -261,7 +324,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
history = _conversation_buffers.get(session_id, [])
|
||||
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
|
||||
|
||||
tier, light_reply = await router.route(clean_message, history, force_complex)
|
||||
# Retrieve memories once; inject into history so ALL tiers can use them
|
||||
memories = await _retrieve_memories(clean_message, session_id)
|
||||
enriched_history = (
|
||||
[{"role": "system", "content": memories}] + history if memories else history
|
||||
)
|
||||
|
||||
tier, light_reply = await router.route(clean_message, enriched_history, force_complex)
|
||||
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
|
||||
|
||||
final_text = None
|
||||
@@ -273,6 +342,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
|
||||
elif tier == "medium":
|
||||
system_prompt = MEDIUM_SYSTEM_PROMPT
|
||||
if memories:
|
||||
system_prompt = system_prompt + "\n\n" + memories
|
||||
result = await medium_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
@@ -289,9 +360,12 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
if not ok:
|
||||
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
|
||||
tier = "medium"
|
||||
system_prompt = MEDIUM_SYSTEM_PROMPT
|
||||
if memories:
|
||||
system_prompt = system_prompt + "\n\n" + memories
|
||||
result = await medium_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": MEDIUM_SYSTEM_PROMPT},
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]
|
||||
@@ -320,7 +394,10 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
# Deliver reply through the originating channel
|
||||
if final_text:
|
||||
t1 = time.monotonic()
|
||||
await channels.deliver(session_id, channel, final_text)
|
||||
try:
|
||||
await channels.deliver(session_id, channel, final_text)
|
||||
except Exception as e:
|
||||
print(f"[agent] delivery error (non-fatal): {e}", flush=True)
|
||||
send_elapsed = time.monotonic() - t1
|
||||
print(
|
||||
f"[agent] replied in {time.monotonic() - t0:.1f}s "
|
||||
@@ -331,12 +408,13 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
else:
|
||||
print("[agent] warning: no text reply from agent", flush=True)
|
||||
|
||||
# Update conversation buffer
|
||||
# Update conversation buffer and schedule memory storage
|
||||
if final_text:
|
||||
buf = _conversation_buffers.get(session_id, [])
|
||||
buf.append({"role": "user", "content": clean_message})
|
||||
buf.append({"role": "assistant", "content": final_text})
|
||||
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
|
||||
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
|
||||
|
||||
|
||||
# ── endpoints ──────────────────────────────────────────────────────────────────
|
||||
@@ -374,7 +452,7 @@ async def reply_stream(session_id: str):
|
||||
try:
|
||||
text = await asyncio.wait_for(q.get(), timeout=900)
|
||||
# Escape newlines so entire reply fits in one SSE data line
|
||||
yield f"data: {text.replace(chr(10), '\\n').replace(chr(13), '')}\n\n"
|
||||
yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield "data: [timeout]\n\n"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user