Switch from Bifrost to LiteLLM; add Matrix channel; update rules
Infrastructure: - docker-compose.yml: replace bifrost container with LiteLLM proxy (host.docker.internal:4000); complex model → deepseek-r1:free via OpenRouter; add Matrix URL env var; mount logs volume - bifrost-config.json: add auth_config + postgres config_store (archived) Routing: - router.py: full semantic 3-tier classifier rewrite — nomic-embed-text centroids for light/medium/complex; regex pre-classifiers for all tiers; Russian utterance sets expanded - agent.py: wire LiteLLM URL; add dry_run support; add Matrix channel Channels: - channels.py: add Matrix adapter (_matrix_send via mx- session prefix) Rules / docs: - agent-pipeline.md: remove /think prefix requirement; document automatic complex tier classification - llm-inference.md: update BIFROST_URL → LITELLM_URL references; add remote model note for complex tier - ARCHITECTURE.md: deleted (superseded by README.md) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
421
agent.py
421
agent.py
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import json as _json_module
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -16,6 +18,7 @@ _URL_RE = _re.compile(r'https?://[^\s<>"\']+')
|
||||
def _extract_urls(text: str) -> list[str]:
|
||||
return _URL_RE.findall(text)
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_community.utilities import SearxSearchWrapper
|
||||
@@ -27,8 +30,9 @@ from agent_factory import build_medium_agent, build_complex_agent
|
||||
from fast_tools import FastToolRunner, WeatherTool, CommuteTool
|
||||
import channels
|
||||
|
||||
# Bifrost gateway — all LLM inference goes through here
|
||||
BIFROST_URL = os.getenv("BIFROST_URL", "http://bifrost:8080/v1")
|
||||
# LiteLLM proxy — all LLM inference goes through here
|
||||
LITELLM_URL = os.getenv("LITELLM_URL", "http://host.docker.internal:4000/v1")
|
||||
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY", "dummy")
|
||||
# Direct Ollama URL — used only by VRAMManager for flush/prewarm/poll
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
@@ -44,6 +48,45 @@ ROUTECHECK_TOKEN = os.getenv("ROUTECHECK_TOKEN", "")
|
||||
MAX_HISTORY_TURNS = 5
|
||||
_conversation_buffers: dict[str, list] = {}
|
||||
|
||||
# ── Interaction logging (RLHF data collection) ─────────────────────────────────
|
||||
_LOG_DIR = Path(os.getenv("ADOLF_LOG_DIR", "/app/logs"))
|
||||
_INTERACTIONS_LOG = _LOG_DIR / "interactions.jsonl"
|
||||
|
||||
def _ensure_log_dir() -> None:
|
||||
try:
|
||||
_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"[log] cannot create log dir {_LOG_DIR}: {e}", flush=True)
|
||||
|
||||
|
||||
async def _log_interaction(
|
||||
session_id: str,
|
||||
channel: str,
|
||||
tier: str,
|
||||
input_text: str,
|
||||
response_text: str | None,
|
||||
latency_ms: int,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Append one interaction record to the JSONL log for future RLHF/finetuning."""
|
||||
record = {
|
||||
"ts": time.time(),
|
||||
"session_id": session_id,
|
||||
"channel": channel,
|
||||
"tier": tier,
|
||||
"input": input_text,
|
||||
"output": response_text or "",
|
||||
"latency_ms": latency_ms,
|
||||
}
|
||||
if metadata:
|
||||
record["metadata"] = metadata
|
||||
try:
|
||||
_ensure_log_dir()
|
||||
with open(_INTERACTIONS_LOG, "a", encoding="utf-8") as f:
|
||||
f.write(_json_module.dumps(record, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[log] write error: {e}", flush=True)
|
||||
|
||||
# Per-session streaming queues — filled during inference, read by /stream/{session_id}
|
||||
_stream_queues: dict[str, asyncio.Queue] = {}
|
||||
|
||||
@@ -140,31 +183,30 @@ async def lifespan(app: FastAPI):
|
||||
channels.register_defaults()
|
||||
|
||||
# All three models route through Bifrost → Ollama GPU.
|
||||
# Bifrost adds retry logic, observability, and failover.
|
||||
# Model names use provider/model format: Bifrost strips the "ollama/" prefix
|
||||
# before forwarding to Ollama's /v1/chat/completions endpoint.
|
||||
router_model = ChatOpenAI(
|
||||
model=f"ollama/{ROUTER_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
base_url=LITELLM_URL,
|
||||
api_key=LITELLM_API_KEY,
|
||||
temperature=0,
|
||||
timeout=30,
|
||||
)
|
||||
embedder = AsyncOpenAI(base_url=LITELLM_URL, api_key=LITELLM_API_KEY)
|
||||
medium_model = ChatOpenAI(
|
||||
model=f"ollama/{MEDIUM_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
base_url=LITELLM_URL,
|
||||
api_key=LITELLM_API_KEY,
|
||||
timeout=180,
|
||||
)
|
||||
complex_model = ChatOpenAI(
|
||||
model=f"ollama/{COMPLEX_MODEL}",
|
||||
base_url=BIFROST_URL,
|
||||
api_key="dummy",
|
||||
model=COMPLEX_MODEL, # full model name — may be remote (OpenRouter) or local ollama/*
|
||||
base_url=LITELLM_URL,
|
||||
api_key=LITELLM_API_KEY,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL)
|
||||
router = Router(model=router_model, fast_tool_runner=_fast_tool_runner)
|
||||
router = Router(model=router_model, embedder=embedder, fast_tool_runner=_fast_tool_runner)
|
||||
await router.initialize()
|
||||
|
||||
mcp_connections = {
|
||||
"openmemory": {"transport": "sse", "url": f"{OPENMEMORY_URL}/sse"},
|
||||
@@ -279,8 +321,8 @@ async def lifespan(app: FastAPI):
|
||||
)
|
||||
|
||||
print(
|
||||
f"[agent] bifrost={BIFROST_URL} | router=ollama/{ROUTER_MODEL} | "
|
||||
f"medium=ollama/{MEDIUM_MODEL} | complex=ollama/{COMPLEX_MODEL}",
|
||||
f"[agent] litellm={LITELLM_URL} | router=semantic(ollama/{ROUTER_MODEL}+nomic-embed-text) | "
|
||||
f"medium=ollama/{MEDIUM_MODEL} | complex={COMPLEX_MODEL}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[agent] agent tools: {[t.name for t in agent_tools]}", flush=True)
|
||||
@@ -346,6 +388,12 @@ def _log_messages(result):
|
||||
|
||||
# ── memory helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _resolve_user_id(session_id: str) -> str:
|
||||
"""Map any session_id to a canonical user identity for openmemory.
|
||||
All channels share the same memory pool for the single user."""
|
||||
return "alvis"
|
||||
|
||||
|
||||
async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) -> None:
|
||||
"""Store a conversation turn in openmemory (runs as a background task)."""
|
||||
if _memory_add_tool is None:
|
||||
@@ -353,7 +401,8 @@ async def _store_memory(session_id: str, user_msg: str, assistant_reply: str) ->
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
text = f"User: {user_msg}\nAssistant: {assistant_reply}"
|
||||
await _memory_add_tool.ainvoke({"text": text, "user_id": session_id})
|
||||
user_id = _resolve_user_id(session_id)
|
||||
await _memory_add_tool.ainvoke({"text": text, "user_id": user_id})
|
||||
print(f"[memory] stored in {time.monotonic() - t0:.1f}s", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[memory] error: {e}", flush=True)
|
||||
@@ -364,7 +413,8 @@ async def _retrieve_memories(message: str, session_id: str) -> str:
|
||||
if _memory_search_tool is None:
|
||||
return ""
|
||||
try:
|
||||
result = await _memory_search_tool.ainvoke({"query": message, "user_id": session_id})
|
||||
user_id = _resolve_user_id(session_id)
|
||||
result = await _memory_search_tool.ainvoke({"query": message, "user_id": user_id})
|
||||
if result and result.strip() and result.strip() != "[]":
|
||||
return f"Relevant memories:\n{result}"
|
||||
except Exception:
|
||||
@@ -372,36 +422,41 @@ async def _retrieve_memories(message: str, session_id: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
# ── core task ──────────────────────────────────────────────────────────────────
|
||||
# ── core pipeline ──────────────────────────────────────────────────────────────
|
||||
|
||||
async def run_agent_task(message: str, session_id: str, channel: str = "telegram"):
|
||||
print(f"[agent] queued: {message[:80]!r} chat={session_id}", flush=True)
|
||||
from typing import AsyncGenerator
|
||||
|
||||
force_complex = False
|
||||
clean_message = message
|
||||
if message.startswith("/think "):
|
||||
force_complex = True
|
||||
clean_message = message[len("/think "):]
|
||||
print("[agent] /think prefix → force_complex=True", flush=True)
|
||||
async def _run_agent_pipeline(
|
||||
message: str,
|
||||
history: list[dict],
|
||||
session_id: str,
|
||||
tier_override: str | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Core pipeline: pre-flight → routing → inference. Yields text chunks.
|
||||
|
||||
tier_override: "light" | "medium" | "complex" | None (auto-route)
|
||||
dry_run: if True and tier=complex, log tier=complex but use medium model (avoids API cost)
|
||||
Caller is responsible for scheduling _store_memory after consuming all chunks.
|
||||
"""
|
||||
async with _reply_semaphore:
|
||||
t0 = time.monotonic()
|
||||
history = _conversation_buffers.get(session_id, [])
|
||||
clean_message = message
|
||||
print(f"[agent] running: {clean_message[:80]!r}", flush=True)
|
||||
|
||||
# Fetch URL content, memories, and fast-tool context concurrently — all IO-bound
|
||||
# Fetch URL content, memories, and fast-tool context concurrently
|
||||
url_context, memories, fast_context = await asyncio.gather(
|
||||
_fetch_urls_from_message(clean_message),
|
||||
_retrieve_memories(clean_message, session_id),
|
||||
_fast_tool_runner.run_matching(clean_message),
|
||||
)
|
||||
if url_context:
|
||||
print(f"[agent] crawl4ai: {len(url_context)} chars fetched from message URLs", flush=True)
|
||||
print(f"[agent] crawl4ai: {len(url_context)} chars fetched", flush=True)
|
||||
if fast_context:
|
||||
names = _fast_tool_runner.matching_names(clean_message)
|
||||
print(f"[agent] fast_tools={names}: {len(fast_context)} chars injected", flush=True)
|
||||
|
||||
# Build enriched history: memories + url_context + fast_context for ALL tiers
|
||||
# Build enriched history
|
||||
enriched_history = list(history)
|
||||
if url_context:
|
||||
enriched_history = [{"role": "system", "content": url_context}] + enriched_history
|
||||
@@ -410,45 +465,58 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
if memories:
|
||||
enriched_history = [{"role": "system", "content": memories}] + enriched_history
|
||||
|
||||
# Short-circuit: fast tool result is already a complete reply — skip router+LLM
|
||||
if fast_context and not force_complex and not url_context:
|
||||
tier = "fast"
|
||||
final_text = fast_context
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
names = _fast_tool_runner.matching_names(clean_message)
|
||||
print(f"[agent] tier=fast tools={names} — delivering directly", flush=True)
|
||||
await _push_stream_chunk(session_id, final_text)
|
||||
await _end_stream(session_id)
|
||||
else:
|
||||
tier, light_reply = await router.route(clean_message, enriched_history, force_complex)
|
||||
final_text = None
|
||||
llm_elapsed = 0.0
|
||||
|
||||
# Messages with URL content must be handled by at least medium tier
|
||||
if url_context and tier == "light":
|
||||
tier = "medium"
|
||||
light_reply = None
|
||||
print("[agent] URL in message → upgraded light→medium", flush=True)
|
||||
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
|
||||
try:
|
||||
# Short-circuit: fast tool already has the answer
|
||||
if fast_context and tier_override is None and not url_context:
|
||||
tier = "fast"
|
||||
final_text = fast_context
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
names = _fast_tool_runner.matching_names(clean_message)
|
||||
print(f"[agent] tier=fast tools={names} — delivering directly", flush=True)
|
||||
yield final_text
|
||||
|
||||
else:
|
||||
# Determine tier
|
||||
if tier_override in ("light", "medium", "complex"):
|
||||
tier = tier_override
|
||||
light_reply = None
|
||||
if tier_override == "light":
|
||||
tier, light_reply = await router.route(clean_message, enriched_history)
|
||||
tier = "light"
|
||||
else:
|
||||
tier, light_reply = await router.route(clean_message, enriched_history)
|
||||
if url_context and tier == "light":
|
||||
tier = "medium"
|
||||
light_reply = None
|
||||
print("[agent] URL in message → upgraded light→medium", flush=True)
|
||||
|
||||
# Dry-run: log as complex but infer with medium (no remote API call)
|
||||
effective_tier = tier
|
||||
if dry_run and tier == "complex":
|
||||
effective_tier = "medium"
|
||||
print(f"[agent] tier=complex (dry-run) → using medium model, message={clean_message[:60]!r}", flush=True)
|
||||
else:
|
||||
print(f"[agent] tier={tier} message={clean_message[:60]!r}", flush=True)
|
||||
tier = effective_tier
|
||||
|
||||
if tier != "fast":
|
||||
final_text = None
|
||||
try:
|
||||
if tier == "light":
|
||||
final_text = light_reply
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
print(f"[agent] light path: answered by router", flush=True)
|
||||
await _push_stream_chunk(session_id, final_text)
|
||||
await _end_stream(session_id)
|
||||
print("[agent] light path: answered by router", flush=True)
|
||||
yield final_text
|
||||
|
||||
elif tier == "medium":
|
||||
system_prompt = MEDIUM_SYSTEM_PROMPT
|
||||
if memories:
|
||||
system_prompt = system_prompt + "\n\n" + memories
|
||||
system_prompt += "\n\n" + memories
|
||||
if url_context:
|
||||
system_prompt = system_prompt + "\n\n" + url_context
|
||||
system_prompt += "\n\n" + url_context
|
||||
if fast_context:
|
||||
system_prompt = system_prompt + "\n\nLive web search results (use these to answer):\n\n" + fast_context
|
||||
system_prompt += "\n\nLive web search results (use these to answer):\n\n" + fast_context
|
||||
|
||||
# Stream tokens directly — filter out qwen3 <think> blocks
|
||||
in_think = False
|
||||
response_parts = []
|
||||
async for chunk in medium_model.astream([
|
||||
@@ -464,91 +532,117 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
in_think = False
|
||||
after = token.split("</think>", 1)[1]
|
||||
if after:
|
||||
await _push_stream_chunk(session_id, after)
|
||||
yield after
|
||||
response_parts.append(after)
|
||||
else:
|
||||
if "<think>" in token:
|
||||
in_think = True
|
||||
before = token.split("<think>", 1)[0]
|
||||
if before:
|
||||
await _push_stream_chunk(session_id, before)
|
||||
yield before
|
||||
response_parts.append(before)
|
||||
else:
|
||||
await _push_stream_chunk(session_id, token)
|
||||
yield token
|
||||
response_parts.append(token)
|
||||
|
||||
await _end_stream(session_id)
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
final_text = "".join(response_parts).strip() or None
|
||||
|
||||
else: # complex
|
||||
ok = await vram_manager.enter_complex_mode()
|
||||
if not ok:
|
||||
print("[agent] complex→medium fallback (eviction timeout)", flush=True)
|
||||
tier = "medium"
|
||||
system_prompt = MEDIUM_SYSTEM_PROMPT
|
||||
if memories:
|
||||
system_prompt = system_prompt + "\n\n" + memories
|
||||
if url_context:
|
||||
system_prompt = system_prompt + "\n\n" + url_context
|
||||
result = await medium_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]
|
||||
})
|
||||
else:
|
||||
system_prompt = COMPLEX_SYSTEM_PROMPT.format(user_id=session_id)
|
||||
if url_context:
|
||||
system_prompt = system_prompt + "\n\n[Pre-fetched URL content from user's message:]\n" + url_context
|
||||
result = await complex_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]
|
||||
})
|
||||
asyncio.create_task(vram_manager.exit_complex_mode())
|
||||
else: # complex — remote model, no VRAM management needed
|
||||
system_prompt = COMPLEX_SYSTEM_PROMPT.format(user_id=session_id)
|
||||
if url_context:
|
||||
system_prompt += "\n\n[Pre-fetched URL content from user's message:]\n" + url_context
|
||||
result = await complex_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]
|
||||
})
|
||||
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
_log_messages(result)
|
||||
final_text = _extract_final_text(result)
|
||||
if final_text:
|
||||
await _push_stream_chunk(session_id, final_text)
|
||||
await _end_stream(session_id)
|
||||
yield final_text
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
print(f"[agent] error after {llm_elapsed:.1f}s for chat {session_id}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
await _end_stream(session_id)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
print(f"[agent] error after {llm_elapsed:.1f}s for {session_id}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
|
||||
# Deliver reply through the originating channel
|
||||
print(f"[agent] pipeline done in {time.monotonic() - t0:.1f}s tier={tier if 'tier' in dir() else '?'}", flush=True)
|
||||
|
||||
# Store memory as side-effect (non-blocking, best-effort)
|
||||
if final_text:
|
||||
t1 = time.monotonic()
|
||||
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
|
||||
|
||||
|
||||
# ── core task (Telegram / Matrix / CLI wrapper) ─────────────────────────────────
|
||||
|
||||
async def run_agent_task(
|
||||
message: str,
|
||||
session_id: str,
|
||||
channel: str = "telegram",
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
print(f"[agent] queued: {message[:80]!r} chat={session_id}", flush=True)
|
||||
t0 = time.monotonic()
|
||||
|
||||
meta = metadata or {}
|
||||
dry_run = bool(meta.get("dry_run", False))
|
||||
is_benchmark = bool(meta.get("benchmark", False))
|
||||
|
||||
history = _conversation_buffers.get(session_id, [])
|
||||
final_text = None
|
||||
actual_tier = "unknown"
|
||||
|
||||
# Patch pipeline to capture tier for logging
|
||||
# We read it from logs post-hoc; capture via a wrapper
|
||||
async for chunk in _run_agent_pipeline(message, history, session_id, dry_run=dry_run):
|
||||
await _push_stream_chunk(session_id, chunk)
|
||||
if final_text is None:
|
||||
final_text = chunk
|
||||
else:
|
||||
final_text += chunk
|
||||
|
||||
await _end_stream(session_id)
|
||||
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
if final_text:
|
||||
final_text = final_text.strip()
|
||||
|
||||
# Skip channel delivery for benchmark sessions (no Telegram spam)
|
||||
if not is_benchmark:
|
||||
try:
|
||||
await channels.deliver(session_id, channel, final_text)
|
||||
except Exception as e:
|
||||
print(f"[agent] delivery error (non-fatal): {e}", flush=True)
|
||||
send_elapsed = time.monotonic() - t1
|
||||
print(
|
||||
f"[agent] replied in {time.monotonic() - t0:.1f}s "
|
||||
f"(llm={llm_elapsed:.1f}s, send={send_elapsed:.1f}s) tier={tier}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[agent] reply_text: {final_text}", flush=True)
|
||||
else:
|
||||
print("[agent] warning: no text reply from agent", flush=True)
|
||||
|
||||
# Update conversation buffer and schedule memory storage
|
||||
if final_text:
|
||||
buf = _conversation_buffers.get(session_id, [])
|
||||
buf.append({"role": "user", "content": clean_message})
|
||||
buf.append({"role": "assistant", "content": final_text})
|
||||
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
|
||||
asyncio.create_task(_store_memory(session_id, clean_message, final_text))
|
||||
print(f"[agent] replied in {elapsed_ms / 1000:.1f}s", flush=True)
|
||||
print(f"[agent] reply_text: {final_text[:200]}", flush=True)
|
||||
|
||||
# Update conversation buffer
|
||||
buf = _conversation_buffers.get(session_id, [])
|
||||
buf.append({"role": "user", "content": message})
|
||||
buf.append({"role": "assistant", "content": final_text})
|
||||
_conversation_buffers[session_id] = buf[-(MAX_HISTORY_TURNS * 2):]
|
||||
|
||||
# Log interaction for RLHF data collection (skip benchmark sessions to avoid noise)
|
||||
if not is_benchmark:
|
||||
asyncio.create_task(_log_interaction(
|
||||
session_id=session_id,
|
||||
channel=channel,
|
||||
tier=actual_tier,
|
||||
input_text=message,
|
||||
response_text=final_text,
|
||||
latency_ms=elapsed_ms,
|
||||
metadata=meta if meta else None,
|
||||
))
|
||||
else:
|
||||
print("[agent] warning: no text reply from agent", flush=True)
|
||||
|
||||
|
||||
# ── endpoints ──────────────────────────────────────────────────────────────────
|
||||
@@ -560,7 +654,7 @@ async def message(request: InboundMessage, background_tasks: BackgroundTasks):
|
||||
return JSONResponse(status_code=503, content={"error": "Agent not ready"})
|
||||
session_id = request.session_id
|
||||
channel = request.channel
|
||||
background_tasks.add_task(run_agent_task, request.text, session_id, channel)
|
||||
background_tasks.add_task(run_agent_task, request.text, session_id, channel, request.metadata)
|
||||
return JSONResponse(status_code=202, content={"status": "accepted"})
|
||||
|
||||
|
||||
@@ -622,3 +716,96 @@ async def stream_reply(session_id: str):
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "agent_ready": medium_agent is not None}
|
||||
|
||||
|
||||
# ── OpenAI-compatible API (for OpenWebUI and other clients) ────────────────────
|
||||
|
||||
_TIER_MAP = {
|
||||
"adolf": None,
|
||||
"adolf-light": "light",
|
||||
"adolf-medium": "medium",
|
||||
"adolf-deep": "complex",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": "adolf", "object": "model", "owned_by": "adolf"},
|
||||
{"id": "adolf-light", "object": "model", "owned_by": "adolf"},
|
||||
{"id": "adolf-medium", "object": "model", "owned_by": "adolf"},
|
||||
{"id": "adolf-deep", "object": "model", "owned_by": "adolf"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
if medium_agent is None:
|
||||
return JSONResponse(status_code=503, content={"error": {"message": "Agent not ready", "type": "server_error"}})
|
||||
|
||||
body = await request.json()
|
||||
model = body.get("model", "adolf")
|
||||
messages = body.get("messages", [])
|
||||
stream = body.get("stream", True)
|
||||
|
||||
# Extract current user message and history
|
||||
user_messages = [m for m in messages if m.get("role") == "user"]
|
||||
if not user_messages:
|
||||
return JSONResponse(status_code=400, content={"error": {"message": "No user message", "type": "invalid_request_error"}})
|
||||
|
||||
current_message = user_messages[-1]["content"]
|
||||
# History = everything before the last user message (excluding system messages from OpenWebUI)
|
||||
last_user_idx = len(messages) - 1 - next(
|
||||
i for i, m in enumerate(reversed(messages)) if m.get("role") == "user"
|
||||
)
|
||||
history = [m for m in messages[:last_user_idx] if m.get("role") in ("user", "assistant")]
|
||||
|
||||
session_id = request.headers.get("X-Session-Id", "owui-default")
|
||||
tier_override = _TIER_MAP.get(model)
|
||||
|
||||
import json as _json
|
||||
import uuid as _uuid
|
||||
|
||||
response_id = f"chatcmpl-{_uuid.uuid4().hex[:12]}"
|
||||
|
||||
if stream:
|
||||
async def event_stream():
|
||||
# Opening chunk with role
|
||||
opening = {
|
||||
"id": response_id, "object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]
|
||||
}
|
||||
yield f"data: {_json.dumps(opening)}\n\n"
|
||||
|
||||
async for chunk in _run_agent_pipeline(current_message, history, session_id, tier_override):
|
||||
data = {
|
||||
"id": response_id, "object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}]
|
||||
}
|
||||
yield f"data: {_json.dumps(data)}\n\n"
|
||||
|
||||
# Final chunk
|
||||
final = {
|
||||
"id": response_id, "object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
|
||||
}
|
||||
yield f"data: {_json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
# Non-streaming: collect all chunks
|
||||
parts = []
|
||||
async for chunk in _run_agent_pipeline(current_message, history, session_id, tier_override):
|
||||
if chunk:
|
||||
parts.append(chunk)
|
||||
full_text = "".join(parts).strip()
|
||||
return {
|
||||
"id": response_id, "object": "chat.completion",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": "stop"}],
|
||||
"model": model,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user