feat(serving): replace MLflow run logging with native trace spans

Convert ml-serving from isolated MLflow runs to nested traces using
mlflow.start_span_no_context(). The recommend endpoint now emits a full
span tree: recommend (CHAIN) → build_context (TOOL), agent:* (AGENT) ×N,
llm_orchestrator (LLM). Compute and infer endpoints each emit a single span.

Supporting changes:
- mlflow-skinny>=3.1.0 added to requirements
- MLflow configured with --serve-artifacts + mlflow-artifacts:/ default root
  for cross-container artifact proxy (spans now persist from ml-serving)
- --allowed-hosts extended to include mlflow:5000 (SDK includes port in Host)
- science_destiny slider wired through prompts.py and recommend endpoint
- Config page exposes science/destiny slider (0=data-driven, 100=intuitive)
- Tip page shows rationale inline on tap

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-11 08:26:05 +00:00
parent afacc34969
commit 161e654027
14 changed files with 419 additions and 141 deletions

View File

@@ -28,9 +28,11 @@ from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
import mlflow
from mlflow.entities import SpanType
import logging_config
import nats_consumer
from mlflow_client import MLflowClient
from prompts import get_prompt, build_orchestrator_messages
# Make ml.agents importable regardless of working directory.
@@ -83,36 +85,69 @@ LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
# ── MLflow tracing (optional) ───────────────────────────────────────────────
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
# is logged at WARNING and never propagates to the caller.
# Set MLFLOW_TRACKING_URI to enable. Spans are fire-and-forget; errors are
# logged at WARNING and never propagate to the caller.
# MLflow --allowed-hosts must include "mlflow" (the container DNS name) so the
# SDK can reach the server from inside other containers.
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
_mlflow: MLflowClient | None = (
MLflowClient(
tracking_uri=_MLFLOW_URI,
username=os.getenv("MLFLOW_TRACKING_USERNAME", "admin"),
password=os.getenv("MLFLOW_TRACKING_PASSWORD") or os.getenv("MLFLOW_ADMIN_PASSWORD", "password"),
host_header="localhost",
)
if _MLFLOW_URI else None
)
_MLFLOW_EXP = "oO/serving"
_mlflow_exp_id: str | None = None
if _MLFLOW_URI:
try:
mlflow.set_tracking_uri(_MLFLOW_URI)
_mlflow_exp_id = mlflow.set_experiment(_MLFLOW_EXP).experiment_id
except Exception as _exc:
log.warning("mlflow_init_failed", error=str(_exc))
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
if _mlflow is None:
class _NoOpSpan:
"""Returned when MLflow is disabled or span creation fails."""
def set_inputs(self, *a, **k): pass
def set_outputs(self, *a, **k): pass
def set_attribute(self, *a, **k): pass
def set_attributes(self, *a, **k): pass
def end(self, *a, **k): pass
_NOOP = _NoOpSpan()
def _start_span(name: str, span_type: str, *, parent=_NOOP, inputs=None):
"""Start an MLflow span. Returns _NOOP on failure or when tracing is off.
experiment_id is only passed for root spans (no parent) — passing it to
child spans causes the SDK to fail with '_Span has no attribute _span'.
"""
if _mlflow_exp_id is None:
return _NOOP
try:
kw: dict = {"span_type": span_type}
if isinstance(parent, _NoOpSpan):
kw["experiment_id"] = _mlflow_exp_id # root span only
else:
kw["parent_span"] = parent
if inputs is not None:
kw["inputs"] = inputs
return mlflow.start_span_no_context(name, **kw)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_span_start_failed", name=name, error=str(exc))
return _NOOP
def _end_span(span, *, status: str = "OK", outputs=None, attributes: dict | None = None) -> None:
"""End a span safely, ignoring _NoOpSpan and swallowing exceptions."""
if isinstance(span, _NoOpSpan):
return
try:
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
_mlflow.log_metrics(run_id, metrics)
for k, v in tags.items():
_mlflow.log_text(run_id, str(v), k)
_mlflow.end_run(run_id)
if attributes:
span.set_attributes(attributes)
span.end(status=status, outputs=outputs)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_log_failed", error=str(exc))
log.warning("mlflow_span_end_failed", error=str(exc))
STATE_DIR.mkdir(parents=True, exist_ok=True)
@@ -197,6 +232,7 @@ class RecommendRequest(BaseModel):
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
science_destiny: int = 50 # 0=science (data-driven), 100=destiny (intuitive)
class TipResult(BaseModel):
@@ -285,12 +321,15 @@ async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentCompute
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
_mlflow_run(
run_name=f"compute/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
span = _start_span(
f"compute:{agent_id}",
SpanType.AGENT,
inputs={"user_id": req.user_id, "agent_id": agent_id,
"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
)
_end_span(span,
outputs={"prompt_text": output.prompt_text, "signals_snapshot": output.signals_snapshot},
attributes={"agent_version": output.agent_version, "expires_at": output.expires_at})
return AgentComputeResponse(
user_id=output.user_id,
agent_id=output.agent_id,
@@ -347,12 +386,15 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon
history_len=len(events),
latency_ms=latency_ms,
)
_mlflow_run(
run_name=f"infer/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id},
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
tags={"inferred_prefs": json.dumps(inferred)},
span = _start_span(
f"infer:{agent_id}",
SpanType.CHAIN,
inputs={"user_id": req.user_id, "agent_id": agent_id,
"history_len": len(events), "completion_count": len(completions)},
)
_end_span(span,
outputs={"inferred_prefs": inferred},
attributes={"latency_ms": str(latency_ms), "n_params": str(len(inferred))})
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
@@ -364,99 +406,132 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
before calling this endpoint). Falls back to raw task context if empty.
"""
t0_recommend = time.monotonic()
messages = build_orchestrator_messages(
agent_outputs=[s.model_dump() for s in req.agent_outputs],
tasks=req.tasks,
hour_of_day=req.hour_of_day,
day_of_week=req.day_of_week,
)
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_raw = ""
last_parse_error = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
t0 = time.monotonic()
async with httpx.AsyncClient(timeout=30.0) as client:
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
# ── root span ──────────────────────────────────────────────────────────
root = _start_span("recommend", SpanType.CHAIN, inputs={
"user_id": req.user_id,
"agent_ids": [s.agent_id for s in req.agent_outputs],
"hour_of_day": req.hour_of_day,
"day_of_week": req.day_of_week,
"science_destiny": req.science_destiny,
})
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
text = last_raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
parsed = json.loads(text)
item: dict = parsed[0] if isinstance(parsed, list) else parsed
break
except (json.JSONDecodeError, ValueError, IndexError) as exc:
last_parse_error = str(exc)
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
tip = TipResult(
id=item.get("id", f"tip-{req.user_id[:8]}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
log.info(
"recommend_served",
user_id=req.user_id,
agent_count=len(req.agent_outputs),
tip_id=tip.id,
)
_mlflow_run(
run_name="recommend",
params={
"user_id": req.user_id,
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
"model": model_used,
"hour_of_day": req.hour_of_day,
"day_of_week": req.day_of_week,
},
metrics={
"prompt_tokens": total_usage["prompt_tokens"],
"completion_tokens": total_usage["completion_tokens"],
try:
# ── build_context span ─────────────────────────────────────────────
ctx_span = _start_span("build_context", SpanType.TOOL, parent=root, inputs={
"agent_count": len(req.agent_outputs),
"latency_ms": latency_ms_recommend,
},
tags={
"prompt_messages": json.dumps(messages),
"tip_content": tip.content,
"tip_rationale": tip.rationale or "",
},
)
return RecommendResponse(
tip=tip,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
"task_count": len(req.tasks),
"science_destiny": req.science_destiny,
})
messages = build_orchestrator_messages(
agent_outputs=[s.model_dump() for s in req.agent_outputs],
tasks=req.tasks,
hour_of_day=req.hour_of_day,
day_of_week=req.day_of_week,
science_destiny=req.science_destiny,
)
_end_span(ctx_span, outputs={"message_count": len(messages)})
# ── one span per pre-computed agent snippet ────────────────────────
for snippet in req.agent_outputs:
a_span = _start_span(
f"agent:{snippet.agent_id}", SpanType.AGENT, parent=root,
inputs={"agent_id": snippet.agent_id},
)
_end_span(a_span, outputs={"prompt_text": snippet.prompt_text})
# ── LLM orchestrator span (wraps retry loop) ───────────────────────
llm_span = _start_span("llm_orchestrator", SpanType.LLM, parent=root, inputs={
"messages": messages,
"model": "tip-generator",
"temperature": 0.7,
})
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_raw = ""
last_parse_error = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
_attempt = 0
async with httpx.AsyncClient(timeout=30.0) as client:
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
text = last_raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
parsed = json.loads(text)
item: dict = parsed[0] if isinstance(parsed, list) else parsed
break
except (json.JSONDecodeError, ValueError, IndexError) as exc:
last_parse_error = str(exc)
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
else:
_end_span(llm_span, status="ERROR")
_end_span(root, status="ERROR")
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
tip = TipResult(
id=item.get("id", f"tip-{req.user_id[:8]}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
_end_span(llm_span, outputs={"content": tip.content, "rationale": tip.rationale or ""},
attributes={
"prompt_tokens": str(total_usage["prompt_tokens"]),
"completion_tokens": str(total_usage["completion_tokens"]),
"model": model_used,
"attempts": str(_attempt + 1),
})
latency_ms = round((time.monotonic() - t0) * 1000, 1)
log.info("recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id)
_end_span(root, outputs={"tip_id": tip.id, "content": tip.content, "rationale": tip.rationale or ""},
attributes={"latency_ms": str(latency_ms), "agent_count": str(len(req.agent_outputs))})
return RecommendResponse(
tip=tip,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
except HTTPException:
raise
except Exception:
_end_span(root, status="ERROR")
raise
_MAX_GENERATE_RETRIES = 2

View File

@@ -124,17 +124,52 @@ _SYS_V4_ORCHESTRATOR = (
)
def _science_destiny_instruction(science_destiny: int) -> str:
"""Translate 0-100 slider into a prompt instruction.
0 = pure science: prioritise patterns, data, measurable progress.
100 = pure destiny: prioritise meaning, intuition, deeper purpose.
50 = balanced (no extra instruction injected).
"""
if science_destiny <= 20:
return (
"The user strongly prefers data-driven advice. "
"Ground every tip in observable patterns, streaks, or measurable progress. "
"Avoid abstract or motivational language."
)
if science_destiny <= 40:
return (
"The user leans toward evidence-based guidance. "
"Anchor tips in patterns and metrics where possible."
)
if science_destiny >= 80:
return (
"The user strongly believes in intuition and meaning. "
"Frame tips around purpose, values, and deeper intention rather than metrics."
)
if science_destiny >= 60:
return (
"The user leans toward intuitive, meaning-driven advice. "
"Weave in purpose and intention alongside practicality."
)
return "" # balanced — no extra instruction
def build_orchestrator_messages(
agent_outputs: list[dict],
tasks: list[dict],
hour_of_day: int,
day_of_week: int,
science_destiny: int = 50,
) -> list[dict]:
"""Build the [system, user] message list for the orchestrator LLM call.
agent_outputs: list of {agent_id, prompt_text} dicts.
Falls back to raw task summary when agent_outputs is empty.
"""
style_hint = _science_destiny_instruction(science_destiny)
system = _SYS_V4_ORCHESTRATOR + (f"\n\n{style_hint}" if style_hint else "")
lines = [f"Current time: {hour_of_day:02d}:00, day_of_week={day_of_week}", ""]
if agent_outputs:
lines.append("Context from analysis agents:")
@@ -150,7 +185,7 @@ def build_orchestrator_messages(
lines.append(f" - {t.get('content', '?')}")
lines.append("\nGenerate one tip as a JSON object. Write the tip content in English only.")
return [
{"role": "system", "content": _SYS_V4_ORCHESTRATOR},
{"role": "system", "content": system},
{"role": "user", "content": "\n".join(lines)},
]

View File

@@ -7,3 +7,4 @@ anthropic>=0.40.0
nats-py>=2.9.0
structlog>=24.1.0
sentry-sdk>=2.0.0
mlflow-skinny>=3.1.0