"""Thin MLflow REST wrapper. Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup: 1. The MLflow server (3.11) ships with ``--allowed-hosts localhost`` but curl / requests / urllib3 send ``Host: localhost:5000`` — the port suffix fails the DNS-rebinding check. We override the Host header per request, which the SDK doesn't expose. 2. The collect/judge phases only need ~6 endpoints (create/search/log). Pulling a 200MB SDK transitively for that is excess weight. All calls are synchronous httpx with explicit ``Host`` so the script can run from the host shell, from inside docker, or from Airflow workers without further config. """ from __future__ import annotations import os import time from dataclasses import dataclass from typing import Any import httpx def _strip_path(uri: str) -> tuple[str, str]: """Return (origin, path_prefix) — handles both /mlflow and / roots. ``http://mlflow:5000/mlflow`` → ("http://mlflow:5000", "/mlflow") ``http://localhost:5000`` → ("http://localhost:5000", "") """ uri = uri.rstrip("/") if "/" not in uri.split("://", 1)[1]: return uri, "" scheme_host, _, rest = uri.partition("://") host, _, path = rest.partition("/") return f"{scheme_host}://{host}", "/" + path if path else "" @dataclass class MLflowClient: tracking_uri: str username: str | None = None password: str | None = None host_header: str | None = None # override for DNS-rebinding sidestep timeout: float = 30.0 def __post_init__(self) -> None: self._origin, self._ui_prefix = _strip_path(self.tracking_uri) # MLflow 3.x exposes the REST API at the root, *not* under the # ``/mlflow`` UI prefix. Empirically verified against the running # ghcr.io/mlflow/mlflow:v3.11.1 container. self._api = f"{self._origin}/api/2.0/mlflow" self._auth = (self.username, self.password) if self.username else None # If user did not pass a host header, derive from origin. Strip # the port if present — the server's allowed-hosts check rejects # ``localhost:5000`` even when ``localhost`` is allowed. if self.host_header is None: host = self._origin.split("://", 1)[1] self.host_header = host.split(":", 1)[0] @classmethod def from_env(cls) -> "MLflowClient": return cls( tracking_uri=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"), username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin", password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password", host_header=os.environ.get("MLFLOW_HOST_HEADER"), ) def _headers(self) -> dict[str, str]: return {"Host": self.host_header or "localhost"} def _post(self, path: str, body: dict) -> dict: with httpx.Client(trust_env=False, timeout=self.timeout) as c: r = c.post(f"{self._api}{path}", json=body, headers=self._headers(), auth=self._auth) r.raise_for_status() return r.json() def _get(self, path: str, params: dict | None = None) -> dict: with httpx.Client(trust_env=False, timeout=self.timeout) as c: r = c.get(f"{self._api}{path}", params=params or {}, headers=self._headers(), auth=self._auth) r.raise_for_status() return r.json() # ── Experiments ──────────────────────────────────────────────────── def get_or_create_experiment(self, name: str) -> str: try: r = self._get("/experiments/get-by-name", {"experiment_name": name}) return r["experiment"]["experiment_id"] except httpx.HTTPStatusError as e: if e.response.status_code not in (404, 400): raise r = self._post("/experiments/create", {"name": name}) return r["experiment_id"] # ── Runs ─────────────────────────────────────────────────────────── def create_run( self, experiment_id: str, run_name: str, tags: dict[str, str] | None = None, ) -> str: body: dict[str, Any] = { "experiment_id": experiment_id, "start_time": int(time.time() * 1000), "run_name": run_name, "tags": [ {"key": k, "value": str(v)} for k, v in (tags or {}).items() ], } r = self._post("/runs/create", body) return r["run"]["info"]["run_id"] def log_param(self, run_id: str, key: str, value: Any) -> None: self._post("/runs/log-parameter", {"run_id": run_id, "key": key, "value": str(value)}) def log_params(self, run_id: str, params: dict[str, Any]) -> None: for k, v in params.items(): self.log_param(run_id, k, v) def log_metric(self, run_id: str, key: str, value: float, step: int = 0) -> None: self._post("/runs/log-metric", { "run_id": run_id, "key": key, "value": float(value), "timestamp": int(time.time() * 1000), "step": step, }) def log_metrics(self, run_id: str, metrics: dict[str, float]) -> None: for k, v in metrics.items(): self.log_metric(run_id, k, v) def set_tag(self, run_id: str, key: str, value: str) -> None: self._post("/runs/set-tag", {"run_id": run_id, "key": key, "value": str(value)}) def set_tags(self, run_id: str, tags: dict[str, str]) -> None: for k, v in tags.items(): self.set_tag(run_id, k, v) # MLflow tag values are capped at 5000 chars by the server (RESOURCE_DOES_NOT_EXIST # below that, INVALID_PARAMETER_VALUE above). 4500 leaves headroom for # internal metadata MLflow may append on its own. _TAG_VALUE_LIMIT = 4500 def log_text(self, run_id: str, text: str, artifact_path: str) -> None: """Persist short text alongside the run. The MLflow server in this deployment uses a ``file://`` artifact backend, which is only reachable from inside the container — not via the REST proxy. We instead stash short payloads as tags keyed ``artifact:``. Anything longer than 4500 chars is chunked into ``artifact::0``, ``:1`` …; ``get_artifact_text`` re-stitches them in order. """ key_base = f"artifact:{artifact_path}" if len(text) <= self._TAG_VALUE_LIMIT: self.set_tag(run_id, key_base, text) return # chunk for i in range(0, len(text), self._TAG_VALUE_LIMIT): self.set_tag(run_id, f"{key_base}:{i // self._TAG_VALUE_LIMIT}", text[i:i + self._TAG_VALUE_LIMIT]) def get_artifact_text(self, run_id: str, artifact_path: str) -> str: run = self._get("/runs/get", {"run_id": run_id})["run"] tags = {t["key"]: t["value"] for t in run["data"].get("tags", [])} key_base = f"artifact:{artifact_path}" if key_base in tags: return tags[key_base] # chunked form chunks = sorted( (k for k in tags if k.startswith(f"{key_base}:")), key=lambda k: int(k.rsplit(":", 1)[1]), ) return "".join(tags[k] for k in chunks) def end_run(self, run_id: str, status: str = "FINISHED") -> None: self._post("/runs/update", { "run_id": run_id, "status": status, "end_time": int(time.time() * 1000), }) def search_runs( self, experiment_id: str, filter_string: str = "", max_results: int = 1000, ) -> list[dict]: body = { "experiment_ids": [experiment_id], "filter": filter_string, "max_results": max_results, } r = self._post("/runs/search", body) return r.get("runs", [])