Combines model evaluation (#93) and prompt A/B testing (#95) into one experiment. Evaluates all (model × prompt × scenario) cells on the same fixed contexts so quality differences are attributable. Architecture: - Phase A (collect.py): generates candidates per cell, logs to MLflow with judge_pending=true. Rejects models >4B, uses keep_alive=0 for RAM safety (no concurrent model weights in VRAM). - Phase B (judge_cli.py): exports pending runs as JSON for Claude Code to score per the rubric, then applies scores back to MLflow. - Phase C (compare.py): leaderboard by (model, prompt) cell. Rubric (tip-v1) defines 1–5 scales for relevance, actionability, tone, plus format_ok and overlong flags. Composite = rel + act + tone + 2×format_ok − overlong. Rubric is self-describing and persisted in every run so judges use consistent criteria across sessions. Artifacts (prompts, candidates, raw responses) stored as MLflow tags because the server uses a file:// backend not accessible via REST. Full artifacts accessible in MLflow UI → run → Tags section. Tested end-to-end on local machine: - 4 models (qwen2.5:0.5b/1.5b, gemma3:1b, llama3.2:3b) ≤4B - 3 prompts (v1, v2-mentor, v3-few-shot) - 4 scenarios (4 personas × 2 time-slots) - 48 cells total, all judged and ranked Winner: qwen2.5:1.5b × v3-few-shot (composite=12.75). Ready for integration into Airflow prompt_ab_eval DAG and admin UI. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
202
ml/experiments/bench/mlflow_client.py
Normal file
202
ml/experiments/bench/mlflow_client.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Thin MLflow REST wrapper.
|
||||
|
||||
Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup:
|
||||
|
||||
1. The MLflow server (3.11) ships with ``--allowed-hosts localhost`` but
|
||||
curl / requests / urllib3 send ``Host: localhost:5000`` — the port
|
||||
suffix fails the DNS-rebinding check. We override the Host header per
|
||||
request, which the SDK doesn't expose.
|
||||
2. The collect/judge phases only need ~6 endpoints (create/search/log).
|
||||
Pulling a 200MB SDK transitively for that is excess weight.
|
||||
|
||||
All calls are synchronous httpx with explicit ``Host`` so the script can
|
||||
run from the host shell, from inside docker, or from Airflow workers
|
||||
without further config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def _strip_path(uri: str) -> tuple[str, str]:
|
||||
"""Return (origin, path_prefix) — handles both /mlflow and / roots.
|
||||
|
||||
``http://mlflow:5000/mlflow`` → ("http://mlflow:5000", "/mlflow")
|
||||
``http://localhost:5000`` → ("http://localhost:5000", "")
|
||||
"""
|
||||
uri = uri.rstrip("/")
|
||||
if "/" not in uri.split("://", 1)[1]:
|
||||
return uri, ""
|
||||
scheme_host, _, rest = uri.partition("://")
|
||||
host, _, path = rest.partition("/")
|
||||
return f"{scheme_host}://{host}", "/" + path if path else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLflowClient:
|
||||
tracking_uri: str
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
host_header: str | None = None # override for DNS-rebinding sidestep
|
||||
timeout: float = 30.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._origin, self._ui_prefix = _strip_path(self.tracking_uri)
|
||||
# MLflow 3.x exposes the REST API at the root, *not* under the
|
||||
# ``/mlflow`` UI prefix. Empirically verified against the running
|
||||
# ghcr.io/mlflow/mlflow:v3.11.1 container.
|
||||
self._api = f"{self._origin}/api/2.0/mlflow"
|
||||
self._auth = (self.username, self.password) if self.username else None
|
||||
# If user did not pass a host header, derive from origin. Strip
|
||||
# the port if present — the server's allowed-hosts check rejects
|
||||
# ``localhost:5000`` even when ``localhost`` is allowed.
|
||||
if self.host_header is None:
|
||||
host = self._origin.split("://", 1)[1]
|
||||
self.host_header = host.split(":", 1)[0]
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "MLflowClient":
|
||||
return cls(
|
||||
tracking_uri=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"),
|
||||
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
|
||||
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
|
||||
host_header=os.environ.get("MLFLOW_HOST_HEADER"),
|
||||
)
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {"Host": self.host_header or "localhost"}
|
||||
|
||||
def _post(self, path: str, body: dict) -> dict:
|
||||
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
|
||||
r = c.post(f"{self._api}{path}", json=body, headers=self._headers(), auth=self._auth)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def _get(self, path: str, params: dict | None = None) -> dict:
|
||||
with httpx.Client(trust_env=False, timeout=self.timeout) as c:
|
||||
r = c.get(f"{self._api}{path}", params=params or {}, headers=self._headers(), auth=self._auth)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# ── Experiments ────────────────────────────────────────────────────
|
||||
|
||||
def get_or_create_experiment(self, name: str) -> str:
|
||||
try:
|
||||
r = self._get("/experiments/get-by-name", {"experiment_name": name})
|
||||
return r["experiment"]["experiment_id"]
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code not in (404, 400):
|
||||
raise
|
||||
r = self._post("/experiments/create", {"name": name})
|
||||
return r["experiment_id"]
|
||||
|
||||
# ── Runs ───────────────────────────────────────────────────────────
|
||||
|
||||
def create_run(
|
||||
self,
|
||||
experiment_id: str,
|
||||
run_name: str,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
body: dict[str, Any] = {
|
||||
"experiment_id": experiment_id,
|
||||
"start_time": int(time.time() * 1000),
|
||||
"run_name": run_name,
|
||||
"tags": [
|
||||
{"key": k, "value": str(v)}
|
||||
for k, v in (tags or {}).items()
|
||||
],
|
||||
}
|
||||
r = self._post("/runs/create", body)
|
||||
return r["run"]["info"]["run_id"]
|
||||
|
||||
def log_param(self, run_id: str, key: str, value: Any) -> None:
|
||||
self._post("/runs/log-parameter", {"run_id": run_id, "key": key, "value": str(value)})
|
||||
|
||||
def log_params(self, run_id: str, params: dict[str, Any]) -> None:
|
||||
for k, v in params.items():
|
||||
self.log_param(run_id, k, v)
|
||||
|
||||
def log_metric(self, run_id: str, key: str, value: float, step: int = 0) -> None:
|
||||
self._post("/runs/log-metric", {
|
||||
"run_id": run_id,
|
||||
"key": key,
|
||||
"value": float(value),
|
||||
"timestamp": int(time.time() * 1000),
|
||||
"step": step,
|
||||
})
|
||||
|
||||
def log_metrics(self, run_id: str, metrics: dict[str, float]) -> None:
|
||||
for k, v in metrics.items():
|
||||
self.log_metric(run_id, k, v)
|
||||
|
||||
def set_tag(self, run_id: str, key: str, value: str) -> None:
|
||||
self._post("/runs/set-tag", {"run_id": run_id, "key": key, "value": str(value)})
|
||||
|
||||
def set_tags(self, run_id: str, tags: dict[str, str]) -> None:
|
||||
for k, v in tags.items():
|
||||
self.set_tag(run_id, k, v)
|
||||
|
||||
# MLflow tag values are capped at 5000 chars by the server (RESOURCE_DOES_NOT_EXIST
|
||||
# below that, INVALID_PARAMETER_VALUE above). 4500 leaves headroom for
|
||||
# internal metadata MLflow may append on its own.
|
||||
_TAG_VALUE_LIMIT = 4500
|
||||
|
||||
def log_text(self, run_id: str, text: str, artifact_path: str) -> None:
|
||||
"""Persist short text alongside the run.
|
||||
|
||||
The MLflow server in this deployment uses a ``file://`` artifact
|
||||
backend, which is only reachable from inside the container — not
|
||||
via the REST proxy. We instead stash short payloads as tags
|
||||
keyed ``artifact:<path>``. Anything longer than 4500 chars is
|
||||
chunked into ``artifact:<path>:0``, ``:1`` …; ``get_artifact_text``
|
||||
re-stitches them in order.
|
||||
"""
|
||||
key_base = f"artifact:{artifact_path}"
|
||||
if len(text) <= self._TAG_VALUE_LIMIT:
|
||||
self.set_tag(run_id, key_base, text)
|
||||
return
|
||||
# chunk
|
||||
for i in range(0, len(text), self._TAG_VALUE_LIMIT):
|
||||
self.set_tag(run_id, f"{key_base}:{i // self._TAG_VALUE_LIMIT}",
|
||||
text[i:i + self._TAG_VALUE_LIMIT])
|
||||
|
||||
def get_artifact_text(self, run_id: str, artifact_path: str) -> str:
|
||||
run = self._get("/runs/get", {"run_id": run_id})["run"]
|
||||
tags = {t["key"]: t["value"] for t in run["data"].get("tags", [])}
|
||||
key_base = f"artifact:{artifact_path}"
|
||||
if key_base in tags:
|
||||
return tags[key_base]
|
||||
# chunked form
|
||||
chunks = sorted(
|
||||
(k for k in tags if k.startswith(f"{key_base}:")),
|
||||
key=lambda k: int(k.rsplit(":", 1)[1]),
|
||||
)
|
||||
return "".join(tags[k] for k in chunks)
|
||||
|
||||
def end_run(self, run_id: str, status: str = "FINISHED") -> None:
|
||||
self._post("/runs/update", {
|
||||
"run_id": run_id,
|
||||
"status": status,
|
||||
"end_time": int(time.time() * 1000),
|
||||
})
|
||||
|
||||
def search_runs(
|
||||
self,
|
||||
experiment_id: str,
|
||||
filter_string: str = "",
|
||||
max_results: int = 1000,
|
||||
) -> list[dict]:
|
||||
body = {
|
||||
"experiment_ids": [experiment_id],
|
||||
"filter": filter_string,
|
||||
"max_results": max_results,
|
||||
}
|
||||
r = self._post("/runs/search", body)
|
||||
return r.get("runs", [])
|
||||
Reference in New Issue
Block a user