- sim_runs schema: add judge_mode, n_policies, airflow_dag_run_id, mlflow_run_id columns - admin health endpoint: add mlflow + airflow checks (Basic auth for Airflow API) - admin nav: add Simulations page link; rename section label - runner.py: optional MLflow experiment tracking; multi-policy support - sim_dag.py: Airflow DAG for offline sim pipeline - admin simulate page + API client methods for sim runs - shared-types tsconfig: exclude test files from build Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
610 lines
24 KiB
Python
610 lines
24 KiB
Python
"""
|
||
oO simulation runner — compares two recommendation policies.
|
||
|
||
Judge modes:
|
||
rule Deterministic persona-based rules (default, no external deps)
|
||
llm Claude Haiku via Anthropic API (requires ANTHROPIC_API_KEY)
|
||
claude-code Two-phase: Claude Code acts as the judge (you are the judge)
|
||
|
||
Usage — rule/llm (single pass):
|
||
python runner.py --n-users 5 --n-rounds 10 --no-llm
|
||
python runner.py --n-users 5 --n-rounds 10
|
||
|
||
Usage — claude-code judge (two phases):
|
||
# Phase 1: score candidates, write judgment requests
|
||
python runner.py --judge claude-code --phase score \\
|
||
--n-users 5 --n-rounds 10 --out /tmp/oo-cc-sim.json
|
||
|
||
# (Claude Code reads /tmp/oo-cc-sim-requests.json and writes /tmp/oo-cc-sim-responses.json)
|
||
|
||
# Phase 2: apply responses, run rewards, produce results
|
||
python runner.py --judge claude-code --phase reward --plan /tmp/oo-cc-sim-plan.json \\
|
||
--out /tmp/oo-cc-sim.json
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
import sys
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
import httpx
|
||
|
||
from llm_judge import ACTIONS, infer_reward, judge
|
||
from personas import PERSONAS, Persona
|
||
from task_generator import generate_task_pool
|
||
|
||
try:
|
||
import mlflow
|
||
_MLFLOW_AVAILABLE = True
|
||
except ImportError:
|
||
_MLFLOW_AVAILABLE = False
|
||
|
||
POLICY_SCORE_ENDPOINTS: dict[str, str] = {
|
||
"linucb-v1": "/score",
|
||
"egreedy-v1": "/score/egreedy",
|
||
"egreedy-v2": "/score/egreedy/v2",
|
||
}
|
||
POLICY_REWARD_ENDPOINTS: dict[str, str] = {
|
||
"linucb-v1": "/reward",
|
||
"egreedy-v1": "/reward/egreedy",
|
||
"egreedy-v2": "/reward/egreedy/v2",
|
||
}
|
||
|
||
|
||
def _call_score(
|
||
client: httpx.Client, ml_url: str, policy: str,
|
||
user_id: str, tasks: list[dict], hour: int, dow: int,
|
||
profile_features: dict | None = None,
|
||
) -> dict | None:
|
||
endpoint = POLICY_SCORE_ENDPOINTS.get(policy, "/score")
|
||
body: dict = {
|
||
"user_id": user_id,
|
||
"candidates": [
|
||
{
|
||
"id": t["id"], "content": t["content"], "source": t["source"],
|
||
"source_id": None,
|
||
"features": {
|
||
"hour_of_day": hour,
|
||
"is_overdue": t["features"]["is_overdue"],
|
||
"task_age_days": t["features"]["task_age_days"],
|
||
"priority": t["features"]["priority"],
|
||
},
|
||
}
|
||
for t in tasks
|
||
],
|
||
"context": {"hour_of_day": hour, "day_of_week": dow},
|
||
}
|
||
if profile_features is not None:
|
||
body["profile_features"] = profile_features
|
||
try:
|
||
r = client.post(f"{ml_url}{endpoint}", json=body, timeout=5.0)
|
||
r.raise_for_status()
|
||
return r.json()
|
||
except Exception as e:
|
||
print(f" [warn] score {policy}: {e}", file=sys.stderr)
|
||
return None
|
||
|
||
|
||
def _call_reward(
|
||
client: httpx.Client, ml_url: str, policy: str,
|
||
user_id: str, tip_id: str, reward: float, features: dict,
|
||
day_of_week: int = 0,
|
||
profile_features: dict | None = None,
|
||
) -> None:
|
||
endpoint = POLICY_REWARD_ENDPOINTS.get(policy, "/reward")
|
||
body: dict = {
|
||
"user_id": user_id, "tip_id": tip_id, "reward": reward,
|
||
"features": features, "day_of_week": day_of_week,
|
||
}
|
||
if profile_features is not None:
|
||
body["profile_features"] = profile_features
|
||
try:
|
||
client.post(f"{ml_url}{endpoint}", json=body, timeout=5.0)
|
||
except Exception as e:
|
||
print(f" [warn] reward {policy}: {e}", file=sys.stderr)
|
||
|
||
|
||
# ── Standard single-pass runner (rule / llm modes) ─────────────────────────
|
||
|
||
def _init_mlflow(mlflow_url: str | None, experiment: str) -> str | None:
|
||
"""Set up MLflow tracking and return the active run_id, or None if unavailable."""
|
||
if not _MLFLOW_AVAILABLE or not mlflow_url:
|
||
return None
|
||
try:
|
||
mlflow.set_tracking_uri(mlflow_url)
|
||
mlflow.set_experiment(experiment)
|
||
return "ready"
|
||
except Exception as e:
|
||
print(f" [warn] MLflow init failed: {e}", file=sys.stderr)
|
||
return None
|
||
|
||
|
||
def run_simulation(
|
||
n_users: int, n_rounds: int, tasks_per_round: int,
|
||
ml_url: str, policies: list[str], use_llm: bool, seed: int,
|
||
mlflow_url: str | None = None, mlflow_experiment: str = "bandit_simulation",
|
||
) -> dict:
|
||
rng = random.Random(seed)
|
||
run_id = str(uuid.uuid4())[:8]
|
||
started_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||
|
||
_init_mlflow(mlflow_url, mlflow_experiment)
|
||
|
||
user_personas = [
|
||
(f"sim-{run_id}-u{i}", PERSONAS[i % len(PERSONAS)])
|
||
for i in range(n_users)
|
||
]
|
||
|
||
acc: dict[str, dict] = {
|
||
p: {
|
||
"total_reward": 0.0, "n_pulls": 0,
|
||
"cumulative_rewards": [],
|
||
"action_counts": {a: 0 for a in ACTIONS},
|
||
}
|
||
for p in policies
|
||
}
|
||
events: list[dict] = []
|
||
|
||
mlflow_run_id: str | None = None
|
||
mlflow_ctx = (
|
||
mlflow.start_run(run_name=run_id)
|
||
if (_MLFLOW_AVAILABLE and mlflow_url)
|
||
else None
|
||
)
|
||
|
||
try:
|
||
if mlflow_ctx:
|
||
active = mlflow_ctx.__enter__()
|
||
mlflow_run_id = active.info.run_id
|
||
mlflow.log_params({
|
||
"n_users": n_users,
|
||
"n_rounds": n_rounds,
|
||
"tasks_per_round": tasks_per_round,
|
||
"policies": ",".join(policies),
|
||
"judge": "llm" if use_llm else "rule",
|
||
"seed": seed,
|
||
})
|
||
|
||
with httpx.Client(trust_env=False) as client:
|
||
for rnd in range(n_rounds):
|
||
hour = rng.randint(6, 22)
|
||
dow = rng.randint(0, 6)
|
||
round_rewards = {p: 0.0 for p in policies}
|
||
|
||
for user_id, persona in user_personas:
|
||
seed_tasks = rnd * 997 + abs(hash(user_id)) % 997
|
||
tasks = generate_task_pool(n=tasks_per_round, seed=seed_tasks)
|
||
profile = persona.profile_features(hour) if hasattr(persona, "profile_features") else None
|
||
|
||
for policy in policies:
|
||
p_user = f"{user_id}-{policy}"
|
||
scored = _call_score(client, ml_url, policy, p_user, tasks, hour, dow,
|
||
profile_features=profile)
|
||
if not scored:
|
||
continue
|
||
tip_id = scored.get("tip_id")
|
||
tip = next((t for t in tasks if t["id"] == tip_id), None)
|
||
if not tip:
|
||
continue
|
||
|
||
action, dwell_ms, reward = judge(persona, tip, hour, dow, rng, use_llm=use_llm)
|
||
_call_reward(client, ml_url, policy, p_user, tip_id, reward, {
|
||
"hour_of_day": hour,
|
||
"is_overdue": tip["features"]["is_overdue"],
|
||
"task_age_days": tip["features"]["task_age_days"],
|
||
"priority": tip["features"]["priority"],
|
||
}, day_of_week=dow, profile_features=profile)
|
||
|
||
acc[policy]["total_reward"] += reward
|
||
acc[policy]["n_pulls"] += 1
|
||
acc[policy]["action_counts"][action] += 1
|
||
round_rewards[policy] += reward
|
||
events.append({
|
||
"round": rnd, "user_id": user_id, "persona": persona.name,
|
||
"policy": policy, "tip_content": tip["content"],
|
||
"priority": tip["features"]["priority"],
|
||
"is_overdue": tip["features"]["is_overdue"],
|
||
"action": action, "dwell_ms": dwell_ms, "reward": reward,
|
||
"hour": hour, "day_of_week": dow,
|
||
})
|
||
|
||
for p in policies:
|
||
prev = acc[p]["cumulative_rewards"][-1] if acc[p]["cumulative_rewards"] else 0.0
|
||
acc[p]["cumulative_rewards"].append(prev + round_rewards[p])
|
||
|
||
if mlflow_ctx:
|
||
for p in policies:
|
||
mlflow.log_metric(f"{p}_cumulative_reward",
|
||
acc[p]["cumulative_rewards"][-1], step=rnd)
|
||
|
||
mode = "llm" if use_llm else "rule"
|
||
print(f" Round {rnd+1:>3}/{n_rounds} [{mode}] " + " ".join(
|
||
f"{p}={acc[p]['cumulative_rewards'][-1]:+.2f}" for p in policies
|
||
))
|
||
|
||
result = _build_result(run_id, started_at, policies, acc, events,
|
||
n_users, n_rounds, tasks_per_round, use_llm, seed)
|
||
result["mlflow_run_id"] = mlflow_run_id
|
||
|
||
if mlflow_ctx:
|
||
for p, s in result["summary"].items():
|
||
mlflow.log_metrics({
|
||
f"{p}_total_reward": s["total_reward"],
|
||
f"{p}_mean_reward": s["mean_reward"],
|
||
f"{p}_n_pulls": s["n_pulls"],
|
||
})
|
||
mlflow.set_tag("winner", result["winner"])
|
||
|
||
return result
|
||
|
||
finally:
|
||
if mlflow_ctx:
|
||
mlflow_ctx.__exit__(None, None, None)
|
||
|
||
|
||
# ── Claude Code judge — phase 1: score ─────────────────────────────────────
|
||
|
||
def run_score_phase(
|
||
n_users: int, n_rounds: int, tasks_per_round: int,
|
||
ml_url: str, policies: list[str], seed: int, out_path: str,
|
||
) -> None:
|
||
"""Score all candidates and write judgment requests for Claude Code."""
|
||
rng = random.Random(seed)
|
||
run_id = str(uuid.uuid4())[:8]
|
||
started_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||
|
||
user_personas = [
|
||
(f"sim-{run_id}-u{i}", PERSONAS[i % len(PERSONAS)])
|
||
for i in range(n_users)
|
||
]
|
||
|
||
plan_rounds: list[dict] = []
|
||
judgment_requests: list[dict] = []
|
||
|
||
print(f"[Phase 1] Scoring {n_rounds} rounds × {n_users} users × {len(policies)} policies…")
|
||
|
||
with httpx.Client(trust_env=False) as client:
|
||
for rnd in range(n_rounds):
|
||
hour = rng.randint(6, 22)
|
||
dow = rng.randint(0, 6)
|
||
round_sessions: list[dict] = []
|
||
|
||
for user_id, persona in user_personas:
|
||
seed_tasks = rnd * 997 + abs(hash(user_id)) % 997
|
||
tasks = generate_task_pool(n=tasks_per_round, seed=seed_tasks)
|
||
|
||
profile = persona.profile_features(hour) if hasattr(persona, "profile_features") else None
|
||
|
||
for policy in policies:
|
||
p_user = f"{user_id}-{policy}"
|
||
scored = _call_score(client, ml_url, policy, p_user, tasks, hour, dow,
|
||
profile_features=profile)
|
||
if not scored:
|
||
continue
|
||
tip_id = scored.get("tip_id")
|
||
tip = next((t for t in tasks if t["id"] == tip_id), None)
|
||
if not tip:
|
||
continue
|
||
|
||
req_id = f"r{rnd}_{user_id.split('-')[-1]}_{policy}"
|
||
round_sessions.append({
|
||
"req_id": req_id,
|
||
"p_user": p_user,
|
||
"policy": policy,
|
||
"user_id": user_id,
|
||
"persona_name": persona.name,
|
||
"tip_id": tip_id,
|
||
"tip_features": tip["features"],
|
||
"tip_content": tip["content"],
|
||
"ml_score": scored.get("score"),
|
||
"profile_features": profile,
|
||
})
|
||
|
||
judgment_requests.append({
|
||
"id": req_id,
|
||
"round": rnd,
|
||
"hour": hour,
|
||
"day_of_week": dow,
|
||
"policy": policy,
|
||
"persona_name": persona.name,
|
||
"persona_description": persona.description,
|
||
"tip_content": tip["content"],
|
||
"priority": tip["features"]["priority"],
|
||
"is_overdue": tip["features"]["is_overdue"],
|
||
"age_days": tip["features"]["task_age_days"],
|
||
"ml_score": scored.get("score"),
|
||
})
|
||
|
||
plan_rounds.append({
|
||
"round": rnd, "hour": hour, "dow": dow,
|
||
"sessions": round_sessions,
|
||
})
|
||
print(f" Round {rnd+1:>3}/{n_rounds}: {len(round_sessions)} sessions scored")
|
||
|
||
plan = {
|
||
"run_id": run_id,
|
||
"started_at": started_at,
|
||
"config": {
|
||
"n_users": n_users, "n_rounds": n_rounds,
|
||
"tasks_per_round": tasks_per_round, "policies": policies,
|
||
"use_llm": False, "seed": seed,
|
||
},
|
||
"user_personas": [
|
||
{"user_id": uid, "persona_name": p.name, "persona_description": p.description}
|
||
for uid, p in user_personas
|
||
],
|
||
"rounds": plan_rounds,
|
||
}
|
||
|
||
base = out_path.replace(".json", "")
|
||
plan_path = f"{base}-plan.json"
|
||
requests_path = f"{base}-requests.json"
|
||
responses_path = f"{base}-responses.json"
|
||
|
||
Path(plan_path).write_text(json.dumps(plan, indent=2))
|
||
Path(requests_path).write_text(json.dumps(judgment_requests, indent=2))
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print(f"Phase 1 complete — {len(judgment_requests)} judgment requests.")
|
||
print()
|
||
print(f" Requests : {requests_path}")
|
||
print(f" Plan : {plan_path}")
|
||
print()
|
||
print('Claude Code: read the requests file, judge each tip for the persona,')
|
||
print(f'then write your responses to: {responses_path}')
|
||
print()
|
||
print('Response format: { "<id>": "<action>" | { "action": "<action>", "dwell_ms": <int> } }')
|
||
print('Valid actions: done | snooze | dismiss')
|
||
print()
|
||
print('For "done", optionally specify dwell_ms (ms between tip appearing and user acting):')
|
||
print(' { "r0_u0_linucb-v1": { "action": "done", "dwell_ms": 45000 } } # magic zone')
|
||
print(' { "r0_u0_linucb-v1": "snooze" } # plain string also ok (uses default 60s dwell for done)')
|
||
print()
|
||
print('Reward is inferred from action + dwell_ms:')
|
||
print(' dismiss → -1.0')
|
||
print(' snooze → 0.1')
|
||
print(' done < 15s → -0.3 (stale task)')
|
||
print(' done 15s–2min → 1.0 (magic!)')
|
||
print(' done 2–10min → 0.6 (good)')
|
||
print(' done > 10min → 0.3 (eventually)')
|
||
print()
|
||
print('Then run Phase 2:')
|
||
print(f' python runner.py --judge claude-code --phase reward \\')
|
||
print(f' --plan {plan_path} --out {out_path}')
|
||
|
||
|
||
# ── Claude Code judge — phase 2: reward ────────────────────────────────────
|
||
|
||
def run_reward_phase(plan_path: str, out_path: str, ml_url: str) -> dict:
|
||
"""Apply Claude Code judgments, send reward signals, compute metrics."""
|
||
plan = json.loads(Path(plan_path).read_text())
|
||
base = plan_path.replace("-plan.json", "")
|
||
responses_path = f"{base}-responses.json"
|
||
|
||
if not Path(responses_path).exists():
|
||
print(f"ERROR: responses file not found: {responses_path}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
raw_responses = json.loads(Path(responses_path).read_text())
|
||
|
||
# Responses can be either { id: "action" } or { id: { action, dwell_ms } }
|
||
def _parse_response(v) -> tuple[str, int]:
|
||
if isinstance(v, dict):
|
||
return v["action"], int(v.get("dwell_ms", 60_000))
|
||
return str(v), 60_000 # plain string → assume 60s dwell for "done"
|
||
|
||
responses: dict[str, tuple[str, int]] = {k: _parse_response(v) for k, v in raw_responses.items()}
|
||
|
||
invalid = {k: v[0] for k, v in responses.items() if v[0] not in ACTIONS}
|
||
if invalid:
|
||
print(f"ERROR: invalid actions in responses: {invalid}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
policies: list[str] = plan["config"]["policies"]
|
||
acc: dict[str, dict] = {
|
||
p: {
|
||
"total_reward": 0.0, "n_pulls": 0,
|
||
"cumulative_rewards": [],
|
||
"action_counts": {a: 0 for a in ACTIONS},
|
||
}
|
||
for p in policies
|
||
}
|
||
events: list[dict] = []
|
||
persona_map = {u["user_id"]: u["persona_name"] for u in plan["user_personas"]}
|
||
missing_responses = 0
|
||
|
||
print(f"[Phase 2] Applying {len(responses)} judgments → reward calls…")
|
||
|
||
with httpx.Client(trust_env=False) as client:
|
||
for rnd_data in plan["rounds"]:
|
||
rnd = rnd_data["round"]
|
||
round_rewards = {p: 0.0 for p in policies}
|
||
|
||
for session in rnd_data["sessions"]:
|
||
req_id = session["req_id"]
|
||
resp = responses.get(req_id)
|
||
if not resp:
|
||
print(f" [warn] no response for {req_id}, defaulting to snooze")
|
||
action, dwell_ms = "snooze", 10_000
|
||
missing_responses += 1
|
||
else:
|
||
action, dwell_ms = resp
|
||
|
||
reward = infer_reward(action, dwell_ms)
|
||
_call_reward(
|
||
client, ml_url, session["policy"], session["p_user"],
|
||
session["tip_id"], reward,
|
||
{"hour_of_day": rnd_data["hour"], **session["tip_features"]},
|
||
day_of_week=rnd_data["dow"],
|
||
profile_features=session.get("profile_features"),
|
||
)
|
||
|
||
p = session["policy"]
|
||
acc[p]["total_reward"] += reward
|
||
acc[p]["n_pulls"] += 1
|
||
acc[p]["action_counts"][action] += 1
|
||
round_rewards[p] += reward
|
||
|
||
events.append({
|
||
"round": rnd,
|
||
"user_id": session["user_id"],
|
||
"persona": persona_map.get(session["user_id"], "?"),
|
||
"policy": p,
|
||
"tip_content": session["tip_content"],
|
||
"priority": session["tip_features"]["priority"],
|
||
"is_overdue": session["tip_features"]["is_overdue"],
|
||
"action": action,
|
||
"dwell_ms": dwell_ms,
|
||
"reward": reward,
|
||
"hour": rnd_data["hour"],
|
||
"day_of_week": rnd_data["dow"],
|
||
})
|
||
|
||
for p in policies:
|
||
prev = acc[p]["cumulative_rewards"][-1] if acc[p]["cumulative_rewards"] else 0.0
|
||
acc[p]["cumulative_rewards"].append(prev + round_rewards[p])
|
||
|
||
print(f" Round {rnd+1:>3}/{plan['config']['n_rounds']} [cc] " + " ".join(
|
||
f"{p}={acc[p]['cumulative_rewards'][-1]:+.2f}" for p in policies
|
||
))
|
||
|
||
if missing_responses:
|
||
print(f" [warn] {missing_responses} requests had no response (defaulted to snooze)")
|
||
|
||
cfg = plan["config"]
|
||
result = _build_result(
|
||
plan["run_id"], plan["started_at"], policies, acc, events,
|
||
cfg["n_users"], cfg["n_rounds"], cfg["tasks_per_round"],
|
||
use_llm=False, seed=cfg["seed"],
|
||
)
|
||
result["judge_mode"] = "claude-code"
|
||
Path(out_path).write_text(json.dumps(result, indent=2))
|
||
return result
|
||
|
||
|
||
# ── Shared result builder ───────────────────────────────────────────────────
|
||
|
||
def _build_result(
|
||
run_id: str, started_at: str, policies: list[str],
|
||
acc: dict, events: list[dict],
|
||
n_users: int, n_rounds: int, tasks_per_round: int,
|
||
use_llm: bool, seed: int,
|
||
) -> dict:
|
||
summary = {
|
||
p: {
|
||
"total_reward": acc[p]["total_reward"],
|
||
"mean_reward": (
|
||
acc[p]["total_reward"] / acc[p]["n_pulls"]
|
||
if acc[p]["n_pulls"] > 0 else 0.0
|
||
),
|
||
"n_pulls": acc[p]["n_pulls"],
|
||
"cumulative_rewards": acc[p]["cumulative_rewards"],
|
||
"action_counts": acc[p]["action_counts"],
|
||
}
|
||
for p in policies
|
||
}
|
||
winner = max(policies, key=lambda p: summary[p]["total_reward"])
|
||
|
||
persona_breakdown: dict[str, dict] = {}
|
||
for ev in events:
|
||
pname = ev["persona"]
|
||
pol = ev["policy"]
|
||
persona_breakdown.setdefault(pname, {}).setdefault(pol, {"reward": 0.0, "n": 0})
|
||
persona_breakdown[pname][pol]["reward"] += ev["reward"]
|
||
persona_breakdown[pname][pol]["n"] += 1
|
||
|
||
return {
|
||
"run_id": run_id,
|
||
"started_at": started_at,
|
||
"finished_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||
"config": {
|
||
"n_users": n_users, "n_rounds": n_rounds,
|
||
"tasks_per_round": tasks_per_round, "policies": policies,
|
||
"use_llm": use_llm, "seed": seed,
|
||
},
|
||
"summary": summary,
|
||
"winner": winner,
|
||
"persona_breakdown": persona_breakdown,
|
||
"events": events,
|
||
}
|
||
|
||
|
||
# ── CLI ─────────────────────────────────────────────────────────────────────
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="oO simulation runner")
|
||
parser.add_argument("--judge", choices=["rule", "llm", "claude-code"], default="rule")
|
||
parser.add_argument("--phase", choices=["score", "reward"], default=None,
|
||
help="For --judge claude-code only")
|
||
parser.add_argument("--plan", default=None,
|
||
help="Plan file path (for --judge claude-code --phase reward)")
|
||
parser.add_argument("--n-users", type=int, default=5)
|
||
parser.add_argument("--n-rounds", type=int, default=20)
|
||
parser.add_argument("--tasks-per-round", type=int, default=8)
|
||
parser.add_argument("--ml-url", default="http://localhost:5001")
|
||
parser.add_argument("--policies", nargs="+", default=["linucb-v1", "egreedy-v1"])
|
||
parser.add_argument("--no-llm", action="store_true",
|
||
help="Alias for --judge rule (backwards compat)")
|
||
parser.add_argument("--seed", type=int, default=42)
|
||
parser.add_argument("--out", default=None)
|
||
parser.add_argument("--mlflow-url", default=os.environ.get("MLFLOW_TRACKING_URI"),
|
||
help="MLflow tracking URI (e.g. http://mlflow:5000/mlflow)")
|
||
parser.add_argument("--mlflow-experiment", default="bandit_simulation")
|
||
args = parser.parse_args()
|
||
|
||
if args.no_llm:
|
||
args.judge = "rule"
|
||
|
||
out_path = args.out or f"/tmp/oo-sim-{int(time.time())}.json"
|
||
|
||
if args.judge == "claude-code":
|
||
if args.phase == "score":
|
||
run_score_phase(
|
||
n_users=args.n_users, n_rounds=args.n_rounds,
|
||
tasks_per_round=args.tasks_per_round, ml_url=args.ml_url,
|
||
policies=args.policies, seed=args.seed, out_path=out_path,
|
||
)
|
||
elif args.phase == "reward":
|
||
if not args.plan:
|
||
print("ERROR: --plan is required for --phase reward", file=sys.stderr)
|
||
sys.exit(1)
|
||
result = run_reward_phase(args.plan, out_path, args.ml_url)
|
||
print()
|
||
print(f"Winner : {result['winner']}")
|
||
for p, s in result["summary"].items():
|
||
print(f" {p:20s} total={s['total_reward']:+.2f} mean={s['mean_reward']:+.4f} pulls={s['n_pulls']}")
|
||
print(f"Results: {out_path}")
|
||
else:
|
||
print("ERROR: --judge claude-code requires --phase score or --phase reward",
|
||
file=sys.stderr)
|
||
sys.exit(1)
|
||
else:
|
||
use_llm = (args.judge == "llm")
|
||
print(f"oO simulation: {args.n_users} users × {args.n_rounds} rounds")
|
||
print(f"Policies : {args.policies}")
|
||
print(f"ML URL : {args.ml_url}")
|
||
print(f"Judge : {args.judge}")
|
||
print()
|
||
|
||
result = run_simulation(
|
||
n_users=args.n_users, n_rounds=args.n_rounds,
|
||
tasks_per_round=args.tasks_per_round, ml_url=args.ml_url,
|
||
policies=args.policies, use_llm=use_llm, seed=args.seed,
|
||
mlflow_url=args.mlflow_url, mlflow_experiment=args.mlflow_experiment,
|
||
)
|
||
Path(out_path).write_text(json.dumps(result, indent=2))
|
||
print()
|
||
print(f"Winner : {result['winner']}")
|
||
for p, s in result["summary"].items():
|
||
print(f" {p:20s} total={s['total_reward']:+.2f} mean={s['mean_reward']:+.4f} pulls={s['n_pulls']}")
|
||
print(f"Results: {out_path}")
|