Logs one MLflow run per /recommend (params, token metrics, latency,
full prompt + tip as artifacts) and per /agents/{id}/compute and
/infer call (signals snapshot, inferred prefs, latency).
Tracing is a no-op when MLFLOW_TRACKING_URI is unset; ml-serving
starts and serves tips correctly without MLflow configured.
Refs #118 (M4: remove from production / move off critical path).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
545 lines
19 KiB
Python
545 lines
19 KiB
Python
"""
|
|
oO ML Serving — multi-agent orchestrator (ADR-0013).
|
|
|
|
Contract:
|
|
POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet
|
|
POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs
|
|
POST /recommend orchestrate agent snippets → one tip via LiteLLM
|
|
POST /generate LLM tip candidates (legacy; kept for bench/eval)
|
|
GET /health { ok, agents: [...] }
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
import sentry_sdk
|
|
import structlog
|
|
import structlog.contextvars
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from pydantic import BaseModel
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
import logging_config
|
|
import nats_consumer
|
|
from mlflow_client import MLflowClient
|
|
from prompts import get_prompt, build_orchestrator_messages
|
|
|
|
# Make ml.agents importable regardless of working directory.
|
|
# In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path.
|
|
# In local dev (run from ml/serving/): repo root is two levels up.
|
|
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
if _repo_root not in sys.path:
|
|
sys.path.insert(0, _repo_root)
|
|
|
|
from ml.agents.base import AgentInput # noqa: E402
|
|
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
|
|
from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402
|
|
|
|
logging_config.configure()
|
|
|
|
_SENTRY_DSN = os.getenv("SENTRY_DSN")
|
|
if _SENTRY_DSN:
|
|
sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development"))
|
|
|
|
log = structlog.get_logger()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
await nats_consumer.start(STATE_DIR)
|
|
yield
|
|
await nats_consumer.stop()
|
|
|
|
|
|
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
|
|
|
|
|
|
class _TracingMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
structlog.contextvars.clear_contextvars()
|
|
traceparent = request.headers.get("traceparent", "")
|
|
if traceparent:
|
|
parts = traceparent.split("-")
|
|
trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None
|
|
if trace_id:
|
|
structlog.contextvars.bind_contextvars(trace_id=trace_id)
|
|
return await call_next(request)
|
|
|
|
|
|
app.add_middleware(_TracingMiddleware)
|
|
|
|
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
|
|
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
|
|
|
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
|
|
|
|
# ── MLflow tracing (optional) ───────────────────────────────────────────────
|
|
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
|
|
# is logged at WARNING and never propagates to the caller.
|
|
|
|
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
|
|
_mlflow: MLflowClient | None = MLflowClient(tracking_uri=_MLFLOW_URI) if _MLFLOW_URI else None
|
|
_MLFLOW_EXP = "oO/serving"
|
|
|
|
|
|
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
|
|
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
|
|
if _mlflow is None:
|
|
return
|
|
try:
|
|
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
|
|
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
|
|
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
|
|
_mlflow.log_metrics(run_id, metrics)
|
|
for k, v in tags.items():
|
|
_mlflow.log_text(run_id, str(v), k)
|
|
_mlflow.end_run(run_id)
|
|
except Exception as exc: # noqa: BLE001
|
|
log.warning("mlflow_log_failed", error=str(exc))
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# ── API models ─────────────────────────────────────────────────────────────
|
|
|
|
class PromptContext(BaseModel):
|
|
tasks: list[dict] = []
|
|
hour_of_day: int = 12
|
|
day_of_week: int = 0
|
|
extra: dict = {}
|
|
profile_features: Optional[dict] = None
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
user_id: str
|
|
context: PromptContext = PromptContext()
|
|
n: int = 3
|
|
prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION)
|
|
# User-level features (#81 phase A). Accepted by the contract; not yet
|
|
# injected into the prompt — that's a #84-style prompt-design decision.
|
|
profile_features: Optional[dict] = None
|
|
|
|
|
|
class TipCandidate(BaseModel):
|
|
id: str
|
|
content: str
|
|
source: str = "llm"
|
|
rationale: Optional[str] = None
|
|
|
|
|
|
class GenerateResponse(BaseModel):
|
|
candidates: list[TipCandidate]
|
|
model: str
|
|
prompt_version: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
|
|
|
|
# ── Multi-agent models ─────────────────────────────────────────────────────
|
|
|
|
class AgentComputeRequest(BaseModel):
|
|
user_id: str
|
|
tasks: list[dict] = []
|
|
profile: dict[str, Optional[float]] = {}
|
|
feedback_history: list[dict] = []
|
|
now_iso: Optional[str] = None # ISO 8601; defaults to utcnow
|
|
# Per-agent prefs from user_preferences (merged: user source overrides inferred).
|
|
agent_prefs: dict = {}
|
|
|
|
|
|
class AgentComputeResponse(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
prompt_text: str
|
|
signals_snapshot: dict
|
|
computed_at: str
|
|
expires_at: str
|
|
agent_version: str
|
|
|
|
|
|
class AgentInferRequest(BaseModel):
|
|
user_id: str
|
|
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
|
|
task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …]
|
|
|
|
|
|
class AgentInferResponse(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
# {key: inferred_value} — caller persists to user_preferences with source='inferred'
|
|
inferred_prefs: dict
|
|
|
|
|
|
class AgentOutputSnippet(BaseModel):
|
|
agent_id: str
|
|
prompt_text: str
|
|
|
|
|
|
class RecommendRequest(BaseModel):
|
|
user_id: str
|
|
agent_outputs: list[AgentOutputSnippet] = []
|
|
tasks: list[dict] = []
|
|
hour_of_day: int = 12
|
|
day_of_week: int = 0
|
|
|
|
|
|
class TipResult(BaseModel):
|
|
id: str
|
|
content: str
|
|
source: str = "llm"
|
|
kind: str = "advice"
|
|
rationale: Optional[str] = None
|
|
|
|
|
|
class RecommendResponse(BaseModel):
|
|
tip: TipResult
|
|
model: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
|
|
|
|
# ── Endpoints ──────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"ok": True,
|
|
"agents": [a.agent_id for a in all_agents()],
|
|
"nats": {
|
|
"enabled": bool(nats_consumer.NATS_URL),
|
|
"consumers": nats_consumer.consumer_health,
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/agents/registry")
|
|
def agents_registry():
|
|
"""Manifest list for every registered agent (ADR-0014).
|
|
|
|
Consumers: TS recommender (eligibility filter), admin UI (auto-rendered
|
|
pref forms), inference framework (#111). Static at process boot.
|
|
"""
|
|
return {"agents": [m.to_dict() for m in all_manifests()]}
|
|
|
|
|
|
_RETRY_SUFFIX = (
|
|
"\n\nYour previous response was not valid JSON. "
|
|
"Reply ONLY with the JSON array — no prose, no markdown fences."
|
|
)
|
|
|
|
_RETRY_SUFFIX_OBJ = (
|
|
"\n\nYour previous response was not valid JSON. "
|
|
"Reply ONLY with the JSON object — no prose, no markdown fences."
|
|
)
|
|
|
|
|
|
@app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse)
|
|
async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse:
|
|
"""Run a single sub-agent for a user and return its prompt snippet.
|
|
|
|
Called by the precompute pipeline for each (user_id, agent_id) pair.
|
|
The caller is responsible for persisting the result to agent_outputs via the
|
|
TypeScript API callback.
|
|
"""
|
|
try:
|
|
agent = get_agent(agent_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
|
|
|
|
now = (
|
|
datetime.fromisoformat(req.now_iso.replace("Z", "+00:00"))
|
|
if req.now_iso
|
|
else datetime.now(timezone.utc)
|
|
)
|
|
if now.tzinfo is None:
|
|
now = now.replace(tzinfo=timezone.utc)
|
|
|
|
inp = AgentInput(
|
|
user_id=req.user_id,
|
|
tasks=req.tasks,
|
|
profile=req.profile,
|
|
feedback_history=req.feedback_history,
|
|
now=now,
|
|
agent_prefs=req.agent_prefs,
|
|
)
|
|
try:
|
|
output = agent.compute(inp)
|
|
except Exception as exc:
|
|
log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc))
|
|
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
|
|
|
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
|
_mlflow_run(
|
|
run_name=f"compute/{agent_id}",
|
|
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
|
|
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
|
|
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
|
|
)
|
|
return AgentComputeResponse(
|
|
user_id=output.user_id,
|
|
agent_id=output.agent_id,
|
|
prompt_text=output.prompt_text,
|
|
signals_snapshot=output.signals_snapshot,
|
|
computed_at=output.computed_at,
|
|
expires_at=output.expires_at,
|
|
agent_version=output.agent_version,
|
|
)
|
|
|
|
|
|
@app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse)
|
|
async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse:
|
|
"""Run the inference framework for one agent and return inferred preference values.
|
|
|
|
The caller (TS agent-outputs.ts) persists results to user_preferences
|
|
with source='inferred', skipping keys where source='user' already exists.
|
|
"""
|
|
try:
|
|
manifest = get_manifest(agent_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
|
|
|
|
if not manifest.inferred_params:
|
|
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={})
|
|
|
|
events = [
|
|
FeedbackEvent(
|
|
action=e.get("action", ""),
|
|
dwell_ms=e.get("dwell_ms"),
|
|
created_at=e.get("created_at", ""),
|
|
)
|
|
for e in req.feedback_history
|
|
]
|
|
completions = [
|
|
TaskCompletion(
|
|
project_id=c.get("project_id"),
|
|
completed_at=c.get("completed_at", ""),
|
|
due_at=c.get("due_at", ""),
|
|
)
|
|
for c in req.task_completions
|
|
]
|
|
history = UserHistory(user_id=req.user_id, events=events, task_completions=completions)
|
|
|
|
t0 = __import__("time").monotonic()
|
|
inferred = run_inference(manifest, history)
|
|
latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1)
|
|
|
|
log.info(
|
|
"inference_run",
|
|
agent_id=agent_id,
|
|
user_id=req.user_id,
|
|
n_params=len(inferred),
|
|
history_len=len(events),
|
|
latency_ms=latency_ms,
|
|
)
|
|
_mlflow_run(
|
|
run_name=f"infer/{agent_id}",
|
|
params={"agent_id": agent_id, "user_id": req.user_id},
|
|
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
|
|
tags={"inferred_prefs": json.dumps(inferred)},
|
|
)
|
|
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
|
|
|
|
|
|
@app.post("/recommend", response_model=RecommendResponse)
|
|
async def recommend(req: RecommendRequest) -> RecommendResponse:
|
|
"""Orchestrator: combine pre-computed agent outputs into one tip via LLM.
|
|
|
|
Called in real time when a user requests a tip. agent_outputs should be
|
|
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
|
before calling this endpoint). Falls back to raw task context if empty.
|
|
"""
|
|
t0_recommend = time.monotonic()
|
|
messages = build_orchestrator_messages(
|
|
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
|
tasks=req.tasks,
|
|
hour_of_day=req.hour_of_day,
|
|
day_of_week=req.day_of_week,
|
|
)
|
|
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
|
last_raw = ""
|
|
last_parse_error = ""
|
|
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
model_used = "tip-generator"
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
|
|
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
|
try:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
|
|
)
|
|
resp.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
|
|
|
data = resp.json()
|
|
usage = data.get("usage", {})
|
|
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
model_used = data.get("model", "tip-generator")
|
|
last_raw = data["choices"][0]["message"]["content"]
|
|
|
|
try:
|
|
text = last_raw.strip()
|
|
if text.startswith("```"):
|
|
parts = text.split("```")
|
|
text = parts[1] if len(parts) > 1 else text
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
parsed = json.loads(text)
|
|
item: dict = parsed[0] if isinstance(parsed, list) else parsed
|
|
break
|
|
except (json.JSONDecodeError, ValueError, IndexError) as exc:
|
|
last_parse_error = str(exc)
|
|
messages.append({"role": "assistant", "content": last_raw})
|
|
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
|
|
else:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
|
f"{last_parse_error}\n{last_raw[:200]}",
|
|
)
|
|
|
|
tip = TipResult(
|
|
id=item.get("id", f"tip-{req.user_id[:8]}"),
|
|
content=item.get("content", ""),
|
|
rationale=item.get("rationale"),
|
|
)
|
|
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
|
|
log.info(
|
|
"recommend_served",
|
|
user_id=req.user_id,
|
|
agent_count=len(req.agent_outputs),
|
|
tip_id=tip.id,
|
|
)
|
|
_mlflow_run(
|
|
run_name="recommend",
|
|
params={
|
|
"user_id": req.user_id,
|
|
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
|
|
"model": model_used,
|
|
"hour_of_day": req.hour_of_day,
|
|
"day_of_week": req.day_of_week,
|
|
},
|
|
metrics={
|
|
"prompt_tokens": total_usage["prompt_tokens"],
|
|
"completion_tokens": total_usage["completion_tokens"],
|
|
"agent_count": len(req.agent_outputs),
|
|
"latency_ms": latency_ms_recommend,
|
|
},
|
|
tags={
|
|
"prompt_messages": json.dumps(messages),
|
|
"tip_content": tip.content,
|
|
"tip_rationale": tip.rationale or "",
|
|
},
|
|
)
|
|
return RecommendResponse(
|
|
tip=tip,
|
|
model=model_used,
|
|
prompt_tokens=total_usage["prompt_tokens"],
|
|
completion_tokens=total_usage["completion_tokens"],
|
|
)
|
|
|
|
_MAX_GENERATE_RETRIES = 2
|
|
|
|
|
|
def _parse_llm_json(raw: str) -> list[dict]:
|
|
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
|
|
text = raw.strip()
|
|
if text.startswith("```"):
|
|
parts = text.split("```")
|
|
text = parts[1] if len(parts) > 1 else text
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
return json.loads(text)
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse)
|
|
async def generate(req: GenerateRequest) -> GenerateResponse:
|
|
"""Generate tip candidates via LiteLLM → tip-generator alias.
|
|
|
|
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
|
|
a correction hint to the conversation so the model can self-correct.
|
|
"""
|
|
try:
|
|
prompt_template = get_prompt(req.prompt_version)
|
|
except KeyError as e:
|
|
raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}")
|
|
ctx = req.context.model_copy(update={"profile_features": req.profile_features})
|
|
user_msg = prompt_template.build_user(ctx, req.n)
|
|
messages: list[dict] = [
|
|
{"role": "system", "content": prompt_template.system},
|
|
{"role": "user", "content": user_msg},
|
|
]
|
|
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
|
last_parse_error: str = ""
|
|
last_raw: str = ""
|
|
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
model_used = "tip-generator"
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for attempt in range(1 + _MAX_GENERATE_RETRIES):
|
|
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
|
try:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/chat/completions",
|
|
json=payload,
|
|
headers=headers,
|
|
)
|
|
resp.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
|
|
|
data = resp.json()
|
|
usage = data.get("usage", {})
|
|
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
model_used = data.get("model", "tip-generator")
|
|
|
|
last_raw = data["choices"][0]["message"]["content"]
|
|
try:
|
|
items = _parse_llm_json(last_raw)
|
|
break
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
last_parse_error = str(e)
|
|
# Feed the bad reply back so the model can self-correct
|
|
messages.append({"role": "assistant", "content": last_raw})
|
|
messages.append({"role": "user", "content": _RETRY_SUFFIX})
|
|
else:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
|
f"{last_parse_error}\n{last_raw[:200]}",
|
|
)
|
|
|
|
candidates = [
|
|
TipCandidate(
|
|
id=item.get("id", f"tip-{i}"),
|
|
content=item.get("content", ""),
|
|
rationale=item.get("rationale"),
|
|
)
|
|
for i, item in enumerate(items)
|
|
]
|
|
|
|
return GenerateResponse(
|
|
candidates=candidates,
|
|
model=model_used,
|
|
prompt_version=prompt_template.version,
|
|
prompt_tokens=total_usage["prompt_tokens"],
|
|
completion_tokens=total_usage["completion_tokens"],
|
|
)
|
|
|
|
|