Each unique task title is now enriched by LiteLLM once and cached in the DB. Subsequent agent compute cycles (every 12h) fetch the cache before calling ml-serving; only new titles hit the tip-generator. - DB: task_enrichments(content_hash PK, description, model, created_at) - TS: fetchEnrichmentCache / persistEnrichments helpers in agent-outputs.ts; enrichment_cache passed in compute request, new_enrichments persisted from response - Python: AgentComputeRequest.enrichment_cache / AgentComputeResponse.new_enrichments; AgentInput.enrichment_cache; _enrich_batch returns (descriptions, new_entries); cluster_tasks returns (clusters, new_enrichments) - FocusAreaAgent stashes new_enrichments in signals_snapshot under _new_enrichments; compute_agent endpoint pops it before storing the snapshot Closes part of #129 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
639 lines
24 KiB
Python
639 lines
24 KiB
Python
"""
|
|
oO ML Serving — multi-agent orchestrator (ADR-0013).
|
|
|
|
Contract:
|
|
POST /agents/{agent_id}/compute run a sub-agent, return prompt snippet
|
|
POST /agents/{agent_id}/infer run inference framework for a user, return inferred prefs
|
|
POST /recommend orchestrate agent snippets → one tip via LiteLLM
|
|
POST /generate LLM tip candidates (legacy; kept for bench/eval)
|
|
GET /health { ok, agents: [...] }
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
import sentry_sdk
|
|
import structlog
|
|
import structlog.contextvars
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from pydantic import BaseModel
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
import mlflow
|
|
from mlflow.entities import SpanType
|
|
|
|
import logging_config
|
|
import nats_consumer
|
|
from prompts import get_prompt, build_orchestrator_messages
|
|
|
|
# Make ml.agents importable regardless of working directory.
|
|
# In Docker (WORKDIR=/app/ml/serving, PYTHONPATH=/app): /app already on path.
|
|
# In local dev (run from ml/serving/): repo root is two levels up.
|
|
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
if _repo_root not in sys.path:
|
|
sys.path.insert(0, _repo_root)
|
|
|
|
from ml.agents.base import AgentInput # noqa: E402
|
|
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
|
|
from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402
|
|
|
|
logging_config.configure()
|
|
|
|
_SENTRY_DSN = os.getenv("SENTRY_DSN")
|
|
if _SENTRY_DSN:
|
|
sentry_sdk.init(dsn=_SENTRY_DSN, environment=os.getenv("ENV", "development"))
|
|
|
|
log = structlog.get_logger()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
await nats_consumer.start(STATE_DIR)
|
|
yield
|
|
await nats_consumer.stop()
|
|
|
|
|
|
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
|
|
|
|
|
|
class _TracingMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
structlog.contextvars.clear_contextvars()
|
|
traceparent = request.headers.get("traceparent", "")
|
|
if traceparent:
|
|
parts = traceparent.split("-")
|
|
trace_id = parts[1] if len(parts) == 4 and len(parts[1]) == 32 else None
|
|
if trace_id:
|
|
structlog.contextvars.bind_contextvars(trace_id=trace_id)
|
|
return await call_next(request)
|
|
|
|
|
|
app.add_middleware(_TracingMiddleware)
|
|
|
|
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
|
|
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
|
|
|
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-serving-state"))
|
|
|
|
# ── MLflow tracing (optional) ───────────────────────────────────────────────
|
|
# Set MLFLOW_TRACKING_URI to enable. Spans are fire-and-forget; errors are
|
|
# logged at WARNING and never propagate to the caller.
|
|
# MLflow --allowed-hosts must include "mlflow" (the container DNS name) so the
|
|
# SDK can reach the server from inside other containers.
|
|
|
|
_MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "")
|
|
_MLFLOW_EXP = "oO/serving"
|
|
_mlflow_exp_id: str | None = None
|
|
|
|
if _MLFLOW_URI:
|
|
try:
|
|
mlflow.set_tracking_uri(_MLFLOW_URI)
|
|
_mlflow_exp_id = mlflow.set_experiment(_MLFLOW_EXP).experiment_id
|
|
except Exception as _exc:
|
|
log.warning("mlflow_init_failed", error=str(_exc))
|
|
|
|
|
|
class _NoOpSpan:
|
|
"""Returned when MLflow is disabled or span creation fails."""
|
|
def set_inputs(self, *a, **k): pass
|
|
def set_outputs(self, *a, **k): pass
|
|
def set_attribute(self, *a, **k): pass
|
|
def set_attributes(self, *a, **k): pass
|
|
def end(self, *a, **k): pass
|
|
|
|
|
|
_NOOP = _NoOpSpan()
|
|
|
|
|
|
def _start_span(name: str, span_type: str, *, parent=_NOOP, inputs=None):
|
|
"""Start an MLflow span. Returns _NOOP on failure or when tracing is off.
|
|
|
|
experiment_id is only passed for root spans (no parent) — passing it to
|
|
child spans causes the SDK to fail with '_Span has no attribute _span'.
|
|
"""
|
|
if _mlflow_exp_id is None:
|
|
return _NOOP
|
|
try:
|
|
kw: dict = {"span_type": span_type}
|
|
if isinstance(parent, _NoOpSpan):
|
|
kw["experiment_id"] = _mlflow_exp_id # root span only
|
|
else:
|
|
kw["parent_span"] = parent
|
|
if inputs is not None:
|
|
kw["inputs"] = inputs
|
|
return mlflow.start_span_no_context(name, **kw)
|
|
except Exception as exc: # noqa: BLE001
|
|
log.warning("mlflow_span_start_failed", name=name, error=str(exc))
|
|
return _NOOP
|
|
|
|
|
|
def _end_span(span, *, status: str = "OK", outputs=None, attributes: dict | None = None) -> None:
|
|
"""End a span safely, ignoring _NoOpSpan and swallowing exceptions."""
|
|
if isinstance(span, _NoOpSpan):
|
|
return
|
|
try:
|
|
if attributes:
|
|
span.set_attributes(attributes)
|
|
span.end(status=status, outputs=outputs)
|
|
except Exception as exc: # noqa: BLE001
|
|
log.warning("mlflow_span_end_failed", error=str(exc))
|
|
|
|
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# ── API models ─────────────────────────────────────────────────────────────
|
|
|
|
class PromptContext(BaseModel):
|
|
tasks: list[dict] = []
|
|
hour_of_day: int = 12
|
|
day_of_week: int = 0
|
|
extra: dict = {}
|
|
profile_features: Optional[dict] = None
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
user_id: str
|
|
context: PromptContext = PromptContext()
|
|
n: int = 3
|
|
prompt_version: Optional[str] = None # None → server default (env DEFAULT_PROMPT_VERSION)
|
|
# User-level features (#81 phase A). Accepted by the contract; not yet
|
|
# injected into the prompt — that's a #84-style prompt-design decision.
|
|
profile_features: Optional[dict] = None
|
|
|
|
|
|
class TipCandidate(BaseModel):
|
|
id: str
|
|
content: str
|
|
source: str = "llm"
|
|
rationale: Optional[str] = None
|
|
|
|
|
|
class GenerateResponse(BaseModel):
|
|
candidates: list[TipCandidate]
|
|
model: str
|
|
prompt_version: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
|
|
|
|
# ── Multi-agent models ─────────────────────────────────────────────────────
|
|
|
|
class AgentComputeRequest(BaseModel):
|
|
user_id: str
|
|
tasks: list[dict] = []
|
|
profile: dict[str, Optional[float]] = {}
|
|
feedback_history: list[dict] = []
|
|
now_iso: Optional[str] = None # ISO 8601; defaults to utcnow
|
|
# Per-agent prefs from user_preferences (merged: user source overrides inferred).
|
|
agent_prefs: dict = {}
|
|
# Pre-fetched enrichment cache: {content_hash -> description}. Avoids re-calling
|
|
# LiteLLM for task titles already expanded in a prior compute cycle.
|
|
enrichment_cache: dict[str, str] = {}
|
|
|
|
|
|
class AgentComputeResponse(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
prompt_text: str
|
|
signals_snapshot: dict
|
|
computed_at: str
|
|
expires_at: str
|
|
agent_version: str
|
|
# New enrichments generated during this compute cycle; caller persists to DB.
|
|
new_enrichments: dict[str, str] = {}
|
|
|
|
|
|
class AgentInferRequest(BaseModel):
|
|
user_id: str
|
|
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
|
|
task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …]
|
|
|
|
|
|
class AgentInferResponse(BaseModel):
|
|
user_id: str
|
|
agent_id: str
|
|
# {key: inferred_value} — caller persists to user_preferences with source='inferred'
|
|
inferred_prefs: dict
|
|
|
|
|
|
class AgentOutputSnippet(BaseModel):
|
|
agent_id: str
|
|
prompt_text: str
|
|
|
|
|
|
class RecommendRequest(BaseModel):
|
|
user_id: str
|
|
agent_outputs: list[AgentOutputSnippet] = []
|
|
tasks: list[dict] = []
|
|
hour_of_day: int = 12
|
|
day_of_week: int = 0
|
|
science_destiny: int = 50 # 0=science (data-driven), 100=destiny (intuitive)
|
|
recent_tip: Optional[str] = None # content of last snoozed tip; LLM avoids repeating it
|
|
|
|
|
|
class TipResult(BaseModel):
|
|
id: str
|
|
content: str
|
|
source: str = "llm"
|
|
kind: str = "advice"
|
|
rationale: Optional[str] = None
|
|
|
|
|
|
class RecommendResponse(BaseModel):
|
|
tip: TipResult
|
|
model: str
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
|
|
|
|
# ── Endpoints ──────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"ok": True,
|
|
"agents": [a.agent_id for a in all_agents()],
|
|
"nats": {
|
|
"enabled": bool(nats_consumer.NATS_URL),
|
|
"consumers": nats_consumer.consumer_health,
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/agents/registry")
|
|
def agents_registry():
|
|
"""Manifest list for every registered agent (ADR-0014).
|
|
|
|
Consumers: TS recommender (eligibility filter), admin UI (auto-rendered
|
|
pref forms), inference framework (#111). Static at process boot.
|
|
"""
|
|
return {"agents": [m.to_dict() for m in all_manifests()]}
|
|
|
|
|
|
_RETRY_SUFFIX = (
|
|
"\n\nYour previous response was not valid JSON. "
|
|
"Reply ONLY with the JSON array — no prose, no markdown fences."
|
|
)
|
|
|
|
_RETRY_SUFFIX_OBJ = (
|
|
"\n\nYour previous response was not valid JSON. "
|
|
"Reply ONLY with the JSON object — no prose, no markdown fences."
|
|
)
|
|
|
|
|
|
@app.post("/agents/{agent_id}/compute", response_model=AgentComputeResponse)
|
|
async def compute_agent(agent_id: str, req: AgentComputeRequest) -> AgentComputeResponse:
|
|
"""Run a single sub-agent for a user and return its prompt snippet.
|
|
|
|
Called by the precompute pipeline for each (user_id, agent_id) pair.
|
|
The caller is responsible for persisting the result to agent_outputs via the
|
|
TypeScript API callback.
|
|
"""
|
|
try:
|
|
agent = get_agent(agent_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
|
|
|
|
now = (
|
|
datetime.fromisoformat(req.now_iso.replace("Z", "+00:00"))
|
|
if req.now_iso
|
|
else datetime.now(timezone.utc)
|
|
)
|
|
if now.tzinfo is None:
|
|
now = now.replace(tzinfo=timezone.utc)
|
|
|
|
inp = AgentInput(
|
|
user_id=req.user_id,
|
|
tasks=req.tasks,
|
|
profile=req.profile,
|
|
feedback_history=req.feedback_history,
|
|
now=now,
|
|
agent_prefs=req.agent_prefs,
|
|
enrichment_cache=req.enrichment_cache,
|
|
)
|
|
try:
|
|
output = agent.compute(inp)
|
|
except Exception as exc:
|
|
log.error("agent_compute_failed", agent_id=agent_id, user_id=req.user_id, error=str(exc))
|
|
raise HTTPException(status_code=500, detail=f"Agent compute failed: {exc}")
|
|
|
|
new_enrichments: dict[str, str] = output.signals_snapshot.pop("_new_enrichments", {})
|
|
|
|
log.info("agent_computed", agent_id=agent_id, user_id=req.user_id, expires_at=output.expires_at)
|
|
span = _start_span(
|
|
f"compute:{agent_id}",
|
|
SpanType.AGENT,
|
|
inputs={"user_id": req.user_id, "agent_id": agent_id,
|
|
"task_count": len(req.tasks), "feedback_count": len(req.feedback_history)},
|
|
)
|
|
_end_span(span,
|
|
outputs={"prompt_text": output.prompt_text, "signals_snapshot": output.signals_snapshot},
|
|
attributes={"agent_version": output.agent_version, "expires_at": output.expires_at})
|
|
return AgentComputeResponse(
|
|
user_id=output.user_id,
|
|
agent_id=output.agent_id,
|
|
prompt_text=output.prompt_text,
|
|
signals_snapshot=output.signals_snapshot,
|
|
computed_at=output.computed_at,
|
|
expires_at=output.expires_at,
|
|
agent_version=output.agent_version,
|
|
new_enrichments=new_enrichments,
|
|
)
|
|
|
|
|
|
@app.post("/agents/{agent_id}/infer", response_model=AgentInferResponse)
|
|
async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferResponse:
|
|
"""Run the inference framework for one agent and return inferred preference values.
|
|
|
|
The caller (TS agent-outputs.ts) persists results to user_preferences
|
|
with source='inferred', skipping keys where source='user' already exists.
|
|
"""
|
|
try:
|
|
manifest = get_manifest(agent_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Unknown agent: {agent_id!r}")
|
|
|
|
if not manifest.inferred_params:
|
|
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs={})
|
|
|
|
events = [
|
|
FeedbackEvent(
|
|
action=e.get("action", ""),
|
|
dwell_ms=e.get("dwell_ms"),
|
|
created_at=e.get("created_at", ""),
|
|
)
|
|
for e in req.feedback_history
|
|
]
|
|
completions = [
|
|
TaskCompletion(
|
|
project_id=c.get("project_id"),
|
|
completed_at=c.get("completed_at", ""),
|
|
due_at=c.get("due_at", ""),
|
|
)
|
|
for c in req.task_completions
|
|
]
|
|
history = UserHistory(user_id=req.user_id, events=events, task_completions=completions)
|
|
|
|
t0 = __import__("time").monotonic()
|
|
inferred = run_inference(manifest, history)
|
|
latency_ms = round((__import__("time").monotonic() - t0) * 1000, 1)
|
|
|
|
log.info(
|
|
"inference_run",
|
|
agent_id=agent_id,
|
|
user_id=req.user_id,
|
|
n_params=len(inferred),
|
|
history_len=len(events),
|
|
latency_ms=latency_ms,
|
|
)
|
|
span = _start_span(
|
|
f"infer:{agent_id}",
|
|
SpanType.CHAIN,
|
|
inputs={"user_id": req.user_id, "agent_id": agent_id,
|
|
"history_len": len(events), "completion_count": len(completions)},
|
|
)
|
|
_end_span(span,
|
|
outputs={"inferred_prefs": inferred},
|
|
attributes={"latency_ms": str(latency_ms), "n_params": str(len(inferred))})
|
|
return AgentInferResponse(user_id=req.user_id, agent_id=agent_id, inferred_prefs=inferred)
|
|
|
|
|
|
@app.post("/recommend", response_model=RecommendResponse)
|
|
async def recommend(req: RecommendRequest) -> RecommendResponse:
|
|
"""Orchestrator: combine pre-computed agent outputs into one tip via LLM.
|
|
|
|
Called in real time when a user requests a tip. agent_outputs should be
|
|
the fresh rows from agent_outputs table (fetched by the TypeScript recommender
|
|
before calling this endpoint). Falls back to raw task context if empty.
|
|
"""
|
|
t0 = time.monotonic()
|
|
|
|
# ── root span ──────────────────────────────────────────────────────────
|
|
root = _start_span("recommend", SpanType.CHAIN, inputs={
|
|
"user_id": req.user_id,
|
|
"agent_ids": [s.agent_id for s in req.agent_outputs],
|
|
"hour_of_day": req.hour_of_day,
|
|
"day_of_week": req.day_of_week,
|
|
"science_destiny": req.science_destiny,
|
|
})
|
|
|
|
try:
|
|
# ── build_context span ─────────────────────────────────────────────
|
|
ctx_span = _start_span("build_context", SpanType.TOOL, parent=root, inputs={
|
|
"agent_count": len(req.agent_outputs),
|
|
"task_count": len(req.tasks),
|
|
"science_destiny": req.science_destiny,
|
|
})
|
|
messages = build_orchestrator_messages(
|
|
agent_outputs=[s.model_dump() for s in req.agent_outputs],
|
|
tasks=req.tasks,
|
|
hour_of_day=req.hour_of_day,
|
|
day_of_week=req.day_of_week,
|
|
science_destiny=req.science_destiny,
|
|
recent_tip=req.recent_tip,
|
|
)
|
|
_end_span(ctx_span, outputs={"message_count": len(messages)})
|
|
|
|
# ── one span per pre-computed agent snippet ────────────────────────
|
|
for snippet in req.agent_outputs:
|
|
a_span = _start_span(
|
|
f"agent:{snippet.agent_id}", SpanType.AGENT, parent=root,
|
|
inputs={"agent_id": snippet.agent_id},
|
|
)
|
|
_end_span(a_span, outputs={"prompt_text": snippet.prompt_text})
|
|
|
|
# ── LLM orchestrator span (wraps retry loop) ───────────────────────
|
|
llm_span = _start_span("llm_orchestrator", SpanType.LLM, parent=root, inputs={
|
|
"messages": messages,
|
|
"model": "tip-generator",
|
|
"temperature": 0.7,
|
|
})
|
|
|
|
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
|
last_raw = ""
|
|
last_parse_error = ""
|
|
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
model_used = "tip-generator"
|
|
_attempt = 0
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for _attempt in range(1 + _MAX_GENERATE_RETRIES):
|
|
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
|
try:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/chat/completions", json=payload, headers=headers
|
|
)
|
|
resp.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
_end_span(llm_span, status="ERROR")
|
|
_end_span(root, status="ERROR")
|
|
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
|
except httpx.RequestError as e:
|
|
_end_span(llm_span, status="ERROR")
|
|
_end_span(root, status="ERROR")
|
|
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
|
|
|
data = resp.json()
|
|
usage = data.get("usage", {})
|
|
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
model_used = data.get("model", "tip-generator")
|
|
last_raw = data["choices"][0]["message"]["content"]
|
|
|
|
try:
|
|
text = last_raw.strip()
|
|
if text.startswith("```"):
|
|
parts = text.split("```")
|
|
text = parts[1] if len(parts) > 1 else text
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
parsed = json.loads(text)
|
|
item: dict = parsed[0] if isinstance(parsed, list) else parsed
|
|
break
|
|
except (json.JSONDecodeError, ValueError, IndexError) as exc:
|
|
last_parse_error = str(exc)
|
|
messages.append({"role": "assistant", "content": last_raw})
|
|
messages.append({"role": "user", "content": _RETRY_SUFFIX_OBJ})
|
|
else:
|
|
_end_span(llm_span, status="ERROR")
|
|
_end_span(root, status="ERROR")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
|
f"{last_parse_error}\n{last_raw[:200]}",
|
|
)
|
|
|
|
tip = TipResult(
|
|
id=item.get("id", f"tip-{req.user_id[:8]}"),
|
|
content=item.get("content", ""),
|
|
rationale=item.get("rationale"),
|
|
)
|
|
_end_span(llm_span, outputs={"content": tip.content, "rationale": tip.rationale or ""},
|
|
attributes={
|
|
"prompt_tokens": str(total_usage["prompt_tokens"]),
|
|
"completion_tokens": str(total_usage["completion_tokens"]),
|
|
"model": model_used,
|
|
"attempts": str(_attempt + 1),
|
|
})
|
|
|
|
latency_ms = round((time.monotonic() - t0) * 1000, 1)
|
|
log.info("recommend_served", user_id=req.user_id, agent_count=len(req.agent_outputs), tip_id=tip.id)
|
|
_end_span(root, outputs={"tip_id": tip.id, "content": tip.content, "rationale": tip.rationale or ""},
|
|
attributes={"latency_ms": str(latency_ms), "agent_count": str(len(req.agent_outputs))})
|
|
|
|
return RecommendResponse(
|
|
tip=tip,
|
|
model=model_used,
|
|
prompt_tokens=total_usage["prompt_tokens"],
|
|
completion_tokens=total_usage["completion_tokens"],
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception:
|
|
_end_span(root, status="ERROR")
|
|
raise
|
|
|
|
_MAX_GENERATE_RETRIES = 2
|
|
|
|
|
|
def _parse_llm_json(raw: str) -> list[dict]:
|
|
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
|
|
text = raw.strip()
|
|
if text.startswith("```"):
|
|
parts = text.split("```")
|
|
text = parts[1] if len(parts) > 1 else text
|
|
if text.startswith("json"):
|
|
text = text[4:]
|
|
return json.loads(text)
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse)
|
|
async def generate(req: GenerateRequest) -> GenerateResponse:
|
|
"""Generate tip candidates via LiteLLM → tip-generator alias.
|
|
|
|
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
|
|
a correction hint to the conversation so the model can self-correct.
|
|
"""
|
|
try:
|
|
prompt_template = get_prompt(req.prompt_version)
|
|
except KeyError as e:
|
|
raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}")
|
|
ctx = req.context.model_copy(update={"profile_features": req.profile_features})
|
|
user_msg = prompt_template.build_user(ctx, req.n)
|
|
messages: list[dict] = [
|
|
{"role": "system", "content": prompt_template.system},
|
|
{"role": "user", "content": user_msg},
|
|
]
|
|
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
|
last_parse_error: str = ""
|
|
last_raw: str = ""
|
|
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
model_used = "tip-generator"
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for attempt in range(1 + _MAX_GENERATE_RETRIES):
|
|
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
|
try:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/chat/completions",
|
|
json=payload,
|
|
headers=headers,
|
|
)
|
|
resp.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
|
|
|
data = resp.json()
|
|
usage = data.get("usage", {})
|
|
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
model_used = data.get("model", "tip-generator")
|
|
|
|
last_raw = data["choices"][0]["message"]["content"]
|
|
try:
|
|
items = _parse_llm_json(last_raw)
|
|
break
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
last_parse_error = str(e)
|
|
# Feed the bad reply back so the model can self-correct
|
|
messages.append({"role": "assistant", "content": last_raw})
|
|
messages.append({"role": "user", "content": _RETRY_SUFFIX})
|
|
else:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
|
f"{last_parse_error}\n{last_raw[:200]}",
|
|
)
|
|
|
|
candidates = [
|
|
TipCandidate(
|
|
id=item.get("id", f"tip-{i}"),
|
|
content=item.get("content", ""),
|
|
rationale=item.get("rationale"),
|
|
)
|
|
for i, item in enumerate(items)
|
|
]
|
|
|
|
return GenerateResponse(
|
|
candidates=candidates,
|
|
model=model_used,
|
|
prompt_version=prompt_template.version,
|
|
prompt_tokens=total_usage["prompt_tokens"],
|
|
completion_tokens=total_usage["completion_tokens"],
|
|
)
|
|
|
|
|