Add Rich token streaming: server SSE + CLI live display + CLI container
Server (agent.py):
- _stream_queues: per-session asyncio.Queue for token chunks
- _push_stream_chunk() / _end_stream() helpers
- Medium tier: astream() with <think> block filtering — real token streaming
- Light tier: full reply pushed as single chunk then [DONE]
- Complex tier: full reply pushed after agent completes then [DONE]
- GET /stream/{session_id} SSE endpoint (data: <chunk>\n\n, data: [DONE]\n\n)
- medium_model promoted to module-level global for astream() access
CLI (cli.py):
- stream_reply(): reads /stream/ SSE, renders tokens live with Rich Live (transient)
- Final reply rendered as Markdown after stream completes
- os.getlogin() replaced with os.getenv("USER") for container compatibility
Dockerfile.cli + docker-compose cli service (profiles: tools):
- Run: docker compose --profile tools run --rm -it cli
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
91
agent.py
91
agent.py
@@ -41,6 +41,19 @@ CRAWL4AI_URL = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
|
||||
MAX_HISTORY_TURNS = 5
|
||||
_conversation_buffers: dict[str, list] = {}
|
||||
|
||||
# Per-session streaming queues — filled during inference, read by /stream/{session_id}
|
||||
_stream_queues: dict[str, asyncio.Queue] = {}
|
||||
|
||||
|
||||
async def _push_stream_chunk(session_id: str, chunk: str) -> None:
|
||||
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
||||
await q.put(chunk)
|
||||
|
||||
|
||||
async def _end_stream(session_id: str) -> None:
|
||||
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
||||
await q.put("[DONE]")
|
||||
|
||||
|
||||
async def _crawl4ai_fetch_async(url: str) -> str:
|
||||
"""Async fetch via Crawl4AI — JS-rendered, bot-bypass, returns clean markdown."""
|
||||
@@ -95,6 +108,7 @@ COMPLEX_SYSTEM_PROMPT = (
|
||||
"NEVER invent URLs. End with: **Sources checked: N**"
|
||||
)
|
||||
|
||||
medium_model = None
|
||||
medium_agent = None
|
||||
complex_agent = None
|
||||
router: Router = None
|
||||
@@ -109,7 +123,7 @@ _reply_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global medium_agent, complex_agent, router, vram_manager, mcp_client, \
|
||||
global medium_model, medium_agent, complex_agent, router, vram_manager, mcp_client, \
|
||||
_memory_add_tool, _memory_search_tool
|
||||
|
||||
# Register channel adapters
|
||||
@@ -263,6 +277,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
yield
|
||||
|
||||
medium_model = None
|
||||
medium_agent = None
|
||||
complex_agent = None
|
||||
router = None
|
||||
@@ -394,6 +409,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
final_text = light_reply
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
print(f"[agent] light path: answered by router", flush=True)
|
||||
await _push_stream_chunk(session_id, final_text)
|
||||
await _end_stream(session_id)
|
||||
|
||||
elif tier == "medium":
|
||||
system_prompt = MEDIUM_SYSTEM_PROMPT
|
||||
@@ -401,16 +418,39 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
system_prompt = system_prompt + "\n\n" + memories
|
||||
if url_context:
|
||||
system_prompt = system_prompt + "\n\n" + url_context
|
||||
result = await medium_agent.ainvoke({
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]
|
||||
})
|
||||
|
||||
# Stream tokens directly — filter out qwen3 <think> blocks
|
||||
in_think = False
|
||||
response_parts = []
|
||||
async for chunk in medium_model.astream([
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{"role": "user", "content": clean_message},
|
||||
]):
|
||||
token = chunk.content or ""
|
||||
if not token:
|
||||
continue
|
||||
if in_think:
|
||||
if "</think>" in token:
|
||||
in_think = False
|
||||
after = token.split("</think>", 1)[1]
|
||||
if after:
|
||||
await _push_stream_chunk(session_id, after)
|
||||
response_parts.append(after)
|
||||
else:
|
||||
if "<think>" in token:
|
||||
in_think = True
|
||||
before = token.split("<think>", 1)[0]
|
||||
if before:
|
||||
await _push_stream_chunk(session_id, before)
|
||||
response_parts.append(before)
|
||||
else:
|
||||
await _push_stream_chunk(session_id, token)
|
||||
response_parts.append(token)
|
||||
|
||||
await _end_stream(session_id)
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
_log_messages(result)
|
||||
final_text = _extract_final_text(result)
|
||||
final_text = "".join(response_parts).strip() or None
|
||||
|
||||
else: # complex
|
||||
ok = await vram_manager.enter_complex_mode()
|
||||
@@ -432,7 +472,6 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
else:
|
||||
system_prompt = COMPLEX_SYSTEM_PROMPT.format(user_id=session_id)
|
||||
if url_context:
|
||||
# Inject pre-fetched content — complex agent can still re-fetch or follow links
|
||||
system_prompt = system_prompt + "\n\n[Pre-fetched URL content from user's message:]\n" + url_context
|
||||
result = await complex_agent.ainvoke({
|
||||
"messages": [
|
||||
@@ -446,12 +485,16 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
_log_messages(result)
|
||||
final_text = _extract_final_text(result)
|
||||
if final_text:
|
||||
await _push_stream_chunk(session_id, final_text)
|
||||
await _end_stream(session_id)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
llm_elapsed = time.monotonic() - t0
|
||||
print(f"[agent] error after {llm_elapsed:.1f}s for chat {session_id}: {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
await _end_stream(session_id)
|
||||
|
||||
# Deliver reply through the originating channel
|
||||
if final_text:
|
||||
@@ -521,6 +564,32 @@ async def reply_stream(session_id: str):
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.get("/stream/{session_id}")
|
||||
async def stream_reply(session_id: str):
|
||||
"""
|
||||
SSE endpoint — streams reply tokens as they are generated.
|
||||
Each chunk: data: <token>\\n\\n
|
||||
Signals completion: data: [DONE]\\n\\n
|
||||
|
||||
Medium tier: real token-by-token streaming (think blocks filtered out).
|
||||
Light and complex tiers: full reply delivered as one chunk then [DONE].
|
||||
"""
|
||||
q = _stream_queues.setdefault(session_id, asyncio.Queue())
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while True:
|
||||
chunk = await asyncio.wait_for(q.get(), timeout=900)
|
||||
escaped = chunk.replace("\n", "\\n").replace("\r", "")
|
||||
yield f"data: {escaped}\n\n"
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "agent_ready": medium_agent is not None}
|
||||
|
||||
Reference in New Issue
Block a user