Files
oO/ml/experiments/bench/collect.py
alvis 556019b060 feat(bench): MLflow-based tip-generation benchmark harness (#93, #95)
Combines model evaluation (#93) and prompt A/B testing (#95) into one
experiment. Evaluates all (model × prompt × scenario) cells on the same
fixed contexts so quality differences are attributable.

Architecture:
- Phase A (collect.py): generates candidates per cell, logs to MLflow
  with judge_pending=true. Rejects models >4B, uses keep_alive=0 for
  RAM safety (no concurrent model weights in VRAM).
- Phase B (judge_cli.py): exports pending runs as JSON for Claude Code
  to score per the rubric, then applies scores back to MLflow.
- Phase C (compare.py): leaderboard by (model, prompt) cell.

Rubric (tip-v1) defines 1–5 scales for relevance, actionability, tone,
plus format_ok and overlong flags. Composite = rel + act + tone +
2×format_ok − overlong. Rubric is self-describing and persisted in every
run so judges use consistent criteria across sessions.

Artifacts (prompts, candidates, raw responses) stored as MLflow tags
because the server uses a file:// backend not accessible via REST. Full
artifacts accessible in MLflow UI → run → Tags section.

Tested end-to-end on local machine:
- 4 models (qwen2.5:0.5b/1.5b, gemma3:1b, llama3.2:3b) ≤4B
- 3 prompts (v1, v2-mentor, v3-few-shot)
- 4 scenarios (4 personas × 2 time-slots)
- 48 cells total, all judged and ranked

Winner: qwen2.5:1.5b × v3-few-shot (composite=12.75).

Ready for integration into Airflow prompt_ab_eval DAG and admin UI.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-27 11:48:59 +00:00

339 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase A — collect tip candidates per (model × prompt × scenario) cell.
Each cell produces one MLflow run with:
params: model, prompt_version, scenario_id, persona, hour_of_day,
n_tips_requested, temperature
tags: judge_pending=true, judge_kind=claude-code, rubric=tip-v1
metrics: latency_ms, prompt_tokens (best effort), completion_tokens,
n_parsed, format_ok, mean_diversity (cosine, optional)
artifacts (as tags via mlflow_client.log_text):
prompt.txt system + user prompt as sent
candidates.json parsed candidate array
raw.txt the model's raw response (for triage)
Models are called **sequentially** with ``keep_alive=0`` so Ollama unloads
the previous model from VRAM before loading the next — keeps the box
within RAM/VRAM budget. Models > 4B are rejected up front.
Usage:
python collect.py \\
--models qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b \\
--prompts v1,v2-mentor,v3-few-shot \\
--n-tips 5 \\
--experiment tip-bench-2026-04-27
"""
from __future__ import annotations
import argparse
import json
import math
import os
import re
import sys
import time
from dataclasses import asdict
from pathlib import Path
import httpx
_BENCH = Path(__file__).resolve().parent
_ML = _BENCH.parent.parent
sys.path.insert(0, str(_BENCH))
sys.path.insert(0, str(_BENCH.parent / "sim"))
sys.path.insert(0, str(_ML / "serving"))
from mlflow_client import MLflowClient # type: ignore
from prompts import get_prompt, PROMPTS # type: ignore
from scenarios import build_scenarios # type: ignore
# Hard cap mirrors the issue #93 comment: "don't use models larger than 4b
# locally because of RAM limits". A regex cheap-match on the tag handles
# the common ``name:Nb`` and ``name:N.Mb`` forms; anything that doesn't
# match the pattern is allowed (cloud aliases, embeddings, etc.).
_SIZE_TAG = re.compile(r":(\d+(?:\.\d+)?)b\b", re.IGNORECASE)
def _model_too_big(model: str, max_b: float = 4.0) -> bool:
m = _SIZE_TAG.search(model)
if not m:
return False
return float(m.group(1)) > max_b
def _parse_json_array(raw: str) -> list[dict] | None:
"""Best-effort parse — strip markdown fences, then ``json.loads``."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.lstrip().lower().startswith("json"):
text = text.lstrip()[4:]
# Sometimes models prefix with garbage — try to slice from the first ``[``.
if not text.lstrip().startswith("["):
i = text.find("[")
if i >= 0:
text = text[i:]
try:
v = json.loads(text)
return v if isinstance(v, list) else None
except (json.JSONDecodeError, ValueError):
return None
def _embed(text: str, ollama_url: str) -> list[float] | None:
"""Use nomic-embed-text via Ollama for diversity scoring. ~250MB,
safe to load alongside any 4B chat model thanks to ``keep_alive=0``.
"""
try:
with httpx.Client(trust_env=False, timeout=30.0) as c:
r = c.post(
f"{ollama_url}/api/embeddings",
json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0},
)
r.raise_for_status()
return r.json().get("embedding")
except Exception:
return None
def _mean_pairwise_cosine(vecs: list[list[float]]) -> float:
if len(vecs) < 2:
return 0.0
def cos(a: list[float], b: list[float]) -> float:
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(x * x for x in b))
if na == 0 or nb == 0:
return 0.0
return sum(x * y for x, y in zip(a, b)) / (na * nb)
n = len(vecs)
total, count = 0.0, 0
for i in range(n):
for j in range(i + 1, n):
total += cos(vecs[i], vecs[j])
count += 1
return total / count if count else 0.0
def _call_ollama(
*,
model: str,
system: str,
user: str,
ollama_url: str,
temperature: float = 0.7,
) -> tuple[str, dict]:
"""Direct call to Ollama. Returns (raw_text, telemetry).
``keep_alive=0`` is the key RAM-safety lever: the model is unloaded
immediately after the response. The next model in the loop loads
fresh, so we never hold two models in VRAM at once.
"""
t0 = time.perf_counter()
body = {
"model": model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"stream": False,
"keep_alive": 0,
"options": {"temperature": temperature},
}
with httpx.Client(trust_env=False, timeout=180.0) as c:
r = c.post(f"{ollama_url}/api/chat", json=body)
r.raise_for_status()
data = r.json()
elapsed_ms = (time.perf_counter() - t0) * 1000.0
raw = data.get("message", {}).get("content", "")
telemetry = {
"latency_ms": elapsed_ms,
# Ollama exposes token counts at top-level of the response when
# ``stream=false``; missing on some older versions, hence the
# ``.get`` defaults.
"prompt_tokens": float(data.get("prompt_eval_count", 0) or 0),
"completion_tokens": float(data.get("eval_count", 0) or 0),
}
return raw, telemetry
def main() -> int:
parser = argparse.ArgumentParser(description="oO tip-generation benchmark — Phase A")
parser.add_argument("--models", required=True,
help="Comma-separated model tags (Ollama-side names).")
parser.add_argument("--prompts", default=",".join(PROMPTS.keys()),
help="Comma-separated prompt versions from ml/serving/prompts.py.")
parser.add_argument("--experiment", default="tip-bench-v1",
help="MLflow experiment name.")
parser.add_argument("--n-tips", type=int, default=5,
help="Tips to request per scenario.")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--ollama-url", default=os.environ.get("OLLAMA_URL", "http://localhost:11434"))
parser.add_argument("--mlflow-url", default=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"))
parser.add_argument("--diversity", action="store_true",
help="Embed each candidate for cosine-diversity metric (~+1s/call).")
parser.add_argument("--max-model-b", type=float, default=4.0,
help="Reject models tagged larger than this many billion params.")
parser.add_argument("--n-scenarios", type=int, default=0,
help="Cap scenario count (0 = use all from scenarios.py).")
parser.add_argument("--rubric", default=str(_BENCH / "rubric.md"),
help="Rubric file logged once per experiment.")
args = parser.parse_args()
models = [m.strip() for m in args.models.split(",") if m.strip()]
prompts = [p.strip() for p in args.prompts.split(",") if p.strip()]
too_big = [m for m in models if _model_too_big(m, args.max_model_b)]
if too_big:
print(f"ERROR: models exceed --max-model-b={args.max_model_b}: {too_big}", file=sys.stderr)
return 2
unknown_prompts = [p for p in prompts if p not in PROMPTS]
if unknown_prompts:
print(f"ERROR: unknown prompt versions: {unknown_prompts}. "
f"Available: {list(PROMPTS)}", file=sys.stderr)
return 2
scenarios = build_scenarios()
if args.n_scenarios and args.n_scenarios < len(scenarios):
scenarios = scenarios[:args.n_scenarios]
n_cells = len(models) * len(prompts) * len(scenarios)
print(f"Models : {models}")
print(f"Prompts : {prompts}")
print(f"Scenarios : {len(scenarios)}")
print(f"Cells : {n_cells} ({len(models)} × {len(prompts)} × {len(scenarios)})")
print()
client = MLflowClient(
tracking_uri=args.mlflow_url,
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
)
exp_id = client.get_or_create_experiment(args.experiment)
print(f"MLflow experiment: {args.experiment} (id={exp_id})")
rubric_text = Path(args.rubric).read_text(encoding="utf-8")
# Outer loop is *model* so each model loads once-per-pass instead of
# once-per-cell. With ``keep_alive=0`` that's 1 load per (model ×
# scenario × prompt) but Ollama caches recently-touched models for
# the duration of a single HTTP burst — practically each model is
# warm-loaded throughout its sub-loop.
cell_idx = 0
for model in models:
print(f"── model {model} ──")
for prompt_v in prompts:
prompt = get_prompt(prompt_v)
for sc in scenarios:
cell_idx += 1
ctx = sc.to_prompt_context()
class _Ctx:
pass
_ctx = _Ctx()
_ctx.tasks = ctx["tasks"]
_ctx.hour_of_day = ctx["hour_of_day"]
_ctx.day_of_week = ctx["day_of_week"]
_ctx.extra = ctx["extra"]
user_msg = prompt.build_user(_ctx, args.n_tips)
run_id = client.create_run(
exp_id,
run_name=f"{model}__{prompt_v}__{sc.id}",
tags={
"judge_pending": "true",
"judge_kind": "claude-code",
"rubric": "tip-v1",
"model": model,
"prompt_version": prompt_v,
"scenario_id": sc.id,
"persona": sc.persona.name,
},
)
client.log_params(run_id, {
"model": model,
"prompt_version": prompt_v,
"scenario_id": sc.id,
"persona": sc.persona.name,
"hour_of_day": sc.hour_of_day,
"day_of_week": sc.day_of_week,
"n_tips_requested": args.n_tips,
"temperature": args.temperature,
})
try:
raw, telemetry = _call_ollama(
model=model,
system=prompt.system,
user=user_msg,
ollama_url=args.ollama_url,
temperature=args.temperature,
)
except Exception as e:
print(f" [{cell_idx}/{n_cells}] {model} {prompt_v} {sc.id}: ERROR {e}")
client.set_tag(run_id, "error", str(e)[:500])
client.end_run(run_id, status="FAILED")
continue
items = _parse_json_array(raw)
format_ok = 1.0 if items is not None else 0.0
items = items or []
# Filter to dict-shaped items only (some models return string lists).
cand_dicts = [
{
"id": str(it.get("id", f"tip-{i}")),
"content": str(it.get("content", "")),
"rationale": str(it.get("rationale", "")),
}
for i, it in enumerate(items)
if isinstance(it, dict)
]
n_parsed = float(len(cand_dicts))
metrics = {
"latency_ms": telemetry["latency_ms"],
"prompt_tokens": telemetry["prompt_tokens"],
"completion_tokens": telemetry["completion_tokens"],
"n_parsed": n_parsed,
"format_ok": format_ok,
}
if args.diversity and len(cand_dicts) >= 2:
embs = []
for c in cand_dicts:
e = _embed(c["content"], args.ollama_url)
if e:
embs.append(e)
if len(embs) >= 2:
# Cosine *similarity* — lower means more diverse, so
# we report ``mean_diversity = 1 - sim``.
sim = _mean_pairwise_cosine(embs)
metrics["mean_diversity"] = 1.0 - sim
client.log_metrics(run_id, metrics)
client.log_text(run_id, prompt.system + "\n\n---\n\n" + user_msg, "prompt.txt")
client.log_text(run_id, json.dumps(cand_dicts, indent=2), "candidates.json")
client.log_text(run_id, raw[:9_000], "raw.txt")
# Persist the rubric exactly once per experiment as a parameter
# of every run — cheap, but means every run is self-describing.
client.set_tag(run_id, "rubric_md", rubric_text[: client._TAG_VALUE_LIMIT])
client.end_run(run_id)
print(f" [{cell_idx:>3}/{n_cells}] {model:18s} {prompt_v:12s} {sc.id:24s} "
f"lat={metrics['latency_ms']:>6.0f}ms parsed={int(n_parsed)}/{args.n_tips} "
f"fmt={int(format_ok)}")
print()
print(f"Phase A complete. Run judge_cli.py --export to score pending runs.")
print(f" python ml/experiments/bench/judge_cli.py --experiment {args.experiment} \\")
print(f" --export /tmp/oo-bench-judge-requests.json")
return 0
if __name__ == "__main__":
sys.exit(main())