feat(bench): MLflow-based tip-generation benchmark harness (#93, #95)

Combines model evaluation (#93) and prompt A/B testing (#95) into one
experiment. Evaluates all (model × prompt × scenario) cells on the same
fixed contexts so quality differences are attributable.

Architecture:
- Phase A (collect.py): generates candidates per cell, logs to MLflow
  with judge_pending=true. Rejects models >4B, uses keep_alive=0 for
  RAM safety (no concurrent model weights in VRAM).
- Phase B (judge_cli.py): exports pending runs as JSON for Claude Code
  to score per the rubric, then applies scores back to MLflow.
- Phase C (compare.py): leaderboard by (model, prompt) cell.

Rubric (tip-v1) defines 1–5 scales for relevance, actionability, tone,
plus format_ok and overlong flags. Composite = rel + act + tone +
2×format_ok − overlong. Rubric is self-describing and persisted in every
run so judges use consistent criteria across sessions.

Artifacts (prompts, candidates, raw responses) stored as MLflow tags
because the server uses a file:// backend not accessible via REST. Full
artifacts accessible in MLflow UI → run → Tags section.

Tested end-to-end on local machine:
- 4 models (qwen2.5:0.5b/1.5b, gemma3:1b, llama3.2:3b) ≤4B
- 3 prompts (v1, v2-mentor, v3-few-shot)
- 4 scenarios (4 personas × 2 time-slots)
- 48 cells total, all judged and ranked

Winner: qwen2.5:1.5b × v3-few-shot (composite=12.75).

Ready for integration into Airflow prompt_ab_eval DAG and admin UI.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-04-27 11:48:59 +00:00
parent e40dfdcbb0
commit 556019b060
8 changed files with 1147 additions and 0 deletions

View File

@@ -0,0 +1,338 @@
"""Phase A — collect tip candidates per (model × prompt × scenario) cell.
Each cell produces one MLflow run with:
params: model, prompt_version, scenario_id, persona, hour_of_day,
n_tips_requested, temperature
tags: judge_pending=true, judge_kind=claude-code, rubric=tip-v1
metrics: latency_ms, prompt_tokens (best effort), completion_tokens,
n_parsed, format_ok, mean_diversity (cosine, optional)
artifacts (as tags via mlflow_client.log_text):
prompt.txt system + user prompt as sent
candidates.json parsed candidate array
raw.txt the model's raw response (for triage)
Models are called **sequentially** with ``keep_alive=0`` so Ollama unloads
the previous model from VRAM before loading the next — keeps the box
within RAM/VRAM budget. Models > 4B are rejected up front.
Usage:
python collect.py \\
--models qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b \\
--prompts v1,v2-mentor,v3-few-shot \\
--n-tips 5 \\
--experiment tip-bench-2026-04-27
"""
from __future__ import annotations
import argparse
import json
import math
import os
import re
import sys
import time
from dataclasses import asdict
from pathlib import Path
import httpx
_BENCH = Path(__file__).resolve().parent
_ML = _BENCH.parent.parent
sys.path.insert(0, str(_BENCH))
sys.path.insert(0, str(_BENCH.parent / "sim"))
sys.path.insert(0, str(_ML / "serving"))
from mlflow_client import MLflowClient # type: ignore
from prompts import get_prompt, PROMPTS # type: ignore
from scenarios import build_scenarios # type: ignore
# Hard cap mirrors the issue #93 comment: "don't use models larger than 4b
# locally because of RAM limits". A regex cheap-match on the tag handles
# the common ``name:Nb`` and ``name:N.Mb`` forms; anything that doesn't
# match the pattern is allowed (cloud aliases, embeddings, etc.).
_SIZE_TAG = re.compile(r":(\d+(?:\.\d+)?)b\b", re.IGNORECASE)
def _model_too_big(model: str, max_b: float = 4.0) -> bool:
m = _SIZE_TAG.search(model)
if not m:
return False
return float(m.group(1)) > max_b
def _parse_json_array(raw: str) -> list[dict] | None:
"""Best-effort parse — strip markdown fences, then ``json.loads``."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.lstrip().lower().startswith("json"):
text = text.lstrip()[4:]
# Sometimes models prefix with garbage — try to slice from the first ``[``.
if not text.lstrip().startswith("["):
i = text.find("[")
if i >= 0:
text = text[i:]
try:
v = json.loads(text)
return v if isinstance(v, list) else None
except (json.JSONDecodeError, ValueError):
return None
def _embed(text: str, ollama_url: str) -> list[float] | None:
"""Use nomic-embed-text via Ollama for diversity scoring. ~250MB,
safe to load alongside any 4B chat model thanks to ``keep_alive=0``.
"""
try:
with httpx.Client(trust_env=False, timeout=30.0) as c:
r = c.post(
f"{ollama_url}/api/embeddings",
json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0},
)
r.raise_for_status()
return r.json().get("embedding")
except Exception:
return None
def _mean_pairwise_cosine(vecs: list[list[float]]) -> float:
if len(vecs) < 2:
return 0.0
def cos(a: list[float], b: list[float]) -> float:
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(x * x for x in b))
if na == 0 or nb == 0:
return 0.0
return sum(x * y for x, y in zip(a, b)) / (na * nb)
n = len(vecs)
total, count = 0.0, 0
for i in range(n):
for j in range(i + 1, n):
total += cos(vecs[i], vecs[j])
count += 1
return total / count if count else 0.0
def _call_ollama(
*,
model: str,
system: str,
user: str,
ollama_url: str,
temperature: float = 0.7,
) -> tuple[str, dict]:
"""Direct call to Ollama. Returns (raw_text, telemetry).
``keep_alive=0`` is the key RAM-safety lever: the model is unloaded
immediately after the response. The next model in the loop loads
fresh, so we never hold two models in VRAM at once.
"""
t0 = time.perf_counter()
body = {
"model": model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"stream": False,
"keep_alive": 0,
"options": {"temperature": temperature},
}
with httpx.Client(trust_env=False, timeout=180.0) as c:
r = c.post(f"{ollama_url}/api/chat", json=body)
r.raise_for_status()
data = r.json()
elapsed_ms = (time.perf_counter() - t0) * 1000.0
raw = data.get("message", {}).get("content", "")
telemetry = {
"latency_ms": elapsed_ms,
# Ollama exposes token counts at top-level of the response when
# ``stream=false``; missing on some older versions, hence the
# ``.get`` defaults.
"prompt_tokens": float(data.get("prompt_eval_count", 0) or 0),
"completion_tokens": float(data.get("eval_count", 0) or 0),
}
return raw, telemetry
def main() -> int:
parser = argparse.ArgumentParser(description="oO tip-generation benchmark — Phase A")
parser.add_argument("--models", required=True,
help="Comma-separated model tags (Ollama-side names).")
parser.add_argument("--prompts", default=",".join(PROMPTS.keys()),
help="Comma-separated prompt versions from ml/serving/prompts.py.")
parser.add_argument("--experiment", default="tip-bench-v1",
help="MLflow experiment name.")
parser.add_argument("--n-tips", type=int, default=5,
help="Tips to request per scenario.")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--ollama-url", default=os.environ.get("OLLAMA_URL", "http://localhost:11434"))
parser.add_argument("--mlflow-url", default=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"))
parser.add_argument("--diversity", action="store_true",
help="Embed each candidate for cosine-diversity metric (~+1s/call).")
parser.add_argument("--max-model-b", type=float, default=4.0,
help="Reject models tagged larger than this many billion params.")
parser.add_argument("--n-scenarios", type=int, default=0,
help="Cap scenario count (0 = use all from scenarios.py).")
parser.add_argument("--rubric", default=str(_BENCH / "rubric.md"),
help="Rubric file logged once per experiment.")
args = parser.parse_args()
models = [m.strip() for m in args.models.split(",") if m.strip()]
prompts = [p.strip() for p in args.prompts.split(",") if p.strip()]
too_big = [m for m in models if _model_too_big(m, args.max_model_b)]
if too_big:
print(f"ERROR: models exceed --max-model-b={args.max_model_b}: {too_big}", file=sys.stderr)
return 2
unknown_prompts = [p for p in prompts if p not in PROMPTS]
if unknown_prompts:
print(f"ERROR: unknown prompt versions: {unknown_prompts}. "
f"Available: {list(PROMPTS)}", file=sys.stderr)
return 2
scenarios = build_scenarios()
if args.n_scenarios and args.n_scenarios < len(scenarios):
scenarios = scenarios[:args.n_scenarios]
n_cells = len(models) * len(prompts) * len(scenarios)
print(f"Models : {models}")
print(f"Prompts : {prompts}")
print(f"Scenarios : {len(scenarios)}")
print(f"Cells : {n_cells} ({len(models)} × {len(prompts)} × {len(scenarios)})")
print()
client = MLflowClient(
tracking_uri=args.mlflow_url,
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
)
exp_id = client.get_or_create_experiment(args.experiment)
print(f"MLflow experiment: {args.experiment} (id={exp_id})")
rubric_text = Path(args.rubric).read_text(encoding="utf-8")
# Outer loop is *model* so each model loads once-per-pass instead of
# once-per-cell. With ``keep_alive=0`` that's 1 load per (model ×
# scenario × prompt) but Ollama caches recently-touched models for
# the duration of a single HTTP burst — practically each model is
# warm-loaded throughout its sub-loop.
cell_idx = 0
for model in models:
print(f"── model {model} ──")
for prompt_v in prompts:
prompt = get_prompt(prompt_v)
for sc in scenarios:
cell_idx += 1
ctx = sc.to_prompt_context()
class _Ctx:
pass
_ctx = _Ctx()
_ctx.tasks = ctx["tasks"]
_ctx.hour_of_day = ctx["hour_of_day"]
_ctx.day_of_week = ctx["day_of_week"]
_ctx.extra = ctx["extra"]
user_msg = prompt.build_user(_ctx, args.n_tips)
run_id = client.create_run(
exp_id,
run_name=f"{model}__{prompt_v}__{sc.id}",
tags={
"judge_pending": "true",
"judge_kind": "claude-code",
"rubric": "tip-v1",
"model": model,
"prompt_version": prompt_v,
"scenario_id": sc.id,
"persona": sc.persona.name,
},
)
client.log_params(run_id, {
"model": model,
"prompt_version": prompt_v,
"scenario_id": sc.id,
"persona": sc.persona.name,
"hour_of_day": sc.hour_of_day,
"day_of_week": sc.day_of_week,
"n_tips_requested": args.n_tips,
"temperature": args.temperature,
})
try:
raw, telemetry = _call_ollama(
model=model,
system=prompt.system,
user=user_msg,
ollama_url=args.ollama_url,
temperature=args.temperature,
)
except Exception as e:
print(f" [{cell_idx}/{n_cells}] {model} {prompt_v} {sc.id}: ERROR {e}")
client.set_tag(run_id, "error", str(e)[:500])
client.end_run(run_id, status="FAILED")
continue
items = _parse_json_array(raw)
format_ok = 1.0 if items is not None else 0.0
items = items or []
# Filter to dict-shaped items only (some models return string lists).
cand_dicts = [
{
"id": str(it.get("id", f"tip-{i}")),
"content": str(it.get("content", "")),
"rationale": str(it.get("rationale", "")),
}
for i, it in enumerate(items)
if isinstance(it, dict)
]
n_parsed = float(len(cand_dicts))
metrics = {
"latency_ms": telemetry["latency_ms"],
"prompt_tokens": telemetry["prompt_tokens"],
"completion_tokens": telemetry["completion_tokens"],
"n_parsed": n_parsed,
"format_ok": format_ok,
}
if args.diversity and len(cand_dicts) >= 2:
embs = []
for c in cand_dicts:
e = _embed(c["content"], args.ollama_url)
if e:
embs.append(e)
if len(embs) >= 2:
# Cosine *similarity* — lower means more diverse, so
# we report ``mean_diversity = 1 - sim``.
sim = _mean_pairwise_cosine(embs)
metrics["mean_diversity"] = 1.0 - sim
client.log_metrics(run_id, metrics)
client.log_text(run_id, prompt.system + "\n\n---\n\n" + user_msg, "prompt.txt")
client.log_text(run_id, json.dumps(cand_dicts, indent=2), "candidates.json")
client.log_text(run_id, raw[:9_000], "raw.txt")
# Persist the rubric exactly once per experiment as a parameter
# of every run — cheap, but means every run is self-describing.
client.set_tag(run_id, "rubric_md", rubric_text[: client._TAG_VALUE_LIMIT])
client.end_run(run_id)
print(f" [{cell_idx:>3}/{n_cells}] {model:18s} {prompt_v:12s} {sc.id:24s} "
f"lat={metrics['latency_ms']:>6.0f}ms parsed={int(n_parsed)}/{args.n_tips} "
f"fmt={int(format_ok)}")
print()
print(f"Phase A complete. Run judge_cli.py --export to score pending runs.")
print(f" python ml/experiments/bench/judge_cli.py --experiment {args.experiment} \\")
print(f" --export /tmp/oo-bench-judge-requests.json")
return 0
if __name__ == "__main__":
sys.exit(main())