feat(simulate): MLflow tracking, Airflow DAG integration, health checks for mlflow/airflow
- sim_runs schema: add judge_mode, n_policies, airflow_dag_run_id, mlflow_run_id columns - admin health endpoint: add mlflow + airflow checks (Basic auth for Airflow API) - admin nav: add Simulations page link; rename section label - runner.py: optional MLflow experiment tracking; multi-policy support - sim_dag.py: Airflow DAG for offline sim pipeline - admin simulate page + API client methods for sim runs - shared-types tsconfig: exclude test files from build Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
220
apps/admin/src/app/simulate/page.tsx
Normal file
220
apps/admin/src/app/simulate/page.tsx
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
'use client';
|
||||||
|
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
import { AdminShell } from '@/components/AdminShell';
|
||||||
|
import {
|
||||||
|
startSimulation,
|
||||||
|
getSimulationRuns,
|
||||||
|
getSimulationRun,
|
||||||
|
SimRun,
|
||||||
|
} from '@/lib/api';
|
||||||
|
|
||||||
|
const POLICIES = ['linucb-v1', 'egreedy-v1', 'egreedy-v2'];
|
||||||
|
const mlflowBase = process.env.NEXT_PUBLIC_MLFLOW_URL ?? '/mlflow';
|
||||||
|
const airflowBase = process.env.NEXT_PUBLIC_AIRFLOW_URL ?? '/airflow';
|
||||||
|
|
||||||
|
function mlflowRunUrl(runId: string) {
|
||||||
|
return `${mlflowBase}/#/experiments/1/runs/${runId}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function airflowRunUrl(dagRunId: string) {
|
||||||
|
return `${airflowBase}/dags/bandit_sim/grid?dag_run_id=${encodeURIComponent(dagRunId)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function StatusBadge({ status }: { status: string }) {
|
||||||
|
const cls: Record<string, string> = {
|
||||||
|
running: 'bg-blue-900 text-blue-300 border-blue-800',
|
||||||
|
done: 'bg-green-900 text-green-300 border-green-800',
|
||||||
|
failed: 'bg-red-900 text-red-300 border-red-800',
|
||||||
|
pending: 'bg-gray-800 text-gray-400 border-gray-700',
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<span className={`text-xs px-2 py-0.5 rounded border ${cls[status] ?? cls.pending}`}>
|
||||||
|
{status}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function SummaryRow({ run }: { run: SimRun }) {
|
||||||
|
const summary = run.summaryJson ? JSON.parse(run.summaryJson) as Record<string, { total_reward: number; mean_reward: number; n_pulls: number }> : null;
|
||||||
|
return (
|
||||||
|
<div className="bg-gray-900 border border-gray-800 rounded p-4 space-y-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="space-y-0.5">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="font-mono text-xs text-gray-500">{run.id}</span>
|
||||||
|
<StatusBadge status={run.status} />
|
||||||
|
{run.winner && <span className="text-xs text-indigo-400">winner: {run.winner}</span>}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-gray-600">
|
||||||
|
{run.nUsers}u × {run.nRounds}r × {run.tasksPerRound}t/r — {run.judgeMode} judge
|
||||||
|
{' · '}{new Date(run.createdAt).toLocaleString()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2 flex-shrink-0">
|
||||||
|
{run.mlflowRunId && (
|
||||||
|
<a href={mlflowRunUrl(run.mlflowRunId)} target="_blank" rel="noreferrer"
|
||||||
|
className="text-xs text-indigo-400 hover:underline">MLflow ↗</a>
|
||||||
|
)}
|
||||||
|
{run.airflowDagRunId && (
|
||||||
|
<a href={airflowRunUrl(run.airflowDagRunId)} target="_blank" rel="noreferrer"
|
||||||
|
className="text-xs text-indigo-400 hover:underline">Airflow ↗</a>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{summary && (
|
||||||
|
<div className="grid grid-cols-2 gap-2 pt-1 lg:grid-cols-3">
|
||||||
|
{Object.entries(summary).map(([policy, s]) => (
|
||||||
|
<div key={policy} className={`rounded border p-2 text-xs ${policy === run.winner ? 'border-indigo-700 bg-indigo-950' : 'border-gray-800'}`}>
|
||||||
|
<div className="font-mono font-medium text-gray-300 mb-1">{policy}</div>
|
||||||
|
<div className="text-gray-500 space-y-0.5">
|
||||||
|
<div>total <span className="text-gray-300">{s.total_reward.toFixed(2)}</span></div>
|
||||||
|
<div>mean <span className="text-gray-300">{s.mean_reward.toFixed(4)}</span></div>
|
||||||
|
<div>pulls <span className="text-gray-300">{s.n_pulls}</span></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function SimulatePage() {
|
||||||
|
const [runs, setRuns] = useState<SimRun[]>([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [launching, setLaunching] = useState(false);
|
||||||
|
const [error, setError] = useState('');
|
||||||
|
const [msg, setMsg] = useState('');
|
||||||
|
|
||||||
|
const [nUsers, setNUsers] = useState(5);
|
||||||
|
const [nRounds, setNRounds] = useState(20);
|
||||||
|
const [tasksPerRound, setTasksPerRound] = useState(8);
|
||||||
|
const [judgeMode, setJudgeMode] = useState<'rule' | 'llm'>('rule');
|
||||||
|
const [selectedPolicies, setSelectedPolicies] = useState<string[]>(['linucb-v1', 'egreedy-v1']);
|
||||||
|
|
||||||
|
const refresh = () =>
|
||||||
|
getSimulationRuns()
|
||||||
|
.then((r) => setRuns(r.runs))
|
||||||
|
.catch((e) => setError(e.message))
|
||||||
|
.finally(() => setLoading(false));
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
refresh();
|
||||||
|
const t = setInterval(refresh, 8_000);
|
||||||
|
return () => clearInterval(t);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const togglePolicy = (p: string) =>
|
||||||
|
setSelectedPolicies((prev) =>
|
||||||
|
prev.includes(p) ? prev.filter((x) => x !== p) : [...prev, p],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleLaunch = async () => {
|
||||||
|
if (selectedPolicies.length < 2) { setError('Select at least 2 policies.'); return; }
|
||||||
|
setLaunching(true); setError(''); setMsg('');
|
||||||
|
try {
|
||||||
|
const r = await startSimulation({ nUsers, nRounds, tasksPerRound, judgeMode, policies: selectedPolicies });
|
||||||
|
setMsg(r.airflow_dag_run_id
|
||||||
|
? `Launched via Airflow — dag_run_id: ${r.airflow_dag_run_id}`
|
||||||
|
: `Launched locally — run id: ${r.id}`);
|
||||||
|
await refresh();
|
||||||
|
} catch (e: unknown) {
|
||||||
|
setError((e as Error).message);
|
||||||
|
} finally {
|
||||||
|
setLaunching(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AdminShell>
|
||||||
|
<div className="space-y-8 max-w-4xl">
|
||||||
|
<h1 className="text-xl font-semibold">Simulations</h1>
|
||||||
|
{error && <p className="text-red-400 text-sm">{error}</p>}
|
||||||
|
{msg && <p className="text-green-400 text-sm">{msg}</p>}
|
||||||
|
|
||||||
|
{/* Launch form */}
|
||||||
|
<section className="bg-gray-900 border border-gray-800 rounded p-5 space-y-4">
|
||||||
|
<h2 className="text-base font-medium text-gray-300">New simulation</h2>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-3 gap-4 text-sm">
|
||||||
|
<label className="space-y-1">
|
||||||
|
<span className="text-gray-500">Users</span>
|
||||||
|
<input type="number" min={1} max={50} value={nUsers}
|
||||||
|
onChange={(e) => setNUsers(Number(e.target.value))}
|
||||||
|
className="w-full bg-gray-950 border border-gray-700 rounded px-2 py-1 text-gray-300" />
|
||||||
|
</label>
|
||||||
|
<label className="space-y-1">
|
||||||
|
<span className="text-gray-500">Rounds</span>
|
||||||
|
<input type="number" min={1} max={200} value={nRounds}
|
||||||
|
onChange={(e) => setNRounds(Number(e.target.value))}
|
||||||
|
className="w-full bg-gray-950 border border-gray-700 rounded px-2 py-1 text-gray-300" />
|
||||||
|
</label>
|
||||||
|
<label className="space-y-1">
|
||||||
|
<span className="text-gray-500">Tasks/round</span>
|
||||||
|
<input type="number" min={1} max={20} value={tasksPerRound}
|
||||||
|
onChange={(e) => setTasksPerRound(Number(e.target.value))}
|
||||||
|
className="w-full bg-gray-950 border border-gray-700 rounded px-2 py-1 text-gray-300" />
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-1 text-sm">
|
||||||
|
<span className="text-gray-500">Policies (select ≥ 2)</span>
|
||||||
|
<div className="flex gap-2 flex-wrap pt-1">
|
||||||
|
{POLICIES.map((p) => (
|
||||||
|
<button key={p} onClick={() => togglePolicy(p)}
|
||||||
|
className={`px-3 py-1 rounded border text-xs font-mono ${
|
||||||
|
selectedPolicies.includes(p)
|
||||||
|
? 'bg-indigo-900 border-indigo-700 text-indigo-200'
|
||||||
|
: 'border-gray-700 text-gray-500 hover:border-gray-500'
|
||||||
|
}`}>
|
||||||
|
{p}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-1 text-sm">
|
||||||
|
<span className="text-gray-500">Judge</span>
|
||||||
|
<div className="flex gap-2 pt-1">
|
||||||
|
{(['rule', 'llm'] as const).map((m) => (
|
||||||
|
<button key={m} onClick={() => setJudgeMode(m)}
|
||||||
|
className={`px-3 py-1 rounded border text-xs ${
|
||||||
|
judgeMode === m
|
||||||
|
? 'bg-gray-700 border-gray-500 text-white'
|
||||||
|
: 'border-gray-700 text-gray-500 hover:border-gray-500'
|
||||||
|
}`}>
|
||||||
|
{m}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
{judgeMode === 'llm' && (
|
||||||
|
<p className="text-xs text-yellow-600 mt-1">LLM judge requires ANTHROPIC_API_KEY in ml/serving env.</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button onClick={handleLaunch} disabled={launching}
|
||||||
|
className="bg-indigo-600 hover:bg-indigo-500 disabled:opacity-50 text-white rounded px-4 py-2 text-sm">
|
||||||
|
{launching ? 'Launching…' : 'Launch simulation'}
|
||||||
|
</button>
|
||||||
|
<p className="text-xs text-gray-600">
|
||||||
|
Runs via <a href={airflowBase} target="_blank" rel="noreferrer" className="text-indigo-500 hover:underline">Airflow</a> (mlops profile) when available; falls back to local subprocess.
|
||||||
|
Results logged to <a href={mlflowBase} target="_blank" rel="noreferrer" className="text-indigo-500 hover:underline">MLflow</a>.
|
||||||
|
</p>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
{/* Run history */}
|
||||||
|
<section className="space-y-3">
|
||||||
|
<h2 className="text-base font-medium text-gray-300">
|
||||||
|
Run history
|
||||||
|
{loading && <span className="text-xs text-gray-600 ml-2">loading…</span>}
|
||||||
|
</h2>
|
||||||
|
{runs.length === 0 && !loading && (
|
||||||
|
<p className="text-gray-600 text-sm">No simulations yet.</p>
|
||||||
|
)}
|
||||||
|
{runs.map((r) => <SummaryRow key={r.id} run={r} />)}
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
</AdminShell>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import Link from 'next/link';
|
import Link from 'next/link';
|
||||||
import { usePathname } from 'next/navigation';
|
import { usePathname } from 'next/navigation';
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
|
||||||
const mlflowUrl = process.env.NEXT_PUBLIC_MLFLOW_URL ?? '/mlflow';
|
const mlflowUrl = process.env.NEXT_PUBLIC_MLFLOW_URL ?? '/mlflow';
|
||||||
const airflowUrl = process.env.NEXT_PUBLIC_AIRFLOW_URL ?? '/airflow';
|
const airflowUrl = process.env.NEXT_PUBLIC_AIRFLOW_URL ?? '/airflow';
|
||||||
@@ -10,6 +11,7 @@ type NavItem = {
|
|||||||
href: string;
|
href: string;
|
||||||
label: string;
|
label: string;
|
||||||
external?: boolean;
|
external?: boolean;
|
||||||
|
svcName?: string; // key in the health services map
|
||||||
};
|
};
|
||||||
|
|
||||||
type NavSection = {
|
type NavSection = {
|
||||||
@@ -31,10 +33,11 @@ const NAV: NavSection[] = [
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: 'Recommender status',
|
label: 'Recommender',
|
||||||
items: [
|
items: [
|
||||||
{ href: '/tips', label: 'Tips' },
|
{ href: '/tips', label: 'Tips' },
|
||||||
{ href: '/reward-analytics', label: 'Rewards' },
|
{ href: '/reward-analytics', label: 'Rewards' },
|
||||||
|
{ href: '/simulate', label: 'Simulations' },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -50,14 +53,33 @@ const NAV: NavSection[] = [
|
|||||||
label: 'Resources',
|
label: 'Resources',
|
||||||
items: [
|
items: [
|
||||||
{ href: '/docs', label: 'Docs' },
|
{ href: '/docs', label: 'Docs' },
|
||||||
{ href: mlflowUrl, label: 'MLflow ↗', external: true },
|
{ href: mlflowUrl, label: 'MLflow ↗', external: true, svcName: 'mlflow' },
|
||||||
{ href: airflowUrl, label: 'Airflow ↗', external: true },
|
{ href: airflowUrl, label: 'Airflow ↗', external: true, svcName: 'airflow' },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const STATUS_DOT: Record<string, string> = {
|
||||||
|
ok: 'bg-green-500',
|
||||||
|
degraded: 'bg-yellow-400',
|
||||||
|
down: 'bg-red-500',
|
||||||
|
};
|
||||||
|
|
||||||
export function AdminShell({ children }: { children: React.ReactNode }) {
|
export function AdminShell({ children }: { children: React.ReactNode }) {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
|
const [svcStatus, setSvcStatus] = useState<Record<string, string>>({});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetch('/api/admin/health', { credentials: 'include' })
|
||||||
|
.then((r) => r.json())
|
||||||
|
.then((data: { services?: { name: string; status: string }[] }) => {
|
||||||
|
const map: Record<string, string> = {};
|
||||||
|
for (const s of data.services ?? []) map[s.name] = s.status;
|
||||||
|
setSvcStatus(map);
|
||||||
|
})
|
||||||
|
.catch(() => {});
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex min-h-screen">
|
<div className="flex min-h-screen">
|
||||||
{/* Sidebar */}
|
{/* Sidebar */}
|
||||||
@@ -83,13 +105,19 @@ export function AdminShell({ children }: { children: React.ReactNode }) {
|
|||||||
const active =
|
const active =
|
||||||
!item.external &&
|
!item.external &&
|
||||||
(item.href === '/' ? pathname === '/' : pathname.startsWith(item.href));
|
(item.href === '/' ? pathname === '/' : pathname.startsWith(item.href));
|
||||||
const className = `flex items-center px-3 py-2 rounded text-sm transition-colors ${
|
const className = `flex items-center gap-2 px-3 py-2 rounded text-sm transition-colors ${
|
||||||
active
|
active
|
||||||
? 'bg-gray-800 text-white font-medium'
|
? 'bg-gray-800 text-white font-medium'
|
||||||
: item.external
|
: item.external
|
||||||
? 'text-gray-500 hover:text-white hover:bg-gray-900'
|
? 'text-gray-500 hover:text-white hover:bg-gray-900'
|
||||||
: 'text-gray-400 hover:text-white hover:bg-gray-900'
|
: 'text-gray-400 hover:text-white hover:bg-gray-900'
|
||||||
}`;
|
}`;
|
||||||
|
const dot = item.svcName
|
||||||
|
? svcStatus[item.svcName]
|
||||||
|
? <span className={`inline-block w-1.5 h-1.5 rounded-full flex-shrink-0 ${STATUS_DOT[svcStatus[item.svcName]] ?? STATUS_DOT.down}`} />
|
||||||
|
: <span className="inline-block w-1.5 h-1.5 rounded-full flex-shrink-0 bg-gray-700" />
|
||||||
|
: null;
|
||||||
|
|
||||||
return item.external ? (
|
return item.external ? (
|
||||||
<a
|
<a
|
||||||
key={item.href}
|
key={item.href}
|
||||||
@@ -98,6 +126,7 @@ export function AdminShell({ children }: { children: React.ReactNode }) {
|
|||||||
rel="noreferrer"
|
rel="noreferrer"
|
||||||
className={className}
|
className={className}
|
||||||
>
|
>
|
||||||
|
{dot}
|
||||||
{item.label}
|
{item.label}
|
||||||
</a>
|
</a>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@@ -262,3 +262,49 @@ export function saveQuery(name: string, querySql: string) {
|
|||||||
export function deleteSavedQuery(id: string) {
|
export function deleteSavedQuery(id: string) {
|
||||||
return apiFetch<{ ok: boolean }>(`/admin/saved-queries/${id}`, { method: 'DELETE' });
|
return apiFetch<{ ok: boolean }>(`/admin/saved-queries/${id}`, { method: 'DELETE' });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Simulations ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
export interface SimRun {
|
||||||
|
id: string;
|
||||||
|
policyA: string;
|
||||||
|
policyB: string;
|
||||||
|
nUsers: number;
|
||||||
|
nRounds: number;
|
||||||
|
tasksPerRound: number;
|
||||||
|
judgeMode: string;
|
||||||
|
nPolicies: number;
|
||||||
|
status: 'pending' | 'running' | 'done' | 'failed';
|
||||||
|
summaryJson: string | null;
|
||||||
|
winner: string | null;
|
||||||
|
personaBreakdownJson: string | null;
|
||||||
|
airflowDagRunId: string | null;
|
||||||
|
mlflowRunId: string | null;
|
||||||
|
createdAt: string;
|
||||||
|
finishedAt: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SimStartRequest {
|
||||||
|
nUsers?: number;
|
||||||
|
nRounds?: number;
|
||||||
|
tasksPerRound?: number;
|
||||||
|
judgeMode?: 'rule' | 'llm';
|
||||||
|
policies?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function startSimulation(req: SimStartRequest) {
|
||||||
|
return apiFetch<{ id: string; status: string; airflow_dag_run_id?: string }>(
|
||||||
|
'/admin/simulate/start',
|
||||||
|
{ method: 'POST', body: JSON.stringify(req) },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSimulationRuns() {
|
||||||
|
return apiFetch<{ runs: SimRun[] }>('/admin/simulate/runs');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSimulationRun(id: string) {
|
||||||
|
return apiFetch<{ run: SimRun & { isRunning: boolean }; events: unknown[] }>(
|
||||||
|
`/admin/simulate/${id}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -26,6 +26,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -40,6 +41,12 @@ from llm_judge import ACTIONS, infer_reward, judge
|
|||||||
from personas import PERSONAS, Persona
|
from personas import PERSONAS, Persona
|
||||||
from task_generator import generate_task_pool
|
from task_generator import generate_task_pool
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
_MLFLOW_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_MLFLOW_AVAILABLE = False
|
||||||
|
|
||||||
POLICY_SCORE_ENDPOINTS: dict[str, str] = {
|
POLICY_SCORE_ENDPOINTS: dict[str, str] = {
|
||||||
"linucb-v1": "/score",
|
"linucb-v1": "/score",
|
||||||
"egreedy-v1": "/score/egreedy",
|
"egreedy-v1": "/score/egreedy",
|
||||||
@@ -107,14 +114,30 @@ def _call_reward(
|
|||||||
|
|
||||||
# ── Standard single-pass runner (rule / llm modes) ─────────────────────────
|
# ── Standard single-pass runner (rule / llm modes) ─────────────────────────
|
||||||
|
|
||||||
|
def _init_mlflow(mlflow_url: str | None, experiment: str) -> str | None:
|
||||||
|
"""Set up MLflow tracking and return the active run_id, or None if unavailable."""
|
||||||
|
if not _MLFLOW_AVAILABLE or not mlflow_url:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
mlflow.set_tracking_uri(mlflow_url)
|
||||||
|
mlflow.set_experiment(experiment)
|
||||||
|
return "ready"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [warn] MLflow init failed: {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def run_simulation(
|
def run_simulation(
|
||||||
n_users: int, n_rounds: int, tasks_per_round: int,
|
n_users: int, n_rounds: int, tasks_per_round: int,
|
||||||
ml_url: str, policies: list[str], use_llm: bool, seed: int,
|
ml_url: str, policies: list[str], use_llm: bool, seed: int,
|
||||||
|
mlflow_url: str | None = None, mlflow_experiment: str = "bandit_simulation",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
run_id = str(uuid.uuid4())[:8]
|
run_id = str(uuid.uuid4())[:8]
|
||||||
started_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
started_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
|
||||||
|
_init_mlflow(mlflow_url, mlflow_experiment)
|
||||||
|
|
||||||
user_personas = [
|
user_personas = [
|
||||||
(f"sim-{run_id}-u{i}", PERSONAS[i % len(PERSONAS)])
|
(f"sim-{run_id}-u{i}", PERSONAS[i % len(PERSONAS)])
|
||||||
for i in range(n_users)
|
for i in range(n_users)
|
||||||
@@ -130,6 +153,26 @@ def run_simulation(
|
|||||||
}
|
}
|
||||||
events: list[dict] = []
|
events: list[dict] = []
|
||||||
|
|
||||||
|
mlflow_run_id: str | None = None
|
||||||
|
mlflow_ctx = (
|
||||||
|
mlflow.start_run(run_name=run_id)
|
||||||
|
if (_MLFLOW_AVAILABLE and mlflow_url)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mlflow_ctx:
|
||||||
|
active = mlflow_ctx.__enter__()
|
||||||
|
mlflow_run_id = active.info.run_id
|
||||||
|
mlflow.log_params({
|
||||||
|
"n_users": n_users,
|
||||||
|
"n_rounds": n_rounds,
|
||||||
|
"tasks_per_round": tasks_per_round,
|
||||||
|
"policies": ",".join(policies),
|
||||||
|
"judge": "llm" if use_llm else "rule",
|
||||||
|
"seed": seed,
|
||||||
|
})
|
||||||
|
|
||||||
with httpx.Client(trust_env=False) as client:
|
with httpx.Client(trust_env=False) as client:
|
||||||
for rnd in range(n_rounds):
|
for rnd in range(n_rounds):
|
||||||
hour = rng.randint(6, 22)
|
hour = rng.randint(6, 22)
|
||||||
@@ -139,8 +182,6 @@ def run_simulation(
|
|||||||
for user_id, persona in user_personas:
|
for user_id, persona in user_personas:
|
||||||
seed_tasks = rnd * 997 + abs(hash(user_id)) % 997
|
seed_tasks = rnd * 997 + abs(hash(user_id)) % 997
|
||||||
tasks = generate_task_pool(n=tasks_per_round, seed=seed_tasks)
|
tasks = generate_task_pool(n=tasks_per_round, seed=seed_tasks)
|
||||||
|
|
||||||
# Per-persona profile features for v2 (synthetic for sim — see ADR-0012)
|
|
||||||
profile = persona.profile_features(hour) if hasattr(persona, "profile_features") else None
|
profile = persona.profile_features(hour) if hasattr(persona, "profile_features") else None
|
||||||
|
|
||||||
for policy in policies:
|
for policy in policies:
|
||||||
@@ -179,13 +220,34 @@ def run_simulation(
|
|||||||
prev = acc[p]["cumulative_rewards"][-1] if acc[p]["cumulative_rewards"] else 0.0
|
prev = acc[p]["cumulative_rewards"][-1] if acc[p]["cumulative_rewards"] else 0.0
|
||||||
acc[p]["cumulative_rewards"].append(prev + round_rewards[p])
|
acc[p]["cumulative_rewards"].append(prev + round_rewards[p])
|
||||||
|
|
||||||
|
if mlflow_ctx:
|
||||||
|
for p in policies:
|
||||||
|
mlflow.log_metric(f"{p}_cumulative_reward",
|
||||||
|
acc[p]["cumulative_rewards"][-1], step=rnd)
|
||||||
|
|
||||||
mode = "llm" if use_llm else "rule"
|
mode = "llm" if use_llm else "rule"
|
||||||
print(f" Round {rnd+1:>3}/{n_rounds} [{mode}] " + " ".join(
|
print(f" Round {rnd+1:>3}/{n_rounds} [{mode}] " + " ".join(
|
||||||
f"{p}={acc[p]['cumulative_rewards'][-1]:+.2f}" for p in policies
|
f"{p}={acc[p]['cumulative_rewards'][-1]:+.2f}" for p in policies
|
||||||
))
|
))
|
||||||
|
|
||||||
return _build_result(run_id, started_at, policies, acc, events,
|
result = _build_result(run_id, started_at, policies, acc, events,
|
||||||
n_users, n_rounds, tasks_per_round, use_llm, seed)
|
n_users, n_rounds, tasks_per_round, use_llm, seed)
|
||||||
|
result["mlflow_run_id"] = mlflow_run_id
|
||||||
|
|
||||||
|
if mlflow_ctx:
|
||||||
|
for p, s in result["summary"].items():
|
||||||
|
mlflow.log_metrics({
|
||||||
|
f"{p}_total_reward": s["total_reward"],
|
||||||
|
f"{p}_mean_reward": s["mean_reward"],
|
||||||
|
f"{p}_n_pulls": s["n_pulls"],
|
||||||
|
})
|
||||||
|
mlflow.set_tag("winner", result["winner"])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if mlflow_ctx:
|
||||||
|
mlflow_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
# ── Claude Code judge — phase 1: score ─────────────────────────────────────
|
# ── Claude Code judge — phase 1: score ─────────────────────────────────────
|
||||||
@@ -494,6 +556,9 @@ if __name__ == "__main__":
|
|||||||
help="Alias for --judge rule (backwards compat)")
|
help="Alias for --judge rule (backwards compat)")
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--out", default=None)
|
parser.add_argument("--out", default=None)
|
||||||
|
parser.add_argument("--mlflow-url", default=os.environ.get("MLFLOW_TRACKING_URI"),
|
||||||
|
help="MLflow tracking URI (e.g. http://mlflow:5000/mlflow)")
|
||||||
|
parser.add_argument("--mlflow-experiment", default="bandit_simulation")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.no_llm:
|
if args.no_llm:
|
||||||
@@ -534,6 +599,7 @@ if __name__ == "__main__":
|
|||||||
n_users=args.n_users, n_rounds=args.n_rounds,
|
n_users=args.n_users, n_rounds=args.n_rounds,
|
||||||
tasks_per_round=args.tasks_per_round, ml_url=args.ml_url,
|
tasks_per_round=args.tasks_per_round, ml_url=args.ml_url,
|
||||||
policies=args.policies, use_llm=use_llm, seed=args.seed,
|
policies=args.policies, use_llm=use_llm, seed=args.seed,
|
||||||
|
mlflow_url=args.mlflow_url, mlflow_experiment=args.mlflow_experiment,
|
||||||
)
|
)
|
||||||
Path(out_path).write_text(json.dumps(result, indent=2))
|
Path(out_path).write_text(json.dumps(result, indent=2))
|
||||||
print()
|
print()
|
||||||
|
|||||||
124
ml/pipelines/sim_dag.py
Normal file
124
ml/pipelines/sim_dag.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Airflow DAG: bandit_sim
|
||||||
|
|
||||||
|
Runs a bandit policy simulation and logs results to MLflow.
|
||||||
|
Triggered on-demand from the oO admin panel or manually from the Airflow UI.
|
||||||
|
|
||||||
|
Required conf keys (passed via dag_run.conf):
|
||||||
|
sim_run_id str — oO SQLite run ID for callback correlation
|
||||||
|
n_users int — number of synthetic users
|
||||||
|
n_rounds int — rounds per user
|
||||||
|
tasks_per_round int — candidate pool size per round
|
||||||
|
policies list — policy names to compare
|
||||||
|
judge_mode str — "rule" | "llm"
|
||||||
|
ml_url str — ml/serving URL (e.g. http://ml-serving:8000)
|
||||||
|
mlflow_url str — MLflow tracking URI (e.g. http://mlflow:5000/mlflow)
|
||||||
|
callback_url str — oO API callback endpoint
|
||||||
|
internal_token str — x-internal-token header value
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from airflow import DAG
|
||||||
|
from airflow.operators.python import PythonOperator
|
||||||
|
|
||||||
|
|
||||||
|
def _run_sim(**context: object) -> dict:
|
||||||
|
conf: dict = context["dag_run"].conf or {}
|
||||||
|
|
||||||
|
n_users = int(conf.get("n_users", 5))
|
||||||
|
n_rounds = int(conf.get("n_rounds", 20))
|
||||||
|
tasks_per_round = int(conf.get("tasks_per_round", 8))
|
||||||
|
policies = list(conf.get("policies", ["linucb-v1", "egreedy-v1"]))
|
||||||
|
judge_mode = str(conf.get("judge_mode", "rule"))
|
||||||
|
ml_url = str(conf.get("ml_url", "http://ml-serving:8000"))
|
||||||
|
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "")))
|
||||||
|
mlflow_experiment = "bandit_simulation"
|
||||||
|
|
||||||
|
sys.path.insert(0, "/opt/airflow/ml/experiments/sim")
|
||||||
|
from runner import run_simulation # type: ignore[import]
|
||||||
|
|
||||||
|
use_llm = judge_mode == "llm"
|
||||||
|
result = run_simulation(
|
||||||
|
n_users=n_users,
|
||||||
|
n_rounds=n_rounds,
|
||||||
|
tasks_per_round=tasks_per_round,
|
||||||
|
ml_url=ml_url,
|
||||||
|
policies=policies,
|
||||||
|
use_llm=use_llm,
|
||||||
|
seed=42,
|
||||||
|
mlflow_url=mlflow_url or None,
|
||||||
|
mlflow_experiment=mlflow_experiment,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _callback(**context: object) -> None:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
conf: dict = context["dag_run"].conf or {}
|
||||||
|
callback_url: str = str(conf.get("callback_url", ""))
|
||||||
|
internal_token: str = str(conf.get("internal_token", ""))
|
||||||
|
|
||||||
|
if not callback_url or not internal_token:
|
||||||
|
print("No callback_url or internal_token — skipping result push.", flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
result: dict = context["ti"].xcom_pull(task_ids="run_sim")
|
||||||
|
if not result:
|
||||||
|
print("No result from run_sim task — callback skipped.", flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"summary": result.get("summary", {}),
|
||||||
|
"winner": result.get("winner", ""),
|
||||||
|
"persona_breakdown": result.get("persona_breakdown", {}),
|
||||||
|
"events": result.get("events", []),
|
||||||
|
"mlflow_run_id": result.get("mlflow_run_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = httpx.post(
|
||||||
|
callback_url,
|
||||||
|
json=payload,
|
||||||
|
headers={"x-internal-token": internal_token},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
print(f"Callback OK: {r.status_code}", flush=True)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Callback failed: {exc}", flush=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
with DAG(
|
||||||
|
dag_id="bandit_sim",
|
||||||
|
description="On-demand bandit policy simulation with MLflow tracking",
|
||||||
|
schedule_interval=None,
|
||||||
|
start_date=datetime(2025, 1, 1),
|
||||||
|
catchup=False,
|
||||||
|
tags=["bandit", "simulation", "ml"],
|
||||||
|
default_args={
|
||||||
|
"retries": 1,
|
||||||
|
"retry_delay": timedelta(minutes=2),
|
||||||
|
},
|
||||||
|
) as dag:
|
||||||
|
|
||||||
|
run_sim = PythonOperator(
|
||||||
|
task_id="run_sim",
|
||||||
|
python_callable=_run_sim,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
push_results = PythonOperator(
|
||||||
|
task_id="push_results",
|
||||||
|
python_callable=_callback,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_sim >> push_results
|
||||||
@@ -4,5 +4,6 @@
|
|||||||
"outDir": "dist",
|
"outDir": "dist",
|
||||||
"rootDir": "src"
|
"rootDir": "src"
|
||||||
},
|
},
|
||||||
"include": ["src"]
|
"include": ["src"],
|
||||||
|
"exclude": ["src/__tests__", "**/*.test.ts"]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -156,6 +156,10 @@ export function runMigrations() {
|
|||||||
`ALTER TABLE tip_scores ADD COLUMN prompt_version TEXT`,
|
`ALTER TABLE tip_scores ADD COLUMN prompt_version TEXT`,
|
||||||
`ALTER TABLE tip_scores ADD COLUMN llm_model TEXT`,
|
`ALTER TABLE tip_scores ADD COLUMN llm_model TEXT`,
|
||||||
`ALTER TABLE tip_scores ADD COLUMN tip_kind TEXT`,
|
`ALTER TABLE tip_scores ADD COLUMN tip_kind TEXT`,
|
||||||
|
`ALTER TABLE sim_runs ADD COLUMN airflow_dag_run_id TEXT`,
|
||||||
|
`ALTER TABLE sim_runs ADD COLUMN mlflow_run_id TEXT`,
|
||||||
|
`ALTER TABLE sim_runs ADD COLUMN judge_mode TEXT NOT NULL DEFAULT 'rule'`,
|
||||||
|
`ALTER TABLE sim_runs ADD COLUMN n_policies INTEGER NOT NULL DEFAULT 2`,
|
||||||
]) {
|
]) {
|
||||||
try { sqlite.exec(stmt); } catch { /* column already exists */ }
|
try { sqlite.exec(stmt); } catch { /* column already exists */ }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,9 +112,13 @@ export const simRuns = sqliteTable('sim_runs', {
|
|||||||
tasksPerRound: integer('tasks_per_round').notNull().default(8),
|
tasksPerRound: integer('tasks_per_round').notNull().default(8),
|
||||||
useLlm: integer('use_llm', { mode: 'boolean' }).notNull().default(false),
|
useLlm: integer('use_llm', { mode: 'boolean' }).notNull().default(false),
|
||||||
status: text('status').notNull().default('pending'), // 'pending'|'running'|'done'|'failed'
|
status: text('status').notNull().default('pending'), // 'pending'|'running'|'done'|'failed'
|
||||||
|
judgeMode: text('judge_mode').notNull().default('rule'),
|
||||||
|
nPolicies: integer('n_policies').notNull().default(2),
|
||||||
summaryJson: text('summary_json'), // JSON: { [policy]: PolicySummary }
|
summaryJson: text('summary_json'), // JSON: { [policy]: PolicySummary }
|
||||||
winner: text('winner'),
|
winner: text('winner'),
|
||||||
personaBreakdownJson: text('persona_breakdown_json'), // JSON: { [persona]: { [policy]: {reward,n} } }
|
personaBreakdownJson: text('persona_breakdown_json'), // JSON: { [persona]: { [policy]: {reward,n} } }
|
||||||
|
airflowDagRunId: text('airflow_dag_run_id'),
|
||||||
|
mlflowRunId: text('mlflow_run_id'),
|
||||||
createdAt: text('created_at').notNull(),
|
createdAt: text('created_at').notNull(),
|
||||||
finishedAt: text('finished_at'),
|
finishedAt: text('finished_at'),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import { integrationsRouter } from './routes/integrations.js';
|
|||||||
import { recommenderRouter } from './routes/recommender.js';
|
import { recommenderRouter } from './routes/recommender.js';
|
||||||
import { userRouter } from './routes/user.js';
|
import { userRouter } from './routes/user.js';
|
||||||
import { pushRouter } from './routes/push.js';
|
import { pushRouter } from './routes/push.js';
|
||||||
import { adminRouter } from './routes/admin.js';
|
import { adminRouter, adminInternalRouter } from './routes/admin.js';
|
||||||
import { mkdir } from 'fs/promises';
|
import { mkdir } from 'fs/promises';
|
||||||
import { dirname } from 'path';
|
import { dirname } from 'path';
|
||||||
import { requireAuth } from './middleware/session.js';
|
import { requireAuth } from './middleware/session.js';
|
||||||
@@ -65,6 +65,7 @@ app.use('/api', recommenderRouter);
|
|||||||
app.use('/api/user', userRouter);
|
app.use('/api/user', userRouter);
|
||||||
app.use('/api/push', pushRouter);
|
app.use('/api/push', pushRouter);
|
||||||
app.use('/api/admin', adminRouter);
|
app.use('/api/admin', adminRouter);
|
||||||
|
app.use('/api/admin', adminInternalRouter);
|
||||||
|
|
||||||
app.use('/api/ml', requireAuth as any, requireAdmin as any, async (req: Request, res: Response) => {
|
app.use('/api/ml', requireAuth as any, requireAdmin as any, async (req: Request, res: Response) => {
|
||||||
const mlUrl = config.ML_SERVING_URL;
|
const mlUrl = config.ML_SERVING_URL;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* A real Express app + in-memory SQLite DB per test suite.
|
* A real Express app + in-memory SQLite DB per test suite.
|
||||||
* Auth and admin middleware are mocked so we can focus on route logic.
|
* Auth and admin middleware are mocked so we can focus on route logic.
|
||||||
*/
|
*/
|
||||||
import { describe, it, expect, vi, beforeAll } from 'vitest';
|
import { describe, it, expect, vi, beforeAll, afterEach } from 'vitest';
|
||||||
import express from 'express';
|
import express from 'express';
|
||||||
import * as http from 'http';
|
import * as http from 'http';
|
||||||
import { makeTestDb } from '../../test/db.js';
|
import { makeTestDb } from '../../test/db.js';
|
||||||
@@ -385,16 +385,126 @@ describe('GET /api/admin/events', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Health endpoint — mock fetch so tests don't depend on running services.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
describe('GET /api/admin/health', () => {
|
describe('GET /api/admin/health', () => {
|
||||||
it('returns 200 with ok, services array, and checkedAt', async () => {
|
const EXPECTED_HTTP_SERVICES = ['api', 'ml-serving', 'mlflow', 'airflow'] as const;
|
||||||
|
const EXPECTED_INTERNAL = ['sqlite', 'event-bus'] as const;
|
||||||
|
const VALID_STATUSES = new Set(['ok', 'degraded', 'down']);
|
||||||
|
|
||||||
|
type ServiceRow = { name: string; status: string; latencyMs: number };
|
||||||
|
type HealthBody = { ok: boolean; services: ServiceRow[]; checkedAt: string };
|
||||||
|
|
||||||
|
function mockFetch(upServices: Set<string>) {
|
||||||
|
// Resolve service name by port (matches defaults in config.ts).
|
||||||
|
// Up services return HTTP 200; absent ones throw (simulates connection refused → 'down').
|
||||||
|
vi.stubGlobal('fetch', async (url: string) => {
|
||||||
|
const s = String(url);
|
||||||
|
let name: string;
|
||||||
|
if (s.includes(':8000')) name = 'ml-serving';
|
||||||
|
else if (s.includes(':5000')) name = 'mlflow';
|
||||||
|
else if (s.includes(':8080')) name = 'airflow';
|
||||||
|
else name = 'api';
|
||||||
|
|
||||||
|
if (!upServices.has(name)) throw new Error(`ECONNREFUSED ${name}`);
|
||||||
|
return { ok: true, json: async () => ({ ok: true, status: 'healthy' }) };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
afterEach(() => vi.unstubAllGlobals());
|
||||||
|
|
||||||
|
it('shape: 200, typed fields, all expected services present', async () => {
|
||||||
|
mockFetch(new Set(['api', 'ml-serving', 'mlflow', 'airflow']));
|
||||||
const { server, call } = await startServer(buildApp());
|
const { server, call } = await startServer(buildApp());
|
||||||
try {
|
try {
|
||||||
const { status, body } = await call('GET', '/api/admin/health');
|
const { status, body } = await call('GET', '/api/admin/health');
|
||||||
const b = body as { ok: boolean; services: { name: string; status: string }[]; checkedAt: string };
|
const b = body as HealthBody;
|
||||||
expect(status).toBe(200);
|
expect(status).toBe(200);
|
||||||
expect(typeof b.ok).toBe('boolean');
|
expect(typeof b.ok).toBe('boolean');
|
||||||
expect(Array.isArray(b.services)).toBe(true);
|
expect(Array.isArray(b.services)).toBe(true);
|
||||||
expect(typeof b.checkedAt).toBe('string');
|
expect(typeof b.checkedAt).toBe('string');
|
||||||
|
expect(new Date(b.checkedAt).getTime()).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
const names = b.services.map((s) => s.name);
|
||||||
|
for (const svc of [...EXPECTED_HTTP_SERVICES, ...EXPECTED_INTERNAL]) {
|
||||||
|
expect(names).toContain(svc);
|
||||||
|
}
|
||||||
|
for (const svc of b.services) {
|
||||||
|
expect(VALID_STATUSES).toContain(svc.status);
|
||||||
|
expect(typeof svc.latencyMs).toBe('number');
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
server.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('ok=true when all HTTP services respond 200', async () => {
|
||||||
|
mockFetch(new Set(['api', 'ml-serving', 'mlflow', 'airflow']));
|
||||||
|
const { server, call } = await startServer(buildApp());
|
||||||
|
try {
|
||||||
|
const { body } = await call('GET', '/api/admin/health');
|
||||||
|
const b = body as HealthBody;
|
||||||
|
for (const name of EXPECTED_HTTP_SERVICES) {
|
||||||
|
const svc = b.services.find((s) => s.name === name);
|
||||||
|
expect(svc?.status, `${name} should be ok`).toBe('ok');
|
||||||
|
}
|
||||||
|
expect(b.ok).toBe(true);
|
||||||
|
} finally {
|
||||||
|
server.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('ml-serving=down and ok=false when ml-serving is unreachable', async () => {
|
||||||
|
mockFetch(new Set(['api', 'mlflow', 'airflow'])); // ml-serving absent
|
||||||
|
const { server, call } = await startServer(buildApp());
|
||||||
|
try {
|
||||||
|
const { body } = await call('GET', '/api/admin/health');
|
||||||
|
const b = body as HealthBody;
|
||||||
|
const mlSvc = b.services.find((s) => s.name === 'ml-serving');
|
||||||
|
expect(mlSvc?.status).toBe('down');
|
||||||
|
expect(b.ok).toBe(false);
|
||||||
|
} finally {
|
||||||
|
server.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('airflow=down and ok=false when airflow is unreachable', async () => {
|
||||||
|
mockFetch(new Set(['api', 'ml-serving', 'mlflow'])); // airflow absent
|
||||||
|
const { server, call } = await startServer(buildApp());
|
||||||
|
try {
|
||||||
|
const { body } = await call('GET', '/api/admin/health');
|
||||||
|
const b = body as HealthBody;
|
||||||
|
const svc = b.services.find((s) => s.name === 'airflow');
|
||||||
|
expect(svc?.status).toBe('down');
|
||||||
|
expect(b.ok).toBe(false);
|
||||||
|
} finally {
|
||||||
|
server.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('mlflow=down and ok=false when mlflow is unreachable', async () => {
|
||||||
|
mockFetch(new Set(['api', 'ml-serving', 'airflow'])); // mlflow absent
|
||||||
|
const { server, call } = await startServer(buildApp());
|
||||||
|
try {
|
||||||
|
const { body } = await call('GET', '/api/admin/health');
|
||||||
|
const b = body as HealthBody;
|
||||||
|
const svc = b.services.find((s) => s.name === 'mlflow');
|
||||||
|
expect(svc?.status).toBe('down');
|
||||||
|
expect(b.ok).toBe(false);
|
||||||
|
} finally {
|
||||||
|
server.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('sqlite and event-bus are always present regardless of HTTP service status', async () => {
|
||||||
|
mockFetch(new Set()); // all HTTP services down
|
||||||
|
const { server, call } = await startServer(buildApp());
|
||||||
|
try {
|
||||||
|
const { body } = await call('GET', '/api/admin/health');
|
||||||
|
const b = body as HealthBody;
|
||||||
|
expect(b.services.find((s) => s.name === 'sqlite')?.status).toBe('ok');
|
||||||
|
expect(b.services.find((s) => s.name === 'event-bus')?.status).toBe('ok');
|
||||||
} finally {
|
} finally {
|
||||||
server.close();
|
server.close();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { type Router as ExpressRouter, Router, Response } from 'express';
|
import { type Router as ExpressRouter, Router, Response, type Request } from 'express';
|
||||||
import { logger } from '../logger.js';
|
import { logger } from '../logger.js';
|
||||||
import { db, rawSqlite } from '../db/index.js';
|
import { db, rawSqlite } from '../db/index.js';
|
||||||
import {
|
import {
|
||||||
@@ -524,16 +524,24 @@ router.get('/data-quality', async (req: AuthenticatedRequest, res: Response) =>
|
|||||||
// Fan-out to all subsystem /health endpoints.
|
// Fan-out to all subsystem /health endpoints.
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
router.get('/health', async (_req: AuthenticatedRequest, res: Response) => {
|
router.get('/health', async (_req: AuthenticatedRequest, res: Response) => {
|
||||||
const checks: Array<{ name: string; url: string }> = [
|
const airflowAuth = Buffer.from(`${config.AIRFLOW_API_USER}:${config.AIRFLOW_API_PASSWORD}`).toString('base64');
|
||||||
{ name: 'api', url: `http://localhost:${process.env.PORT ?? 3001}/health` },
|
|
||||||
|
const checks: Array<{ name: string; url: string; headers?: Record<string, string> }> = [
|
||||||
|
{ name: 'api', url: `http://localhost:${config.PORT}/health` },
|
||||||
{ name: 'ml-serving', url: `${config.ML_SERVING_URL}/health` },
|
{ name: 'ml-serving', url: `${config.ML_SERVING_URL}/health` },
|
||||||
|
{ name: 'mlflow', url: `${config.MLFLOW_URL}/health` },
|
||||||
|
{ name: 'airflow', url: `${config.AIRFLOW_URL}/api/v1/health`,
|
||||||
|
headers: { Authorization: `Basic ${airflowAuth}` } },
|
||||||
];
|
];
|
||||||
|
|
||||||
const results = await Promise.allSettled(
|
const results = await Promise.allSettled(
|
||||||
checks.map(async ({ name, url }) => {
|
checks.map(async ({ name, url, headers }) => {
|
||||||
const t0 = Date.now();
|
const t0 = Date.now();
|
||||||
try {
|
try {
|
||||||
const r = await fetch(url, { signal: AbortSignal.timeout(3000) });
|
const r = await fetch(url, {
|
||||||
|
headers,
|
||||||
|
signal: AbortSignal.timeout(3000),
|
||||||
|
});
|
||||||
return { name, status: r.ok ? 'ok' : 'degraded', latencyMs: Date.now() - t0 };
|
return { name, status: r.ok ? 'ok' : 'degraded', latencyMs: Date.now() - t0 };
|
||||||
} catch {
|
} catch {
|
||||||
return { name, status: 'down', latencyMs: Date.now() - t0 };
|
return { name, status: 'down', latencyMs: Date.now() - t0 };
|
||||||
@@ -549,15 +557,12 @@ router.get('/health', async (_req: AuthenticatedRequest, res: Response) => {
|
|||||||
dbStatus = 'down';
|
dbStatus = 'down';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Event bus: always ok if process is alive
|
|
||||||
const eventBusStatus = 'ok';
|
|
||||||
|
|
||||||
const services = results.map((r) =>
|
const services = results.map((r) =>
|
||||||
r.status === 'fulfilled' ? r.value : { name: 'unknown', status: 'down', latencyMs: 0 },
|
r.status === 'fulfilled' ? r.value : { name: 'unknown', status: 'down', latencyMs: 0 },
|
||||||
);
|
);
|
||||||
|
|
||||||
services.push({ name: 'sqlite', status: dbStatus, latencyMs: 0 });
|
services.push({ name: 'sqlite', status: dbStatus, latencyMs: 0 });
|
||||||
services.push({ name: 'event-bus', status: eventBusStatus, latencyMs: 0 });
|
services.push({ name: 'event-bus', status: 'ok', latencyMs: 0 });
|
||||||
|
|
||||||
const allOk = services.every((s) => s.status === 'ok');
|
const allOk = services.every((s) => s.status === 'ok');
|
||||||
res.json({ ok: allOk, services, checkedAt: new Date().toISOString() });
|
res.json({ ok: allOk, services, checkedAt: new Date().toISOString() });
|
||||||
@@ -700,22 +705,21 @@ router.delete('/saved-queries/:id', async (req: AuthenticatedRequest, res: Respo
|
|||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// POST /api/admin/simulate/start
|
// POST /api/admin/simulate/start
|
||||||
// Spawn ml/experiments/sim/runner.py in the background; return run_id.
|
// Trigger an Airflow DAG run (bandit_sim). Falls back to a local subprocess
|
||||||
|
// when AIRFLOW_URL is not reachable, so local dev still works.
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response) => {
|
router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response) => {
|
||||||
const {
|
const {
|
||||||
nUsers = 5,
|
nUsers = 5,
|
||||||
nRounds = 20,
|
nRounds = 20,
|
||||||
tasksPerRound = 8,
|
tasksPerRound = 8,
|
||||||
useLlm = false,
|
|
||||||
judgeMode = 'rule',
|
judgeMode = 'rule',
|
||||||
policies = ['linucb-v1', 'egreedy-v1'],
|
policies = ['linucb-v1', 'egreedy-v1'],
|
||||||
} = req.body as {
|
} = req.body as {
|
||||||
nUsers?: number;
|
nUsers?: number;
|
||||||
nRounds?: number;
|
nRounds?: number;
|
||||||
tasksPerRound?: number;
|
tasksPerRound?: number;
|
||||||
useLlm?: boolean;
|
judgeMode?: 'rule' | 'llm';
|
||||||
judgeMode?: 'rule' | 'llm' | 'claude-code';
|
|
||||||
policies?: string[];
|
policies?: string[];
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -734,17 +738,69 @@ router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response)
|
|||||||
nUsers,
|
nUsers,
|
||||||
nRounds,
|
nRounds,
|
||||||
tasksPerRound,
|
tasksPerRound,
|
||||||
useLlm,
|
useLlm: judgeMode === 'llm',
|
||||||
|
judgeMode,
|
||||||
|
nPolicies: policies.length,
|
||||||
status: 'running',
|
status: 'running',
|
||||||
createdAt: now,
|
createdAt: now,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// ── Try Airflow first ────────────────────────────────────────────────────
|
||||||
|
if (config.AIRFLOW_URL && config.INTERNAL_API_TOKEN) {
|
||||||
|
try {
|
||||||
|
const airflowAuth = Buffer.from(
|
||||||
|
`${config.AIRFLOW_API_USER}:${config.AIRFLOW_API_PASSWORD}`,
|
||||||
|
).toString('base64');
|
||||||
|
|
||||||
|
const dagRes = await fetch(
|
||||||
|
`${config.AIRFLOW_URL}/api/v1/dags/bandit_sim/dagRuns`,
|
||||||
|
{
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Basic ${airflowAuth}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
conf: {
|
||||||
|
sim_run_id: id,
|
||||||
|
n_users: nUsers,
|
||||||
|
n_rounds: nRounds,
|
||||||
|
tasks_per_round: tasksPerRound,
|
||||||
|
policies,
|
||||||
|
judge_mode: judgeMode,
|
||||||
|
ml_url: config.ML_SERVING_URL,
|
||||||
|
mlflow_url: config.MLFLOW_URL,
|
||||||
|
callback_url: `${config.API_BASE_URL}/api/admin/simulate/${id}/complete`,
|
||||||
|
internal_token: config.INTERNAL_API_TOKEN,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
signal: AbortSignal.timeout(5000),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (dagRes.ok) {
|
||||||
|
const dagBody = await dagRes.json() as { dag_run_id: string };
|
||||||
|
await db
|
||||||
|
.update(simRuns)
|
||||||
|
.set({ airflowDagRunId: dagBody.dag_run_id })
|
||||||
|
.where(eq(simRuns.id, id));
|
||||||
|
|
||||||
|
res.json({ id, status: 'running', airflow_dag_run_id: dagBody.dag_run_id });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
logger.warn({ status: dagRes.status }, 'sim: Airflow trigger failed, falling back to subprocess');
|
||||||
|
} catch (err) {
|
||||||
|
logger.warn({ err }, 'sim: Airflow unreachable, falling back to subprocess');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Subprocess fallback (local dev / Airflow not configured) ────────────
|
||||||
const runnerPath = resolve(__dirname, '../../../../ml/experiments/sim/runner.py');
|
const runnerPath = resolve(__dirname, '../../../../ml/experiments/sim/runner.py');
|
||||||
const venvPython = resolve(__dirname, '../../../../ml/serving/.venv/bin/python');
|
const venvPython = resolve(__dirname, '../../../../ml/serving/.venv/bin/python');
|
||||||
const pythonBin = existsSync(venvPython) ? venvPython : 'python3';
|
const pythonBin = existsSync(venvPython) ? venvPython : 'python3';
|
||||||
const outPath = `/tmp/oo-sim-${id}.json`;
|
const outPath = `/tmp/oo-sim-${id}.json`;
|
||||||
|
|
||||||
const args = [
|
const child = spawn(pythonBin, [
|
||||||
runnerPath,
|
runnerPath,
|
||||||
'--n-users', String(nUsers),
|
'--n-users', String(nUsers),
|
||||||
'--n-rounds', String(nRounds),
|
'--n-rounds', String(nRounds),
|
||||||
@@ -752,32 +808,22 @@ router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response)
|
|||||||
'--ml-url', config.ML_SERVING_URL,
|
'--ml-url', config.ML_SERVING_URL,
|
||||||
'--policies', ...policies,
|
'--policies', ...policies,
|
||||||
'--out', outPath,
|
'--out', outPath,
|
||||||
'--judge', judgeMode === 'llm' ? 'llm' : judgeMode === 'claude-code' ? 'rule' : 'rule',
|
'--judge', judgeMode,
|
||||||
// claude-code mode isn't auto-runnable from the API (requires human in the loop)
|
'--mlflow-url', config.MLFLOW_URL,
|
||||||
// it falls back to rule judge when triggered from the panel
|
'--mlflow-experiment', 'bandit_simulation',
|
||||||
];
|
], { stdio: ['ignore', 'pipe', 'pipe'] });
|
||||||
|
|
||||||
const child = spawn(pythonBin, args, { stdio: ['ignore', 'pipe', 'pipe'] });
|
if (child.pid) _simProcesses.set(id, { pid: child.pid, startedAt: now });
|
||||||
|
|
||||||
if (child.pid) {
|
|
||||||
_simProcesses.set(id, { pid: child.pid, startedAt: now });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Without this listener, a spawn failure (ENOENT when python3 is absent
|
|
||||||
// — e.g. in the alpine api container) would emit an unhandled 'error' event
|
|
||||||
// and crash the whole API process.
|
|
||||||
child.on('error', async (err) => {
|
child.on('error', async (err) => {
|
||||||
logger.error({ err }, 'sim: spawn error');
|
logger.error({ err }, 'sim: spawn error');
|
||||||
_simProcesses.delete(id);
|
_simProcesses.delete(id);
|
||||||
await db
|
await db.update(simRuns)
|
||||||
.update(simRuns)
|
|
||||||
.set({ status: 'failed', finishedAt: new Date().toISOString() })
|
.set({ status: 'failed', finishedAt: new Date().toISOString() })
|
||||||
.where(eq(simRuns.id, id));
|
.where(eq(simRuns.id, id));
|
||||||
});
|
});
|
||||||
|
|
||||||
// Capture stderr for debugging
|
child.stderr?.on('data', (d: Buffer) => logger.debug({ stderr: d.toString() }, 'sim stderr'));
|
||||||
const stderrLines: string[] = [];
|
|
||||||
child.stderr?.on('data', (d: Buffer) => stderrLines.push(d.toString()));
|
|
||||||
|
|
||||||
child.on('exit', async (code) => {
|
child.on('exit', async (code) => {
|
||||||
_simProcesses.delete(id);
|
_simProcesses.delete(id);
|
||||||
@@ -786,8 +832,6 @@ router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response)
|
|||||||
if (code === 0 && existsSync(outPath)) {
|
if (code === 0 && existsSync(outPath)) {
|
||||||
try {
|
try {
|
||||||
const raw = JSON.parse(readFileSync(outPath, 'utf-8'));
|
const raw = JSON.parse(readFileSync(outPath, 'utf-8'));
|
||||||
|
|
||||||
// Bulk-insert sim events
|
|
||||||
const eventRows = (raw.events ?? []).map((ev: Record<string, unknown>) => ({
|
const eventRows = (raw.events ?? []).map((ev: Record<string, unknown>) => ({
|
||||||
id: nanoid(),
|
id: nanoid(),
|
||||||
runId: id,
|
runId: id,
|
||||||
@@ -805,21 +849,19 @@ router.post('/simulate/start', async (req: AuthenticatedRequest, res: Response)
|
|||||||
dayOfWeek: Number(ev.day_of_week),
|
dayOfWeek: Number(ev.day_of_week),
|
||||||
createdAt: now,
|
createdAt: now,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
for (const row of eventRows) {
|
for (const row of eventRows) {
|
||||||
await db.insert(simEvents).values(row).catch(() => {});
|
await db.insert(simEvents).values(row).catch(() => {});
|
||||||
}
|
}
|
||||||
|
|
||||||
await db.update(simRuns).set({
|
await db.update(simRuns).set({
|
||||||
status: 'done',
|
status: 'done',
|
||||||
summaryJson: JSON.stringify(raw.summary),
|
summaryJson: JSON.stringify(raw.summary),
|
||||||
winner: raw.winner,
|
winner: raw.winner,
|
||||||
personaBreakdownJson: JSON.stringify(raw.persona_breakdown),
|
personaBreakdownJson: JSON.stringify(raw.persona_breakdown),
|
||||||
|
mlflowRunId: raw.mlflow_run_id ?? null,
|
||||||
finishedAt,
|
finishedAt,
|
||||||
}).where(eq(simRuns.id, id));
|
}).where(eq(simRuns.id, id));
|
||||||
|
|
||||||
try { unlinkSync(outPath); } catch { /* ignore */ }
|
try { unlinkSync(outPath); } catch { /* ignore */ }
|
||||||
} catch (e) {
|
} catch {
|
||||||
await db.update(simRuns).set({ status: 'failed', finishedAt }).where(eq(simRuns.id, id));
|
await db.update(simRuns).set({ status: 'failed', finishedAt }).where(eq(simRuns.id, id));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -864,4 +906,68 @@ router.get('/simulate/:id', async (req: AuthenticatedRequest, res: Response) =>
|
|||||||
res.json({ run: { ...run, isRunning }, events });
|
res.json({ run: { ...run, isRunning }, events });
|
||||||
});
|
});
|
||||||
|
|
||||||
export { router as adminRouter };
|
// ---------------------------------------------------------------------------
|
||||||
|
// internalRouter — no session auth; only INTERNAL_API_TOKEN header check.
|
||||||
|
// Mounted separately in index.ts at /api/admin to avoid router.use() auth.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
const internalRouter: ExpressRouter = Router();
|
||||||
|
|
||||||
|
internalRouter.post('/simulate/:id/complete', async (req: Request, res: Response) => {
|
||||||
|
const token = req.headers['x-internal-token'];
|
||||||
|
if (!config.INTERNAL_API_TOKEN || token !== config.INTERNAL_API_TOKEN) {
|
||||||
|
res.status(401).json({ error: 'Unauthorized' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { id } = req.params as { id: string };
|
||||||
|
const { summary, winner, persona_breakdown, events: rawEvents, mlflow_run_id } =
|
||||||
|
req.body as {
|
||||||
|
summary: Record<string, unknown>;
|
||||||
|
winner: string;
|
||||||
|
persona_breakdown: Record<string, unknown>;
|
||||||
|
events: Record<string, unknown>[];
|
||||||
|
mlflow_run_id?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const finishedAt = new Date().toISOString();
|
||||||
|
const now = finishedAt;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const eventRows = (rawEvents ?? []).map((ev) => ({
|
||||||
|
id: nanoid(),
|
||||||
|
runId: id,
|
||||||
|
round: Number(ev['round']),
|
||||||
|
userId: String(ev['user_id']),
|
||||||
|
persona: String(ev['persona']),
|
||||||
|
policy: String(ev['policy']),
|
||||||
|
tipContent: String(ev['tip_content']),
|
||||||
|
priority: Number(ev['priority']),
|
||||||
|
isOverdue: Boolean(ev['is_overdue']),
|
||||||
|
action: String(ev['action']),
|
||||||
|
dwellMs: ev['dwell_ms'] != null ? Number(ev['dwell_ms']) : null,
|
||||||
|
rewardMilli: Math.round(Number(ev['reward']) * 1000),
|
||||||
|
hour: Number(ev['hour']),
|
||||||
|
dayOfWeek: Number(ev['day_of_week']),
|
||||||
|
createdAt: now,
|
||||||
|
}));
|
||||||
|
for (const row of eventRows) {
|
||||||
|
await db.insert(simEvents).values(row).catch(() => {});
|
||||||
|
}
|
||||||
|
await db.update(simRuns).set({
|
||||||
|
status: 'done',
|
||||||
|
summaryJson: JSON.stringify(summary),
|
||||||
|
winner,
|
||||||
|
personaBreakdownJson: JSON.stringify(persona_breakdown),
|
||||||
|
mlflowRunId: mlflow_run_id ?? null,
|
||||||
|
finishedAt,
|
||||||
|
}).where(eq(simRuns.id, id));
|
||||||
|
|
||||||
|
res.json({ ok: true });
|
||||||
|
} catch (err) {
|
||||||
|
logger.error({ err }, 'sim: complete callback failed');
|
||||||
|
await db.update(simRuns).set({ status: 'failed', finishedAt }).where(eq(simRuns.id, id));
|
||||||
|
res.status(500).json({ error: 'Failed to store results' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
export { router as adminRouter, internalRouter as adminInternalRouter };
|
||||||
|
|||||||
Reference in New Issue
Block a user