""" oO ML Serving — multi-agent orchestrator (ADR-0013). Contract: POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs POST /recommend orchestrate agent snippets → one tip via LiteLLM POST /generate LLM tip candidates (legacy; kept for bench/eval) GET /health { ok, agents: [...] } """ from __future__ import annotations import json import os import sys from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path from typing import Optional import httpx import sentry_sdk import structlog import structlog.contextvars from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware import logging_config import nats_consumer from prompts import get_prompt, build_orchestrator_messages # Make ml.agents importable regardless of working directory. # In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path. # In local dev (run from ml/serving/): repo root is two levels up. _repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if _repo_root not in sys.path: sys.path.insert(0, _repo_root) from ml.agents.base import AgentInput # noqa: E402 from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402 from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402 logging_config.configure() _SENTRY_DSN = os.getenv("SENTRY_DSN") if _SENTRY_DSN: sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development")) log = structlog.get_logger() @asynccontextmanager async def lifespan(app: FastAPI): await nats_consumer.start(STATE_DIR) yield await nats_consumer.stop() app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan) class _TracingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): structlog.contextvars.clear_contextvars() traceparent = request.headers.get("traceparent", "") if traceparent: parts = traceparent.split("-") trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None if trace_id: structlog.contextvars.bind_contextvars(trace_id=trace_id) return await call_next(request) app.add_middleware(_TracingMiddleware) LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state")) STATE_DIR.mkdir(parents=True, exist_ok=True) # ── API models ───────────────────────────────────────────────────────────── class PromptContext(BaseModel): tasks: list[dict] = [] hour_of_day: int = 12 day_of_week: int = 0 extra: dict = {} profile_features: Optional[dict] = None class GenerateRequest(BaseModel): user_id: str context: PromptContext = PromptContext() n: int = 3 prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION) # User-level features (#81 phase A). Accepted by the contract; not yet # injected into the prompt — that's a #84-style prompt-design decision. profile_features: Optional[dict] = None class TipCandidate(BaseModel): id: str content: str source: str = "llm" rationale: Optional[str] = None class GenerateResponse(BaseModel): candidates: list[TipCandidate] model: str prompt_version: str prompt_tokens: int = 0 completion_tokens: int = 0 # ── Multi-agent models ───────────────────────────────────────────────────── class AgentComputeRequest(BaseModel): user_id: str tasks: list[dict] = [] profile: dict[str, Optional[float]] = {} feedback_history: list[dict] = [] now_iso: Optional[str] = None # ISO 8601; defaults to utcnow # Per-agent prefs from user_preferences (merged: user source overrides inferred). agent_prefs: dict = {} class AgentComputeResponse(BaseModel): user_id: str agent_id: str prompt_text: str signals_snapshot: dict computed_at: str expires_at: str agent_version: str class AgentInferRequest(BaseModel): user_id: str feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …] task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …] class AgentInferResponse(BaseModel): user_id: str agent_id: str # {key: inferred_value} — caller persists to user_preferences with source='inferred' inferred_prefs: dict class AgentOutputSnippet(BaseModel): agent_id: str prompt_text: str class RecommendRequest(BaseModel): user_id: str agent_outputs: list[AgentOutputSnippet] = [] tasks: list[dict] = [] hour_of_day: int = 12 day_of_week: int = 0 class TipResult(BaseModel): id: str content: str source: str = "llm" kind: str = "advice" rationale: Optional[str] = None class RecommendResponse(BaseModel): tip: TipResult model: str prompt_tokens: int = 0 completion_tokens: int = 0 # ── Endpoints ────────────────────────────────────────────────────────────── @app.get("/health") def health(): return { "ok": True, "agents": [a.agent_id for a in all_agents()], "nats": { "enabled": bool(nats_consumer.NATS_URL), "consumers": nats_consumer.consumer_health, }, } @app.get("/agents/registry") def agents_registry(): """Manifest list for every registered agent (ADR-0014). Consumers: TS recommender (eligibility filter), admin UI (auto-rendered pref forms), inference framework (#111). Static at process boot. """ return {"agents": [m.to_dict() for m in all_manifests()]} _RETRY_SUFFIX = ( "\n\nYour previous response was not valid JSON. " "Reply ONLY with the JSON array — no prose, no markdown fences." ) _RETRY_SUFFIX_OBJ = ( "\n\nYour previous response was not valid JSON. " "Reply ONLY with the JSON object — no prose, no markdown fences." ) @app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse) async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse: """Run a single sub-agent for a user and return its prompt snippet. Called by the precompute pipeline for each (user_id, agent_id) pair. The caller is responsible for persisting the result to agent_outputs via the TypeScript API callback. """ try: agent = get_agent(agent_id) except KeyError: raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}") now = ( datetime.fromisoformat(req.now_iso.replace("Z", "+00:00")) if req.now_iso else datetime.now(timezone.utc) ) if now.tzinfo is None: now = now.replace(tzinfo=timezone.utc) inp = AgentInput( user_id=req.user_id, tasks=req.tasks, profile=req.profile, feedback_history=req.feedback_history, now=now, agent_prefs=req.agent_prefs, ) try: output = agent.compute(inp) except Exception as exc: log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc)) raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}") log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at) return AgentComputeResponse( user_id=output.user_id, agent_id=output.agent_id, prompt_text=output.prompt_text, signals_snapshot=output.signals_snapshot, computed_at=output.computed_at, expires_at=output.expires_at, agent_version=output.agent_version, ) @app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse) async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse: """Run the inference framework for one agent and return inferred preference values. The caller (TS agent-outputs.ts) persists results to user_preferences with source='inferred', skipping keys where source='user' already exists. """ try: manifest = get_manifest(agent_id) except KeyError: raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}") if not manifest.inferred_params: return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={}) events = [ FeedbackEvent( action=e.get("action", ""), dwell_ms=e.get("dwell_ms"), created_at=e.get("created_at", ""), ) for e in req.feedback_history ] completions = [ TaskCompletion( project_id=c.get("project_id"), completed_at=c.get("completed_at", ""), due_at=c.get("due_at", ""), ) for c in req.task_completions ] history = UserHistory(user_id=req.user_id, events=events, task_completions=completions) t0 = __import__("time").monotonic() inferred = run_inference(manifest, history) latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1) log.info( "inference_run", agent_id=agent_id, user_id=req.user_id, n_params=len(inferred), history_len=len(events), latency_ms=latency_ms, ) return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred) @app.post("/recommend", response_model=RecommendResponse) async def recommend(req: RecommendRequest) -> RecommendResponse: """Orchestrator: combine pre-computed agent outputs into one tip via LLM. Called in real time when a user requests a tip. agent_outputs should be the fresh rows from agent_outputs table (fetched by the TypeScript recommender before calling this endpoint). Falls back to raw task context if empty. """ messages = build_orchestrator_messages( agent_outputs=[s.model_dump() for s in req.agent_outputs], tasks=req.tasks, hour_of_day=req.hour_of_day, day_of_week=req.day_of_week, ) headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} last_raw = "" last_parse_error = "" total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0} model_used = "tip-generator" async with httpx.AsyncClient(timeout=30.0) as client: for _attempt in range(1 + _MAX_GENERATE_RETRIES): payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7} try: resp = await client.post( f"{LITELLM_URL}/chat/completions", json=payload, headers=headers ) resp.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}") except httpx.RequestError as e: raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}") data = resp.json() usage = data.get("usage", {}) total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0) total_usage["completion_tokens"] += usage.get("completion_tokens", 0) model_used = data.get("model", "tip-generator") last_raw = data["choices"][0]["message"]["content"] try: text = last_raw.strip() if text.startswith("```"): parts = text.split("```") text = parts[1] if len(parts) > 1 else text if text.startswith("json"): text = text[4:] parsed = json.loads(text) item: dict = parsed[0] if isinstance(parsed, list) else parsed break except (json.JSONDecodeError, ValueError, IndexError) as exc: last_parse_error = str(exc) messages.append({"role": "assistant", "content": last_raw}) messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ}) else: raise HTTPException( status_code=502, detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: " f"{last_parse_error}\n{last_raw[:200]}", ) tip = TipResult( id=item.get("id", f"tip-{req.user_id[:8]}"), content=item.get("content", ""), rationale=item.get("rationale"), ) log.info( "recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id, ) return RecommendResponse( tip=tip, model=model_used, prompt_tokens=total_usage["prompt_tokens"], completion_tokens=total_usage["completion_tokens"], ) _MAX_GENERATE_RETRIES = 2 def _parse_llm_json(raw: str) -> list[dict]: """Strip markdown fences and parse JSON array. Raises ValueError on failure.""" text = raw.strip() if text.startswith("```"): parts = text.split("```") text = parts[1] if len(parts) > 1 else text if text.startswith("json"): text = text[4:] return json.loads(text) @app.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest) -> GenerateResponse: """Generate tip candidates via LiteLLM → tip-generator alias. Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending a correction hint to the conversation so the model can self-correct. """ try: prompt_template = get_prompt(req.prompt_version) except KeyError as e: raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}") ctx = req.context.model_copy(update={"profile_features": req.profile_features}) user_msg = prompt_template.build_user(ctx, req.n) messages: list[dict] = [ {"role": "system", "content": prompt_template.system}, {"role": "user", "content": user_msg}, ] headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} last_parse_error: str = "" last_raw: str = "" total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0} model_used = "tip-generator" async with httpx.AsyncClient(timeout=30.0) as client: for attempt in range(1 + _MAX_GENERATE_RETRIES): payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7} try: resp = await client.post( f"{LITELLM_URL}/chat/completions", json=payload, headers=headers, ) resp.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}") except httpx.RequestError as e: raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}") data = resp.json() usage = data.get("usage", {}) total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0) total_usage["completion_tokens"] += usage.get("completion_tokens", 0) model_used = data.get("model", "tip-generator") last_raw = data["choices"][0]["message"]["content"] try: items = _parse_llm_json(last_raw) break except (json.JSONDecodeError, ValueError) as e: last_parse_error = str(e) # Feed the bad reply back so the model can self-correct messages.append({"role": "assistant", "content": last_raw}) messages.append({"role": "user", "content": _RETRY_SUFFIX}) else: raise HTTPException( status_code=502, detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: " f"{last_parse_error}\n{last_raw[:200]}", ) candidates = [ TipCandidate( id=item.get("id", f"tip-{i}"), content=item.get("content", ""), rationale=item.get("rationale"), ) for i, item in enumerate(items) ] return GenerateResponse( candidates=candidates, model=model_used, prompt_version=prompt_template.version, prompt_tokens=total_usage["prompt_tokens"], completion_tokens=total_usage["completion_tokens"], )