Files
oO/ml/serving/main.py
alvis 161e654027 feat(serving): replace MLflow run logging with native trace spans
Convert ml-serving from isolated MLflow runs to nested traces using
mlflow.start_span_no_context(). The recommend endpoint now emits a full
span tree: recommend (CHAIN) → build_context (TOOL), agent:* (AGENT) ×N,
llm_orchestrator (LLM). Compute and infer endpoints each emit a single span.

Supporting changes:
- mlflow-skinny>=3.1.0 added to requirements
- MLflow configured with --serve-artifacts + mlflow-artifacts:/ default root
  for cross-container artifact proxy (spans now persist from ml-serving)
- --allowed-hosts extended to include mlflow:5000 (SDK includes port in Host)
- science_destiny slider wired through prompts.py and recommend endpoint
- Config page exposes science/destiny slider (0=data-driven, 100=intuitive)
- Tip page shows rationale inline on tap

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-11 08:26:05 +00:00

628 lines
23 KiB
Python

"""
oO ML Serving — multi-agent orchestrator (ADR-0013).
Contract:
POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet
POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs
POST /recommend orchestrate agent snippets → one tip via LiteLLM
POST /generate LLM tip candidates (legacy; kept for bench/eval)
GET /health { ok, agents: [...] }
"""
from __future__ import annotations
import json
import os
import sys
import time
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
import sentry_sdk
import structlog
import structlog.contextvars
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
import mlflow
from mlflow.entities import SpanType
import logging_config
import nats_consumer
from prompts import get_prompt, build_orchestrator_messages
# Make ml.agents importable regardless of working directory.
# In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path.
# In local dev (run from ml/serving/): repo root is two levels up.
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from ml.agents.base import AgentInput # noqa: E402
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402
logging_config.configure()
_SENTRY_DSN = os.getenv("SENTRY_DSN")
if _SENTRY_DSN:
sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development"))
log = structlog.get_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
await nats_consumer.start(STATE_DIR)
yield
await nats_consumer.stop()
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
class _TracingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
structlog.contextvars.clear_contextvars()
traceparent = request.headers.get("traceparent", "")
if traceparent:
parts = traceparent.split("-")
trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None
if trace_id:
structlog.contextvars.bind_contextvars(trace_id=trace_id)
return await call_next(request)
app.add_middleware(_TracingMiddleware)
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
# ── MLflow tracing (optional) ───────────────────────────────────────────────
# Set MLFLOW_TRACKING_URI to enable. Spans are fire-and-forget; errors are
# logged at WARNING and never propagate to the caller.
# MLflow --allowed-hosts must include "mlflow" (the container DNS name) so the
# SDK can reach the server from inside other containers.
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
_MLFLOW_EXP = "oO/serving"
_mlflow_exp_id: str | None = None
if _MLFLOW_URI:
try:
mlflow.set_tracking_uri(_MLFLOW_URI)
_mlflow_exp_id = mlflow.set_experiment(_MLFLOW_EXP).experiment_id
except Exception as _exc:
log.warning("mlflow_init_failed", error=str(_exc))
class _NoOpSpan:
"""Returned when MLflow is disabled or span creation fails."""
def set_inputs(self, *a, **k): pass
def set_outputs(self, *a, **k): pass
def set_attribute(self, *a, **k): pass
def set_attributes(self, *a, **k): pass
def end(self, *a, **k): pass
_NOOP = _NoOpSpan()
def _start_span(name: str, span_type: str, *, parent=_NOOP, inputs=None):
"""Start an MLflow span. Returns _NOOP on failure or when tracing is off.
experiment_id is only passed for root spans (no parent) — passing it to
child spans causes the SDK to fail with '_Span has no attribute _span'.
"""
if _mlflow_exp_id is None:
return _NOOP
try:
kw: dict = {"span_type": span_type}
if isinstance(parent, _NoOpSpan):
kw["experiment_id"] = _mlflow_exp_id # root span only
else:
kw["parent_span"] = parent
if inputs is not None:
kw["inputs"] = inputs
return mlflow.start_span_no_context(name, **kw)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_span_start_failed", name=name, error=str(exc))
return _NOOP
def _end_span(span, *, status: str = "OK", outputs=None, attributes: dict | None = None) -> None:
"""End a span safely, ignoring _NoOpSpan and swallowing exceptions."""
if isinstance(span, _NoOpSpan):
return
try:
if attributes:
span.set_attributes(attributes)
span.end(status=status, outputs=outputs)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_span_end_failed", error=str(exc))
STATE_DIR.mkdir(parents=True, exist_ok=True)
# ── API models ─────────────────────────────────────────────────────────────
class PromptContext(BaseModel):
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = {}
profile_features: Optional[dict] = None
class GenerateRequest(BaseModel):
user_id: str
context: PromptContext = PromptContext()
n: int = 3
prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION)
# User-level features (#81 phase A). Accepted by the contract; not yet
# injected into the prompt — that's a #84-style prompt-design decision.
profile_features: Optional[dict] = None
class TipCandidate(BaseModel):
id: str
content: str
source: str = "llm"
rationale: Optional[str] = None
class GenerateResponse(BaseModel):
candidates: list[TipCandidate]
model: str
prompt_version: str
prompt_tokens: int = 0
completion_tokens: int = 0
# ── Multi-agent models ─────────────────────────────────────────────────────
class AgentComputeRequest(BaseModel):
user_id: str
tasks: list[dict] = []
profile: dict[str, Optional[float]] = {}
feedback_history: list[dict] = []
now_iso: Optional[str] = None # ISO 8601; defaults to utcnow
# Per-agent prefs from user_preferences (merged: user source overrides inferred).
agent_prefs: dict = {}
class AgentComputeResponse(BaseModel):
user_id: str
agent_id: str
prompt_text: str
signals_snapshot: dict
computed_at: str
expires_at: str
agent_version: str
class AgentInferRequest(BaseModel):
user_id: str
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …]
class AgentInferResponse(BaseModel):
user_id: str
agent_id: str
# {key: inferred_value} — caller persists to user_preferences with source='inferred'
inferred_prefs: dict
class AgentOutputSnippet(BaseModel):
agent_id: str
prompt_text: str
class RecommendRequest(BaseModel):
user_id: str
agent_outputs: list[AgentOutputSnippet] = []
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
science_destiny: int = 50 # 0=science (data-driven), 100=destiny (intuitive)
class TipResult(BaseModel):
id: str
content: str
source: str = "llm"
kind: str = "advice"
rationale: Optional[str] = None
class RecommendResponse(BaseModel):
tip: TipResult
model: str
prompt_tokens: int = 0
completion_tokens: int = 0
# ── Endpoints ──────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {
"ok": True,
"agents": [a.agent_id for a in all_agents()],
"nats": {
"enabled": bool(nats_consumer.NATS_URL),
"consumers": nats_consumer.consumer_health,
},
}
@app.get("/agents/registry")
def agents_registry():
"""Manifest list for every registered agent (ADR-0014).
Consumers: TS recommender (eligibility filter), admin UI (auto-rendered
pref forms), inference framework (#111). Static at process boot.
"""
return {"agents": [m.to_dict() for m in all_manifests()]}
_RETRY_SUFFIX = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON array — no prose, no markdown fences."
)
_RETRY_SUFFIX_OBJ = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON object — no prose, no markdown fences."
)
@app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse)
async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse:
"""Run a single sub-agent for a user and return its prompt snippet.
Called by the precompute pipeline for each (user_id, agent_id) pair.
The caller is responsible for persisting the result to agent_outputs via the
TypeScript API callback.
"""
try:
agent = get_agent(agent_id)
except KeyError:
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
now = (
datetime.fromisoformat(req.now_iso.replace("Z", "+00:00"))
if req.now_iso
else datetime.now(timezone.utc)
)
if now.tzinfo is None:
now = now.replace(tzinfo=timezone.utc)
inp = AgentInput(
user_id=req.user_id,
tasks=req.tasks,
profile=req.profile,
feedback_history=req.feedback_history,
now=now,
agent_prefs=req.agent_prefs,
)
try:
output = agent.compute(inp)
except Exception as exc:
log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc))
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
span = _start_span(
f"compute:{agent_id}",
SpanType.AGENT,
inputs={"user_id": req.user_id, "agent_id": agent_id,
"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
)
_end_span(span,
outputs={"prompt_text": output.prompt_text, "signals_snapshot": output.signals_snapshot},
attributes={"agent_version": output.agent_version, "expires_at": output.expires_at})
return AgentComputeResponse(
user_id=output.user_id,
agent_id=output.agent_id,
prompt_text=output.prompt_text,
signals_snapshot=output.signals_snapshot,
computed_at=output.computed_at,
expires_at=output.expires_at,
agent_version=output.agent_version,
)
@app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse)
async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse:
"""Run the inference framework for one agent and return inferred preference values.
The caller (TS agent-outputs.ts) persists results to user_preferences
with source='inferred', skipping keys where source='user' already exists.
"""
try:
manifest = get_manifest(agent_id)
except KeyError:
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
if not manifest.inferred_params:
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={})
events = [
FeedbackEvent(
action=e.get("action", ""),
dwell_ms=e.get("dwell_ms"),
created_at=e.get("created_at", ""),
)
for e in req.feedback_history
]
completions = [
TaskCompletion(
project_id=c.get("project_id"),
completed_at=c.get("completed_at", ""),
due_at=c.get("due_at", ""),
)
for c in req.task_completions
]
history = UserHistory(user_id=req.user_id, events=events, task_completions=completions)
t0 = __import__("time").monotonic()
inferred = run_inference(manifest, history)
latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1)
log.info(
"inference_run",
agent_id=agent_id,
user_id=req.user_id,
n_params=len(inferred),
history_len=len(events),
latency_ms=latency_ms,
)
span = _start_span(
f"infer:{agent_id}",
SpanType.CHAIN,
inputs={"user_id": req.user_id, "agent_id": agent_id,
"history_len": len(events), "completion_count": len(completions)},
)
_end_span(span,
outputs={"inferred_prefs": inferred},
attributes={"latency_ms": str(latency_ms), "n_params": str(len(inferred))})
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
@app.post("/recommend", response_model=RecommendResponse)
async def recommend(req: RecommendRequest) -> RecommendResponse:
"""Orchestrator: combine pre-computed agent outputs into one tip via LLM.
Called in real time when a user requests a tip. agent_outputs should be
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
before calling this endpoint). Falls back to raw task context if empty.
"""
t0 = time.monotonic()
# ── root span ──────────────────────────────────────────────────────────
root = _start_span("recommend", SpanType.CHAIN, inputs={
"user_id": req.user_id,
"agent_ids": [s.agent_id for s in req.agent_outputs],
"hour_of_day": req.hour_of_day,
"day_of_week": req.day_of_week,
"science_destiny": req.science_destiny,
})
try:
# ── build_context span ─────────────────────────────────────────────
ctx_span = _start_span("build_context", SpanType.TOOL, parent=root, inputs={
"agent_count": len(req.agent_outputs),
"task_count": len(req.tasks),
"science_destiny": req.science_destiny,
})
messages = build_orchestrator_messages(
agent_outputs=[s.model_dump() for s in req.agent_outputs],
tasks=req.tasks,
hour_of_day=req.hour_of_day,
day_of_week=req.day_of_week,
science_destiny=req.science_destiny,
)
_end_span(ctx_span, outputs={"message_count": len(messages)})
# ── one span per pre-computed agent snippet ────────────────────────
for snippet in req.agent_outputs:
a_span = _start_span(
f"agent:{snippet.agent_id}", SpanType.AGENT, parent=root,
inputs={"agent_id": snippet.agent_id},
)
_end_span(a_span, outputs={"prompt_text": snippet.prompt_text})
# ── LLM orchestrator span (wraps retry loop) ───────────────────────
llm_span = _start_span("llm_orchestrator", SpanType.LLM, parent=root, inputs={
"messages": messages,
"model": "tip-generator",
"temperature": 0.7,
})
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_raw = ""
last_parse_error = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
_attempt = 0
async with httpx.AsyncClient(timeout=30.0) as client:
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
text = last_raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
parsed = json.loads(text)
item: dict = parsed[0] if isinstance(parsed, list) else parsed
break
except (json.JSONDecodeError, ValueError, IndexError) as exc:
last_parse_error = str(exc)
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
else:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
tip = TipResult(
id=item.get("id", f"tip-{req.user_id[:8]}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
_end_span(llm_span, outputs={"content": tip.content, "rationale": tip.rationale or ""},
attributes={
"prompt_tokens": str(total_usage["prompt_tokens"]),
"completion_tokens": str(total_usage["completion_tokens"]),
"model": model_used,
"attempts": str(_attempt + 1),
})
latency_ms = round((time.monotonic() - t0) * 1000, 1)
log.info("recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id)
_end_span(root, outputs={"tip_id": tip.id, "content": tip.content, "rationale": tip.rationale or ""},
attributes={"latency_ms": str(latency_ms), "agent_count": str(len(req.agent_outputs))})
return RecommendResponse(
tip=tip,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
except HTTPException:
raise
except Exception:
_end_span(root, status="ERROR")
raise
_MAX_GENERATE_RETRIES = 2
def _parse_llm_json(raw: str) -> list[dict]:
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
return json.loads(text)
@app.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest) -> GenerateResponse:
"""Generate tip candidates via LiteLLM → tip-generator alias.
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
a correction hint to the conversation so the model can self-correct.
"""
try:
prompt_template = get_prompt(req.prompt_version)
except KeyError as e:
raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}")
ctx = req.context.model_copy(update={"profile_features": req.profile_features})
user_msg = prompt_template.build_user(ctx, req.n)
messages: list[dict] = [
{"role": "system", "content": prompt_template.system},
{"role": "user", "content": user_msg},
]
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_parse_error: str = ""
last_raw: str = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
async with httpx.AsyncClient(timeout=30.0) as client:
for attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions",
json=payload,
headers=headers,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
items = _parse_llm_json(last_raw)
break
except (json.JSONDecodeError, ValueError) as e:
last_parse_error = str(e)
# Feed the bad reply back so the model can self-correct
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
candidates = [
TipCandidate(
id=item.get("id", f"tip-{i}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
for i, item in enumerate(items)
]
return GenerateResponse(
candidates=candidates,
model=model_used,
prompt_version=prompt_template.version,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)