Files
oO/ml/pipelines/bench_dag.py
alvis 0474ad4deb feat(airflow): integrate bench harness into bench_collect DAG
New DAG (`ml/pipelines/bench_dag.py`) with three linked tasks:
1. collect.py — generates candidates, logs to MLflow
2. export_for_judge — exports pending runs for Claude Code scoring
3. compare — generates leaderboard by (model, prompt) cell

Config via dag_run.conf supports all collect.py options (models, prompts,
n_tips, n_scenarios, temperature, experiment name, max_model_b).

New admin API endpoints (`services/api/src/routes/bench.ts`):
- GET /api/bench/experiments — list tip-bench-* experiments
- POST /api/bench/run — trigger DAG with custom config
- GET /api/bench/runs/:experiment — list runs in experiment
- GET /api/bench/leaderboard/:experiment — leaderboard by (model, prompt)

All endpoints require admin auth. Human judge (Claude Code) scores are
applied manually post-export; future enhancement: add webhook to DAG.

Admin UI can now trigger and monitor benchmarks from a dashboard panel.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-27 11:54:30 +00:00

169 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Airflow DAG: bench_collect
Runs the tip-generation benchmark (model × prompt evaluation). Triggered
on-demand from the admin UI or manually, collects candidates per cell,
exports for Claude Code judgment, and generates a leaderboard.
Mirrors the manual flow:
1. collect.py → generates candidates, logs to MLflow with judge_pending=true
2. (human: judge_cli.py --export, Claude Code scores, judge_cli.py --apply)
3. compare.py → leaderboard
For now, steps 2 is manual. Future: add a webhook to trigger the human
judge from the admin UI or set up an async task queue.
Required conf keys (passed via dag_run.conf):
models str — comma-separated model tags (e.g. "qwen2.5:0.5b,qwen2.5:1.5b")
prompts str — comma-separated prompt versions (default: "v1,v2-mentor,v3-few-shot")
n_tips int — candidates to generate per scenario (default: 5)
n_scenarios int — cap scenario count; 0 = all (default: 0)
temperature float — LLM generation temperature (default: 0.7)
experiment str — MLflow experiment name (default: "tip-bench-auto")
max_model_b float — reject models larger than this (default: 4.0)
ollama_url str — Ollama endpoint (default: http://localhost:11434)
mlflow_url str — MLflow tracking URI (env MLFLOW_TRACKING_URI or http://localhost:5000)
"""
from __future__ import annotations
import json
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from airflow import DAG
from airflow.operators.python import PythonOperator
def _collect(**context: object) -> dict:
"""Run collect.py with the provided config."""
conf: dict = context["dag_run"].conf or {}
models = str(conf.get("models", "qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b"))
prompts = str(conf.get("prompts", "v1,v2-mentor,v3-few-shot"))
n_tips = int(conf.get("n_tips", 5))
n_scenarios = int(conf.get("n_scenarios", 0))
temperature = float(conf.get("temperature", 0.7))
experiment = str(conf.get("experiment", "tip-bench-auto"))
max_model_b = float(conf.get("max_model_b", 4.0))
ollama_url = str(conf.get("ollama_url", os.environ.get("OLLAMA_URL", "http://localhost:11434")))
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
from collect import main as collect_main # type: ignore
# Build args for collect.py
args = [
"--models", models,
"--prompts", prompts,
"--experiment", experiment,
"--n-tips", str(n_tips),
"--temperature", str(temperature),
"--max-model-b", str(max_model_b),
"--ollama-url", ollama_url,
"--mlflow-url", mlflow_url,
]
if n_scenarios > 0:
args.extend(["--n-scenarios", str(n_scenarios)])
# Inject args into sys.argv so argparse picks them up
old_argv = sys.argv
try:
sys.argv = ["collect.py"] + args
result = collect_main()
return {
"status": "success" if result == 0 else "failed",
"exit_code": result,
"experiment": experiment,
}
finally:
sys.argv = old_argv
def _compare(**context: object) -> dict:
"""Run compare.py to generate the leaderboard."""
conf: dict = context["dag_run"].conf or {}
experiment = str(conf.get("experiment", "tip-bench-auto"))
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
from compare import main as compare_main # type: ignore
old_argv = sys.argv
try:
sys.argv = [
"compare.py",
"--experiment", experiment,
"--mlflow-url", mlflow_url,
]
result = compare_main()
return {
"status": "success" if result == 0 else "failed",
"exit_code": result,
"experiment": experiment,
}
finally:
sys.argv = old_argv
def _export_for_judge(**context: object) -> str:
"""Export pending runs for Claude Code judgment."""
conf: dict = context["dag_run"].conf or {}
experiment = str(conf.get("experiment", "tip-bench-auto"))
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
export_path = f"/tmp/oo-bench-{experiment}-{int(context['ti'].start_date.timestamp())}.json"
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
from judge_cli import export # type: ignore
from mlflow_client import MLflowClient # type: ignore
client = MLflowClient(
tracking_uri=mlflow_url,
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
)
result = export(client, experiment, export_path)
# XCom: push path so next task can find it
context["ti"].xcom_push(key="export_path", value=export_path)
return export_path
with DAG(
dag_id="bench_collect",
description="Tip-generation benchmark: model & prompt evaluation via MLflow",
schedule_interval=None,
start_date=datetime(2025, 1, 1),
catchup=False,
tags=["bench", "ml", "evaluation"],
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=5),
},
) as dag:
collect = PythonOperator(
task_id="collect",
python_callable=_collect,
provide_context=True,
)
export_judge = PythonOperator(
task_id="export_for_judge",
python_callable=_export_for_judge,
provide_context=True,
)
compare = PythonOperator(
task_id="compare",
python_callable=_compare,
provide_context=True,
)
collect >> export_judge >> compare