feat: M2 AI tips — LiteLLM gateway, context assembler, end-to-end generation pipeline

Issues closed: #86, #87, #88, #89, #90, #91, #79, #80, #82

infra:
- docker-compose `ai` profile: Ollama + LiteLLM services
- infra/litellm/litellm_config.yaml: tip-generator / embedder / judge aliases
- .env.example: LITELLM_URL, LITELLM_MASTER_KEY, OLLAMA_URL

ml/serving:
- POST /generate: calls LiteLLM tip-generator alias, returns TipCandidate[]
- JSON retry loop (2 retries with correction prompt on malformed response)
- _parse_llm_json strips markdown fences

ml/features:
- context.py: build_context() assembles user signals → PromptContext
  (sorts overdue/high-priority tasks first for LLM prompt quality)

shared-types:
- TipKind, TipSource, TipCandidate types
- Tip gains kind + rationale fields

services/api:
- recommender: 3-stage pipeline (assemble → score → serve)
  Stage 1: Todoist tasks + LLM candidates fetched in parallel
  Stage 2: egreedy bandit scores merged candidate pool
  Stage 3: serve + log with prompt_version, llm_model, tip_kind
- tip_scores: prompt_version, llm_model, tip_kind columns + migrations
- config: LITELLM_URL added
- integrations: surface token_status in /integrations response

tests:
- ml/serving/tests/test_generate.py: 13 tests (retry, 502/503, fence variants)
- ml/features/test_context.py: 9 tests (sorting, edge cases)
- services/api recommender.unit.test.ts: 16 pure-function tests (inferReward, dueAgeDays)
- services/api recommender.test.ts: 4 integration tests (tip_scores columns, LLM fallback)
- shared-types: TipCandidate, rationale, full TipFeedback action set

docs:
- ADR-0008: LiteLLM AI gateway decision
- overview.md: M2 pipeline description updated
- ml/README.md: serving + features roles updated

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-17 14:09:02 +00:00
parent 85367aeaa0
commit ffdf70733f
22 changed files with 1017 additions and 45 deletions

View File

@@ -4,8 +4,8 @@ Python. Owns models, features, training, online scoring.
| Dir | Role | Phase |
|---|---|---|
| `serving/` | FastAPI online scorer (`/score`), called by `recommender` | 1 |
| `features/` | feature definitions + store adapter (Feast later) | 1 |
| `serving/` | FastAPI online scorer (`/score`, `/generate`) + LiteLLM gateway, called by `recommender` | 12 |
| `features/` | context assembler (`context.py`): signals → `PromptContext`; Feast adapter later | 2 |
| `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 |
| `registry/` | MLflow-backed model registry integration | 4 |
| `experiments/` | A/B assignment + multi-armed bandit policies | 4 |

3
ml/features/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .context import build_context, PromptContext, TaskSignal
__all__ = ["build_context", "PromptContext", "TaskSignal"]

63
ml/features/context.py Normal file
View File

@@ -0,0 +1,63 @@
"""
Context assembler — converts raw user signals into a PromptContext for LLM tip generation.
Usage:
from ml.features.context import build_context
ctx = build_context(tasks, hour_of_day=9, day_of_week=2)
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class TaskSignal:
id: str
content: str
priority: int = 1 # 14 (Todoist scale)
is_overdue: bool = False
task_age_days: float = 0.0
due_date: str | None = None
@dataclass
class PromptContext:
tasks: list[dict] = field(default_factory=list)
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = field(default_factory=dict)
def build_context(
tasks: list[TaskSignal],
hour_of_day: int = 12,
day_of_week: int = 0,
extra: dict | None = None,
) -> PromptContext:
"""
Assemble user signals into a PromptContext.
Signals are sorted so overdue + high-priority tasks appear first,
giving the LLM the most actionable context at the top of the prompt.
"""
sorted_tasks = sorted(
tasks,
key=lambda t: (not t.is_overdue, -t.priority, -t.task_age_days),
)
task_dicts = [
{
"id": t.id,
"content": t.content,
"priority": t.priority,
"is_overdue": t.is_overdue,
"task_age_days": round(t.task_age_days, 1),
"due_date": t.due_date,
}
for t in sorted_tasks
]
return PromptContext(
tasks=task_dicts,
hour_of_day=hour_of_day,
day_of_week=day_of_week,
extra=extra or {},
)

View File

@@ -0,0 +1,64 @@
"""Tests for ml/features/context.py"""
import pytest
import sys, os; sys.path.insert(0, os.path.dirname(__file__))
from context import build_context, TaskSignal, PromptContext
def test_empty_tasks():
ctx = build_context([], hour_of_day=9, day_of_week=1)
assert ctx.tasks == []
assert ctx.hour_of_day == 9
assert ctx.day_of_week == 1
def test_overdue_tasks_sorted_first():
tasks = [
TaskSignal(id="a", content="Normal task", priority=1, is_overdue=False),
TaskSignal(id="b", content="Overdue task", priority=2, is_overdue=True, task_age_days=3.0),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "b"
def test_high_priority_within_non_overdue():
tasks = [
TaskSignal(id="lo", content="Low prio", priority=1, is_overdue=False),
TaskSignal(id="hi", content="High prio", priority=4, is_overdue=False),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "hi"
def test_extra_fields_passed_through():
ctx = build_context([], extra={"mood": "focused"})
assert ctx.extra["mood"] == "focused"
def test_task_age_rounded():
tasks = [TaskSignal(id="x", content="Task", task_age_days=1.23456)]
ctx = build_context(tasks)
assert ctx.tasks[0]["task_age_days"] == 1.2
def test_overdue_sorted_by_priority():
tasks = [
TaskSignal(id="lo", content="Low", priority=1, is_overdue=True),
TaskSignal(id="hi", content="High", priority=4, is_overdue=True),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "hi"
def test_overdue_same_priority_sorted_by_age():
tasks = [
TaskSignal(id="new", content="New", priority=2, is_overdue=True, task_age_days=1.0),
TaskSignal(id="old", content="Old", priority=2, is_overdue=True, task_age_days=5.0),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "old"
def test_due_date_none_preserved():
tasks = [TaskSignal(id="x", content="No due", due_date=None)]
ctx = build_context(tasks)
assert ctx.tasks[0]["due_date"] is None

View File

@@ -26,12 +26,16 @@ from collections import deque
from pathlib import Path
from typing import Optional, Deque
import httpx
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI(title="oO ML Serving", version="1.0.0")
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-bandit-state"))
STATE_DIR.mkdir(parents=True, exist_ok=True)
@@ -166,6 +170,56 @@ class RewardResponse(BaseModel):
ok: bool
class PromptContext(BaseModel):
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = {}
class GenerateRequest(BaseModel):
user_id: str
context: PromptContext = PromptContext()
n: int = 3
class TipCandidate(BaseModel):
id: str
content: str
source: str = "llm"
rationale: Optional[str] = None
class GenerateResponse(BaseModel):
candidates: list[TipCandidate]
model: str
prompt_tokens: int = 0
completion_tokens: int = 0
_GENERATE_SYSTEM = (
"You are a personal productivity coach. "
"Given the user's current context, generate actionable, specific tips. "
"Respond ONLY with a JSON array of objects, each with keys: "
'"id" (short slug), "content" (the tip, ≤2 sentences), "rationale" (why now, ≤1 sentence). '
"No markdown, no prose outside the JSON array."
)
def _build_prompt(ctx: PromptContext, n: int) -> str:
lines = [f"Time: {ctx.hour_of_day:02d}:00, day_of_week={ctx.day_of_week}"]
if ctx.tasks:
overdue = [t for t in ctx.tasks if t.get("is_overdue")]
lines.append(f"Tasks: {len(ctx.tasks)} total, {len(overdue)} overdue")
for t in ctx.tasks[:5]:
due = t.get("due_date", "no due date")
lines.append(f" - [{t.get('priority','?')}] {t.get('content','?')} (due: {due})")
for k, v in ctx.extra.items():
lines.append(f"{k}: {v}")
lines.append(f"\nGenerate {n} tips as a JSON array.")
return "\n".join(lines)
# ── Endpoints ──────────────────────────────────────────────────────────────
@app.get("/health")
@@ -173,6 +227,97 @@ def health():
return {"ok": True}
_RETRY_SUFFIX = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON array — no prose, no markdown fences."
)
_MAX_GENERATE_RETRIES = 2
def _parse_llm_json(raw: str) -> list[dict]:
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
return json.loads(text)
@app.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest) -> GenerateResponse:
"""Generate tip candidates via LiteLLM → tip-generator alias.
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
a correction hint to the conversation so the model can self-correct.
"""
prompt = _build_prompt(req.context, req.n)
messages: list[dict] = [
{"role": "system", "content": _GENERATE_SYSTEM},
{"role": "user", "content": prompt},
]
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_parse_error: str = ""
last_raw: str = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
async with httpx.AsyncClient(timeout=30.0) as client:
for attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions",
json=payload,
headers=headers,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
items = _parse_llm_json(last_raw)
break
except (json.JSONDecodeError, ValueError) as e:
last_parse_error = str(e)
# Feed the bad reply back so the model can self-correct
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
candidates = [
TipCandidate(
id=item.get("id", f"tip-{i}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
for i, item in enumerate(items)
]
return GenerateResponse(
candidates=candidates,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
@app.post("/score", response_model=ScoreResponse)
def score(req: ScoreRequest) -> ScoreResponse:
if not req.candidates:

View File

@@ -0,0 +1,225 @@
"""
Tests for POST /generate — LiteLLM gateway.
LiteLLM is mocked; no real network calls.
"""
import json
import pytest
import httpx
from unittest.mock import AsyncMock, patch
from httpx import AsyncClient, ASGITransport, Response
from main import app, _build_prompt, PromptContext
def _litellm_response(candidates: list[dict]) -> Response:
import httpx
body = {
"model": "tip-generator",
"choices": [{"message": {"content": json.dumps(candidates)}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
return Response(200, json=body, request=req)
@pytest.mark.anyio
async def test_generate_returns_candidates():
fake_items = [
{"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."},
{"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."},
]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 2})
assert resp.status_code == 200
data = resp.json()
assert len(data["candidates"]) == 2
assert data["candidates"][0]["id"] == "tip-1"
assert data["model"] == "tip-generator"
@pytest.mark.anyio
async def test_generate_strips_markdown_fence():
fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}]
fenced = "```json\n" + json.dumps(fake_items) + "\n```"
body = {
"model": "tip-generator",
"choices": [{"message": {"content": fenced}}],
"usage": {},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
mock_resp = Response(200, json=body, request=req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-a"
@pytest.mark.anyio
async def test_generate_503_on_unreachable():
import httpx as _httpx
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused"))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 503
def test_build_prompt_includes_tasks():
ctx = PromptContext(
tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}],
hour_of_day=9,
day_of_week=2,
)
prompt = _build_prompt(ctx, n=3)
assert "Write report" in prompt
assert "09:00" in prompt
assert "Generate 3 tips" in prompt
def test_build_prompt_truncates_at_five():
tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)]
ctx = PromptContext(tasks=tasks, hour_of_day=12)
prompt = _build_prompt(ctx, n=2)
assert "Task 4" in prompt
assert "Task 5" not in prompt
def test_build_prompt_extra_fields():
ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"})
prompt = _build_prompt(ctx, n=1)
assert "mood: focused" in prompt
assert "energy: high" in prompt
def test_build_prompt_empty_tasks_no_task_line():
ctx = PromptContext(tasks=[], hour_of_day=10)
prompt = _build_prompt(ctx, n=2)
assert "Tasks:" not in prompt
assert "Generate 2 tips" in prompt
@pytest.mark.anyio
async def test_generate_retry_succeeds_on_second_attempt():
"""First response is invalid JSON; second is valid. Should return 200."""
valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}]
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
bad_resp = Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "this is not json"}}],
"usage": {},
}, request=bad_req)
good_resp = _litellm_response(valid_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=[bad_resp, good_resp])
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 1})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-ok"
assert instance.post.call_count == 2
# Retry message should include the correction suffix
second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"]
assert any("not valid JSON" in m["content"] for m in second_call_messages)
@pytest.mark.anyio
async def test_generate_502_after_all_retries_exhausted():
"""All attempts return invalid JSON → 502."""
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
def _bad_resp():
return Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "not json at all"}}],
"usage": {},
}, request=bad_req)
from main import _MAX_GENERATE_RETRIES
responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)]
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=responses)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "retries" in resp.json()["detail"]
@pytest.mark.anyio
async def test_generate_502_on_upstream_http_error():
"""LiteLLM returns 500 → HTTPStatusError → 502."""
err_req = httpx.Request("POST", "http://litellm/chat/completions")
err_resp = Response(500, text="internal error", request=err_req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
"500", request=err_req, response=err_resp
))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "LiteLLM error" in resp.json()["detail"]
def test_parse_llm_json_bare_fence():
from main import _parse_llm_json
raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```"
items = _parse_llm_json(raw)
assert items[0]["id"] == "x"
def test_parse_llm_json_no_fence():
from main import _parse_llm_json
raw = '[{"id":"plain","content":"no fence"}]'
items = _parse_llm_json(raw)
assert items[0]["id"] == "plain"
def test_parse_llm_json_raises_on_invalid():
from main import _parse_llm_json
with pytest.raises((ValueError, Exception)):
_parse_llm_json("this is not json")