Files
oO/ml/serving/main.py
alvis 95e1b342b4 fix(serving): wire MLflow auth and Host header for container-to-container calls
- Pass MLFLOW_ADMIN_PASSWORD as fallback password credential
- Set host_header='localhost' to satisfy MLflow's --allowed-hosts check
  (MLflow rejects Host: mlflow but accepts Host: localhost)
- Default MLFLOW_TRACKING_URI to http://mlflow:5000 in compose so the
  env_file value is not silently overridden to empty

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-06 10:39:08 +00:00

553 lines
20 KiB
Python

"""
oO ML Serving — multi-agent orchestrator (ADR-0013).
Contract:
POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet
POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs
POST /recommend orchestrate agent snippets → one tip via LiteLLM
POST /generate LLM tip candidates (legacy; kept for bench/eval)
GET /health { ok, agents: [...] }
"""
from __future__ import annotations
import json
import os
import sys
import time
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
import sentry_sdk
import structlog
import structlog.contextvars
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
import logging_config
import nats_consumer
from mlflow_client import MLflowClient
from prompts import get_prompt, build_orchestrator_messages
# Make ml.agents importable regardless of working directory.
# In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path.
# In local dev (run from ml/serving/): repo root is two levels up.
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from ml.agents.base import AgentInput # noqa: E402
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402
logging_config.configure()
_SENTRY_DSN = os.getenv("SENTRY_DSN")
if _SENTRY_DSN:
sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development"))
log = structlog.get_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
await nats_consumer.start(STATE_DIR)
yield
await nats_consumer.stop()
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
class _TracingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
structlog.contextvars.clear_contextvars()
traceparent = request.headers.get("traceparent", "")
if traceparent:
parts = traceparent.split("-")
trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None
if trace_id:
structlog.contextvars.bind_contextvars(trace_id=trace_id)
return await call_next(request)
app.add_middleware(_TracingMiddleware)
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
# ── MLflow tracing (optional) ───────────────────────────────────────────────
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
# is logged at WARNING and never propagates to the caller.
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
_mlflow: MLflowClient | None = (
MLflowClient(
tracking_uri=_MLFLOW_URI,
username=os.getenv("MLFLOW_TRACKING_USERNAME", "admin"),
password=os.getenv("MLFLOW_TRACKING_PASSWORD") or os.getenv("MLFLOW_ADMIN_PASSWORD", "password"),
host_header="localhost",
)
if _MLFLOW_URI else None
)
_MLFLOW_EXP = "oO/serving"
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
if _mlflow is None:
return
try:
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
_mlflow.log_metrics(run_id, metrics)
for k, v in tags.items():
_mlflow.log_text(run_id, str(v), k)
_mlflow.end_run(run_id)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_log_failed", error=str(exc))
STATE_DIR.mkdir(parents=True, exist_ok=True)
# ── API models ─────────────────────────────────────────────────────────────
class PromptContext(BaseModel):
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = {}
profile_features: Optional[dict] = None
class GenerateRequest(BaseModel):
user_id: str
context: PromptContext = PromptContext()
n: int = 3
prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION)
# User-level features (#81 phase A). Accepted by the contract; not yet
# injected into the prompt — that's a #84-style prompt-design decision.
profile_features: Optional[dict] = None
class TipCandidate(BaseModel):
id: str
content: str
source: str = "llm"
rationale: Optional[str] = None
class GenerateResponse(BaseModel):
candidates: list[TipCandidate]
model: str
prompt_version: str
prompt_tokens: int = 0
completion_tokens: int = 0
# ── Multi-agent models ─────────────────────────────────────────────────────
class AgentComputeRequest(BaseModel):
user_id: str
tasks: list[dict] = []
profile: dict[str, Optional[float]] = {}
feedback_history: list[dict] = []
now_iso: Optional[str] = None # ISO 8601; defaults to utcnow
# Per-agent prefs from user_preferences (merged: user source overrides inferred).
agent_prefs: dict = {}
class AgentComputeResponse(BaseModel):
user_id: str
agent_id: str
prompt_text: str
signals_snapshot: dict
computed_at: str
expires_at: str
agent_version: str
class AgentInferRequest(BaseModel):
user_id: str
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …]
class AgentInferResponse(BaseModel):
user_id: str
agent_id: str
# {key: inferred_value} — caller persists to user_preferences with source='inferred'
inferred_prefs: dict
class AgentOutputSnippet(BaseModel):
agent_id: str
prompt_text: str
class RecommendRequest(BaseModel):
user_id: str
agent_outputs: list[AgentOutputSnippet] = []
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
class TipResult(BaseModel):
id: str
content: str
source: str = "llm"
kind: str = "advice"
rationale: Optional[str] = None
class RecommendResponse(BaseModel):
tip: TipResult
model: str
prompt_tokens: int = 0
completion_tokens: int = 0
# ── Endpoints ──────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {
"ok": True,
"agents": [a.agent_id for a in all_agents()],
"nats": {
"enabled": bool(nats_consumer.NATS_URL),
"consumers": nats_consumer.consumer_health,
},
}
@app.get("/agents/registry")
def agents_registry():
"""Manifest list for every registered agent (ADR-0014).
Consumers: TS recommender (eligibility filter), admin UI (auto-rendered
pref forms), inference framework (#111). Static at process boot.
"""
return {"agents": [m.to_dict() for m in all_manifests()]}
_RETRY_SUFFIX = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON array — no prose, no markdown fences."
)
_RETRY_SUFFIX_OBJ = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON object — no prose, no markdown fences."
)
@app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse)
async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse:
"""Run a single sub-agent for a user and return its prompt snippet.
Called by the precompute pipeline for each (user_id, agent_id) pair.
The caller is responsible for persisting the result to agent_outputs via the
TypeScript API callback.
"""
try:
agent = get_agent(agent_id)
except KeyError:
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
now = (
datetime.fromisoformat(req.now_iso.replace("Z", "+00:00"))
if req.now_iso
else datetime.now(timezone.utc)
)
if now.tzinfo is None:
now = now.replace(tzinfo=timezone.utc)
inp = AgentInput(
user_id=req.user_id,
tasks=req.tasks,
profile=req.profile,
feedback_history=req.feedback_history,
now=now,
agent_prefs=req.agent_prefs,
)
try:
output = agent.compute(inp)
except Exception as exc:
log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc))
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
_mlflow_run(
run_name=f"compute/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
)
return AgentComputeResponse(
user_id=output.user_id,
agent_id=output.agent_id,
prompt_text=output.prompt_text,
signals_snapshot=output.signals_snapshot,
computed_at=output.computed_at,
expires_at=output.expires_at,
agent_version=output.agent_version,
)
@app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse)
async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse:
"""Run the inference framework for one agent and return inferred preference values.
The caller (TS agent-outputs.ts) persists results to user_preferences
with source='inferred', skipping keys where source='user' already exists.
"""
try:
manifest = get_manifest(agent_id)
except KeyError:
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
if not manifest.inferred_params:
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={})
events = [
FeedbackEvent(
action=e.get("action", ""),
dwell_ms=e.get("dwell_ms"),
created_at=e.get("created_at", ""),
)
for e in req.feedback_history
]
completions = [
TaskCompletion(
project_id=c.get("project_id"),
completed_at=c.get("completed_at", ""),
due_at=c.get("due_at", ""),
)
for c in req.task_completions
]
history = UserHistory(user_id=req.user_id, events=events, task_completions=completions)
t0 = __import__("time").monotonic()
inferred = run_inference(manifest, history)
latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1)
log.info(
"inference_run",
agent_id=agent_id,
user_id=req.user_id,
n_params=len(inferred),
history_len=len(events),
latency_ms=latency_ms,
)
_mlflow_run(
run_name=f"infer/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id},
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
tags={"inferred_prefs": json.dumps(inferred)},
)
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
@app.post("/recommend", response_model=RecommendResponse)
async def recommend(req: RecommendRequest) -> RecommendResponse:
"""Orchestrator: combine pre-computed agent outputs into one tip via LLM.
Called in real time when a user requests a tip. agent_outputs should be
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
before calling this endpoint). Falls back to raw task context if empty.
"""
t0_recommend = time.monotonic()
messages = build_orchestrator_messages(
agent_outputs=[s.model_dump() for s in req.agent_outputs],
tasks=req.tasks,
hour_of_day=req.hour_of_day,
day_of_week=req.day_of_week,
)
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_raw = ""
last_parse_error = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
async with httpx.AsyncClient(timeout=30.0) as client:
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
text = last_raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
parsed = json.loads(text)
item: dict = parsed[0] if isinstance(parsed, list) else parsed
break
except (json.JSONDecodeError, ValueError, IndexError) as exc:
last_parse_error = str(exc)
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
tip = TipResult(
id=item.get("id", f"tip-{req.user_id[:8]}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
log.info(
"recommend_served",
user_id=req.user_id,
agent_count=len(req.agent_outputs),
tip_id=tip.id,
)
_mlflow_run(
run_name="recommend",
params={
"user_id": req.user_id,
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
"model": model_used,
"hour_of_day": req.hour_of_day,
"day_of_week": req.day_of_week,
},
metrics={
"prompt_tokens": total_usage["prompt_tokens"],
"completion_tokens": total_usage["completion_tokens"],
"agent_count": len(req.agent_outputs),
"latency_ms": latency_ms_recommend,
},
tags={
"prompt_messages": json.dumps(messages),
"tip_content": tip.content,
"tip_rationale": tip.rationale or "",
},
)
return RecommendResponse(
tip=tip,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
_MAX_GENERATE_RETRIES = 2
def _parse_llm_json(raw: str) -> list[dict]:
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
return json.loads(text)
@app.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest) -> GenerateResponse:
"""Generate tip candidates via LiteLLM → tip-generator alias.
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
a correction hint to the conversation so the model can self-correct.
"""
try:
prompt_template = get_prompt(req.prompt_version)
except KeyError as e:
raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}")
ctx = req.context.model_copy(update={"profile_features": req.profile_features})
user_msg = prompt_template.build_user(ctx, req.n)
messages: list[dict] = [
{"role": "system", "content": prompt_template.system},
{"role": "user", "content": user_msg},
]
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_parse_error: str = ""
last_raw: str = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
async with httpx.AsyncClient(timeout=30.0) as client:
for attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions",
json=payload,
headers=headers,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
items = _parse_llm_json(last_raw)
break
except (json.JSONDecodeError, ValueError) as e:
last_parse_error = str(e)
# Feed the bad reply back so the model can self-correct
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
candidates = [
TipCandidate(
id=item.get("id", f"tip-{i}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
for i, item in enumerate(items)
]
return GenerateResponse(
candidates=candidates,
model=model_used,
prompt_version=prompt_template.version,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)