- Add tier_capture param to _run_agent_pipeline; append tier after determination - Capture actual_tier in run_agent_task from tier_capture list - Log tier in replied-in line: [agent] replied in Xs tier=Y - Remove reply_text[:200] truncation (was breaking benchmark keyword matching) - Update parse_run_block regex to match new log format; llm/send fields now None Fixes #1, #3, #4 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
815 lines
31 KiB
Python
815 lines
31 KiB
Python
import asyncio
|
|
import json as _json_module
|
|
import os
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, BackgroundTasks, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
import re as _re
|
|
import httpx as _httpx
|
|
|
|
_URL_RE = _re.compile(r'https?://[^\s<>"\']+')
|
|
|
|
|
|
def _extract_urls(text: str) -> list[str]:
|
|
return _URL_RE.findall(text)
|
|
|
|
from openai import AsyncOpenAI
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
from langchain_community.utilities import SearxSearchWrapper
|
|
from langchain_core.tools import Tool
|
|
|
|
from vram_manager import VRAMManager
|
|
from router import Router
|
|
from agent_factory import build_medium_agent, build_complex_agent
|
|
from fast_tools import FastToolRunner, WeatherTool, CommuteTool
|
|
import channels
|
|
|
|
# LiteLLM proxy — all LLM inference goes through here
|
|
LITELLM_URL = os.getenv("LITELLM_URL", "http://host.docker.internal:4000/v1")
|
|
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY", "dummy")
|
|
# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
|
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
|
|
ROUTER_MODEL = os.getenv("DEEPAGENTS_ROUTER_MODEL", "qwen2.5:1.5b")
|
|
MEDIUM_MODEL = os.getenv("DEEPAGENTS_MODEL", "qwen3:4b")
|
|
COMPLEX_MODEL = os.getenv("DEEPAGENTS_COMPLEX_MODEL", "qwen3:8b")
|
|
SEARXNG_URL = os.getenv("SEARXNG_URL", "http://host.docker.internal:11437")
|
|
OPENMEMORY_URL = os.getenv("OPENMEMORY_URL", "http://openmemory:8765")
|
|
CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
|
|
ROUTECHECK_URL = os.getenv("ROUTECHECK_URL", "http://routecheck:8090")
|
|
ROUTECHECK_TOKEN = os.getenv("ROUTECHECK_TOKEN", "")
|
|
|
|
MAX_HISTORY_TURNS = 5
|
|
_conversation_buffers: dict[str, list] = {}
|
|
|
|
# ── Interaction logging (RLHF data collection) ─────────────────────────────────
|
|
_LOG_DIR = Path(os.getenv("ADOLF_LOG_DIR", "/app/logs"))
|
|
_INTERACTIONS_LOG = _LOG_DIR / "interactions.jsonl"
|
|
|
|
def _ensure_log_dir() -> None:
|
|
try:
|
|
_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
except Exception as e:
|
|
print(f"[log] cannot create log dir {_LOG_DIR}: {e}", flush=True)
|
|
|
|
|
|
async def _log_interaction(
|
|
session_id: str,
|
|
channel: str,
|
|
tier: str,
|
|
input_text: str,
|
|
response_text: str | None,
|
|
latency_ms: int,
|
|
metadata: dict | None = None,
|
|
) -> None:
|
|
"""Append one interaction record to the JSONL log for future RLHF/finetuning."""
|
|
record = {
|
|
"ts": time.time(),
|
|
"session_id": session_id,
|
|
"channel": channel,
|
|
"tier": tier,
|
|
"input": input_text,
|
|
"output": response_text or "",
|
|
"latency_ms": latency_ms,
|
|
}
|
|
if metadata:
|
|
record["metadata"] = metadata
|
|
try:
|
|
_ensure_log_dir()
|
|
with open(_INTERACTIONS_LOG, "a", encoding="utf-8") as f:
|
|
f.write(_json_module.dumps(record, ensure_ascii=False) + "\n")
|
|
except Exception as e:
|
|
print(f"[log] write error: {e}", flush=True)
|
|
|
|
# Per-session streaming queues — filled during inference, read by /stream/{session_id}
|
|
_stream_queues: dict[str, asyncio.Queue] = {}
|
|
|
|
|
|
async def _push_stream_chunk(session_id: str, chunk: str) -> None:
|
|
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
|
await q.put(chunk)
|
|
|
|
|
|
async def _end_stream(session_id: str) -> None:
|
|
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
|
await q.put("[DONE]")
|
|
|
|
|
|
async def _crawl4ai_fetch_async(url: str) -> str:
|
|
"""Async fetch via Crawl4AI — JS-rendered, bot-bypass, returns clean markdown."""
|
|
try:
|
|
async with _httpx.AsyncClient(timeout=60) as client:
|
|
r = await client.post(f"{CRAWL4AI_URL}/crawl", json={"urls": [url]})
|
|
r.raise_for_status()
|
|
results = r.json().get("results", [])
|
|
if not results or not results[0].get("success"):
|
|
return ""
|
|
md_obj = results[0].get("markdown") or {}
|
|
md = md_obj.get("raw_markdown") if isinstance(md_obj, dict) else str(md_obj)
|
|
return (md or "")[:5000]
|
|
except Exception as e:
|
|
return f"[fetch error: {e}]"
|
|
|
|
|
|
async def _fetch_urls_from_message(message: str) -> str:
|
|
"""If message contains URLs, fetch their content concurrently via Crawl4AI.
|
|
Returns a formatted context block, or '' if no URLs or all fetches fail."""
|
|
urls = _extract_urls(message)
|
|
if not urls:
|
|
return ""
|
|
# Fetch up to 3 URLs concurrently
|
|
results = await asyncio.gather(*[_crawl4ai_fetch_async(u) for u in urls[:3]])
|
|
parts = []
|
|
for url, content in zip(urls[:3], results):
|
|
if content and not content.startswith("[fetch error"):
|
|
parts.append(f"### {url}\n{content[:3000]}")
|
|
if not parts:
|
|
return ""
|
|
return "User's message contains URLs. Fetched content:\n\n" + "\n\n".join(parts)
|
|
|
|
|
|
|
|
# /no_think at the start of the system prompt disables qwen3 chain-of-thought.
|
|
# create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so
|
|
# /no_think lands at position 0 and is respected by qwen3 models via Ollama.
|
|
MEDIUM_SYSTEM_PROMPT = (
|
|
"You are a helpful AI assistant. Reply concisely. "
|
|
"If asked to remember a fact or name, simply confirm: 'Got it, I'll remember that.'"
|
|
)
|
|
|
|
COMPLEX_SYSTEM_PROMPT = (
|
|
"You are a deep research assistant. "
|
|
"web_search automatically fetches full page content from top results — use it 6+ times with different queries. "
|
|
"Also call fetch_url on any specific URL you want to read in full.\n\n"
|
|
"Run searches in English AND Russian/Latvian. "
|
|
"After getting results, run follow-up searches based on new facts found.\n\n"
|
|
"Write a structured markdown report with sections: "
|
|
"Overview, Education, Career, Publications, Online Presence, Interesting Findings.\n"
|
|
"Every fact must link to the real URL it came from: [fact](url). "
|
|
"NEVER invent URLs. End with: **Sources checked: N**"
|
|
)
|
|
|
|
medium_model = None
|
|
medium_agent = None
|
|
complex_agent = None
|
|
router: Router = None
|
|
vram_manager: VRAMManager = None
|
|
mcp_client = None
|
|
_memory_add_tool = None
|
|
_memory_search_tool = None
|
|
|
|
# Fast tools run before the LLM — classifier + context enricher
|
|
_fast_tool_runner = FastToolRunner([
|
|
WeatherTool(),
|
|
CommuteTool(routecheck_url=ROUTECHECK_URL, internal_token=ROUTECHECK_TOKEN),
|
|
])
|
|
|
|
# GPU mutex: one LLM inference at a time
|
|
_reply_semaphore = asyncio.Semaphore(1)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global medium_model, medium_agent, complex_agent, router, vram_manager, mcp_client, \
|
|
_memory_add_tool, _memory_search_tool
|
|
|
|
# Register channel adapters
|
|
channels.register_defaults()
|
|
|
|
# All three models route through Bifrost → Ollama GPU.
|
|
router_model = ChatOpenAI(
|
|
model=f"ollama/{ROUTER_MODEL}",
|
|
base_url=LITELLM_URL,
|
|
api_key=LITELLM_API_KEY,
|
|
temperature=0,
|
|
timeout=30,
|
|
)
|
|
embedder = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_API_KEY)
|
|
medium_model = ChatOpenAI(
|
|
model=f"ollama/{MEDIUM_MODEL}",
|
|
base_url=LITELLM_URL,
|
|
api_key=LITELLM_API_KEY,
|
|
timeout=180,
|
|
)
|
|
complex_model = ChatOpenAI(
|
|
model=COMPLEX_MODEL, # full model name — may be remote (OpenRouter) or local ollama/*
|
|
base_url=LITELLM_URL,
|
|
api_key=LITELLM_API_KEY,
|
|
timeout=600,
|
|
)
|
|
|
|
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
|
|
router = Router(model=router_model, embedder=embedder, fast_tool_runner=_fast_tool_runner)
|
|
await router.initialize()
|
|
|
|
mcp_connections = {
|
|
"openmemory": {"transport": "sse", "url": f"{OPENMEMORY_URL}/sse"},
|
|
}
|
|
mcp_client = MultiServerMCPClient(mcp_connections)
|
|
for attempt in range(12):
|
|
try:
|
|
mcp_tools = await mcp_client.get_tools()
|
|
break
|
|
except Exception as e:
|
|
if attempt == 11:
|
|
raise
|
|
print(f"[agent] MCP not ready (attempt {attempt + 1}/12): {e}. Retrying in 5s...")
|
|
await asyncio.sleep(5)
|
|
|
|
agent_tools = [t for t in mcp_tools if t.name not in ("add_memory", "search_memory", "get_all_memories")]
|
|
|
|
# Expose memory tools directly so run_agent_task can call them outside the agent loop
|
|
for t in mcp_tools:
|
|
if t.name == "add_memory":
|
|
_memory_add_tool = t
|
|
elif t.name == "search_memory":
|
|
_memory_search_tool = t
|
|
|
|
searx = SearxSearchWrapper(searx_host=SEARXNG_URL)
|
|
|
|
def _crawl4ai_fetch(url: str) -> str:
|
|
"""Fetch a URL via Crawl4AI (JS-rendered, bot-bypass) and return clean markdown."""
|
|
try:
|
|
r = _httpx.post(f"{CRAWL4AI_URL}/crawl", json={"urls": [url]}, timeout=60)
|
|
r.raise_for_status()
|
|
results = r.json().get("results", [])
|
|
if not results or not results[0].get("success"):
|
|
return ""
|
|
md_obj = results[0].get("markdown") or {}
|
|
md = md_obj.get("raw_markdown") if isinstance(md_obj, dict) else str(md_obj)
|
|
return (md or "")[:5000]
|
|
except Exception as e:
|
|
return f"[fetch error: {e}]"
|
|
|
|
def _search_and_read(query: str) -> str:
|
|
"""Search the web and automatically fetch full content of top results.
|
|
Returns snippets + full page content from the top URLs."""
|
|
import json as _json
|
|
# Get structured results from SearXNG
|
|
try:
|
|
r = _httpx.get(
|
|
f"{SEARXNG_URL}/search",
|
|
params={"q": query, "format": "json"},
|
|
timeout=15,
|
|
)
|
|
data = r.json()
|
|
items = data.get("results", [])[:5]
|
|
except Exception as e:
|
|
return f"[search error: {e}]"
|
|
|
|
if not items:
|
|
return "No results found."
|
|
|
|
out = [f"Search: {query}\n"]
|
|
for i, item in enumerate(items, 1):
|
|
url = item.get("url", "")
|
|
title = item.get("title", "")
|
|
snippet = item.get("content", "")[:300]
|
|
out.append(f"\n[{i}] {title}\nURL: {url}\nSnippet: {snippet}")
|
|
|
|
# Auto-fetch top 2 URLs for full content
|
|
out.append("\n\n--- Full page content ---")
|
|
for item in items[:2]:
|
|
url = item.get("url", "")
|
|
if not url:
|
|
continue
|
|
content = _crawl4ai_fetch(url)
|
|
if content and not content.startswith("[fetch error"):
|
|
out.append(f"\n### {url}\n{content[:3000]}")
|
|
|
|
return "\n".join(out)
|
|
|
|
agent_tools.append(Tool(
|
|
name="web_search",
|
|
func=_search_and_read,
|
|
description=(
|
|
"Search the web and read full content of top results. "
|
|
"Returns search snippets AND full page text from the top URLs. "
|
|
"Use multiple different queries to research a topic thoroughly."
|
|
),
|
|
))
|
|
|
|
def _fetch_url(url: str) -> str:
|
|
"""Fetch and read the full text content of a URL."""
|
|
content = _crawl4ai_fetch(url)
|
|
return content if content else "[fetch_url: empty or blocked]"
|
|
|
|
agent_tools.append(Tool(
|
|
name="fetch_url",
|
|
func=_fetch_url,
|
|
description=(
|
|
"Fetch and read the full text content of a specific URL. "
|
|
"Use for URLs not covered by web_search. Input: a single URL string."
|
|
),
|
|
))
|
|
|
|
medium_agent = build_medium_agent(
|
|
model=medium_model,
|
|
agent_tools=agent_tools,
|
|
system_prompt=MEDIUM_SYSTEM_PROMPT,
|
|
)
|
|
complex_agent = build_complex_agent(
|
|
model=complex_model,
|
|
agent_tools=agent_tools,
|
|
system_prompt=COMPLEX_SYSTEM_PROMPT.format(user_id="{user_id}"),
|
|
)
|
|
|
|
print(
|
|
f"[agent] litellm={LITELLM_URL} | router=semantic(ollama/{ROUTER_MODEL}+nomic-embed-text) | "
|
|
f"medium=ollama/{MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
|
|
flush=True,
|
|
)
|
|
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
|
|
|
|
yield
|
|
|
|
medium_model = None
|
|
medium_agent = None
|
|
complex_agent = None
|
|
router = None
|
|
vram_manager = None
|
|
mcp_client = None
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
# ── request models ─────────────────────────────────────────────────────────────
|
|
|
|
class InboundMessage(BaseModel):
|
|
text: str
|
|
session_id: str # e.g. "tg-346967270", "cli-alvis"
|
|
channel: str # "telegram" | "cli"
|
|
user_id: str = "" # human identity; defaults to session_id if empty
|
|
metadata: dict = {}
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
"""Legacy model — kept for test_pipeline.py compatibility."""
|
|
message: str
|
|
chat_id: str
|
|
|
|
|
|
# ── helpers ────────────────────────────────────────────────────────────────────
|
|
|
|
def _strip_think(text: str) -> str:
|
|
"""Strip qwen3 chain-of-thought blocks that appear inline in content
|
|
when using Ollama's OpenAI-compatible endpoint (/v1/chat/completions)."""
|
|
return _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL).strip()
|
|
|
|
|
|
def _extract_final_text(result) -> str | None:
|
|
msgs = result.get("messages", [])
|
|
for m in reversed(msgs):
|
|
if type(m).__name__ == "AIMessage" and getattr(m, "content", ""):
|
|
return _strip_think(m.content)
|
|
if isinstance(result, dict) and result.get("output"):
|
|
return _strip_think(result["output"])
|
|
return None
|
|
|
|
|
|
def _log_messages(result):
|
|
msgs = result.get("messages", [])
|
|
for m in msgs:
|
|
role = type(m).__name__
|
|
content = getattr(m, "content", "")
|
|
tool_calls = getattr(m, "tool_calls", [])
|
|
if content:
|
|
print(f"[agent] {role}: {str(content)[:150]}", flush=True)
|
|
for tc in tool_calls:
|
|
print(f"[agent] {role} → {tc['name']}({tc['args']})", flush=True)
|
|
|
|
|
|
# ── memory helpers ─────────────────────────────────────────────────────────────
|
|
|
|
def _resolve_user_id(session_id: str) -> str:
|
|
"""Map any session_id to a canonical user identity for openmemory.
|
|
All channels share the same memory pool for the single user."""
|
|
return "alvis"
|
|
|
|
|
|
async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
|
|
"""Store a conversation turn in openmemory (runs as a background task)."""
|
|
if _memory_add_tool is None:
|
|
return
|
|
t0 = time.monotonic()
|
|
try:
|
|
text = f"User: {user_msg}\nAssistant: {assistant_reply}"
|
|
user_id = _resolve_user_id(session_id)
|
|
await _memory_add_tool.ainvoke({"text": text, "user_id": user_id})
|
|
print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
|
|
except Exception as e:
|
|
print(f"[memory] error: {e}", flush=True)
|
|
|
|
|
|
async def _retrieve_memories(message: str, session_id: str) -> str:
|
|
"""Search openmemory for relevant context. Returns formatted string or ''."""
|
|
if _memory_search_tool is None:
|
|
return ""
|
|
try:
|
|
user_id = _resolve_user_id(session_id)
|
|
result = await _memory_search_tool.ainvoke({"query": message, "user_id": user_id})
|
|
if result and result.strip() and result.strip() != "[]":
|
|
return f"Relevant memories:\n{result}"
|
|
except Exception:
|
|
pass
|
|
return ""
|
|
|
|
|
|
# ── core pipeline ──────────────────────────────────────────────────────────────
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
async def _run_agent_pipeline(
|
|
message: str,
|
|
history: list[dict],
|
|
session_id: str,
|
|
tier_override: str | None = None,
|
|
dry_run: bool = False,
|
|
tier_capture: list | None = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Core pipeline: pre-flight → routing → inference. Yields text chunks.
|
|
|
|
tier_override: "light" | "medium" | "complex" | None (auto-route)
|
|
dry_run: if True and tier=complex, log tier=complex but use medium model (avoids API cost)
|
|
Caller is responsible for scheduling _store_memory after consuming all chunks.
|
|
"""
|
|
async with _reply_semaphore:
|
|
t0 = time.monotonic()
|
|
clean_message = message
|
|
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
|
|
|
|
# Fetch URL content, memories, and fast-tool context concurrently
|
|
url_context, memories, fast_context = await asyncio.gather(
|
|
_fetch_urls_from_message(clean_message),
|
|
_retrieve_memories(clean_message, session_id),
|
|
_fast_tool_runner.run_matching(clean_message),
|
|
)
|
|
if url_context:
|
|
print(f"[agent] crawl4ai: {len(url_context)} chars fetched", flush=True)
|
|
if fast_context:
|
|
names = _fast_tool_runner.matching_names(clean_message)
|
|
print(f"[agent] fast_tools={names}: {len(fast_context)} chars injected", flush=True)
|
|
|
|
# Build enriched history
|
|
enriched_history = list(history)
|
|
if url_context:
|
|
enriched_history = [{"role": "system", "content": url_context}] + enriched_history
|
|
if fast_context:
|
|
enriched_history = [{"role": "system", "content": fast_context}] + enriched_history
|
|
if memories:
|
|
enriched_history = [{"role": "system", "content": memories}] + enriched_history
|
|
|
|
final_text = None
|
|
llm_elapsed = 0.0
|
|
|
|
try:
|
|
# Short-circuit: fast tool already has the answer
|
|
if fast_context and tier_override is None and not url_context:
|
|
tier = "fast"
|
|
final_text = fast_context
|
|
llm_elapsed = time.monotonic() - t0
|
|
names = _fast_tool_runner.matching_names(clean_message)
|
|
print(f"[agent] tier=fast tools={names} — delivering directly", flush=True)
|
|
yield final_text
|
|
|
|
else:
|
|
# Determine tier
|
|
if tier_override in ("light", "medium", "complex"):
|
|
tier = tier_override
|
|
light_reply = None
|
|
if tier_override == "light":
|
|
tier, light_reply = await router.route(clean_message, enriched_history)
|
|
tier = "light"
|
|
else:
|
|
tier, light_reply = await router.route(clean_message, enriched_history)
|
|
if url_context and tier == "light":
|
|
tier = "medium"
|
|
light_reply = None
|
|
print("[agent] URL in message → upgraded light→medium", flush=True)
|
|
|
|
# Dry-run: log as complex but infer with medium (no remote API call)
|
|
effective_tier = tier
|
|
if dry_run and tier == "complex":
|
|
effective_tier = "medium"
|
|
print(f"[agent] tier=complex (dry-run) → using medium model, message={clean_message[:60]!r}", flush=True)
|
|
else:
|
|
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
|
|
tier = effective_tier
|
|
if tier_capture is not None:
|
|
tier_capture.append(tier)
|
|
|
|
if tier == "light":
|
|
final_text = light_reply
|
|
llm_elapsed = time.monotonic() - t0
|
|
print("[agent] light path: answered by router", flush=True)
|
|
yield final_text
|
|
|
|
elif tier == "medium":
|
|
system_prompt = MEDIUM_SYSTEM_PROMPT
|
|
if memories:
|
|
system_prompt += "\n\n" + memories
|
|
if url_context:
|
|
system_prompt += "\n\n" + url_context
|
|
if fast_context:
|
|
system_prompt += "\n\nLive web search results (use these to answer):\n\n" + fast_context
|
|
|
|
in_think = False
|
|
response_parts = []
|
|
async for chunk in medium_model.astream([
|
|
{"role": "system", "content": system_prompt},
|
|
*history,
|
|
{"role": "user", "content": clean_message},
|
|
]):
|
|
token = chunk.content or ""
|
|
if not token:
|
|
continue
|
|
if in_think:
|
|
if "</think>" in token:
|
|
in_think = False
|
|
after = token.split("</think>", 1)[1]
|
|
if after:
|
|
yield after
|
|
response_parts.append(after)
|
|
else:
|
|
if "<think>" in token:
|
|
in_think = True
|
|
before = token.split("<think>", 1)[0]
|
|
if before:
|
|
yield before
|
|
response_parts.append(before)
|
|
else:
|
|
yield token
|
|
response_parts.append(token)
|
|
|
|
llm_elapsed = time.monotonic() - t0
|
|
final_text = "".join(response_parts).strip() or None
|
|
|
|
else: # complex — remote model, no VRAM management needed
|
|
system_prompt = COMPLEX_SYSTEM_PROMPT.format(user_id=session_id)
|
|
if url_context:
|
|
system_prompt += "\n\n[Pre-fetched URL content from user's message:]\n" + url_context
|
|
result = await complex_agent.ainvoke({
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
*history,
|
|
{"role": "user", "content": clean_message},
|
|
]
|
|
})
|
|
|
|
llm_elapsed = time.monotonic() - t0
|
|
_log_messages(result)
|
|
final_text = _extract_final_text(result)
|
|
if final_text:
|
|
yield final_text
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
llm_elapsed = time.monotonic() - t0
|
|
print(f"[agent] error after {llm_elapsed:.1f}s for {session_id}: {e}", flush=True)
|
|
traceback.print_exc()
|
|
|
|
print(f"[agent] pipeline done in {time.monotonic() - t0:.1f}s tier={tier if 'tier' in dir() else '?'}", flush=True)
|
|
|
|
# Store memory as side-effect (non-blocking, best-effort)
|
|
if final_text:
|
|
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
|
|
|
|
|
|
# ── core task (Telegram / Matrix / CLI wrapper) ─────────────────────────────────
|
|
|
|
async def run_agent_task(
|
|
message: str,
|
|
session_id: str,
|
|
channel: str = "telegram",
|
|
metadata: dict | None = None,
|
|
):
|
|
print(f"[agent] queued: {message[:80]!r} chat={session_id}", flush=True)
|
|
t0 = time.monotonic()
|
|
|
|
meta = metadata or {}
|
|
dry_run = bool(meta.get("dry_run", False))
|
|
is_benchmark = bool(meta.get("benchmark", False))
|
|
|
|
history = _conversation_buffers.get(session_id, [])
|
|
final_text = None
|
|
actual_tier = "unknown"
|
|
tier_capture: list = []
|
|
|
|
async for chunk in _run_agent_pipeline(message, history, session_id, dry_run=dry_run, tier_capture=tier_capture):
|
|
await _push_stream_chunk(session_id, chunk)
|
|
if final_text is None:
|
|
final_text = chunk
|
|
else:
|
|
final_text += chunk
|
|
|
|
await _end_stream(session_id)
|
|
actual_tier = tier_capture[0] if tier_capture else "unknown"
|
|
|
|
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
|
|
|
if final_text:
|
|
final_text = final_text.strip()
|
|
|
|
# Skip channel delivery for benchmark sessions (no Telegram spam)
|
|
if not is_benchmark:
|
|
try:
|
|
await channels.deliver(session_id, channel, final_text)
|
|
except Exception as e:
|
|
print(f"[agent] delivery error (non-fatal): {e}", flush=True)
|
|
|
|
print(f"[agent] replied in {elapsed_ms / 1000:.1f}s tier={actual_tier}", flush=True)
|
|
print(f"[agent] reply_text: {final_text}", flush=True)
|
|
|
|
# Update conversation buffer
|
|
buf = _conversation_buffers.get(session_id, [])
|
|
buf.append({"role": "user", "content": message})
|
|
buf.append({"role": "assistant", "content": final_text})
|
|
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
|
|
|
|
# Log interaction for RLHF data collection (skip benchmark sessions to avoid noise)
|
|
if not is_benchmark:
|
|
asyncio.create_task(_log_interaction(
|
|
session_id=session_id,
|
|
channel=channel,
|
|
tier=actual_tier,
|
|
input_text=message,
|
|
response_text=final_text,
|
|
latency_ms=elapsed_ms,
|
|
metadata=meta if meta else None,
|
|
))
|
|
else:
|
|
print("[agent] warning: no text reply from agent", flush=True)
|
|
|
|
|
|
# ── endpoints ──────────────────────────────────────────────────────────────────
|
|
|
|
@app.post("/message")
|
|
async def message(request: InboundMessage, background_tasks: BackgroundTasks):
|
|
"""Unified inbound endpoint for all channels."""
|
|
if medium_agent is None:
|
|
return JSONResponse(status_code=503, content={"error": "Agent not ready"})
|
|
session_id = request.session_id
|
|
channel = request.channel
|
|
background_tasks.add_task(run_agent_task, request.text, session_id, channel, request.metadata)
|
|
return JSONResponse(status_code=202, content={"status": "accepted"})
|
|
|
|
|
|
@app.post("/chat")
|
|
async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
|
|
"""Legacy endpoint — maps chat_id to tg-<chat_id> session for backward compatibility."""
|
|
if medium_agent is None:
|
|
return JSONResponse(status_code=503, content={"error": "Agent not ready"})
|
|
session_id = f"tg-{request.chat_id}"
|
|
background_tasks.add_task(run_agent_task, request.message, session_id, "telegram")
|
|
return JSONResponse(status_code=202, content={"status": "accepted"})
|
|
|
|
|
|
@app.get("/reply/{session_id}")
|
|
async def reply_stream(session_id: str):
|
|
"""
|
|
SSE endpoint — streams the reply for a session once available, then closes.
|
|
Used by CLI client and wiki_research.py instead of log polling.
|
|
"""
|
|
q = channels.pending_replies.setdefault(session_id, asyncio.Queue())
|
|
|
|
async def event_generator():
|
|
try:
|
|
text = await asyncio.wait_for(q.get(), timeout=900)
|
|
# Escape newlines so entire reply fits in one SSE data line
|
|
yield f"data: {text.replace(chr(10), chr(92) + 'n').replace(chr(13), '')}\n\n"
|
|
except asyncio.TimeoutError:
|
|
yield "data: [timeout]\n\n"
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
|
|
|
@app.get("/stream/{session_id}")
|
|
async def stream_reply(session_id: str):
|
|
"""
|
|
SSE endpoint — streams reply tokens as they are generated.
|
|
Each chunk: data: <token>\\n\\n
|
|
Signals completion: data: [DONE]\\n\\n
|
|
|
|
Medium tier: real token-by-token streaming (think blocks filtered out).
|
|
Light and complex tiers: full reply delivered as one chunk then [DONE].
|
|
"""
|
|
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
|
|
|
async def event_generator():
|
|
try:
|
|
while True:
|
|
chunk = await asyncio.wait_for(q.get(), timeout=900)
|
|
escaped = chunk.replace("\n", "\\n").replace("\r", "")
|
|
yield f"data: {escaped}\n\n"
|
|
if chunk == "[DONE]":
|
|
break
|
|
except asyncio.TimeoutError:
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "agent_ready": medium_agent is not None}
|
|
|
|
|
|
# ── OpenAI-compatible API (for OpenWebUI and other clients) ────────────────────
|
|
|
|
_TIER_MAP = {
|
|
"adolf": None,
|
|
"adolf-light": "light",
|
|
"adolf-medium": "medium",
|
|
"adolf-deep": "complex",
|
|
}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{"id": "adolf", "object": "model", "owned_by": "adolf"},
|
|
{"id": "adolf-light", "object": "model", "owned_by": "adolf"},
|
|
{"id": "adolf-medium", "object": "model", "owned_by": "adolf"},
|
|
{"id": "adolf-deep", "object": "model", "owned_by": "adolf"},
|
|
],
|
|
}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request):
|
|
if medium_agent is None:
|
|
return JSONResponse(status_code=503, content={"error": {"message": "Agent not ready", "type": "server_error"}})
|
|
|
|
body = await request.json()
|
|
model = body.get("model", "adolf")
|
|
messages = body.get("messages", [])
|
|
stream = body.get("stream", True)
|
|
|
|
# Extract current user message and history
|
|
user_messages = [m for m in messages if m.get("role") == "user"]
|
|
if not user_messages:
|
|
return JSONResponse(status_code=400, content={"error": {"message": "No user message", "type": "invalid_request_error"}})
|
|
|
|
current_message = user_messages[-1]["content"]
|
|
# History = everything before the last user message (excluding system messages from OpenWebUI)
|
|
last_user_idx = len(messages) - 1 - next(
|
|
i for i, m in enumerate(reversed(messages)) if m.get("role") == "user"
|
|
)
|
|
history = [m for m in messages[:last_user_idx] if m.get("role") in ("user", "assistant")]
|
|
|
|
session_id = request.headers.get("X-Session-Id", "owui-default")
|
|
tier_override = _TIER_MAP.get(model)
|
|
|
|
import json as _json
|
|
import uuid as _uuid
|
|
|
|
response_id = f"chatcmpl-{_uuid.uuid4().hex[:12]}"
|
|
|
|
if stream:
|
|
async def event_stream():
|
|
# Opening chunk with role
|
|
opening = {
|
|
"id": response_id, "object": "chat.completion.chunk",
|
|
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]
|
|
}
|
|
yield f"data: {_json.dumps(opening)}\n\n"
|
|
|
|
async for chunk in _run_agent_pipeline(current_message, history, session_id, tier_override):
|
|
data = {
|
|
"id": response_id, "object": "chat.completion.chunk",
|
|
"choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}]
|
|
}
|
|
yield f"data: {_json.dumps(data)}\n\n"
|
|
|
|
# Final chunk
|
|
final = {
|
|
"id": response_id, "object": "chat.completion.chunk",
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
|
|
}
|
|
yield f"data: {_json.dumps(final)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
else:
|
|
# Non-streaming: collect all chunks
|
|
parts = []
|
|
async for chunk in _run_agent_pipeline(current_message, history, session_id, tier_override):
|
|
if chunk:
|
|
parts.append(chunk)
|
|
full_text = "".join(parts).strip()
|
|
return {
|
|
"id": response_id, "object": "chat.completion",
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": "stop"}],
|
|
"model": model,
|
|
}
|