diff --git a/Dockerfile b/Dockerfile index 22b7a8e..6b8f6e7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,6 @@ WORKDIR /app RUN pip install --no-cache-dir deepagents langchain-openai langgraph \ fastapi uvicorn langchain-mcp-adapters langchain-community httpx -COPY agent.py channels.py vram_manager.py router.py agent_factory.py hello_world.py . +COPY agent.py channels.py vram_manager.py router.py agent_factory.py fast_tools.py hello_world.py . CMD ["uvicorn", "agent:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/agent.py b/agent.py index 896079c..3db9ace 100644 --- a/agent.py +++ b/agent.py @@ -12,16 +12,6 @@ import httpx as _httpx _URL_RE = _re.compile(r'https?://[^\s<>"\']+') -# Queries that need live data — trigger pre-search enrichment for medium tier -_REALTIME_RE = _re.compile( - r"\b(weather|forecast|temperature|rain(ing)?|snow(ing)?|humidity|wind speed" - r"|today.?s news|breaking news|latest news|news today|current events" - r"|bitcoin price|crypto price|stock price|exchange rate" - r"|right now|currently|at the moment|live score|score now|score today" - r"|open now|hours today|is .+ open)\b", - _re.IGNORECASE, -) - def _extract_urls(text: str) -> list[str]: return _URL_RE.findall(text) @@ -34,6 +24,7 @@ from langchain_core.tools import Tool from vram_manager import VRAMManager from router import Router from agent_factory import build_medium_agent, build_complex_agent +from fast_tools import FastToolRunner, RealTimeSearchTool import channels # Bifrost gateway — all LLM inference goes through here @@ -98,29 +89,6 @@ async def _fetch_urls_from_message(message: str) -> str: return "User's message contains URLs. Fetched content:\n\n" + "\n\n".join(parts) -async def _searxng_search_async(query: str) -> str: - """Run a SearXNG search and return top result snippets as text for prompt injection. - Kept short (snippets only) so medium model context stays within streaming timeout.""" - try: - async with _httpx.AsyncClient(timeout=15) as client: - r = await client.get( - f"{SEARXNG_URL}/search", - params={"q": query, "format": "json"}, - ) - r.raise_for_status() - items = r.json().get("results", [])[:4] - except Exception as e: - return f"[search error: {e}]" - if not items: - return "" - lines = [f"Web search results for: {query}\n"] - for i, item in enumerate(items, 1): - title = item.get("title", "") - url = item.get("url", "") - snippet = item.get("content", "")[:400] - lines.append(f"[{i}] {title}\nURL: {url}\n{snippet}\n") - return "\n".join(lines) - # /no_think at the start of the system prompt disables qwen3 chain-of-thought. # create_deep_agent prepends our system_prompt before BASE_AGENT_PROMPT, so @@ -151,6 +119,11 @@ mcp_client = None _memory_add_tool = None _memory_search_tool = None +# Fast tools run before the LLM — classifier + context enricher +_fast_tool_runner = FastToolRunner([ + RealTimeSearchTool(searxng_url=SEARXNG_URL), +]) + # GPU mutex: one LLM inference at a time _reply_semaphore = asyncio.Semaphore(1) @@ -188,7 +161,7 @@ async def lifespan(app: FastAPI): ) vram_manager = VRAMManager(base_url=OLLAMA_BASE_URL) - router = Router(model=router_model) + router = Router(model=router_model, fast_tool_runner=_fast_tool_runner) mcp_connections = { "openmemory": {"transport": "sse", "url": f"{OPENMEMORY_URL}/sse"}, @@ -413,33 +386,24 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram history = _conversation_buffers.get(session_id, []) print(f"[agent] running: {clean_message[:80]!r}", flush=True) - # Fetch URL content, memories, and (for real-time queries) web search — all IO-bound - is_realtime = bool(_REALTIME_RE.search(clean_message)) - if is_realtime: - url_context, memories, search_context = await asyncio.gather( - _fetch_urls_from_message(clean_message), - _retrieve_memories(clean_message, session_id), - _searxng_search_async(clean_message), - ) - if search_context and not search_context.startswith("[search error"): - print(f"[agent] pre-search: {len(search_context)} chars for real-time query", flush=True) - else: - search_context = "" - else: - url_context, memories = await asyncio.gather( - _fetch_urls_from_message(clean_message), - _retrieve_memories(clean_message, session_id), - ) - search_context = "" + # Fetch URL content, memories, and fast-tool context concurrently — all IO-bound + url_context, memories, fast_context = await asyncio.gather( + _fetch_urls_from_message(clean_message), + _retrieve_memories(clean_message, session_id), + _fast_tool_runner.run_matching(clean_message), + ) if url_context: print(f"[agent] crawl4ai: {len(url_context)} chars fetched from message URLs", flush=True) + if fast_context: + names = _fast_tool_runner.matching_names(clean_message) + print(f"[agent] fast_tools={names}: {len(fast_context)} chars injected", flush=True) - # Build enriched history: memories + url_context + search_context for ALL tiers + # Build enriched history: memories + url_context + fast_context for ALL tiers enriched_history = list(history) if url_context: enriched_history = [{"role": "system", "content": url_context}] + enriched_history - if search_context: - enriched_history = [{"role": "system", "content": search_context}] + enriched_history + if fast_context: + enriched_history = [{"role": "system", "content": fast_context}] + enriched_history if memories: enriched_history = [{"role": "system", "content": memories}] + enriched_history @@ -467,8 +431,8 @@ async def run_agent_task(message: str, session_id: str, channel: str = "telegram system_prompt = system_prompt + "\n\n" + memories if url_context: system_prompt = system_prompt + "\n\n" + url_context - if search_context: - system_prompt = system_prompt + "\n\nLive web search results (use these to answer):\n\n" + search_context + if fast_context: + system_prompt = system_prompt + "\n\nLive web search results (use these to answer):\n\n" + fast_context # Stream tokens directly — filter out qwen3 blocks in_think = False diff --git a/fast_tools.py b/fast_tools.py new file mode 100644 index 0000000..b4ac2e4 --- /dev/null +++ b/fast_tools.py @@ -0,0 +1,116 @@ +""" +Fast Tools — pre-flight tools invoked by a classifier before the main LLM call. + +Each FastTool has: + - matches(message) → bool : regex classifier that determines if this tool applies + - run(message) → str : async fetch that returns enrichment context + +FastToolRunner holds a list of FastTools. The Router uses any_matches() to force +the tier to medium before LLM classification. run_agent_task() calls run_matching() +to build extra context that is injected into the system prompt. + +To add a new fast tool: + 1. Subclass FastTool, implement name/matches/run + 2. Add an instance to the list passed to FastToolRunner in agent.py +""" + +import asyncio +import re +from abc import ABC, abstractmethod + +import httpx + + +class FastTool(ABC): + """Base class for all pre-flight fast tools.""" + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def matches(self, message: str) -> bool: ... + + @abstractmethod + async def run(self, message: str) -> str: ... + + +class RealTimeSearchTool(FastTool): + """ + Injects live SearXNG search snippets for queries that require real-time data: + weather, news, prices, scores, business hours. + + Matched queries are also forced to medium tier by the Router so the richer + model handles the injected context. + """ + + _PATTERN = re.compile( + r"\b(weather|forecast|temperature|rain(ing)?|snow(ing)?|humidity|wind\s*speed" + r"|today.?s news|breaking news|latest news|news today|current events" + r"|bitcoin price|crypto price|stock price|exchange rate" + r"|right now|currently|at the moment|live score|score now|score today" + r"|open now|hours today|is .+ open)\b", + re.IGNORECASE, + ) + + def __init__(self, searxng_url: str): + self._searxng_url = searxng_url + + @property + def name(self) -> str: + return "real_time_search" + + def matches(self, message: str) -> bool: + return bool(self._PATTERN.search(message)) + + async def run(self, message: str) -> str: + """Search SearXNG and return top snippets as a context block.""" + try: + async with httpx.AsyncClient(timeout=15) as client: + r = await client.get( + f"{self._searxng_url}/search", + params={"q": message, "format": "json"}, + ) + r.raise_for_status() + items = r.json().get("results", [])[:4] + except Exception as e: + return f"[real_time_search error: {e}]" + if not items: + return "" + lines = [f"Live web search results for: {message}\n"] + for i, item in enumerate(items, 1): + title = item.get("title", "") + url = item.get("url", "") + snippet = item.get("content", "")[:400] + lines.append(f"[{i}] {title}\nURL: {url}\n{snippet}\n") + return "\n".join(lines) + + +class FastToolRunner: + """ + Classifier + executor for fast tools. + + Used in two places: + - Router.route(): any_matches() forces medium tier before LLM classification + - run_agent_task(): run_matching() builds enrichment context in the pre-flight gather + """ + + def __init__(self, tools: list[FastTool]): + self._tools = tools + + def any_matches(self, message: str) -> bool: + """True if any fast tool applies to this message.""" + return any(t.matches(message) for t in self._tools) + + def matching_names(self, message: str) -> list[str]: + """Names of tools that match this message (for logging).""" + return [t.name for t in self._tools if t.matches(message)] + + async def run_matching(self, message: str) -> str: + """Run all matching tools concurrently and return combined context.""" + matching = [t for t in self._tools if t.matches(message)] + if not matching: + return "" + results = await asyncio.gather(*[t.run(message) for t in matching]) + parts = [r for r in results if r and not r.startswith("[")] + return "\n\n".join(parts) diff --git a/router.py b/router.py index 03f1559..a68e7c3 100644 --- a/router.py +++ b/router.py @@ -1,6 +1,7 @@ import re from typing import Optional from langchain_core.messages import SystemMessage, HumanMessage +from fast_tools import FastToolRunner # ── Regex pre-classifier ────────────────────────────────────────────────────── # Catches obvious light-tier patterns before calling the LLM. @@ -23,16 +24,6 @@ _LIGHT_PATTERNS = re.compile( re.IGNORECASE, ) -# Queries that require live data — never answer from static knowledge -_MEDIUM_FORCE_PATTERNS = re.compile( - r"\b(weather|forecast|temperature|rain(ing)?|snow(ing)?|humidity|wind speed" - r"|today.?s news|breaking news|latest news|news today|current events" - r"|bitcoin price|crypto price|stock price|exchange rate|usd|eur|btc" - r"|right now|currently|at the moment|live score|score now|score today" - r"|open now|hours today|is .+ open)\b", - re.IGNORECASE, -) - # ── LLM classification prompt ───────────────────────────────────────────────── CLASSIFY_PROMPT = """Classify the message. Output ONLY one word: light, medium, or complex. @@ -83,8 +74,9 @@ def _parse_tier(text: str) -> str: class Router: - def __init__(self, model): + def __init__(self, model, fast_tool_runner: FastToolRunner | None = None): self.model = model + self._fast_tool_runner = fast_tool_runner async def route( self, @@ -100,9 +92,10 @@ class Router: if force_complex: return "complex", None - # Step 0a: force medium for real-time / live-data queries - if _MEDIUM_FORCE_PATTERNS.search(message.strip()): - print(f"[router] regex→medium (real-time query)", flush=True) + # Step 0a: force medium if any fast tool matches (live-data queries) + if self._fast_tool_runner and self._fast_tool_runner.any_matches(message.strip()): + names = self._fast_tool_runner.matching_names(message.strip()) + print(f"[router] fast_tool_match={names} → medium", flush=True) return "medium", None # Step 0b: regex pre-classification for obvious light patterns