feat(serving): add MLflow tracing to ml-serving for all agent calls

Logs one MLflow run per /recommend (params, token metrics, latency,
full prompt + tip as artifacts) and per /agents/{id}/compute and
/infer call (signals snapshot, inferred prefs, latency).

Tracing is a no-op when MLFLOW_TRACKING_URI is unset; ml-serving
starts and serves tips correctly without MLflow configured.

Refs #118 (M4: remove from production / move off critical path).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 10:30:24 +00:00
parent 488a764519
commit c43dbaf23d
3 changed files with 263 additions and 0 deletions

View File

@@ -71,6 +71,7 @@ services:
environment: environment:
LITELLM_URL: ${LITELLM_URL:-http://host.docker.internal:4000} LITELLM_URL: ${LITELLM_URL:-http://host.docker.internal:4000}
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434} OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
MLFLOW_TRACKING_URI: ${MLFLOW_TRACKING_URI:-}
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
ports: ports:

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import json import json
import os import os
import sys import sys
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@@ -29,6 +30,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
import logging_config import logging_config
import nats_consumer import nats_consumer
from mlflow_client import MLflowClient
from prompts import get_prompt, build_orchestrator_messages from prompts import get_prompt, build_orchestrator_messages
# Make ml.agents importable regardless of working directory. # Make ml.agents importable regardless of working directory.
@@ -79,6 +81,30 @@ LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state")) STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
# ── MLflow tracing (optional) ───────────────────────────────────────────────
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
# is logged at WARNING and never propagates to the caller.
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
_mlflow: MLflowClient | None = MLflowClient(tracking_uri=_MLFLOW_URI) if _MLFLOW_URI else None
_MLFLOW_EXP = "oO/serving"
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
if _mlflow is None:
return
try:
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
_mlflow.log_metrics(run_id, metrics)
for k, v in tags.items():
_mlflow.log_text(run_id, str(v), k)
_mlflow.end_run(run_id)
except Exception as exc: # noqa: BLE001
log.warning("mlflow_log_failed", error=str(exc))
STATE_DIR.mkdir(parents=True, exist_ok=True) STATE_DIR.mkdir(parents=True, exist_ok=True)
@@ -251,6 +277,12 @@ async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentCompute
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}") raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at) log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
_mlflow_run(
run_name=f"compute/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
)
return AgentComputeResponse( return AgentComputeResponse(
user_id=output.user_id, user_id=output.user_id,
agent_id=output.agent_id, agent_id=output.agent_id,
@@ -307,6 +339,12 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon
history_len=len(events), history_len=len(events),
latency_ms=latency_ms, latency_ms=latency_ms,
) )
_mlflow_run(
run_name=f"infer/{agent_id}",
params={"agent_id": agent_id, "user_id": req.user_id},
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
tags={"inferred_prefs": json.dumps(inferred)},
)
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred) return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
@@ -318,6 +356,7 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
the fresh rows from agent_outputs table (fetched by the TypeScript recommender the fresh rows from agent_outputs table (fetched by the TypeScript recommender
before calling this endpoint). Falls back to raw task context if empty. before calling this endpoint). Falls back to raw task context if empty.
""" """
t0_recommend = time.monotonic()
messages = build_orchestrator_messages( messages = build_orchestrator_messages(
agent_outputs=[s.model_dump() for s in req.agent_outputs], agent_outputs=[s.model_dump() for s in req.agent_outputs],
tasks=req.tasks, tasks=req.tasks,
@@ -376,12 +415,34 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
content=item.get("content", ""), content=item.get("content", ""),
rationale=item.get("rationale"), rationale=item.get("rationale"),
) )
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
log.info( log.info(
"recommend_served", "recommend_served",
user_id=req.user_id, user_id=req.user_id,
agent_count=len(req.agent_outputs), agent_count=len(req.agent_outputs),
tip_id=tip.id, tip_id=tip.id,
) )
_mlflow_run(
run_name="recommend",
params={
"user_id": req.user_id,
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
"model": model_used,
"hour_of_day": req.hour_of_day,
"day_of_week": req.day_of_week,
},
metrics={
"prompt_tokens": total_usage["prompt_tokens"],
"completion_tokens": total_usage["completion_tokens"],
"agent_count": len(req.agent_outputs),
"latency_ms": latency_ms_recommend,
},
tags={
"prompt_messages": json.dumps(messages),
"tip_content": tip.content,
"tip_rationale": tip.rationale or "",
},
)
return RecommendResponse( return RecommendResponse(
tip=tip, tip=tip,
model=model_used, model=model_used,

201
ml/serving/mlflow_client.py Normal file
View File

@@ -0,0 +1,201 @@
"""Thin MLflow REST wrapper.
Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup:
1. The MLflow server (3.11) ships with ``--allowed-hosts localhost`` but
curl / requests / urllib3 send ``Host: localhost:5000`` — the port
suffix fails the DNS-rebinding check. We override the Host header per
request, which the SDK doesn't expose.
2. The collect/judge phases only need ~6 endpoints (create/search/log).
Pulling a 200MB SDK transitively for that is excess weight.
All calls are synchronous httpx with explicit ``Host`` so the script can
run from the host shell or from inside docker without further config.
"""
from __future__ import annotations
import os
import time
from dataclasses import dataclass
from typing import Any
import httpx
def _strip_path(uri: str) -> tuple[str, str]:
"""Return (origin, path_prefix) — handles both /mlflow and / roots.
``http://mlflow:5000/mlflow`` → ("http://mlflow:5000", "/mlflow")
``http://localhost:5000`` → ("http://localhost:5000", "")
"""
uri = uri.rstrip("/")
if "/" not in uri.split("://", 1)[1]:
return uri, ""
scheme_host, _, rest = uri.partition("://")
host, _, path = rest.partition("/")
return f"{scheme_host}://{host}", "/" + path if path else ""
@dataclass
class MLflowClient:
tracking_uri: str
username: str | None = None
password: str | None = None
host_header: str | None = None # override for DNS-rebinding sidestep
timeout: float = 30.0
def __post_init__(self) -> None:
self._origin, self._ui_prefix = _strip_path(self.tracking_uri)
# MLflow 3.x exposes the REST API at the root, *not* under the
# ``/mlflow`` UI prefix. Empirically verified against the running
# ghcr.io/mlflow/mlflow:v3.11.1 container.
self._api = f"{self._origin}/api/2.0/mlflow"
self._auth = (self.username, self.password) if self.username else None
# If user did not pass a host header, derive from origin. Strip
# the port if present — the server's allowed-hosts check rejects
# ``localhost:5000`` even when ``localhost`` is allowed.
if self.host_header is None:
host = self._origin.split("://", 1)[1]
self.host_header = host.split(":", 1)[0]
@classmethod
def from_env(cls) -> "MLflowClient":
return cls(
tracking_uri=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"),
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
host_header=os.environ.get("MLFLOW_HOST_HEADER"),
)
def _headers(self) -> dict[str, str]:
return {"Host": self.host_header or "localhost"}
def _post(self, path: str, body: dict) -> dict:
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
r = c.post(f"{self._api}{path}", json=body, headers=self._headers(), auth=self._auth)
r.raise_for_status()
return r.json()
def _get(self, path: str, params: dict | None = None) -> dict:
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
r = c.get(f"{self._api}{path}", params=params or {}, headers=self._headers(), auth=self._auth)
r.raise_for_status()
return r.json()
# ── Experiments ────────────────────────────────────────────────────
def get_or_create_experiment(self, name: str) -> str:
try:
r = self._get("/experiments/get-by-name", {"experiment_name": name})
return r["experiment"]["experiment_id"]
except httpx.HTTPStatusError as e:
if e.response.status_code not in (404, 400):
raise
r = self._post("/experiments/create", {"name": name})
return r["experiment_id"]
# ── Runs ───────────────────────────────────────────────────────────
def create_run(
self,
experiment_id: str,
run_name: str,
tags: dict[str, str] | None = None,
) -> str:
body: dict[str, Any] = {
"experiment_id": experiment_id,
"start_time": int(time.time() * 1000),
"run_name": run_name,
"tags": [
{"key": k, "value": str(v)}
for k, v in (tags or {}).items()
],
}
r = self._post("/runs/create", body)
return r["run"]["info"]["run_id"]
def log_param(self, run_id: str, key: str, value: Any) -> None:
self._post("/runs/log-parameter", {"run_id": run_id, "key": key, "value": str(value)})
def log_params(self, run_id: str, params: dict[str, Any]) -> None:
for k, v in params.items():
self.log_param(run_id, k, v)
def log_metric(self, run_id: str, key: str, value: float, step: int = 0) -> None:
self._post("/runs/log-metric", {
"run_id": run_id,
"key": key,
"value": float(value),
"timestamp": int(time.time() * 1000),
"step": step,
})
def log_metrics(self, run_id: str, metrics: dict[str, float]) -> None:
for k, v in metrics.items():
self.log_metric(run_id, k, v)
def set_tag(self, run_id: str, key: str, value: str) -> None:
self._post("/runs/set-tag", {"run_id": run_id, "key": key, "value": str(value)})
def set_tags(self, run_id: str, tags: dict[str, str]) -> None:
for k, v in tags.items():
self.set_tag(run_id, k, v)
# MLflow tag values are capped at 5000 chars by the server (RESOURCE_DOES_NOT_EXIST
# below that, INVALID_PARAMETER_VALUE above). 4500 leaves headroom for
# internal metadata MLflow may append on its own.
_TAG_VALUE_LIMIT = 4500
def log_text(self, run_id: str, text: str, artifact_path: str) -> None:
"""Persist short text alongside the run.
The MLflow server in this deployment uses a ``file://`` artifact
backend, which is only reachable from inside the container — not
via the REST proxy. We instead stash short payloads as tags
keyed ``artifact:<path>``. Anything longer than 4500 chars is
chunked into ``artifact:<path>:0``, ``:1`` …; ``get_artifact_text``
re-stitches them in order.
"""
key_base = f"artifact:{artifact_path}"
if len(text) <= self._TAG_VALUE_LIMIT:
self.set_tag(run_id, key_base, text)
return
# chunk
for i in range(0, len(text), self._TAG_VALUE_LIMIT):
self.set_tag(run_id, f"{key_base}:{i // self._TAG_VALUE_LIMIT}",
text[i:i + self._TAG_VALUE_LIMIT])
def get_artifact_text(self, run_id: str, artifact_path: str) -> str:
run = self._get("/runs/get", {"run_id": run_id})["run"]
tags = {t["key"]: t["value"] for t in run["data"].get("tags", [])}
key_base = f"artifact:{artifact_path}"
if key_base in tags:
return tags[key_base]
# chunked form
chunks = sorted(
(k for k in tags if k.startswith(f"{key_base}:")),
key=lambda k: int(k.rsplit(":", 1)[1]),
)
return "".join(tags[k] for k in chunks)
def end_run(self, run_id: str, status: str = "FINISHED") -> None:
self._post("/runs/update", {
"run_id": run_id,
"status": status,
"end_time": int(time.time() * 1000),
})
def search_runs(
self,
experiment_id: str,
filter_string: str = "",
max_results: int = 1000,
) -> list[dict]:
body = {
"experiment_ids": [experiment_id],
"filter": filter_string,
"max_results": max_results,
}
r = self._post("/runs/search", body)
return r.get("runs", [])