""" oO ML Serving — multi-agent orchestrator (ADR-0013). Contract: POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs POST /recommend orchestrate agent snippets → one tip via LiteLLM POST /generate LLM tip candidates (legacy; kept for bench/eval) GET /health { ok, agents: [...] } """ from __future__ import annotations import json import os import sys import time from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path from typing import Optional import httpx import sentry_sdk import structlog import structlog.contextvars from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware import mlflow from mlflow.entities import SpanType import logging_config import nats_consumer from prompts import get_prompt, build_orchestrator_messages # Make ml.agents importable regardless of working directory. # In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path. # In local dev (run from ml/serving/): repo root is two levels up. _repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if _repo_root not in sys.path: sys.path.insert(0, _repo_root) from ml.agents.base import AgentInput # noqa: E402 from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402 from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402 logging_config.configure() _SENTRY_DSN = os.getenv("SENTRY_DSN") if _SENTRY_DSN: sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development")) log = structlog.get_logger() @asynccontextmanager async def lifespan(app: FastAPI): await nats_consumer.start(STATE_DIR) yield await nats_consumer.stop() app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan) class _TracingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): structlog.contextvars.clear_contextvars() traceparent = request.headers.get("traceparent", "") if traceparent: parts = traceparent.split("-") trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None if trace_id: structlog.contextvars.bind_contextvars(trace_id=trace_id) return await call_next(request) app.add_middleware(_TracingMiddleware) LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state")) # ── MLflow tracing (optional) ─────────────────────────────────────────────── # Set MLFLOW_TRACKING_URI to enable. Spans are fire-and-forget; errors are # logged at WARNING and never propagate to the caller. # MLflow --allowed-hosts must include "mlflow" (the container DNS name) so the # SDK can reach the server from inside other containers. _MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "") _MLFLOW_EXP = "oO/serving" _mlflow_exp_id: str | None = None if _MLFLOW_URI: try: mlflow.set_tracking_uri(_MLFLOW_URI) _mlflow_exp_id = mlflow.set_experiment(_MLFLOW_EXP).experiment_id except Exception as _exc: log.warning("mlflow_init_failed", error=str(_exc)) class _NoOpSpan: """Returned when MLflow is disabled or span creation fails.""" def set_inputs(self, *a, **k): pass def set_outputs(self, *a, **k): pass def set_attribute(self, *a, **k): pass def set_attributes(self, *a, **k): pass def end(self, *a, **k): pass _NOOP = _NoOpSpan() def _start_span(name: str, span_type: str, *, parent=_NOOP, inputs=None): """Start an MLflow span. Returns _NOOP on failure or when tracing is off. experiment_id is only passed for root spans (no parent) — passing it to child spans causes the SDK to fail with '_Span has no attribute _span'. """ if _mlflow_exp_id is None: return _NOOP try: kw: dict = {"span_type": span_type} if isinstance(parent, _NoOpSpan): kw["experiment_id"] = _mlflow_exp_id # root span only else: kw["parent_span"] = parent if inputs is not None: kw["inputs"] = inputs return mlflow.start_span_no_context(name, **kw) except Exception as exc: # noqa: BLE001 log.warning("mlflow_span_start_failed", name=name, error=str(exc)) return _NOOP def _end_span(span, *, status: str = "OK", outputs=None, attributes: dict | None = None) -> None: """End a span safely, ignoring _NoOpSpan and swallowing exceptions.""" if isinstance(span, _NoOpSpan): return try: if attributes: span.set_attributes(attributes) span.end(status=status, outputs=outputs) except Exception as exc: # noqa: BLE001 log.warning("mlflow_span_end_failed", error=str(exc)) STATE_DIR.mkdir(parents=True, exist_ok=True) # ── API models ───────────────────────────────────────────────────────────── class PromptContext(BaseModel): tasks: list[dict] = [] hour_of_day: int = 12 day_of_week: int = 0 extra: dict = {} profile_features: Optional[dict] = None class GenerateRequest(BaseModel): user_id: str context: PromptContext = PromptContext() n: int = 3 prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION) # User-level features (#81 phase A). Accepted by the contract; not yet # injected into the prompt — that's a #84-style prompt-design decision. profile_features: Optional[dict] = None class TipCandidate(BaseModel): id: str content: str source: str = "llm" rationale: Optional[str] = None class GenerateResponse(BaseModel): candidates: list[TipCandidate] model: str prompt_version: str prompt_tokens: int = 0 completion_tokens: int = 0 # ── Multi-agent models ───────────────────────────────────────────────────── class AgentComputeRequest(BaseModel): user_id: str tasks: list[dict] = [] profile: dict[str, Optional[float]] = {} feedback_history: list[dict] = [] now_iso: Optional[str] = None # ISO 8601; defaults to utcnow # Per-agent prefs from user_preferences (merged: user source overrides inferred). agent_prefs: dict = {} class AgentComputeResponse(BaseModel): user_id: str agent_id: str prompt_text: str signals_snapshot: dict computed_at: str expires_at: str agent_version: str class AgentInferRequest(BaseModel): user_id: str feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …] task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …] class AgentInferResponse(BaseModel): user_id: str agent_id: str # {key: inferred_value} — caller persists to user_preferences with source='inferred' inferred_prefs: dict class AgentOutputSnippet(BaseModel): agent_id: str prompt_text: str class RecommendRequest(BaseModel): user_id: str agent_outputs: list[AgentOutputSnippet] = [] tasks: list[dict] = [] hour_of_day: int = 12 day_of_week: int = 0 science_destiny: int = 50 # 0=science (data-driven), 100=destiny (intuitive) recent_tip: Optional[str] = None # content of last snoozed tip; LLM avoids repeating it class TipResult(BaseModel): id: str content: str source: str = "llm" kind: str = "advice" rationale: Optional[str] = None class RecommendResponse(BaseModel): tip: TipResult model: str prompt_tokens: int = 0 completion_tokens: int = 0 # ── Endpoints ────────────────────────────────────────────────────────────── @app.get("/health") def health(): return { "ok": True, "agents": [a.agent_id for a in all_agents()], "nats": { "enabled": bool(nats_consumer.NATS_URL), "consumers": nats_consumer.consumer_health, }, } @app.get("/agents/registry") def agents_registry(): """Manifest list for every registered agent (ADR-0014). Consumers: TS recommender (eligibility filter), admin UI (auto-rendered pref forms), inference framework (#111). Static at process boot. """ return {"agents": [m.to_dict() for m in all_manifests()]} _RETRY_SUFFIX = ( "\n\nYour previous response was not valid JSON. " "Reply ONLY with the JSON array — no prose, no markdown fences." ) _RETRY_SUFFIX_OBJ = ( "\n\nYour previous response was not valid JSON. " "Reply ONLY with the JSON object — no prose, no markdown fences." ) @app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse) async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse: """Run a single sub-agent for a user and return its prompt snippet. Called by the precompute pipeline for each (user_id, agent_id) pair. The caller is responsible for persisting the result to agent_outputs via the TypeScript API callback. """ try: agent = get_agent(agent_id) except KeyError: raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}") now = ( datetime.fromisoformat(req.now_iso.replace("Z", "+00:00")) if req.now_iso else datetime.now(timezone.utc) ) if now.tzinfo is None: now = now.replace(tzinfo=timezone.utc) inp = AgentInput( user_id=req.user_id, tasks=req.tasks, profile=req.profile, feedback_history=req.feedback_history, now=now, agent_prefs=req.agent_prefs, ) try: output = agent.compute(inp) except Exception as exc: log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc)) raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}") log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at) span = _start_span( f"compute:{agent_id}", SpanType.AGENT, inputs={"user_id": req.user_id, "agent_id": agent_id, "task_count": len(req.tasks), "feedback_count": len(req.feedback_history)}, ) _end_span(span, outputs={"prompt_text": output.prompt_text, "signals_snapshot": output.signals_snapshot}, attributes={"agent_version": output.agent_version, "expires_at": output.expires_at}) return AgentComputeResponse( user_id=output.user_id, agent_id=output.agent_id, prompt_text=output.prompt_text, signals_snapshot=output.signals_snapshot, computed_at=output.computed_at, expires_at=output.expires_at, agent_version=output.agent_version, ) @app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse) async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse: """Run the inference framework for one agent and return inferred preference values. The caller (TS agent-outputs.ts) persists results to user_preferences with source='inferred', skipping keys where source='user' already exists. """ try: manifest = get_manifest(agent_id) except KeyError: raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}") if not manifest.inferred_params: return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={}) events = [ FeedbackEvent( action=e.get("action", ""), dwell_ms=e.get("dwell_ms"), created_at=e.get("created_at", ""), ) for e in req.feedback_history ] completions = [ TaskCompletion( project_id=c.get("project_id"), completed_at=c.get("completed_at", ""), due_at=c.get("due_at", ""), ) for c in req.task_completions ] history = UserHistory(user_id=req.user_id, events=events, task_completions=completions) t0 = __import__("time").monotonic() inferred = run_inference(manifest, history) latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1) log.info( "inference_run", agent_id=agent_id, user_id=req.user_id, n_params=len(inferred), history_len=len(events), latency_ms=latency_ms, ) span = _start_span( f"infer:{agent_id}", SpanType.CHAIN, inputs={"user_id": req.user_id, "agent_id": agent_id, "history_len": len(events), "completion_count": len(completions)}, ) _end_span(span, outputs={"inferred_prefs": inferred}, attributes={"latency_ms": str(latency_ms), "n_params": str(len(inferred))}) return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred) @app.post("/recommend", response_model=RecommendResponse) async def recommend(req: RecommendRequest) -> RecommendResponse: """Orchestrator: combine pre-computed agent outputs into one tip via LLM. Called in real time when a user requests a tip. agent_outputs should be the fresh rows from agent_outputs table (fetched by the TypeScript recommender before calling this endpoint). Falls back to raw task context if empty. """ t0 = time.monotonic() # ── root span ────────────────────────────────────────────────────────── root = _start_span("recommend", SpanType.CHAIN, inputs={ "user_id": req.user_id, "agent_ids": [s.agent_id for s in req.agent_outputs], "hour_of_day": req.hour_of_day, "day_of_week": req.day_of_week, "science_destiny": req.science_destiny, }) try: # ── build_context span ───────────────────────────────────────────── ctx_span = _start_span("build_context", SpanType.TOOL, parent=root, inputs={ "agent_count": len(req.agent_outputs), "task_count": len(req.tasks), "science_destiny": req.science_destiny, }) messages = build_orchestrator_messages( agent_outputs=[s.model_dump() for s in req.agent_outputs], tasks=req.tasks, hour_of_day=req.hour_of_day, day_of_week=req.day_of_week, science_destiny=req.science_destiny, recent_tip=req.recent_tip, ) _end_span(ctx_span, outputs={"message_count": len(messages)}) # ── one span per pre-computed agent snippet ──────────────────────── for snippet in req.agent_outputs: a_span = _start_span( f"agent:{snippet.agent_id}", SpanType.AGENT, parent=root, inputs={"agent_id": snippet.agent_id}, ) _end_span(a_span, outputs={"prompt_text": snippet.prompt_text}) # ── LLM orchestrator span (wraps retry loop) ─────────────────────── llm_span = _start_span("llm_orchestrator", SpanType.LLM, parent=root, inputs={ "messages": messages, "model": "tip-generator", "temperature": 0.7, }) headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} last_raw = "" last_parse_error = "" total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0} model_used = "tip-generator" _attempt = 0 async with httpx.AsyncClient(timeout=30.0) as client: for _attempt in range(1 + _MAX_GENERATE_RETRIES): payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7} try: resp = await client.post( f"{LITELLM_URL}/chat/completions", json=payload, headers=headers ) resp.raise_for_status() except httpx.HTTPStatusError as e: _end_span(llm_span, status="ERROR") _end_span(root, status="ERROR") raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}") except httpx.RequestError as e: _end_span(llm_span, status="ERROR") _end_span(root, status="ERROR") raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}") data = resp.json() usage = data.get("usage", {}) total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0) total_usage["completion_tokens"] += usage.get("completion_tokens", 0) model_used = data.get("model", "tip-generator") last_raw = data["choices"][0]["message"]["content"] try: text = last_raw.strip() if text.startswith("```"): parts = text.split("```") text = parts[1] if len(parts) > 1 else text if text.startswith("json"): text = text[4:] parsed = json.loads(text) item: dict = parsed[0] if isinstance(parsed, list) else parsed break except (json.JSONDecodeError, ValueError, IndexError) as exc: last_parse_error = str(exc) messages.append({"role": "assistant", "content": last_raw}) messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ}) else: _end_span(llm_span, status="ERROR") _end_span(root, status="ERROR") raise HTTPException( status_code=502, detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: " f"{last_parse_error}\n{last_raw[:200]}", ) tip = TipResult( id=item.get("id", f"tip-{req.user_id[:8]}"), content=item.get("content", ""), rationale=item.get("rationale"), ) _end_span(llm_span, outputs={"content": tip.content, "rationale": tip.rationale or ""}, attributes={ "prompt_tokens": str(total_usage["prompt_tokens"]), "completion_tokens": str(total_usage["completion_tokens"]), "model": model_used, "attempts": str(_attempt + 1), }) latency_ms = round((time.monotonic() - t0) * 1000, 1) log.info("recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id) _end_span(root, outputs={"tip_id": tip.id, "content": tip.content, "rationale": tip.rationale or ""}, attributes={"latency_ms": str(latency_ms), "agent_count": str(len(req.agent_outputs))}) return RecommendResponse( tip=tip, model=model_used, prompt_tokens=total_usage["prompt_tokens"], completion_tokens=total_usage["completion_tokens"], ) except HTTPException: raise except Exception: _end_span(root, status="ERROR") raise _MAX_GENERATE_RETRIES = 2 def _parse_llm_json(raw: str) -> list[dict]: """Strip markdown fences and parse JSON array. Raises ValueError on failure.""" text = raw.strip() if text.startswith("```"): parts = text.split("```") text = parts[1] if len(parts) > 1 else text if text.startswith("json"): text = text[4:] return json.loads(text) @app.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest) -> GenerateResponse: """Generate tip candidates via LiteLLM → tip-generator alias. Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending a correction hint to the conversation so the model can self-correct. """ try: prompt_template = get_prompt(req.prompt_version) except KeyError as e: raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}") ctx = req.context.model_copy(update={"profile_features": req.profile_features}) user_msg = prompt_template.build_user(ctx, req.n) messages: list[dict] = [ {"role": "system", "content": prompt_template.system}, {"role": "user", "content": user_msg}, ] headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} last_parse_error: str = "" last_raw: str = "" total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0} model_used = "tip-generator" async with httpx.AsyncClient(timeout=30.0) as client: for attempt in range(1 + _MAX_GENERATE_RETRIES): payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7} try: resp = await client.post( f"{LITELLM_URL}/chat/completions", json=payload, headers=headers, ) resp.raise_for_status() except httpx.HTTPStatusError as e: raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}") except httpx.RequestError as e: raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}") data = resp.json() usage = data.get("usage", {}) total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0) total_usage["completion_tokens"] += usage.get("completion_tokens", 0) model_used = data.get("model", "tip-generator") last_raw = data["choices"][0]["message"]["content"] try: items = _parse_llm_json(last_raw) break except (json.JSONDecodeError, ValueError) as e: last_parse_error = str(e) # Feed the bad reply back so the model can self-correct messages.append({"role": "assistant", "content": last_raw}) messages.append({"role": "user", "content": _RETRY_SUFFIX}) else: raise HTTPException( status_code=502, detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: " f"{last_parse_error}\n{last_raw[:200]}", ) candidates = [ TipCandidate( id=item.get("id", f"tip-{i}"), content=item.get("content", ""), rationale=item.get("rationale"), ) for i, item in enumerate(items) ] return GenerateResponse( candidates=candidates, model=model_used, prompt_version=prompt_template.version, prompt_tokens=total_usage["prompt_tokens"], completion_tokens=total_usage["completion_tokens"], )