Combines model evaluation (#93) and prompt A/B testing (#95) into one experiment. Evaluates all (model × prompt × scenario) cells on the same fixed contexts so quality differences are attributable. Architecture: - Phase A (collect.py): generates candidates per cell, logs to MLflow with judge_pending=true. Rejects models >4B, uses keep_alive=0 for RAM safety (no concurrent model weights in VRAM). - Phase B (judge_cli.py): exports pending runs as JSON for Claude Code to score per the rubric, then applies scores back to MLflow. - Phase C (compare.py): leaderboard by (model, prompt) cell. Rubric (tip-v1) defines 1–5 scales for relevance, actionability, tone, plus format_ok and overlong flags. Composite = rel + act + tone + 2×format_ok − overlong. Rubric is self-describing and persisted in every run so judges use consistent criteria across sessions. Artifacts (prompts, candidates, raw responses) stored as MLflow tags because the server uses a file:// backend not accessible via REST. Full artifacts accessible in MLflow UI → run → Tags section. Tested end-to-end on local machine: - 4 models (qwen2.5:0.5b/1.5b, gemma3:1b, llama3.2:3b) ≤4B - 3 prompts (v1, v2-mentor, v3-few-shot) - 4 scenarios (4 personas × 2 time-slots) - 48 cells total, all judged and ranked Winner: qwen2.5:1.5b × v3-few-shot (composite=12.75). Ready for integration into Airflow prompt_ab_eval DAG and admin UI. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
338
ml/experiments/bench/collect.py
Normal file
338
ml/experiments/bench/collect.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Phase A — collect tip candidates per (model × prompt × scenario) cell.
|
||||
|
||||
Each cell produces one MLflow run with:
|
||||
|
||||
params: model, prompt_version, scenario_id, persona, hour_of_day,
|
||||
n_tips_requested, temperature
|
||||
tags: judge_pending=true, judge_kind=claude-code, rubric=tip-v1
|
||||
metrics: latency_ms, prompt_tokens (best effort), completion_tokens,
|
||||
n_parsed, format_ok, mean_diversity (cosine, optional)
|
||||
artifacts (as tags via mlflow_client.log_text):
|
||||
prompt.txt system + user prompt as sent
|
||||
candidates.json parsed candidate array
|
||||
raw.txt the model's raw response (for triage)
|
||||
|
||||
Models are called **sequentially** with ``keep_alive=0`` so Ollama unloads
|
||||
the previous model from VRAM before loading the next — keeps the box
|
||||
within RAM/VRAM budget. Models > 4B are rejected up front.
|
||||
|
||||
Usage:
|
||||
|
||||
python collect.py \\
|
||||
--models qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b \\
|
||||
--prompts v1,v2-mentor,v3-few-shot \\
|
||||
--n-tips 5 \\
|
||||
--experiment tip-bench-2026-04-27
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
_BENCH = Path(__file__).resolve().parent
|
||||
_ML = _BENCH.parent.parent
|
||||
sys.path.insert(0, str(_BENCH))
|
||||
sys.path.insert(0, str(_BENCH.parent / "sim"))
|
||||
sys.path.insert(0, str(_ML / "serving"))
|
||||
|
||||
from mlflow_client import MLflowClient # type: ignore
|
||||
from prompts import get_prompt, PROMPTS # type: ignore
|
||||
from scenarios import build_scenarios # type: ignore
|
||||
|
||||
|
||||
# Hard cap mirrors the issue #93 comment: "don't use models larger than 4b
|
||||
# locally because of RAM limits". A regex cheap-match on the tag handles
|
||||
# the common ``name:Nb`` and ``name:N.Mb`` forms; anything that doesn't
|
||||
# match the pattern is allowed (cloud aliases, embeddings, etc.).
|
||||
_SIZE_TAG = re.compile(r":(\d+(?:\.\d+)?)b\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _model_too_big(model: str, max_b: float = 4.0) -> bool:
|
||||
m = _SIZE_TAG.search(model)
|
||||
if not m:
|
||||
return False
|
||||
return float(m.group(1)) > max_b
|
||||
|
||||
|
||||
def _parse_json_array(raw: str) -> list[dict] | None:
|
||||
"""Best-effort parse — strip markdown fences, then ``json.loads``."""
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
parts = text.split("```")
|
||||
text = parts[1] if len(parts) > 1 else text
|
||||
if text.lstrip().lower().startswith("json"):
|
||||
text = text.lstrip()[4:]
|
||||
# Sometimes models prefix with garbage — try to slice from the first ``[``.
|
||||
if not text.lstrip().startswith("["):
|
||||
i = text.find("[")
|
||||
if i >= 0:
|
||||
text = text[i:]
|
||||
try:
|
||||
v = json.loads(text)
|
||||
return v if isinstance(v, list) else None
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _embed(text: str, ollama_url: str) -> list[float] | None:
|
||||
"""Use nomic-embed-text via Ollama for diversity scoring. ~250MB,
|
||||
safe to load alongside any 4B chat model thanks to ``keep_alive=0``.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(trust_env=False, timeout=30.0) as c:
|
||||
r = c.post(
|
||||
f"{ollama_url}/api/embeddings",
|
||||
json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json().get("embedding")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _mean_pairwise_cosine(vecs: list[list[float]]) -> float:
|
||||
if len(vecs) < 2:
|
||||
return 0.0
|
||||
|
||||
def cos(a: list[float], b: list[float]) -> float:
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(x * x for x in b))
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return sum(x * y for x, y in zip(a, b)) / (na * nb)
|
||||
|
||||
n = len(vecs)
|
||||
total, count = 0.0, 0
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
total += cos(vecs[i], vecs[j])
|
||||
count += 1
|
||||
return total / count if count else 0.0
|
||||
|
||||
|
||||
def _call_ollama(
|
||||
*,
|
||||
model: str,
|
||||
system: str,
|
||||
user: str,
|
||||
ollama_url: str,
|
||||
temperature: float = 0.7,
|
||||
) -> tuple[str, dict]:
|
||||
"""Direct call to Ollama. Returns (raw_text, telemetry).
|
||||
|
||||
``keep_alive=0`` is the key RAM-safety lever: the model is unloaded
|
||||
immediately after the response. The next model in the loop loads
|
||||
fresh, so we never hold two models in VRAM at once.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
body = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
"stream": False,
|
||||
"keep_alive": 0,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
with httpx.Client(trust_env=False, timeout=180.0) as c:
|
||||
r = c.post(f"{ollama_url}/api/chat", json=body)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
raw = data.get("message", {}).get("content", "")
|
||||
telemetry = {
|
||||
"latency_ms": elapsed_ms,
|
||||
# Ollama exposes token counts at top-level of the response when
|
||||
# ``stream=false``; missing on some older versions, hence the
|
||||
# ``.get`` defaults.
|
||||
"prompt_tokens": float(data.get("prompt_eval_count", 0) or 0),
|
||||
"completion_tokens": float(data.get("eval_count", 0) or 0),
|
||||
}
|
||||
return raw, telemetry
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="oO tip-generation benchmark — Phase A")
|
||||
parser.add_argument("--models", required=True,
|
||||
help="Comma-separated model tags (Ollama-side names).")
|
||||
parser.add_argument("--prompts", default=",".join(PROMPTS.keys()),
|
||||
help="Comma-separated prompt versions from ml/serving/prompts.py.")
|
||||
parser.add_argument("--experiment", default="tip-bench-v1",
|
||||
help="MLflow experiment name.")
|
||||
parser.add_argument("--n-tips", type=int, default=5,
|
||||
help="Tips to request per scenario.")
|
||||
parser.add_argument("--temperature", type=float, default=0.7)
|
||||
parser.add_argument("--ollama-url", default=os.environ.get("OLLAMA_URL", "http://localhost:11434"))
|
||||
parser.add_argument("--mlflow-url", default=os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"))
|
||||
parser.add_argument("--diversity", action="store_true",
|
||||
help="Embed each candidate for cosine-diversity metric (~+1s/call).")
|
||||
parser.add_argument("--max-model-b", type=float, default=4.0,
|
||||
help="Reject models tagged larger than this many billion params.")
|
||||
parser.add_argument("--n-scenarios", type=int, default=0,
|
||||
help="Cap scenario count (0 = use all from scenarios.py).")
|
||||
parser.add_argument("--rubric", default=str(_BENCH / "rubric.md"),
|
||||
help="Rubric file logged once per experiment.")
|
||||
args = parser.parse_args()
|
||||
|
||||
models = [m.strip() for m in args.models.split(",") if m.strip()]
|
||||
prompts = [p.strip() for p in args.prompts.split(",") if p.strip()]
|
||||
too_big = [m for m in models if _model_too_big(m, args.max_model_b)]
|
||||
if too_big:
|
||||
print(f"ERROR: models exceed --max-model-b={args.max_model_b}: {too_big}", file=sys.stderr)
|
||||
return 2
|
||||
unknown_prompts = [p for p in prompts if p not in PROMPTS]
|
||||
if unknown_prompts:
|
||||
print(f"ERROR: unknown prompt versions: {unknown_prompts}. "
|
||||
f"Available: {list(PROMPTS)}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
scenarios = build_scenarios()
|
||||
if args.n_scenarios and args.n_scenarios < len(scenarios):
|
||||
scenarios = scenarios[:args.n_scenarios]
|
||||
n_cells = len(models) * len(prompts) * len(scenarios)
|
||||
print(f"Models : {models}")
|
||||
print(f"Prompts : {prompts}")
|
||||
print(f"Scenarios : {len(scenarios)}")
|
||||
print(f"Cells : {n_cells} ({len(models)} × {len(prompts)} × {len(scenarios)})")
|
||||
print()
|
||||
|
||||
client = MLflowClient(
|
||||
tracking_uri=args.mlflow_url,
|
||||
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
|
||||
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
|
||||
)
|
||||
exp_id = client.get_or_create_experiment(args.experiment)
|
||||
print(f"MLflow experiment: {args.experiment} (id={exp_id})")
|
||||
|
||||
rubric_text = Path(args.rubric).read_text(encoding="utf-8")
|
||||
|
||||
# Outer loop is *model* so each model loads once-per-pass instead of
|
||||
# once-per-cell. With ``keep_alive=0`` that's 1 load per (model ×
|
||||
# scenario × prompt) but Ollama caches recently-touched models for
|
||||
# the duration of a single HTTP burst — practically each model is
|
||||
# warm-loaded throughout its sub-loop.
|
||||
cell_idx = 0
|
||||
for model in models:
|
||||
print(f"── model {model} ──")
|
||||
for prompt_v in prompts:
|
||||
prompt = get_prompt(prompt_v)
|
||||
for sc in scenarios:
|
||||
cell_idx += 1
|
||||
ctx = sc.to_prompt_context()
|
||||
|
||||
class _Ctx:
|
||||
pass
|
||||
_ctx = _Ctx()
|
||||
_ctx.tasks = ctx["tasks"]
|
||||
_ctx.hour_of_day = ctx["hour_of_day"]
|
||||
_ctx.day_of_week = ctx["day_of_week"]
|
||||
_ctx.extra = ctx["extra"]
|
||||
user_msg = prompt.build_user(_ctx, args.n_tips)
|
||||
|
||||
run_id = client.create_run(
|
||||
exp_id,
|
||||
run_name=f"{model}__{prompt_v}__{sc.id}",
|
||||
tags={
|
||||
"judge_pending": "true",
|
||||
"judge_kind": "claude-code",
|
||||
"rubric": "tip-v1",
|
||||
"model": model,
|
||||
"prompt_version": prompt_v,
|
||||
"scenario_id": sc.id,
|
||||
"persona": sc.persona.name,
|
||||
},
|
||||
)
|
||||
client.log_params(run_id, {
|
||||
"model": model,
|
||||
"prompt_version": prompt_v,
|
||||
"scenario_id": sc.id,
|
||||
"persona": sc.persona.name,
|
||||
"hour_of_day": sc.hour_of_day,
|
||||
"day_of_week": sc.day_of_week,
|
||||
"n_tips_requested": args.n_tips,
|
||||
"temperature": args.temperature,
|
||||
})
|
||||
|
||||
try:
|
||||
raw, telemetry = _call_ollama(
|
||||
model=model,
|
||||
system=prompt.system,
|
||||
user=user_msg,
|
||||
ollama_url=args.ollama_url,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" [{cell_idx}/{n_cells}] {model} {prompt_v} {sc.id}: ERROR {e}")
|
||||
client.set_tag(run_id, "error", str(e)[:500])
|
||||
client.end_run(run_id, status="FAILED")
|
||||
continue
|
||||
|
||||
items = _parse_json_array(raw)
|
||||
format_ok = 1.0 if items is not None else 0.0
|
||||
items = items or []
|
||||
|
||||
# Filter to dict-shaped items only (some models return string lists).
|
||||
cand_dicts = [
|
||||
{
|
||||
"id": str(it.get("id", f"tip-{i}")),
|
||||
"content": str(it.get("content", "")),
|
||||
"rationale": str(it.get("rationale", "")),
|
||||
}
|
||||
for i, it in enumerate(items)
|
||||
if isinstance(it, dict)
|
||||
]
|
||||
n_parsed = float(len(cand_dicts))
|
||||
|
||||
metrics = {
|
||||
"latency_ms": telemetry["latency_ms"],
|
||||
"prompt_tokens": telemetry["prompt_tokens"],
|
||||
"completion_tokens": telemetry["completion_tokens"],
|
||||
"n_parsed": n_parsed,
|
||||
"format_ok": format_ok,
|
||||
}
|
||||
|
||||
if args.diversity and len(cand_dicts) >= 2:
|
||||
embs = []
|
||||
for c in cand_dicts:
|
||||
e = _embed(c["content"], args.ollama_url)
|
||||
if e:
|
||||
embs.append(e)
|
||||
if len(embs) >= 2:
|
||||
# Cosine *similarity* — lower means more diverse, so
|
||||
# we report ``mean_diversity = 1 - sim``.
|
||||
sim = _mean_pairwise_cosine(embs)
|
||||
metrics["mean_diversity"] = 1.0 - sim
|
||||
|
||||
client.log_metrics(run_id, metrics)
|
||||
client.log_text(run_id, prompt.system + "\n\n---\n\n" + user_msg, "prompt.txt")
|
||||
client.log_text(run_id, json.dumps(cand_dicts, indent=2), "candidates.json")
|
||||
client.log_text(run_id, raw[:9_000], "raw.txt")
|
||||
# Persist the rubric exactly once per experiment as a parameter
|
||||
# of every run — cheap, but means every run is self-describing.
|
||||
client.set_tag(run_id, "rubric_md", rubric_text[: client._TAG_VALUE_LIMIT])
|
||||
|
||||
client.end_run(run_id)
|
||||
print(f" [{cell_idx:>3}/{n_cells}] {model:18s} {prompt_v:12s} {sc.id:24s} "
|
||||
f"lat={metrics['latency_ms']:>6.0f}ms parsed={int(n_parsed)}/{args.n_tips} "
|
||||
f"fmt={int(format_ok)}")
|
||||
|
||||
print()
|
||||
print(f"Phase A complete. Run judge_cli.py --export to score pending runs.")
|
||||
print(f" python ml/experiments/bench/judge_cli.py --experiment {args.experiment} \\")
|
||||
print(f" --export /tmp/oo-bench-judge-requests.json")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user