feat(serving): add MLflow tracing to ml-serving for all agent calls
Logs one MLflow run per /recommend (params, token metrics, latency,
full prompt + tip as artifacts) and per /agents/{id}/compute and
/infer call (signals snapshot, inferred prefs, latency).
Tracing is a no-op when MLFLOW_TRACKING_URI is unset; ml-serving
starts and serves tips correctly without MLflow configured.
Refs #118 (M4: remove from production / move off critical path).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -71,6 +71,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
LITELLM_URL: ${LITELLM_URL:-http://host.docker.internal:4000}
|
LITELLM_URL: ${LITELLM_URL:-http://host.docker.internal:4000}
|
||||||
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
|
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
|
||||||
|
MLFLOW_TRACKING_URI: ${MLFLOW_TRACKING_URI:-}
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,6 +30,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
|
|
||||||
import logging_config
|
import logging_config
|
||||||
import nats_consumer
|
import nats_consumer
|
||||||
|
from mlflow_client import MLflowClient
|
||||||
from prompts import get_prompt, build_orchestrator_messages
|
from prompts import get_prompt, build_orchestrator_messages
|
||||||
|
|
||||||
# Make ml.agents importable regardless of working directory.
|
# Make ml.agents importable regardless of working directory.
|
||||||
@@ -79,6 +81,30 @@ LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
|
|||||||
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
||||||
|
|
||||||
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
|
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
|
||||||
|
|
||||||
|
# ── MLflow tracing (optional) ───────────────────────────────────────────────
|
||||||
|
# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error
|
||||||
|
# is logged at WARNING and never propagates to the caller.
|
||||||
|
|
||||||
|
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
|
||||||
|
_mlflow: MLflowClient | None = MLflowClient(tracking_uri=_MLFLOW_URI) if _MLFLOW_URI else None
|
||||||
|
_MLFLOW_EXP = "oO/serving"
|
||||||
|
|
||||||
|
|
||||||
|
def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None:
|
||||||
|
"""Create a finished MLflow run. Silently no-ops if MLflow is not configured."""
|
||||||
|
if _mlflow is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP)
|
||||||
|
run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"})
|
||||||
|
_mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()})
|
||||||
|
_mlflow.log_metrics(run_id, metrics)
|
||||||
|
for k, v in tags.items():
|
||||||
|
_mlflow.log_text(run_id, str(v), k)
|
||||||
|
_mlflow.end_run(run_id)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
log.warning("mlflow_log_failed", error=str(exc))
|
||||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -251,6 +277,12 @@ async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentCompute
|
|||||||
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
||||||
|
|
||||||
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
||||||
|
_mlflow_run(
|
||||||
|
run_name=f"compute/{agent_id}",
|
||||||
|
params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version},
|
||||||
|
metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
|
||||||
|
tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)},
|
||||||
|
)
|
||||||
return AgentComputeResponse(
|
return AgentComputeResponse(
|
||||||
user_id=output.user_id,
|
user_id=output.user_id,
|
||||||
agent_id=output.agent_id,
|
agent_id=output.agent_id,
|
||||||
@@ -307,6 +339,12 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon
|
|||||||
history_len=len(events),
|
history_len=len(events),
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
)
|
)
|
||||||
|
_mlflow_run(
|
||||||
|
run_name=f"infer/{agent_id}",
|
||||||
|
params={"agent_id": agent_id, "user_id": req.user_id},
|
||||||
|
metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)},
|
||||||
|
tags={"inferred_prefs": json.dumps(inferred)},
|
||||||
|
)
|
||||||
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
|
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
|
||||||
|
|
||||||
|
|
||||||
@@ -318,6 +356,7 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
|
|||||||
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
||||||
before calling this endpoint). Falls back to raw task context if empty.
|
before calling this endpoint). Falls back to raw task context if empty.
|
||||||
"""
|
"""
|
||||||
|
t0_recommend = time.monotonic()
|
||||||
messages = build_orchestrator_messages(
|
messages = build_orchestrator_messages(
|
||||||
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
||||||
tasks=req.tasks,
|
tasks=req.tasks,
|
||||||
@@ -376,12 +415,34 @@ async def recommend(req: RecommendRequest) -> RecommendResponse:
|
|||||||
content=item.get("content", ""),
|
content=item.get("content", ""),
|
||||||
rationale=item.get("rationale"),
|
rationale=item.get("rationale"),
|
||||||
)
|
)
|
||||||
|
latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1)
|
||||||
log.info(
|
log.info(
|
||||||
"recommend_served",
|
"recommend_served",
|
||||||
user_id=req.user_id,
|
user_id=req.user_id,
|
||||||
agent_count=len(req.agent_outputs),
|
agent_count=len(req.agent_outputs),
|
||||||
tip_id=tip.id,
|
tip_id=tip.id,
|
||||||
)
|
)
|
||||||
|
_mlflow_run(
|
||||||
|
run_name="recommend",
|
||||||
|
params={
|
||||||
|
"user_id": req.user_id,
|
||||||
|
"agent_ids": ",".join(s.agent_id for s in req.agent_outputs),
|
||||||
|
"model": model_used,
|
||||||
|
"hour_of_day": req.hour_of_day,
|
||||||
|
"day_of_week": req.day_of_week,
|
||||||
|
},
|
||||||
|
metrics={
|
||||||
|
"prompt_tokens": total_usage["prompt_tokens"],
|
||||||
|
"completion_tokens": total_usage["completion_tokens"],
|
||||||
|
"agent_count": len(req.agent_outputs),
|
||||||
|
"latency_ms": latency_ms_recommend,
|
||||||
|
},
|
||||||
|
tags={
|
||||||
|
"prompt_messages": json.dumps(messages),
|
||||||
|
"tip_content": tip.content,
|
||||||
|
"tip_rationale": tip.rationale or "",
|
||||||
|
},
|
||||||
|
)
|
||||||
return RecommendResponse(
|
return RecommendResponse(
|
||||||
tip=tip,
|
tip=tip,
|
||||||
model=model_used,
|
model=model_used,
|
||||||
|
|||||||
201
ml/serving/mlflow_client.py
Normal file
201
ml/serving/mlflow_client.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Thin MLflow REST wrapper.
|
||||||
|
|
||||||
|
Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup:
|
||||||
|
|
||||||
|
1. The MLflow server (3.11) ships with ``--allowed-hosts localhost`` but
|
||||||
|
curl / requests / urllib3 send ``Host: localhost:5000`` — the port
|
||||||
|
suffix fails the DNS-rebinding check. We override the Host header per
|
||||||
|
request, which the SDK doesn't expose.
|
||||||
|
2. The collect/judge phases only need ~6 endpoints (create/search/log).
|
||||||
|
Pulling a 200MB SDK transitively for that is excess weight.
|
||||||
|
|
||||||
|
All calls are synchronous httpx with explicit ``Host`` so the script can
|
||||||
|
run from the host shell or from inside docker without further config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_path(uri: str) -> tuple[str, str]:
|
||||||
|
"""Return (origin, path_prefix) — handles both /mlflow and / roots.
|
||||||
|
|
||||||
|
``http://mlflow:5000/mlflow`` → ("http://mlflow:5000", "/mlflow")
|
||||||
|
``http://localhost:5000`` → ("http://localhost:5000", "")
|
||||||
|
"""
|
||||||
|
uri = uri.rstrip("/")
|
||||||
|
if "/" not in uri.split("://", 1)[1]:
|
||||||
|
return uri, ""
|
||||||
|
scheme_host, _, rest = uri.partition("://")
|
||||||
|
host, _, path = rest.partition("/")
|
||||||
|
return f"{scheme_host}://{host}", "/" + path if path else ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLflowClient:
|
||||||
|
tracking_uri: str
|
||||||
|
username: str | None = None
|
||||||
|
password: str | None = None
|
||||||
|
host_header: str | None = None # override for DNS-rebinding sidestep
|
||||||
|
timeout: float = 30.0
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self._origin, self._ui_prefix = _strip_path(self.tracking_uri)
|
||||||
|
# MLflow 3.x exposes the REST API at the root, *not* under the
|
||||||
|
# ``/mlflow`` UI prefix. Empirically verified against the running
|
||||||
|
# ghcr.io/mlflow/mlflow:v3.11.1 container.
|
||||||
|
self._api = f"{self._origin}/api/2.0/mlflow"
|
||||||
|
self._auth = (self.username, self.password) if self.username else None
|
||||||
|
# If user did not pass a host header, derive from origin. Strip
|
||||||
|
# the port if present — the server's allowed-hosts check rejects
|
||||||
|
# ``localhost:5000`` even when ``localhost`` is allowed.
|
||||||
|
if self.host_header is None:
|
||||||
|
host = self._origin.split("://", 1)[1]
|
||||||
|
self.host_header = host.split(":", 1)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "MLflowClient":
|
||||||
|
return cls(
|
||||||
|
tracking_uri=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"),
|
||||||
|
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
|
||||||
|
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
|
||||||
|
host_header=os.environ.get("MLFLOW_HOST_HEADER"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _headers(self) -> dict[str, str]:
|
||||||
|
return {"Host": self.host_header or "localhost"}
|
||||||
|
|
||||||
|
def _post(self, path: str, body: dict) -> dict:
|
||||||
|
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
|
||||||
|
r = c.post(f"{self._api}{path}", json=body, headers=self._headers(), auth=self._auth)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
def _get(self, path: str, params: dict | None = None) -> dict:
|
||||||
|
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
|
||||||
|
r = c.get(f"{self._api}{path}", params=params or {}, headers=self._headers(), auth=self._auth)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
# ── Experiments ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_or_create_experiment(self, name: str) -> str:
|
||||||
|
try:
|
||||||
|
r = self._get("/experiments/get-by-name", {"experiment_name": name})
|
||||||
|
return r["experiment"]["experiment_id"]
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code not in (404, 400):
|
||||||
|
raise
|
||||||
|
r = self._post("/experiments/create", {"name": name})
|
||||||
|
return r["experiment_id"]
|
||||||
|
|
||||||
|
# ── Runs ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_run(
|
||||||
|
self,
|
||||||
|
experiment_id: str,
|
||||||
|
run_name: str,
|
||||||
|
tags: dict[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"experiment_id": experiment_id,
|
||||||
|
"start_time": int(time.time() * 1000),
|
||||||
|
"run_name": run_name,
|
||||||
|
"tags": [
|
||||||
|
{"key": k, "value": str(v)}
|
||||||
|
for k, v in (tags or {}).items()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
r = self._post("/runs/create", body)
|
||||||
|
return r["run"]["info"]["run_id"]
|
||||||
|
|
||||||
|
def log_param(self, run_id: str, key: str, value: Any) -> None:
|
||||||
|
self._post("/runs/log-parameter", {"run_id": run_id, "key": key, "value": str(value)})
|
||||||
|
|
||||||
|
def log_params(self, run_id: str, params: dict[str, Any]) -> None:
|
||||||
|
for k, v in params.items():
|
||||||
|
self.log_param(run_id, k, v)
|
||||||
|
|
||||||
|
def log_metric(self, run_id: str, key: str, value: float, step: int = 0) -> None:
|
||||||
|
self._post("/runs/log-metric", {
|
||||||
|
"run_id": run_id,
|
||||||
|
"key": key,
|
||||||
|
"value": float(value),
|
||||||
|
"timestamp": int(time.time() * 1000),
|
||||||
|
"step": step,
|
||||||
|
})
|
||||||
|
|
||||||
|
def log_metrics(self, run_id: str, metrics: dict[str, float]) -> None:
|
||||||
|
for k, v in metrics.items():
|
||||||
|
self.log_metric(run_id, k, v)
|
||||||
|
|
||||||
|
def set_tag(self, run_id: str, key: str, value: str) -> None:
|
||||||
|
self._post("/runs/set-tag", {"run_id": run_id, "key": key, "value": str(value)})
|
||||||
|
|
||||||
|
def set_tags(self, run_id: str, tags: dict[str, str]) -> None:
|
||||||
|
for k, v in tags.items():
|
||||||
|
self.set_tag(run_id, k, v)
|
||||||
|
|
||||||
|
# MLflow tag values are capped at 5000 chars by the server (RESOURCE_DOES_NOT_EXIST
|
||||||
|
# below that, INVALID_PARAMETER_VALUE above). 4500 leaves headroom for
|
||||||
|
# internal metadata MLflow may append on its own.
|
||||||
|
_TAG_VALUE_LIMIT = 4500
|
||||||
|
|
||||||
|
def log_text(self, run_id: str, text: str, artifact_path: str) -> None:
|
||||||
|
"""Persist short text alongside the run.
|
||||||
|
|
||||||
|
The MLflow server in this deployment uses a ``file://`` artifact
|
||||||
|
backend, which is only reachable from inside the container — not
|
||||||
|
via the REST proxy. We instead stash short payloads as tags
|
||||||
|
keyed ``artifact:<path>``. Anything longer than 4500 chars is
|
||||||
|
chunked into ``artifact:<path>:0``, ``:1`` …; ``get_artifact_text``
|
||||||
|
re-stitches them in order.
|
||||||
|
"""
|
||||||
|
key_base = f"artifact:{artifact_path}"
|
||||||
|
if len(text) <= self._TAG_VALUE_LIMIT:
|
||||||
|
self.set_tag(run_id, key_base, text)
|
||||||
|
return
|
||||||
|
# chunk
|
||||||
|
for i in range(0, len(text), self._TAG_VALUE_LIMIT):
|
||||||
|
self.set_tag(run_id, f"{key_base}:{i // self._TAG_VALUE_LIMIT}",
|
||||||
|
text[i:i + self._TAG_VALUE_LIMIT])
|
||||||
|
|
||||||
|
def get_artifact_text(self, run_id: str, artifact_path: str) -> str:
|
||||||
|
run = self._get("/runs/get", {"run_id": run_id})["run"]
|
||||||
|
tags = {t["key"]: t["value"] for t in run["data"].get("tags", [])}
|
||||||
|
key_base = f"artifact:{artifact_path}"
|
||||||
|
if key_base in tags:
|
||||||
|
return tags[key_base]
|
||||||
|
# chunked form
|
||||||
|
chunks = sorted(
|
||||||
|
(k for k in tags if k.startswith(f"{key_base}:")),
|
||||||
|
key=lambda k: int(k.rsplit(":", 1)[1]),
|
||||||
|
)
|
||||||
|
return "".join(tags[k] for k in chunks)
|
||||||
|
|
||||||
|
def end_run(self, run_id: str, status: str = "FINISHED") -> None:
|
||||||
|
self._post("/runs/update", {
|
||||||
|
"run_id": run_id,
|
||||||
|
"status": status,
|
||||||
|
"end_time": int(time.time() * 1000),
|
||||||
|
})
|
||||||
|
|
||||||
|
def search_runs(
|
||||||
|
self,
|
||||||
|
experiment_id: str,
|
||||||
|
filter_string: str = "",
|
||||||
|
max_results: int = 1000,
|
||||||
|
) -> list[dict]:
|
||||||
|
body = {
|
||||||
|
"experiment_ids": [experiment_id],
|
||||||
|
"filter": filter_string,
|
||||||
|
"max_results": max_results,
|
||||||
|
}
|
||||||
|
r = self._post("/runs/search", body)
|
||||||
|
return r.get("runs", [])
|
||||||
Reference in New Issue
Block a user