""" Airflow DAG: bandit_sim Runs a bandit policy simulation and logs results to MLflow. Triggered on-demand from the oO admin panel or manually from the Airflow UI. Required conf keys (passed via dag_run.conf): sim_run_id str — oO SQLite run ID for callback correlation n_users int — number of synthetic users n_rounds int — rounds per user tasks_per_round int — candidate pool size per round policies list — policy names to compare judge_mode str — "rule" | "llm" ml_url str — ml/serving URL (e.g. http://ml-serving:8000) mlflow_url str — MLflow tracking URI (e.g. http://mlflow:5000/mlflow) callback_url str — oO API callback endpoint internal_token str — x-internal-token header value """ from __future__ import annotations import json import os import sys from datetime import datetime, timedelta from airflow import DAG from airflow.operators.python import PythonOperator def _run_sim(**context: object) -> dict: conf: dict = context["dag_run"].conf or {} n_users = int(conf.get("n_users", 5)) n_rounds = int(conf.get("n_rounds", 20)) tasks_per_round = int(conf.get("tasks_per_round", 8)) policies = list(conf.get("policies", ["linucb-v1", "egreedy-v1"])) judge_mode = str(conf.get("judge_mode", "rule")) ml_url = str(conf.get("ml_url", "http://ml-serving:8000")) mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", ""))) mlflow_experiment = "bandit_simulation" sys.path.insert(0, "/opt/airflow/ml/experiments/sim") from runner import run_simulation # type: ignore[import] use_llm = judge_mode == "llm" result = run_simulation( n_users=n_users, n_rounds=n_rounds, tasks_per_round=tasks_per_round, ml_url=ml_url, policies=policies, use_llm=use_llm, seed=42, mlflow_url=mlflow_url or None, mlflow_experiment=mlflow_experiment, ) return result def _callback(**context: object) -> None: import httpx conf: dict = context["dag_run"].conf or {} callback_url: str = str(conf.get("callback_url", "")) internal_token: str = str(conf.get("internal_token", "")) if not callback_url or not internal_token: print("No callback_url or internal_token — skipping result push.", flush=True) return result: dict = context["ti"].xcom_pull(task_ids="run_sim") if not result: print("No result from run_sim task — callback skipped.", flush=True) return payload = { "summary": result.get("summary", {}), "winner": result.get("winner", ""), "persona_breakdown": result.get("persona_breakdown", {}), "events": result.get("events", []), "mlflow_run_id": result.get("mlflow_run_id"), } try: r = httpx.post( callback_url, json=payload, headers={"x-internal-token": internal_token}, timeout=30.0, ) r.raise_for_status() print(f"Callback OK: {r.status_code}", flush=True) except Exception as exc: print(f"Callback failed: {exc}", flush=True) raise with DAG( dag_id="bandit_sim", description="On-demand bandit policy simulation with MLflow tracking", schedule_interval=None, start_date=datetime(2025, 1, 1), catchup=False, tags=["bandit", "simulation", "ml"], default_args={ "retries": 1, "retry_delay": timedelta(minutes=2), }, ) as dag: run_sim = PythonOperator( task_id="run_sim", python_callable=_run_sim, provide_context=True, ) push_results = PythonOperator( task_id="push_results", python_callable=_callback, provide_context=True, ) run_sim >> push_results