- sim_runs schema: add judge_mode, n_policies, airflow_dag_run_id, mlflow_run_id columns - admin health endpoint: add mlflow + airflow checks (Basic auth for Airflow API) - admin nav: add Simulations page link; rename section label - runner.py: optional MLflow experiment tracking; multi-policy support - sim_dag.py: Airflow DAG for offline sim pipeline - admin simulate page + API client methods for sim runs - shared-types tsconfig: exclude test files from build Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
"""
|
|
Airflow DAG: bandit_sim
|
|
|
|
Runs a bandit policy simulation and logs results to MLflow.
|
|
Triggered on-demand from the oO admin panel or manually from the Airflow UI.
|
|
|
|
Required conf keys (passed via dag_run.conf):
|
|
sim_run_id str — oO SQLite run ID for callback correlation
|
|
n_users int — number of synthetic users
|
|
n_rounds int — rounds per user
|
|
tasks_per_round int — candidate pool size per round
|
|
policies list — policy names to compare
|
|
judge_mode str — "rule" | "llm"
|
|
ml_url str — ml/serving URL (e.g. http://ml-serving:8000)
|
|
mlflow_url str — MLflow tracking URI (e.g. http://mlflow:5000/mlflow)
|
|
callback_url str — oO API callback endpoint
|
|
internal_token str — x-internal-token header value
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
|
|
from airflow import DAG
|
|
from airflow.operators.python import PythonOperator
|
|
|
|
|
|
def _run_sim(**context: object) -> dict:
|
|
conf: dict = context["dag_run"].conf or {}
|
|
|
|
n_users = int(conf.get("n_users", 5))
|
|
n_rounds = int(conf.get("n_rounds", 20))
|
|
tasks_per_round = int(conf.get("tasks_per_round", 8))
|
|
policies = list(conf.get("policies", ["linucb-v1", "egreedy-v1"]))
|
|
judge_mode = str(conf.get("judge_mode", "rule"))
|
|
ml_url = str(conf.get("ml_url", "http://ml-serving:8000"))
|
|
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "")))
|
|
mlflow_experiment = "bandit_simulation"
|
|
|
|
sys.path.insert(0, "/opt/airflow/ml/experiments/sim")
|
|
from runner import run_simulation # type: ignore[import]
|
|
|
|
use_llm = judge_mode == "llm"
|
|
result = run_simulation(
|
|
n_users=n_users,
|
|
n_rounds=n_rounds,
|
|
tasks_per_round=tasks_per_round,
|
|
ml_url=ml_url,
|
|
policies=policies,
|
|
use_llm=use_llm,
|
|
seed=42,
|
|
mlflow_url=mlflow_url or None,
|
|
mlflow_experiment=mlflow_experiment,
|
|
)
|
|
return result
|
|
|
|
|
|
def _callback(**context: object) -> None:
|
|
import httpx
|
|
|
|
conf: dict = context["dag_run"].conf or {}
|
|
callback_url: str = str(conf.get("callback_url", ""))
|
|
internal_token: str = str(conf.get("internal_token", ""))
|
|
|
|
if not callback_url or not internal_token:
|
|
print("No callback_url or internal_token — skipping result push.", flush=True)
|
|
return
|
|
|
|
result: dict = context["ti"].xcom_pull(task_ids="run_sim")
|
|
if not result:
|
|
print("No result from run_sim task — callback skipped.", flush=True)
|
|
return
|
|
|
|
payload = {
|
|
"summary": result.get("summary", {}),
|
|
"winner": result.get("winner", ""),
|
|
"persona_breakdown": result.get("persona_breakdown", {}),
|
|
"events": result.get("events", []),
|
|
"mlflow_run_id": result.get("mlflow_run_id"),
|
|
}
|
|
|
|
try:
|
|
r = httpx.post(
|
|
callback_url,
|
|
json=payload,
|
|
headers={"x-internal-token": internal_token},
|
|
timeout=30.0,
|
|
)
|
|
r.raise_for_status()
|
|
print(f"Callback OK: {r.status_code}", flush=True)
|
|
except Exception as exc:
|
|
print(f"Callback failed: {exc}", flush=True)
|
|
raise
|
|
|
|
|
|
with DAG(
|
|
dag_id="bandit_sim",
|
|
description="On-demand bandit policy simulation with MLflow tracking",
|
|
schedule_interval=None,
|
|
start_date=datetime(2025, 1, 1),
|
|
catchup=False,
|
|
tags=["bandit", "simulation", "ml"],
|
|
default_args={
|
|
"retries": 1,
|
|
"retry_delay": timedelta(minutes=2),
|
|
},
|
|
) as dag:
|
|
|
|
run_sim = PythonOperator(
|
|
task_id="run_sim",
|
|
python_callable=_run_sim,
|
|
provide_context=True,
|
|
)
|
|
|
|
push_results = PythonOperator(
|
|
task_id="push_results",
|
|
python_callable=_callback,
|
|
provide_context=True,
|
|
)
|
|
|
|
run_sim >> push_results
|