feat(serving): replace MLflow run logging with native trace spans
Convert ml-serving from isolated MLflow runs to nested traces using mlflow.start_span_no_context(). The recommend endpoint now emits a full span tree: recommend (CHAIN) → build_context (TOOL), agent:* (AGENT) ×N, llm_orchestrator (LLM). Compute and infer endpoints each emit a single span. Supporting changes: - mlflow-skinny>=3.1.0 added to requirements - MLflow configured with --serve-artifacts + mlflow-artifacts:/ default root for cross-container artifact proxy (spans now persist from ml-serving) - --allowed-hosts extended to include mlflow:5000 (SDK includes port in Host) - science_destiny slider wired through prompts.py and recommend endpoint - Config page exposes science/destiny slider (0=data-driven, 100=intuitive) - Tip page shows rationale inline on tap Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -28,9 +28,11 @@ from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import SpanType
|
||||
|
||||
import logging_config
|
||||
import nats_consumer
|
||||
from mlflow_client import MLflowClient
|
||||
from prompts import get_prompt, build_orchestrator_messages
|
||||
|
||||
# Make ml.agents importable regardless of working directory.
|
||||
@@ -83,36 +85,69 @@ LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
||||
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
|
||||
|
||||
# ── MLflow tracing (optional) ───────────────────────────────────────────────
|
||||
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
|
||||
# is logged at WARNING and never propagates to the caller.
|
||||
# Set MLFLOW_TRACKING_URI to enable. Spans are fire-and-forget; errors are
|
||||
# logged at WARNING and never propagate to the caller.
|
||||
# MLflow --allowed-hosts must include "mlflow" (the container DNS name) so the
|
||||
# SDK can reach the server from inside other containers.
|
||||
|
||||
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
|
||||
_mlflow: MLflowClient | None = (
|
||||
MLflowClient(
|
||||
tracking_uri=_MLFLOW_URI,
|
||||
username=os.getenv("MLFLOW_TRACKING_USERNAME", "admin"),
|
||||
password=os.getenv("MLFLOW_TRACKING_PASSWORD") or os.getenv("MLFLOW_ADMIN_PASSWORD", "password"),
|
||||
host_header="localhost",
|
||||
)
|
||||
if _MLFLOW_URI else None
|
||||
)
|
||||
_MLFLOW_EXP = "oO/serving"
|
||||
_mlflow_exp_id: str | None = None
|
||||
|
||||
if _MLFLOW_URI:
|
||||
try:
|
||||
mlflow.set_tracking_uri(_MLFLOW_URI)
|
||||
_mlflow_exp_id = mlflow.set_experiment(_MLFLOW_EXP).experiment_id
|
||||
except Exception as _exc:
|
||||
log.warning("mlflow_init_failed", error=str(_exc))
|
||||
|
||||
|
||||
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
|
||||
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
|
||||
if _mlflow is None:
|
||||
class _NoOpSpan:
|
||||
"""Returned when MLflow is disabled or span creation fails."""
|
||||
def set_inputs(self, *a, **k): pass
|
||||
def set_outputs(self, *a, **k): pass
|
||||
def set_attribute(self, *a, **k): pass
|
||||
def set_attributes(self, *a, **k): pass
|
||||
def end(self, *a, **k): pass
|
||||
|
||||
|
||||
_NOOP = _NoOpSpan()
|
||||
|
||||
|
||||
def _start_span(name: str, span_type: str, *, parent=_NOOP, inputs=None):
|
||||
"""Start an MLflow span. Returns _NOOP on failure or when tracing is off.
|
||||
|
||||
experiment_id is only passed for root spans (no parent) — passing it to
|
||||
child spans causes the SDK to fail with '_Span has no attribute _span'.
|
||||
"""
|
||||
if _mlflow_exp_id is None:
|
||||
return _NOOP
|
||||
try:
|
||||
kw: dict = {"span_type": span_type}
|
||||
if isinstance(parent, _NoOpSpan):
|
||||
kw["experiment_id"] = _mlflow_exp_id # root span only
|
||||
else:
|
||||
kw["parent_span"] = parent
|
||||
if inputs is not None:
|
||||
kw["inputs"] = inputs
|
||||
return mlflow.start_span_no_context(name, **kw)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("mlflow_span_start_failed", name=name, error=str(exc))
|
||||
return _NOOP
|
||||
|
||||
|
||||
def _end_span(span, *, status: str = "OK", outputs=None, attributes: dict | None = None) -> None:
|
||||
"""End a span safely, ignoring _NoOpSpan and swallowing exceptions."""
|
||||
if isinstance(span, _NoOpSpan):
|
||||
return
|
||||
try:
|
||||
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
|
||||
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
|
||||
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
|
||||
_mlflow.log_metrics(run_id, metrics)
|
||||
for k, v in tags.items():
|
||||
_mlflow.log_text(run_id, str(v), k)
|
||||
_mlflow.end_run(run_id)
|
||||
if attributes:
|
||||
span.set_attributes(attributes)
|
||||
span.end(status=status, outputs=outputs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("mlflow_log_failed", error=str(exc))
|
||||
log.warning("mlflow_span_end_failed", error=str(exc))
|
||||
|
||||
|
||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -197,6 +232,7 @@ class RecommendRequest(BaseModel):
|
||||
tasks: list[dict] = []
|
||||
hour_of_day: int = 12
|
||||
day_of_week: int = 0
|
||||
science_destiny: int = 50 # 0=science (data-driven), 100=destiny (intuitive)
|
||||
|
||||
|
||||
class TipResult(BaseModel):
|
||||
@@ -285,12 +321,15 @@ async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentCompute
|
||||
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
||||
|
||||
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
||||
_mlflow_run(
|
||||
run_name=f"compute/{agent_id}",
|
||||
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
|
||||
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
|
||||
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
|
||||
span = _start_span(
|
||||
f"compute:{agent_id}",
|
||||
SpanType.AGENT,
|
||||
inputs={"user_id": req.user_id, "agent_id": agent_id,
|
||||
"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
|
||||
)
|
||||
_end_span(span,
|
||||
outputs={"prompt_text": output.prompt_text, "signals_snapshot": output.signals_snapshot},
|
||||
attributes={"agent_version": output.agent_version, "expires_at": output.expires_at})
|
||||
return AgentComputeResponse(
|
||||
user_id=output.user_id,
|
||||
agent_id=output.agent_id,
|
||||
@@ -347,12 +386,15 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon
|
||||
history_len=len(events),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
_mlflow_run(
|
||||
run_name=f"infer/{agent_id}",
|
||||
params={"agent_id": agent_id, "user_id": req.user_id},
|
||||
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
|
||||
tags={"inferred_prefs": json.dumps(inferred)},
|
||||
span = _start_span(
|
||||
f"infer:{agent_id}",
|
||||
SpanType.CHAIN,
|
||||
inputs={"user_id": req.user_id, "agent_id": agent_id,
|
||||
"history_len": len(events), "completion_count": len(completions)},
|
||||
)
|
||||
_end_span(span,
|
||||
outputs={"inferred_prefs": inferred},
|
||||
attributes={"latency_ms": str(latency_ms), "n_params": str(len(inferred))})
|
||||
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
|
||||
|
||||
|
||||
@@ -364,99 +406,132 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
|
||||
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
||||
before calling this endpoint). Falls back to raw task context if empty.
|
||||
"""
|
||||
t0_recommend = time.monotonic()
|
||||
messages = build_orchestrator_messages(
|
||||
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
||||
tasks=req.tasks,
|
||||
hour_of_day=req.hour_of_day,
|
||||
day_of_week=req.day_of_week,
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
||||
last_raw = ""
|
||||
last_parse_error = ""
|
||||
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
model_used = "tip-generator"
|
||||
t0 = time.monotonic()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
|
||||
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
||||
# ── root span ──────────────────────────────────────────────────────────
|
||||
root = _start_span("recommend", SpanType.CHAIN, inputs={
|
||||
"user_id": req.user_id,
|
||||
"agent_ids": [s.agent_id for s in req.agent_outputs],
|
||||
"hour_of_day": req.hour_of_day,
|
||||
"day_of_week": req.day_of_week,
|
||||
"science_destiny": req.science_destiny,
|
||||
})
|
||||
|
||||
data = resp.json()
|
||||
usage = data.get("usage", {})
|
||||
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
model_used = data.get("model", "tip-generator")
|
||||
last_raw = data["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
text = last_raw.strip()
|
||||
if text.startswith("```"):
|
||||
parts = text.split("```")
|
||||
text = parts[1] if len(parts) > 1 else text
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
parsed = json.loads(text)
|
||||
item: dict = parsed[0] if isinstance(parsed, list) else parsed
|
||||
break
|
||||
except (json.JSONDecodeError, ValueError, IndexError) as exc:
|
||||
last_parse_error = str(exc)
|
||||
messages.append({"role": "assistant", "content": last_raw})
|
||||
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
||||
f"{last_parse_error}\n{last_raw[:200]}",
|
||||
)
|
||||
|
||||
tip = TipResult(
|
||||
id=item.get("id", f"tip-{req.user_id[:8]}"),
|
||||
content=item.get("content", ""),
|
||||
rationale=item.get("rationale"),
|
||||
)
|
||||
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
|
||||
log.info(
|
||||
"recommend_served",
|
||||
user_id=req.user_id,
|
||||
agent_count=len(req.agent_outputs),
|
||||
tip_id=tip.id,
|
||||
)
|
||||
_mlflow_run(
|
||||
run_name="recommend",
|
||||
params={
|
||||
"user_id": req.user_id,
|
||||
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
|
||||
"model": model_used,
|
||||
"hour_of_day": req.hour_of_day,
|
||||
"day_of_week": req.day_of_week,
|
||||
},
|
||||
metrics={
|
||||
"prompt_tokens": total_usage["prompt_tokens"],
|
||||
"completion_tokens": total_usage["completion_tokens"],
|
||||
try:
|
||||
# ── build_context span ─────────────────────────────────────────────
|
||||
ctx_span = _start_span("build_context", SpanType.TOOL, parent=root, inputs={
|
||||
"agent_count": len(req.agent_outputs),
|
||||
"latency_ms": latency_ms_recommend,
|
||||
},
|
||||
tags={
|
||||
"prompt_messages": json.dumps(messages),
|
||||
"tip_content": tip.content,
|
||||
"tip_rationale": tip.rationale or "",
|
||||
},
|
||||
)
|
||||
return RecommendResponse(
|
||||
tip=tip,
|
||||
model=model_used,
|
||||
prompt_tokens=total_usage["prompt_tokens"],
|
||||
completion_tokens=total_usage["completion_tokens"],
|
||||
)
|
||||
"task_count": len(req.tasks),
|
||||
"science_destiny": req.science_destiny,
|
||||
})
|
||||
messages = build_orchestrator_messages(
|
||||
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
||||
tasks=req.tasks,
|
||||
hour_of_day=req.hour_of_day,
|
||||
day_of_week=req.day_of_week,
|
||||
science_destiny=req.science_destiny,
|
||||
)
|
||||
_end_span(ctx_span, outputs={"message_count": len(messages)})
|
||||
|
||||
# ── one span per pre-computed agent snippet ────────────────────────
|
||||
for snippet in req.agent_outputs:
|
||||
a_span = _start_span(
|
||||
f"agent:{snippet.agent_id}", SpanType.AGENT, parent=root,
|
||||
inputs={"agent_id": snippet.agent_id},
|
||||
)
|
||||
_end_span(a_span, outputs={"prompt_text": snippet.prompt_text})
|
||||
|
||||
# ── LLM orchestrator span (wraps retry loop) ───────────────────────
|
||||
llm_span = _start_span("llm_orchestrator", SpanType.LLM, parent=root, inputs={
|
||||
"messages": messages,
|
||||
"model": "tip-generator",
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
||||
last_raw = ""
|
||||
last_parse_error = ""
|
||||
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
model_used = "tip-generator"
|
||||
_attempt = 0
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
|
||||
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
_end_span(llm_span, status="ERROR")
|
||||
_end_span(root, status="ERROR")
|
||||
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
_end_span(llm_span, status="ERROR")
|
||||
_end_span(root, status="ERROR")
|
||||
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
||||
|
||||
data = resp.json()
|
||||
usage = data.get("usage", {})
|
||||
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
model_used = data.get("model", "tip-generator")
|
||||
last_raw = data["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
text = last_raw.strip()
|
||||
if text.startswith("```"):
|
||||
parts = text.split("```")
|
||||
text = parts[1] if len(parts) > 1 else text
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
parsed = json.loads(text)
|
||||
item: dict = parsed[0] if isinstance(parsed, list) else parsed
|
||||
break
|
||||
except (json.JSONDecodeError, ValueError, IndexError) as exc:
|
||||
last_parse_error = str(exc)
|
||||
messages.append({"role": "assistant", "content": last_raw})
|
||||
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
|
||||
else:
|
||||
_end_span(llm_span, status="ERROR")
|
||||
_end_span(root, status="ERROR")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
||||
f"{last_parse_error}\n{last_raw[:200]}",
|
||||
)
|
||||
|
||||
tip = TipResult(
|
||||
id=item.get("id", f"tip-{req.user_id[:8]}"),
|
||||
content=item.get("content", ""),
|
||||
rationale=item.get("rationale"),
|
||||
)
|
||||
_end_span(llm_span, outputs={"content": tip.content, "rationale": tip.rationale or ""},
|
||||
attributes={
|
||||
"prompt_tokens": str(total_usage["prompt_tokens"]),
|
||||
"completion_tokens": str(total_usage["completion_tokens"]),
|
||||
"model": model_used,
|
||||
"attempts": str(_attempt + 1),
|
||||
})
|
||||
|
||||
latency_ms = round((time.monotonic() - t0) * 1000, 1)
|
||||
log.info("recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id)
|
||||
_end_span(root, outputs={"tip_id": tip.id, "content": tip.content, "rationale": tip.rationale or ""},
|
||||
attributes={"latency_ms": str(latency_ms), "agent_count": str(len(req.agent_outputs))})
|
||||
|
||||
return RecommendResponse(
|
||||
tip=tip,
|
||||
model=model_used,
|
||||
prompt_tokens=total_usage["prompt_tokens"],
|
||||
completion_tokens=total_usage["completion_tokens"],
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
_end_span(root, status="ERROR")
|
||||
raise
|
||||
|
||||
_MAX_GENERATE_RETRIES = 2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user