diff --git a/infra/docker/docker-compose.yml b/infra/docker/docker-compose.yml index 0e90247..05070b4 100644 --- a/infra/docker/docker-compose.yml +++ b/infra/docker/docker-compose.yml @@ -71,6 +71,7 @@ services: environment: LITELLM_URL: ${LITELLM_URL:-http://host.docker.internal:4000} OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434} + MLFLOW_TRACKING_URI: ${MLFLOW_TRACKING_URI:-} extra_hosts: - "host.docker.internal:host-gateway" ports: diff --git a/ml/serving/main.py b/ml/serving/main.py index a852319..7f5867f 100644 --- a/ml/serving/main.py +++ b/ml/serving/main.py @@ -14,6 +14,7 @@ from __future__ import annotations import json import os import sys +import time from contextlib import asynccontextmanager from datetime import datetime, timezone from pathlib import Path @@ -29,6 +30,7 @@ from starlette.middleware.base import BaseHTTPMiddleware import logging_config import nats_consumer +from mlflow_client import MLflowClient from prompts import get_prompt, build_orchestrator_messages # Make ml.agents importable regardless of working directory. @@ -79,6 +81,30 @@ LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state")) + +# ── MLflow tracing (optional) ─────────────────────────────────────────────── +# Set MLFLOW_TRACKING_URI to enable. All calls are fire-and-forget; any error +# is logged at WARNING and never propagates to the caller. + +_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "") +_mlflow: MLflowClient | None = MLflowClient(tracking_uri=_MLFLOW_URI) if _MLFLOW_URI else None +_MLFLOW_EXP = "oO/serving" + + +def _mlflow_run(run_name: str, params: dict, metrics: dict, tags: dict) -> None: + """Create a finished MLflow run. Silently no-ops if MLflow is not configured.""" + if _mlflow is None: + return + try: + exp_id = _mlflow.get_or_create_experiment(_MLFLOW_EXP) + run_id = _mlflow.create_run(exp_id, run_name, tags={"source": "ml-serving"}) + _mlflow.log_params(run_id, {k: str(v)[:250] for k, v in params.items()}) + _mlflow.log_metrics(run_id, metrics) + for k, v in tags.items(): + _mlflow.log_text(run_id, str(v), k) + _mlflow.end_run(run_id) + except Exception as exc: # noqa: BLE001 + log.warning("mlflow_log_failed", error=str(exc)) STATE_DIR.mkdir(parents=True, exist_ok=True) @@ -251,6 +277,12 @@ async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentCompute raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}") log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at) + _mlflow_run( + run_name=f"compute/{agent_id}", + params={"agent_id": agent_id, "user_id": req.user_id, "agent_version": output.agent_version}, + metrics={"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)}, + tags={"prompt_text": output.prompt_text, "signals_snapshot": json.dumps(output.signals_snapshot)}, + ) return AgentComputeResponse( user_id=output.user_id, agent_id=output.agent_id, @@ -307,6 +339,12 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon history_len=len(events), latency_ms=latency_ms, ) + _mlflow_run( + run_name=f"infer/{agent_id}", + params={"agent_id": agent_id, "user_id": req.user_id}, + metrics={"latency_ms": latency_ms, "history_len": len(events), "n_params": len(inferred)}, + tags={"inferred_prefs": json.dumps(inferred)}, + ) return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred) @@ -318,6 +356,7 @@ async def recommend(req: RecommendRequest) -> RecommendResponse: the fresh rows from agent_outputs table (fetched by the TypeScript recommender before calling this endpoint). Falls back to raw task context if empty. """ + t0_recommend = time.monotonic() messages = build_orchestrator_messages( agent_outputs=[s.model_dump() for s in req.agent_outputs], tasks=req.tasks, @@ -376,12 +415,34 @@ async def recommend(req: RecommendRequest) -> RecommendResponse: content=item.get("content", ""), rationale=item.get("rationale"), ) + latency_ms_recommend = round((time.monotonic() - t0_recommend) * 1000, 1) log.info( "recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id, ) + _mlflow_run( + run_name="recommend", + params={ + "user_id": req.user_id, + "agent_ids": ",".join(s.agent_id for s in req.agent_outputs), + "model": model_used, + "hour_of_day": req.hour_of_day, + "day_of_week": req.day_of_week, + }, + metrics={ + "prompt_tokens": total_usage["prompt_tokens"], + "completion_tokens": total_usage["completion_tokens"], + "agent_count": len(req.agent_outputs), + "latency_ms": latency_ms_recommend, + }, + tags={ + "prompt_messages": json.dumps(messages), + "tip_content": tip.content, + "tip_rationale": tip.rationale or "", + }, + ) return RecommendResponse( tip=tip, model=model_used, diff --git a/ml/serving/mlflow_client.py b/ml/serving/mlflow_client.py new file mode 100644 index 0000000..9657b23 --- /dev/null +++ b/ml/serving/mlflow_client.py @@ -0,0 +1,201 @@ +"""Thin MLflow REST wrapper. + +Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup: + +1. The MLflow server (3.11) ships with ``--allowed-hosts localhost`` but + curl / requests / urllib3 send ``Host: localhost:5000`` — the port + suffix fails the DNS-rebinding check. We override the Host header per + request, which the SDK doesn't expose. +2. The collect/judge phases only need ~6 endpoints (create/search/log). + Pulling a 200MB SDK transitively for that is excess weight. + +All calls are synchronous httpx with explicit ``Host`` so the script can +run from the host shell or from inside docker without further config. +""" + +from __future__ import annotations + +import os +import time +from dataclasses import dataclass +from typing import Any + +import httpx + + +def _strip_path(uri: str) -> tuple[str, str]: + """Return (origin, path_prefix) — handles both /mlflow and / roots. + + ``http://mlflow:5000/mlflow`` → ("http://mlflow:5000", "/mlflow") + ``http://localhost:5000`` → ("http://localhost:5000", "") + """ + uri = uri.rstrip("/") + if "/" not in uri.split("://", 1)[1]: + return uri, "" + scheme_host, _, rest = uri.partition("://") + host, _, path = rest.partition("/") + return f"{scheme_host}://{host}", "/" + path if path else "" + + +@dataclass +class MLflowClient: + tracking_uri: str + username: str | None = None + password: str | None = None + host_header: str | None = None # override for DNS-rebinding sidestep + timeout: float = 30.0 + + def __post_init__(self) -> None: + self._origin, self._ui_prefix = _strip_path(self.tracking_uri) + # MLflow 3.x exposes the REST API at the root, *not* under the + # ``/mlflow`` UI prefix. Empirically verified against the running + # ghcr.io/mlflow/mlflow:v3.11.1 container. + self._api = f"{self._origin}/api/2.0/mlflow" + self._auth = (self.username, self.password) if self.username else None + # If user did not pass a host header, derive from origin. Strip + # the port if present — the server's allowed-hosts check rejects + # ``localhost:5000`` even when ``localhost`` is allowed. + if self.host_header is None: + host = self._origin.split("://", 1)[1] + self.host_header = host.split(":", 1)[0] + + @classmethod + def from_env(cls) -> "MLflowClient": + return cls( + tracking_uri=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"), + username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin", + password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password", + host_header=os.environ.get("MLFLOW_HOST_HEADER"), + ) + + def _headers(self) -> dict[str, str]: + return {"Host": self.host_header or "localhost"} + + def _post(self, path: str, body: dict) -> dict: + with httpx.Client(trust_env=False, timeout=self.timeout) as c: + r = c.post(f"{self._api}{path}", json=body, headers=self._headers(), auth=self._auth) + r.raise_for_status() + return r.json() + + def _get(self, path: str, params: dict | None = None) -> dict: + with httpx.Client(trust_env=False, timeout=self.timeout) as c: + r = c.get(f"{self._api}{path}", params=params or {}, headers=self._headers(), auth=self._auth) + r.raise_for_status() + return r.json() + + # ── Experiments ──────────────────────────────────────────────────── + + def get_or_create_experiment(self, name: str) -> str: + try: + r = self._get("/experiments/get-by-name", {"experiment_name": name}) + return r["experiment"]["experiment_id"] + except httpx.HTTPStatusError as e: + if e.response.status_code not in (404, 400): + raise + r = self._post("/experiments/create", {"name": name}) + return r["experiment_id"] + + # ── Runs ─────────────────────────────────────────────────────────── + + def create_run( + self, + experiment_id: str, + run_name: str, + tags: dict[str, str] | None = None, + ) -> str: + body: dict[str, Any] = { + "experiment_id": experiment_id, + "start_time": int(time.time() * 1000), + "run_name": run_name, + "tags": [ + {"key": k, "value": str(v)} + for k, v in (tags or {}).items() + ], + } + r = self._post("/runs/create", body) + return r["run"]["info"]["run_id"] + + def log_param(self, run_id: str, key: str, value: Any) -> None: + self._post("/runs/log-parameter", {"run_id": run_id, "key": key, "value": str(value)}) + + def log_params(self, run_id: str, params: dict[str, Any]) -> None: + for k, v in params.items(): + self.log_param(run_id, k, v) + + def log_metric(self, run_id: str, key: str, value: float, step: int = 0) -> None: + self._post("/runs/log-metric", { + "run_id": run_id, + "key": key, + "value": float(value), + "timestamp": int(time.time() * 1000), + "step": step, + }) + + def log_metrics(self, run_id: str, metrics: dict[str, float]) -> None: + for k, v in metrics.items(): + self.log_metric(run_id, k, v) + + def set_tag(self, run_id: str, key: str, value: str) -> None: + self._post("/runs/set-tag", {"run_id": run_id, "key": key, "value": str(value)}) + + def set_tags(self, run_id: str, tags: dict[str, str]) -> None: + for k, v in tags.items(): + self.set_tag(run_id, k, v) + + # MLflow tag values are capped at 5000 chars by the server (RESOURCE_DOES_NOT_EXIST + # below that, INVALID_PARAMETER_VALUE above). 4500 leaves headroom for + # internal metadata MLflow may append on its own. + _TAG_VALUE_LIMIT = 4500 + + def log_text(self, run_id: str, text: str, artifact_path: str) -> None: + """Persist short text alongside the run. + + The MLflow server in this deployment uses a ``file://`` artifact + backend, which is only reachable from inside the container — not + via the REST proxy. We instead stash short payloads as tags + keyed ``artifact:``. Anything longer than 4500 chars is + chunked into ``artifact::0``, ``:1`` …; ``get_artifact_text`` + re-stitches them in order. + """ + key_base = f"artifact:{artifact_path}" + if len(text) <= self._TAG_VALUE_LIMIT: + self.set_tag(run_id, key_base, text) + return + # chunk + for i in range(0, len(text), self._TAG_VALUE_LIMIT): + self.set_tag(run_id, f"{key_base}:{i // self._TAG_VALUE_LIMIT}", + text[i:i + self._TAG_VALUE_LIMIT]) + + def get_artifact_text(self, run_id: str, artifact_path: str) -> str: + run = self._get("/runs/get", {"run_id": run_id})["run"] + tags = {t["key"]: t["value"] for t in run["data"].get("tags", [])} + key_base = f"artifact:{artifact_path}" + if key_base in tags: + return tags[key_base] + # chunked form + chunks = sorted( + (k for k in tags if k.startswith(f"{key_base}:")), + key=lambda k: int(k.rsplit(":", 1)[1]), + ) + return "".join(tags[k] for k in chunks) + + def end_run(self, run_id: str, status: str = "FINISHED") -> None: + self._post("/runs/update", { + "run_id": run_id, + "status": status, + "end_time": int(time.time() * 1000), + }) + + def search_runs( + self, + experiment_id: str, + filter_string: str = "", + max_results: int = 1000, + ) -> list[dict]: + body = { + "experiment_ids": [experiment_id], + "filter": filter_string, + "max_results": max_results, + } + r = self._post("/runs/search", body) + return r.get("runs", [])