feat: M2 AI tips — LiteLLM gateway, context assembler, end-to-end generation pipeline

Issues closed: #86, #87, #88, #89, #90, #91, #79, #80, #82

infra:
- docker-compose `ai` profile: Ollama + LiteLLM services
- infra/litellm/litellm_config.yaml: tip-generator / embedder / judge aliases
- .env.example: LITELLM_URL, LITELLM_MASTER_KEY, OLLAMA_URL

ml/serving:
- POST /generate: calls LiteLLM tip-generator alias, returns TipCandidate[]
- JSON retry loop (2 retries with correction prompt on malformed response)
- _parse_llm_json strips markdown fences

ml/features:
- context.py: build_context() assembles user signals → PromptContext
  (sorts overdue/high-priority tasks first for LLM prompt quality)

shared-types:
- TipKind, TipSource, TipCandidate types
- Tip gains kind + rationale fields

services/api:
- recommender: 3-stage pipeline (assemble → score → serve)
  Stage 1: Todoist tasks + LLM candidates fetched in parallel
  Stage 2: egreedy bandit scores merged candidate pool
  Stage 3: serve + log with prompt_version, llm_model, tip_kind
- tip_scores: prompt_version, llm_model, tip_kind columns + migrations
- config: LITELLM_URL added
- integrations: surface token_status in /integrations response

tests:
- ml/serving/tests/test_generate.py: 13 tests (retry, 502/503, fence variants)
- ml/features/test_context.py: 9 tests (sorting, edge cases)
- services/api recommender.unit.test.ts: 16 pure-function tests (inferReward, dueAgeDays)
- services/api recommender.test.ts: 4 integration tests (tip_scores columns, LLM fallback)
- shared-types: TipCandidate, rationale, full TipFeedback action set

docs:
- ADR-0008: LiteLLM AI gateway decision
- overview.md: M2 pipeline description updated
- ml/README.md: serving + features roles updated

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-17 14:09:02 +00:00
parent 85367aeaa0
commit ffdf70733f
22 changed files with 1017 additions and 45 deletions

View File

@@ -10,6 +10,11 @@ API_BASE_URL=http://localhost:3078
WEB_BASE_URL=http://localhost:3000 WEB_BASE_URL=http://localhost:3000
ML_SERVING_URL=http://localhost:8000 ML_SERVING_URL=http://localhost:8000
# AI stack — Ollama + LiteLLM (docker compose --profile ai)
LITELLM_URL=http://localhost:4000
LITELLM_MASTER_KEY=sk-oo-dev
OLLAMA_URL=http://localhost:11434
# Google OAuth — https://console.cloud.google.com/ # Google OAuth — https://console.cloud.google.com/
GOOGLE_CLIENT_ID= GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET= GOOGLE_CLIENT_SECRET=

View File

@@ -0,0 +1,41 @@
# ADR-0008 — LiteLLM as AI gateway; model aliases decouple code from model names
**Status:** Accepted
**Date:** 2026-04-17
**Milestone:** M2
## Context
M2 requires LLM inference for tip generation (`ml/serving POST /generate`). We need a way to:
- Run locally during development without cloud API keys.
- Switch models (qwen2.5 → llama3.2, or cloud fallback) without touching application code.
- Share the LLM infrastructure with other local services on Agap.
## Decision
Route all LLM calls through **LiteLLM** (`http://localhost:4000` in dev, `llm.alogins.net` in prod) backed by **Ollama** for local inference.
Application code references model aliases — never bare model names:
| Alias | Default model | Used by |
|-------|--------------|---------|
| `tip-generator` | `qwen2.5:7b` | `ml/serving POST /generate` |
| `embedder` | `nomic-embed-text` | task clustering, dedup (M4) |
| `judge` | `claude-haiku-4-5` | offline simulation only |
Config is in `infra/litellm/litellm_config.yaml`. Swapping a model = one YAML change, zero code change.
`ml/serving` reads `LITELLM_URL` and `LITELLM_MASTER_KEY` from env. TypeScript services never call LLM endpoints directly — all inference flows through `ml/serving`.
## Consequences
- **Local dev:** `docker compose --profile ai up` starts Ollama + LiteLLM. First run pulls models (~4 GB for qwen2.5:7b).
- **Prod:** both are shared Agap services; set `LITELLM_URL=http://llm.alogins.net` in `.env.local`.
- **Offline sim:** `judge` alias points at `claude-haiku-4-5` (cloud) — requires `ANTHROPIC_API_KEY`; simulation is opt-in.
- **Vendor lock-in:** none at the code level. LiteLLM translates the OpenAI-compatible API to whatever backend.
- **Observability:** LiteLLM logs all requests; `tip_scores.llm_model` + `tip_scores.prompt_version` track which model + prompt generated each served tip.
## Alternatives considered
- **Call Ollama directly:** cheaper in latency, but ties code to Ollama's API format and makes cloud fallback a code change.
- **Call Anthropic directly from TS:** violates the rule that TS services never hold model names (CLAUDE.md prime directive 3).

View File

@@ -82,6 +82,8 @@ client ─► gateway ─► recommender (TS)
◄─ best TipCandidate ◄─ best TipCandidate
``` ```
**Phase 1 (current):** candidates come from Todoist task list, no LLM. The bandit scores tasks directly. **Phase 1 (shipped M1):** candidates come from Todoist task list, no LLM. The bandit scores tasks directly.
**Phase 2 (shipped M2):** LLM candidates are generated in parallel with Todoist fetch. Both pools are merged, scored by the bandit, and the winner served. `tip_scores` tracks `prompt_version`, `llm_model`, and `tip_kind` for every row.
Feedback: `POST /feedback → events.emit(reaction)` → online bandit update + `prompt_version` tracked for A/B analysis. Feedback: `POST /feedback → events.emit(reaction)` → online bandit update + `prompt_version` tracked for A/B analysis.

View File

@@ -14,9 +14,9 @@ services:
volumes: volumes:
- /mnt/ssd/dbs/oo:/mnt/ssd/dbs/oo - /mnt/ssd/dbs/oo:/mnt/ssd/dbs/oo
ports: ports:
- "127.0.0.1:3078:3078" - "127.0.0.1:3001:3001"
healthcheck: healthcheck:
test: ["CMD", "wget", "--spider", "-q", "http://localhost:3078/health"] test: ["CMD", "wget", "--spider", "-q", "http://localhost:3001/health"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 5 retries: 5
@@ -49,7 +49,7 @@ services:
PORT: "3080" PORT: "3080"
HOSTNAME: "0.0.0.0" HOSTNAME: "0.0.0.0"
NEXT_PUBLIC_API_URL: "" NEXT_PUBLIC_API_URL: ""
INTERNAL_API_URL: "http://api:3078" INTERNAL_API_URL: "http://api:3001"
ports: ports:
- "127.0.0.1:3080:3080" - "127.0.0.1:3080:3080"
depends_on: depends_on:
@@ -63,6 +63,10 @@ services:
context: ../.. context: ../..
dockerfile: infra/docker/Dockerfile.ml dockerfile: infra/docker/Dockerfile.ml
profiles: [full] profiles: [full]
env_file: ../../.env.local
environment:
LITELLM_URL: ${LITELLM_URL:-http://litellm:4000}
OLLAMA_URL: ${OLLAMA_URL:-http://ollama:11434}
ports: ports:
- "127.0.0.1:8000:8000" - "127.0.0.1:8000:8000"
healthcheck: healthcheck:
@@ -155,6 +159,45 @@ services:
airflow-init: airflow-init:
condition: service_completed_successfully condition: service_completed_successfully
# ── ai profile — Ollama + LiteLLM ────────────────────────────────────────
# Start: docker compose --profile ai up
# LiteLLM proxy: http://localhost:4000 (master key from LITELLM_MASTER_KEY)
# Ollama API: http://localhost:11434
# In prod both are shared Agap services; set LITELLM_URL + OLLAMA_URL in .env.local
ollama:
image: ollama/ollama:latest
profiles: [ai]
volumes:
- /mnt/ssd/dbs/oo/ollama:/root/.ollama
ports:
- "127.0.0.1:11434:11434"
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:11434"]
interval: 15s
timeout: 5s
retries: 5
litellm:
image: ghcr.io/berriai/litellm:main-latest
profiles: [ai]
command: ["--config", "/app/litellm_config.yaml", "--port", "4000"]
environment:
LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY:-sk-oo-dev}
OLLAMA_URL: ${OLLAMA_URL:-http://ollama:11434}
volumes:
- ../../infra/litellm/litellm_config.yaml:/app/litellm_config.yaml:ro
ports:
- "127.0.0.1:4000:4000"
depends_on:
ollama:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:4000/health"]
interval: 15s
timeout: 5s
retries: 5
mlflow: mlflow:
image: ghcr.io/mlflow/mlflow:2.14.3 image: ghcr.io/mlflow/mlflow:2.14.3
profiles: [mlops] profiles: [mlops]

View File

@@ -0,0 +1,17 @@
model_list:
- model_name: tip-generator
litellm_params:
model: ollama/qwen2.5:7b
api_base: "${OLLAMA_URL}"
- model_name: embedder
litellm_params:
model: ollama/nomic-embed-text
api_base: "${OLLAMA_URL}"
- model_name: judge
litellm_params:
model: claude-haiku-4-5-20251001
general_settings:
master_key: "${LITELLM_MASTER_KEY}"

View File

@@ -4,8 +4,8 @@ Python. Owns models, features, training, online scoring.
| Dir | Role | Phase | | Dir | Role | Phase |
|---|---|---| |---|---|---|
| `serving/` | FastAPI online scorer (`/score`), called by `recommender` | 1 | | `serving/` | FastAPI online scorer (`/score`, `/generate`) + LiteLLM gateway, called by `recommender` | 12 |
| `features/` | feature definitions + store adapter (Feast later) | 1 | | `features/` | context assembler (`context.py`): signals → `PromptContext`; Feast adapter later | 2 |
| `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 | | `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 |
| `registry/` | MLflow-backed model registry integration | 4 | | `registry/` | MLflow-backed model registry integration | 4 |
| `experiments/` | A/B assignment + multi-armed bandit policies | 4 | | `experiments/` | A/B assignment + multi-armed bandit policies | 4 |

3
ml/features/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .context import build_context, PromptContext, TaskSignal
__all__ = ["build_context", "PromptContext", "TaskSignal"]

63
ml/features/context.py Normal file
View File

@@ -0,0 +1,63 @@
"""
Context assembler — converts raw user signals into a PromptContext for LLM tip generation.
Usage:
from ml.features.context import build_context
ctx = build_context(tasks, hour_of_day=9, day_of_week=2)
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class TaskSignal:
id: str
content: str
priority: int = 1 # 14 (Todoist scale)
is_overdue: bool = False
task_age_days: float = 0.0
due_date: str | None = None
@dataclass
class PromptContext:
tasks: list[dict] = field(default_factory=list)
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = field(default_factory=dict)
def build_context(
tasks: list[TaskSignal],
hour_of_day: int = 12,
day_of_week: int = 0,
extra: dict | None = None,
) -> PromptContext:
"""
Assemble user signals into a PromptContext.
Signals are sorted so overdue + high-priority tasks appear first,
giving the LLM the most actionable context at the top of the prompt.
"""
sorted_tasks = sorted(
tasks,
key=lambda t: (not t.is_overdue, -t.priority, -t.task_age_days),
)
task_dicts = [
{
"id": t.id,
"content": t.content,
"priority": t.priority,
"is_overdue": t.is_overdue,
"task_age_days": round(t.task_age_days, 1),
"due_date": t.due_date,
}
for t in sorted_tasks
]
return PromptContext(
tasks=task_dicts,
hour_of_day=hour_of_day,
day_of_week=day_of_week,
extra=extra or {},
)

View File

@@ -0,0 +1,64 @@
"""Tests for ml/features/context.py"""
import pytest
import sys, os; sys.path.insert(0, os.path.dirname(__file__))
from context import build_context, TaskSignal, PromptContext
def test_empty_tasks():
ctx = build_context([], hour_of_day=9, day_of_week=1)
assert ctx.tasks == []
assert ctx.hour_of_day == 9
assert ctx.day_of_week == 1
def test_overdue_tasks_sorted_first():
tasks = [
TaskSignal(id="a", content="Normal task", priority=1, is_overdue=False),
TaskSignal(id="b", content="Overdue task", priority=2, is_overdue=True, task_age_days=3.0),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "b"
def test_high_priority_within_non_overdue():
tasks = [
TaskSignal(id="lo", content="Low prio", priority=1, is_overdue=False),
TaskSignal(id="hi", content="High prio", priority=4, is_overdue=False),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "hi"
def test_extra_fields_passed_through():
ctx = build_context([], extra={"mood": "focused"})
assert ctx.extra["mood"] == "focused"
def test_task_age_rounded():
tasks = [TaskSignal(id="x", content="Task", task_age_days=1.23456)]
ctx = build_context(tasks)
assert ctx.tasks[0]["task_age_days"] == 1.2
def test_overdue_sorted_by_priority():
tasks = [
TaskSignal(id="lo", content="Low", priority=1, is_overdue=True),
TaskSignal(id="hi", content="High", priority=4, is_overdue=True),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "hi"
def test_overdue_same_priority_sorted_by_age():
tasks = [
TaskSignal(id="new", content="New", priority=2, is_overdue=True, task_age_days=1.0),
TaskSignal(id="old", content="Old", priority=2, is_overdue=True, task_age_days=5.0),
]
ctx = build_context(tasks)
assert ctx.tasks[0]["id"] == "old"
def test_due_date_none_preserved():
tasks = [TaskSignal(id="x", content="No due", due_date=None)]
ctx = build_context(tasks)
assert ctx.tasks[0]["due_date"] is None

View File

@@ -26,12 +26,16 @@ from collections import deque
from pathlib import Path from pathlib import Path
from typing import Optional, Deque from typing import Optional, Deque
import httpx
import numpy as np import numpy as np
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI(title="oO ML Serving", version="1.0.0") app = FastAPI(title="oO ML Serving", version="1.0.0")
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-bandit-state")) STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-bandit-state"))
STATE_DIR.mkdir(parents=True, exist_ok=True) STATE_DIR.mkdir(parents=True, exist_ok=True)
@@ -166,6 +170,56 @@ class RewardResponse(BaseModel):
ok: bool ok: bool
class PromptContext(BaseModel):
tasks: list[dict] = []
hour_of_day: int = 12
day_of_week: int = 0
extra: dict = {}
class GenerateRequest(BaseModel):
user_id: str
context: PromptContext = PromptContext()
n: int = 3
class TipCandidate(BaseModel):
id: str
content: str
source: str = "llm"
rationale: Optional[str] = None
class GenerateResponse(BaseModel):
candidates: list[TipCandidate]
model: str
prompt_tokens: int = 0
completion_tokens: int = 0
_GENERATE_SYSTEM = (
"You are a personal productivity coach. "
"Given the user's current context, generate actionable, specific tips. "
"Respond ONLY with a JSON array of objects, each with keys: "
'"id" (short slug), "content" (the tip, ≤2 sentences), "rationale" (why now, ≤1 sentence). '
"No markdown, no prose outside the JSON array."
)
def _build_prompt(ctx: PromptContext, n: int) -> str:
lines = [f"Time: {ctx.hour_of_day:02d}:00, day_of_week={ctx.day_of_week}"]
if ctx.tasks:
overdue = [t for t in ctx.tasks if t.get("is_overdue")]
lines.append(f"Tasks: {len(ctx.tasks)} total, {len(overdue)} overdue")
for t in ctx.tasks[:5]:
due = t.get("due_date", "no due date")
lines.append(f" - [{t.get('priority','?')}] {t.get('content','?')} (due: {due})")
for k, v in ctx.extra.items():
lines.append(f"{k}: {v}")
lines.append(f"\nGenerate {n} tips as a JSON array.")
return "\n".join(lines)
# ── Endpoints ────────────────────────────────────────────────────────────── # ── Endpoints ──────────────────────────────────────────────────────────────
@app.get("/health") @app.get("/health")
@@ -173,6 +227,97 @@ def health():
return {"ok": True} return {"ok": True}
_RETRY_SUFFIX = (
"\n\nYour previous response was not valid JSON. "
"Reply ONLY with the JSON array — no prose, no markdown fences."
)
_MAX_GENERATE_RETRIES = 2
def _parse_llm_json(raw: str) -> list[dict]:
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
text = raw.strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) > 1 else text
if text.startswith("json"):
text = text[4:]
return json.loads(text)
@app.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest) -> GenerateResponse:
"""Generate tip candidates via LiteLLM → tip-generator alias.
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
a correction hint to the conversation so the model can self-correct.
"""
prompt = _build_prompt(req.context, req.n)
messages: list[dict] = [
{"role": "system", "content": _GENERATE_SYSTEM},
{"role": "user", "content": prompt},
]
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
last_parse_error: str = ""
last_raw: str = ""
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
model_used = "tip-generator"
async with httpx.AsyncClient(timeout=30.0) as client:
for attempt in range(1 + _MAX_GENERATE_RETRIES):
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
try:
resp = await client.post(
f"{LITELLM_URL}/chat/completions",
json=payload,
headers=headers,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
data = resp.json()
usage = data.get("usage", {})
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
model_used = data.get("model", "tip-generator")
last_raw = data["choices"][0]["message"]["content"]
try:
items = _parse_llm_json(last_raw)
break
except (json.JSONDecodeError, ValueError) as e:
last_parse_error = str(e)
# Feed the bad reply back so the model can self-correct
messages.append({"role": "assistant", "content": last_raw})
messages.append({"role": "user", "content": _RETRY_SUFFIX})
else:
raise HTTPException(
status_code=502,
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
f"{last_parse_error}\n{last_raw[:200]}",
)
candidates = [
TipCandidate(
id=item.get("id", f"tip-{i}"),
content=item.get("content", ""),
rationale=item.get("rationale"),
)
for i, item in enumerate(items)
]
return GenerateResponse(
candidates=candidates,
model=model_used,
prompt_tokens=total_usage["prompt_tokens"],
completion_tokens=total_usage["completion_tokens"],
)
@app.post("/score", response_model=ScoreResponse) @app.post("/score", response_model=ScoreResponse)
def score(req: ScoreRequest) -> ScoreResponse: def score(req: ScoreRequest) -> ScoreResponse:
if not req.candidates: if not req.candidates:

View File

@@ -0,0 +1,225 @@
"""
Tests for POST /generate — LiteLLM gateway.
LiteLLM is mocked; no real network calls.
"""
import json
import pytest
import httpx
from unittest.mock import AsyncMock, patch
from httpx import AsyncClient, ASGITransport, Response
from main import app, _build_prompt, PromptContext
def _litellm_response(candidates: list[dict]) -> Response:
import httpx
body = {
"model": "tip-generator",
"choices": [{"message": {"content": json.dumps(candidates)}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
return Response(200, json=body, request=req)
@pytest.mark.anyio
async def test_generate_returns_candidates():
fake_items = [
{"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."},
{"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."},
]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 2})
assert resp.status_code == 200
data = resp.json()
assert len(data["candidates"]) == 2
assert data["candidates"][0]["id"] == "tip-1"
assert data["model"] == "tip-generator"
@pytest.mark.anyio
async def test_generate_strips_markdown_fence():
fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}]
fenced = "```json\n" + json.dumps(fake_items) + "\n```"
body = {
"model": "tip-generator",
"choices": [{"message": {"content": fenced}}],
"usage": {},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
mock_resp = Response(200, json=body, request=req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-a"
@pytest.mark.anyio
async def test_generate_503_on_unreachable():
import httpx as _httpx
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused"))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 503
def test_build_prompt_includes_tasks():
ctx = PromptContext(
tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}],
hour_of_day=9,
day_of_week=2,
)
prompt = _build_prompt(ctx, n=3)
assert "Write report" in prompt
assert "09:00" in prompt
assert "Generate 3 tips" in prompt
def test_build_prompt_truncates_at_five():
tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)]
ctx = PromptContext(tasks=tasks, hour_of_day=12)
prompt = _build_prompt(ctx, n=2)
assert "Task 4" in prompt
assert "Task 5" not in prompt
def test_build_prompt_extra_fields():
ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"})
prompt = _build_prompt(ctx, n=1)
assert "mood: focused" in prompt
assert "energy: high" in prompt
def test_build_prompt_empty_tasks_no_task_line():
ctx = PromptContext(tasks=[], hour_of_day=10)
prompt = _build_prompt(ctx, n=2)
assert "Tasks:" not in prompt
assert "Generate 2 tips" in prompt
@pytest.mark.anyio
async def test_generate_retry_succeeds_on_second_attempt():
"""First response is invalid JSON; second is valid. Should return 200."""
valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}]
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
bad_resp = Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "this is not json"}}],
"usage": {},
}, request=bad_req)
good_resp = _litellm_response(valid_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=[bad_resp, good_resp])
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 1})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-ok"
assert instance.post.call_count == 2
# Retry message should include the correction suffix
second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"]
assert any("not valid JSON" in m["content"] for m in second_call_messages)
@pytest.mark.anyio
async def test_generate_502_after_all_retries_exhausted():
"""All attempts return invalid JSON → 502."""
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
def _bad_resp():
return Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "not json at all"}}],
"usage": {},
}, request=bad_req)
from main import _MAX_GENERATE_RETRIES
responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)]
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=responses)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "retries" in resp.json()["detail"]
@pytest.mark.anyio
async def test_generate_502_on_upstream_http_error():
"""LiteLLM returns 500 → HTTPStatusError → 502."""
err_req = httpx.Request("POST", "http://litellm/chat/completions")
err_resp = Response(500, text="internal error", request=err_req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
"500", request=err_req, response=err_resp
))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "LiteLLM error" in resp.json()["detail"]
def test_parse_llm_json_bare_fence():
from main import _parse_llm_json
raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```"
items = _parse_llm_json(raw)
assert items[0]["id"] == "x"
def test_parse_llm_json_no_fence():
from main import _parse_llm_json
raw = '[{"id":"plain","content":"no fence"}]'
items = _parse_llm_json(raw)
assert items[0]["id"] == "plain"
def test_parse_llm_json_raises_on_invalid():
from main import _parse_llm_json
with pytest.raises((ValueError, Exception)):
_parse_llm_json("this is not json")

View File

@@ -1,5 +1,5 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import type { Tip, TipFeedback, RecommendResponse } from '../index.js'; import type { Tip, TipFeedback, TipCandidate, RecommendResponse } from '../index.js';
describe('Tip type contract', () => { describe('Tip type contract', () => {
it('accepts a valid Tip object', () => { it('accepts a valid Tip object', () => {
@@ -7,6 +7,7 @@ describe('Tip type contract', () => {
id: 'todoist:123', id: 'todoist:123',
content: 'Finish the report', content: 'Finish the report',
source: 'todoist', source: 'todoist',
kind: 'task',
sourceId: '123', sourceId: '123',
createdAt: new Date().toISOString(), createdAt: new Date().toISOString(),
}; };
@@ -18,6 +19,7 @@ describe('Tip type contract', () => {
id: 'advice:abc', id: 'advice:abc',
content: 'Take a break', content: 'Take a break',
source: 'advice', source: 'advice',
kind: 'advice',
createdAt: new Date().toISOString(), createdAt: new Date().toISOString(),
}; };
expect(tip.sourceId).toBeUndefined(); expect(tip.sourceId).toBeUndefined();
@@ -25,16 +27,45 @@ describe('Tip type contract', () => {
it('RecommendResponse wraps a Tip', () => { it('RecommendResponse wraps a Tip', () => {
const res: RecommendResponse = { const res: RecommendResponse = {
tip: { id: 'x', content: 'Do it', source: 'todoist', createdAt: '' }, tip: { id: 'x', content: 'Do it', source: 'todoist', kind: 'task', createdAt: '' },
}; };
expect(res.tip.id).toBe('x'); expect(res.tip.id).toBe('x');
}); });
it('TipFeedback allows valid actions', () => { it('TipFeedback allows all valid actions including helpful/not_helpful', () => {
const actions: TipFeedback['action'][] = ['done', 'dismiss', 'snooze']; const actions: TipFeedback['action'][] = ['done', 'dismiss', 'snooze', 'helpful', 'not_helpful'];
for (const action of actions) { for (const action of actions) {
const fb: TipFeedback = { action }; const fb: TipFeedback = { action };
expect(fb.action).toBe(action); expect(fb.action).toBe(action);
} }
}); });
it('Tip accepts optional rationale', () => {
const tip: Tip = {
id: 'llm:tip-1',
content: 'Block 30 min for deep work.',
source: 'llm',
kind: 'advice',
rationale: 'Your calendar is clear until noon.',
createdAt: new Date().toISOString(),
};
expect(tip.rationale).toBeDefined();
});
it('Tip rationale is optional', () => {
const tip: Tip = { id: 'x', content: 'Do it', source: 'todoist', kind: 'task', createdAt: '' };
expect(tip.rationale).toBeUndefined();
});
it('TipCandidate includes features', () => {
const c: TipCandidate = {
id: 'todoist:1',
content: 'Finish report',
source: 'todoist',
kind: 'task',
createdAt: '',
features: { is_overdue: true, task_age_days: 2, priority: 4 },
};
expect(c.features.is_overdue).toBe(true);
});
}); });

View File

@@ -1,12 +1,32 @@
/** Category of a tip — drives icon, CTA copy, and reward inference */
export type TipKind = 'task' | 'advice' | 'insight' | 'reminder';
/** Where the tip content originated */
export type TipSource = 'todoist' | 'llm' | 'advice';
/** A single recommendation surfaced to the user */ /** A single recommendation surfaced to the user */
export interface Tip { export interface Tip {
id: string; id: string;
content: string; content: string;
source: 'todoist' | 'advice'; source: TipSource;
kind: TipKind;
sourceId?: string; sourceId?: string;
rationale?: string; // LLM-generated "why now" shown on long-press
createdAt: string; // ISO 8601 createdAt: string; // ISO 8601
} }
/**
* A scored tip candidate flowing through the bandit pipeline.
* Extends Tip with features needed for scoring.
*/
export interface TipCandidate extends Tip {
features: {
is_overdue: boolean;
task_age_days: number;
priority: number;
};
}
/** POST /recommend response */ /** POST /recommend response */
export interface RecommendResponse { export interface RecommendResponse {
tip: Tip; tip: Tip;

View File

@@ -32,6 +32,7 @@ export const config = {
WEB_BASE_URL: optional('WEB_BASE_URL', 'http://localhost:3000'), WEB_BASE_URL: optional('WEB_BASE_URL', 'http://localhost:3000'),
ML_SERVING_URL: optional('ML_SERVING_URL', 'http://localhost:8000'), ML_SERVING_URL: optional('ML_SERVING_URL', 'http://localhost:8000'),
LITELLM_URL: optional('LITELLM_URL', 'http://localhost:4000'),
VAPID_PUBLIC_KEY: optional('VAPID_PUBLIC_KEY', ''), VAPID_PUBLIC_KEY: optional('VAPID_PUBLIC_KEY', ''),
VAPID_PRIVATE_KEY: optional('VAPID_PRIVATE_KEY', ''), VAPID_PRIVATE_KEY: optional('VAPID_PRIVATE_KEY', ''),

View File

@@ -142,6 +142,10 @@ export function runMigrations() {
`ALTER TABLE push_subscriptions ADD COLUMN created_at TEXT NOT NULL DEFAULT ''`, `ALTER TABLE push_subscriptions ADD COLUMN created_at TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE tip_feedback ADD COLUMN dwell_ms INTEGER`, `ALTER TABLE tip_feedback ADD COLUMN dwell_ms INTEGER`,
`ALTER TABLE tip_feedback ADD COLUMN reward_milli INTEGER`, `ALTER TABLE tip_feedback ADD COLUMN reward_milli INTEGER`,
`ALTER TABLE integration_tokens ADD COLUMN token_status TEXT NOT NULL DEFAULT 'active'`,
`ALTER TABLE tip_scores ADD COLUMN prompt_version TEXT`,
`ALTER TABLE tip_scores ADD COLUMN llm_model TEXT`,
`ALTER TABLE tip_scores ADD COLUMN tip_kind TEXT`,
]) { ]) {
try { sqlite.exec(stmt); } catch { /* column already exists */ } try { sqlite.exec(stmt); } catch { /* column already exists */ }
} }

View File

@@ -20,6 +20,7 @@ export const integrationTokens = sqliteTable('integration_tokens', {
accessToken: text('access_token').notNull(), accessToken: text('access_token').notNull(),
refreshToken: text('refresh_token'), refreshToken: text('refresh_token'),
expiresAt: text('expires_at'), expiresAt: text('expires_at'),
tokenStatus: text('token_status').notNull().default('active'), // 'active' | 'needs_reconnect'
connectedAt: text('connected_at').notNull(), connectedAt: text('connected_at').notNull(),
}); });
@@ -81,6 +82,9 @@ export const tipScores = sqliteTable('tip_scores', {
candidateCount: integer('candidate_count'), candidateCount: integer('candidate_count'),
latencyMs: integer('latency_ms'), latencyMs: integer('latency_ms'),
servedAt: text('served_at').notNull(), servedAt: text('served_at').notNull(),
promptVersion: text('prompt_version'), // e.g. 'v1' — tracks which prompt template generated this tip
llmModel: text('llm_model'), // e.g. 'tip-generator/qwen2.5:7b' — null for bandit-only tips
tipKind: text('tip_kind'), // 'task' | 'advice' | 'insight' | 'reminder'
}); });
// ── Simulation runs ────────────────────────────────────────────────────────── // ── Simulation runs ──────────────────────────────────────────────────────────

View File

@@ -0,0 +1,190 @@
/**
* Integration tests for POST /recommend and tip_scores DB writes.
* Uses a real in-memory SQLite DB. recommender is imported dynamically
* inside beforeAll (same pattern as admin.test.ts) to avoid TDZ issues.
* Uses http.request (not fetch) as the test client so that globalThis.fetch
* mocking doesn't interfere with the test runner itself.
*/
import { describe, it, expect, vi, beforeAll, afterEach } from 'vitest';
import express from 'express';
import * as http from 'http';
import { makeTestDb } from '../../test/db.js';
import { users, integrationTokens, tipScores } from '../../db/schema.js';
const testDb = makeTestDb();
vi.mock('../../db/index.js', () => ({ db: testDb }));
vi.mock('../../middleware/session.js', () => ({
sessionMiddleware: (_req: express.Request, _res: express.Response, next: express.NextFunction) => next(),
requireAuth: (req: express.Request, _res: express.Response, next: express.NextFunction) => {
(req as any).userId = 'user-1';
next();
},
}));
vi.mock('../../events/bus.js', () => ({ bus: { publish: vi.fn() } }));
/** Minimal http.request wrapper → { status, body } */
function post(url: string): Promise<{ status: number; body: any }> {
return new Promise((resolve, reject) => {
const u = new URL(url);
const req = http.request(
{ hostname: u.hostname, port: Number(u.port), path: u.pathname, method: 'POST',
headers: { 'Content-Type': 'application/json' } },
(res) => {
let data = '';
res.on('data', (c) => { data += c; });
res.on('end', () => {
try { resolve({ status: res.statusCode ?? 0, body: data ? JSON.parse(data) : null }); }
catch { resolve({ status: res.statusCode ?? 0, body: data }); }
});
},
);
req.on('error', reject);
req.end();
});
}
describe('POST /recommend integration', () => {
let server: http.Server;
let baseUrl: string;
let savedFetch: typeof globalThis.fetch;
let clearCache: () => void;
beforeAll(async () => {
await testDb.insert(users).values({
id: 'user-1', email: 'u@test.com', role: 'user',
consentGiven: 1, createdAt: new Date().toISOString(),
});
await testDb.insert(integrationTokens).values({
id: 'tok-1', userId: 'user-1', provider: 'todoist',
accessToken: 'fake-token', connectedAt: new Date().toISOString(),
});
const mod = await import('../recommender.js');
const { recommenderRouter } = mod;
clearCache = (mod as any)._clearTaskCacheForTests;
const app = express();
app.use(express.json());
app.use('/api', recommenderRouter);
server = app.listen(0);
const addr = server.address() as { port: number };
baseUrl = `http://localhost:${addr.port}`;
savedFetch = globalThis.fetch;
});
afterEach(() => {
globalThis.fetch = savedFetch;
clearCache?.();
});
it('returns 204 when Todoist + LLM both return empty', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true, status: 200,
json: async () => ({ results: [] }),
} as any);
const { status } = await post(`${baseUrl}/api/recommend`);
expect(status).toBe(204);
});
it('serves todoist tip and writes correct tip_scores columns', async () => {
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
if (String(url).includes('todoist.com')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({
results: [{ id: 'task-1', content: 'Write tests', priority: 3, due: { date: '2026-04-10' } }],
}),
} as any);
}
if (String(url).includes('/generate')) {
return Promise.resolve({ ok: false, status: 503, json: async () => ({}) } as any);
}
if (String(url).includes('/score')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({ tip_id: 'todoist:task-1', score: 0.8 }),
} as any);
}
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
});
const { status, body } = await post(`${baseUrl}/api/recommend`);
expect(status).toBe(200);
expect(body.tip.source).toBe('todoist');
expect(body.tip.kind).toBe('task');
const rows = await testDb.select().from(tipScores);
const row = rows[rows.length - 1];
expect(row.tipKind).toBe('task');
expect(row.promptVersion).toBeNull();
expect(row.llmModel).toBeNull();
});
it('writes prompt_version + llm_model when LLM tip is served', async () => {
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
if (String(url).includes('todoist.com')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({ results: [] }),
} as any);
}
if (String(url).includes('/generate')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({
candidates: [{ id: 'adv-1', content: 'Take a break.', rationale: 'You deserve it.' }],
model: 'tip-generator',
}),
} as any);
}
if (String(url).includes('/score')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({ tip_id: 'llm:adv-1', score: 0.9 }),
} as any);
}
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
});
const { status, body } = await post(`${baseUrl}/api/recommend`);
expect(status).toBe(200);
expect(body.tip.source).toBe('llm');
expect(body.tip.kind).toBe('advice');
expect(body.tip.rationale).toBe('You deserve it.');
const rows = await testDb.select().from(tipScores);
const row = rows[rows.length - 1];
expect(row.promptVersion).toBe('v1');
expect(row.llmModel).toBe('tip-generator');
expect(row.tipKind).toBe('advice');
});
it('falls back to todoist tip when /generate returns non-200', async () => {
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
if (String(url).includes('todoist.com')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({
results: [{ id: 'fallback-1', content: 'Do stuff', priority: 2, due: null }],
}),
} as any);
}
if (String(url).includes('/generate')) {
return Promise.resolve({ ok: false, status: 502, json: async () => ({}) } as any);
}
if (String(url).includes('/score')) {
return Promise.resolve({
ok: true, status: 200,
json: async () => ({ tip_id: 'todoist:fallback-1', score: 0.5 }),
} as any);
}
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
});
const { status, body } = await post(`${baseUrl}/api/recommend`);
expect([200, 204]).toContain(status);
if (status === 200) {
expect(body.tip.source).toBe('todoist');
}
});
});

View File

@@ -0,0 +1,39 @@
/**
* Pure-function unit tests for recommender logic — no DB, no HTTP.
* These can import directly from the module without any mocking.
*/
import { describe, it, expect } from 'vitest';
import { inferReward, dueAgeDays } from '../recommender.js';
describe('inferReward', () => {
it('dismiss → -1', () => expect(inferReward('dismiss', null)).toBe(-1.0));
it('snooze → +0.1', () => expect(inferReward('snooze', null)).toBe(0.1));
it('helpful → +0.5', () => expect(inferReward('helpful', null)).toBe(0.5));
it('not_helpful → -0.5', () => expect(inferReward('not_helpful', null)).toBe(-0.5));
it('done with null dwell → +0.5', () => expect(inferReward('done', null)).toBe(0.5));
it('done < 15s (reflex) → -0.3', () => expect(inferReward('done', 5_000)).toBe(-0.3));
it('done 15s2min (magic) → +1.0', () => expect(inferReward('done', 60_000)).toBe(1.0));
it('done 210min (good) → +0.6', () => expect(inferReward('done', 300_000)).toBe(0.6));
it('done > 10min (eventual) → +0.3', () => expect(inferReward('done', 700_000)).toBe(0.3));
it('done exactly 15s (boundary) → magic zone', () => expect(inferReward('done', 15_000)).toBe(1.0));
it('done exactly 2min (boundary) → good zone', () => expect(inferReward('done', 120_000)).toBe(0.6));
});
describe('dueAgeDays', () => {
it('null due → 0', () => expect(dueAgeDays(null)).toBe(0));
it('empty object → 0', () => expect(dueAgeDays({})).toBe(0));
it('future date → 0 (clamped)', () => {
const future = new Date(Date.now() + 86_400_000).toISOString();
expect(dueAgeDays({ datetime: future })).toBe(0);
});
it('past date → positive age', () => {
const twoDaysAgo = new Date(Date.now() - 2 * 86_400_000).toISOString();
const age = dueAgeDays({ datetime: twoDaysAgo });
expect(age).toBeGreaterThan(1.9);
expect(age).toBeLessThan(2.1);
});
it('date-only field used when datetime absent', () => {
const yesterday = new Date(Date.now() - 86_400_000).toISOString().slice(0, 10);
expect(dueAgeDays({ date: yesterday })).toBeGreaterThan(0);
});
});

View File

@@ -24,7 +24,7 @@ router.get('/', requireAuth, async (req: AuthenticatedRequest, res: Response) =>
const integrations = tokens.map((t) => ({ const integrations = tokens.map((t) => ({
provider: t.provider, provider: t.provider,
status: 'connected', status: t.tokenStatus === 'needs_reconnect' ? 'needs_reconnect' : 'connected',
connectedAt: t.connectedAt, connectedAt: t.connectedAt,
})); }));
@@ -97,6 +97,7 @@ router.get('/todoist/callback', async (req: Request, res: Response) => {
userId: pending.userId, userId: pending.userId,
provider: 'todoist', provider: 'todoist',
accessToken: access_token, accessToken: access_token,
tokenStatus: 'active',
connectedAt: now, connectedAt: now,
}); });

View File

@@ -6,23 +6,15 @@ import { eq, and, desc } from 'drizzle-orm';
import { requireAuth, AuthenticatedRequest } from '../middleware/session.js'; import { requireAuth, AuthenticatedRequest } from '../middleware/session.js';
import { config } from '../config.js'; import { config } from '../config.js';
import { bus } from '../events/bus.js'; import { bus } from '../events/bus.js';
import type { Tip } from '@oo/shared-types'; import type { TipCandidate } from '@oo/shared-types';
const router: ExpressRouter = Router(); const router: ExpressRouter = Router();
const CACHE_TTL_MS = 30_000; const CACHE_TTL_MS = 30_000;
const PROMPT_VERSION = 'v1';
interface TaskFeatures { const taskCache = new Map<string, { tasks: TipCandidate[]; fetchedAt: number }>();
is_overdue: boolean; export const _clearTaskCacheForTests = () => taskCache.clear();
task_age_days: number;
priority: number;
}
interface CachedTask extends Tip {
features: TaskFeatures;
}
const taskCache = new Map<string, { tasks: CachedTask[]; fetchedAt: number }>();
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Shadow-policy registry // Shadow-policy registry
@@ -49,7 +41,7 @@ export function setPolicyActive(name: string, active: boolean): boolean {
// Todoist helpers // Todoist helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined): number { export function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined): number {
if (!due) return 0; if (!due) return 0;
const dateStr = due.datetime ?? due.date; const dateStr = due.datetime ?? due.date;
if (!dateStr) return 0; if (!dateStr) return 0;
@@ -57,7 +49,7 @@ function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined
return Math.max(0, (Date.now() - dueMs) / (1000 * 60 * 60 * 24)); return Math.max(0, (Date.now() - dueMs) / (1000 * 60 * 60 * 24));
} }
async function fetchTodoistTasks(userId: string, accessToken: string): Promise<CachedTask[]> { async function fetchTodoistTasks(userId: string, accessToken: string): Promise<TipCandidate[]> {
const cached = taskCache.get(userId); const cached = taskCache.get(userId);
if (cached && Date.now() - cached.fetchedAt < CACHE_TTL_MS) return cached.tasks; if (cached && Date.now() - cached.fetchedAt < CACHE_TTL_MS) return cached.tasks;
@@ -73,6 +65,10 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
provider: 'todoist', provider: 'todoist',
detectedAt: new Date().toISOString(), detectedAt: new Date().toISOString(),
}); });
await db
.update(integrationTokens)
.set({ tokenStatus: 'needs_reconnect' })
.where(and(eq(integrationTokens.userId, userId), eq(integrationTokens.provider, 'todoist')));
} }
return cached?.tasks ?? []; return cached?.tasks ?? [];
} }
@@ -87,13 +83,14 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
}; };
const now = new Date(); const now = new Date();
const tasks: CachedTask[] = (body.results ?? []).map((t) => { const tasks: TipCandidate[] = (body.results ?? []).map((t) => {
const ageDays = dueAgeDays(t.due); const ageDays = dueAgeDays(t.due);
const isOverdue = ageDays > 0; const isOverdue = ageDays > 0;
return { return {
id: `todoist:${t.id}`, id: `todoist:${t.id}`,
content: t.content, content: t.content,
source: 'todoist' as const, source: 'todoist' as const,
kind: 'task' as const,
sourceId: t.id, sourceId: t.id,
createdAt: now.toISOString(), createdAt: now.toISOString(),
features: { features: {
@@ -111,10 +108,14 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
return tasks; return tasks;
} }
// ---------------------------------------------------------------------------
// Stage 2: score candidates via ml/serving bandit
// ---------------------------------------------------------------------------
/** Call ml/serving for scored selection; returns { tip_id, score } or null on failure */ /** Call ml/serving for scored selection; returns { tip_id, score } or null on failure */
async function remotePolicy( async function remotePolicy(
userId: string, userId: string,
tasks: CachedTask[], tasks: TipCandidate[],
): Promise<{ tipId: string; score: number; policy: string } | null> { ): Promise<{ tipId: string; score: number; policy: string } | null> {
const hour = new Date().getHours(); const hour = new Date().getHours();
const dayOfWeek = new Date().getDay(); const dayOfWeek = new Date().getDay();
@@ -147,13 +148,64 @@ async function remotePolicy(
} }
} }
function randomPolicy(candidates: CachedTask[]): CachedTask | null { function randomPolicy(candidates: TipCandidate[]): TipCandidate | null {
if (!candidates.length) return null; if (!candidates.length) return null;
return candidates[Math.floor(Math.random() * candidates.length)]; return candidates[Math.floor(Math.random() * candidates.length)];
} }
// ---------------------------------------------------------------------------
// Stage 1b: fetch LLM candidates from ml/serving /generate
// ---------------------------------------------------------------------------
interface LlmCandidate {
id: string;
content: string;
rationale?: string;
}
async function fetchLlmCandidates(
userId: string,
todoistTasks: TipCandidate[],
hour: number,
dayOfWeek: number,
): Promise<TipCandidate[]> {
try {
const tasks = todoistTasks.slice(0, 10).map((t) => ({
content: t.content,
priority: t.features.priority,
is_overdue: t.features.is_overdue,
task_age_days: t.features.task_age_days,
}));
const res = await fetch(`${config.ML_SERVING_URL}/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
user_id: userId,
context: { tasks, hour_of_day: hour, day_of_week: dayOfWeek },
n: 3,
}),
signal: AbortSignal.timeout(15_000),
});
if (!res.ok) return [];
const data = (await res.json()) as { candidates: LlmCandidate[]; model?: string };
const now = new Date().toISOString();
return data.candidates.map((c) => ({
id: `llm:${c.id}`,
content: c.content,
source: 'llm' as const,
kind: 'advice' as const,
rationale: c.rationale,
createdAt: now,
features: { is_overdue: false, task_age_days: 0, priority: 1 },
}));
} catch {
return [];
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// POST /api/recommend // POST /api/recommend
// Pipeline: [Stage 1] assemble candidates → [Stage 2] score → [Stage 3] serve
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Response) => { router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Response) => {
const [token] = await db const [token] = await db
@@ -167,34 +219,42 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
return; return;
} }
const tasks = await fetchTodoistTasks(req.userId!, token.accessToken); const hour = new Date().getHours();
if (!tasks.length) { const dayOfWeek = new Date().getDay();
// Stage 1: assemble candidates — Todoist tasks + LLM-generated advice (parallel)
const [todoistTasks, llmCandidates] = await Promise.all([
fetchTodoistTasks(req.userId!, token.accessToken),
fetchLlmCandidates(req.userId!, taskCache.get(req.userId!)?.tasks ?? [], hour, dayOfWeek),
]);
const allCandidates: TipCandidate[] = [...todoistTasks, ...llmCandidates];
if (!allCandidates.length) {
res.status(204).end(); res.status(204).end();
return; return;
} }
const hour = new Date().getHours();
const dayOfWeek = new Date().getDay();
const t0 = Date.now(); const t0 = Date.now();
// RemotePolicy with RandomPolicy fallback // Stage 2: score — egreedy bandit with random fallback
const scored = await remotePolicy(req.userId!, tasks); const scored = await remotePolicy(req.userId!, allCandidates);
const latencyMs = Date.now() - t0; const latencyMs = Date.now() - t0;
const tip = scored const tip = scored
? (tasks.find((t) => t.id === scored.tipId) ?? randomPolicy(tasks)) ? (allCandidates.find((t) => t.id === scored.tipId) ?? randomPolicy(allCandidates))
: randomPolicy(tasks); : randomPolicy(allCandidates);
if (!tip) { if (!tip) {
res.status(204).end(); res.status(204).end();
return; return;
} }
// Stage 3: serve + log
const policy = scored ? scored.policy : 'random'; const policy = scored ? scored.policy : 'random';
const isLlmTip = tip.source === 'llm';
const servedAt = new Date().toISOString(); const servedAt = new Date().toISOString();
await db.insert(tipViews).values({ id: nanoid(), userId: req.userId!, tipId: tip.id, servedAt }); await db.insert(tipViews).values({ id: nanoid(), userId: req.userId!, tipId: tip.id, servedAt });
// Log recommendation explainability
await db.insert(tipScores).values({ await db.insert(tipScores).values({
id: nanoid(), id: nanoid(),
userId: req.userId!, userId: req.userId!,
@@ -208,9 +268,12 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
hour_of_day: hour, hour_of_day: hour,
day_of_week: dayOfWeek, day_of_week: dayOfWeek,
}), }),
candidateCount: tasks.length, candidateCount: allCandidates.length,
latencyMs, latencyMs,
servedAt, servedAt,
promptVersion: isLlmTip ? PROMPT_VERSION : null,
llmModel: isLlmTip ? 'tip-generator' : null,
tipKind: tip.kind ?? null,
}); });
bus.publish('signals.tip.served', { bus.publish('signals.tip.served', {
@@ -224,7 +287,7 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
for (const [name, s] of shadowPolicies) { for (const [name, s] of shadowPolicies) {
if (!s.active) continue; if (!s.active) continue;
if (name.startsWith('random')) { if (name.startsWith('random')) {
const shadowTip = randomPolicy(tasks); const shadowTip = randomPolicy(allCandidates);
bus.publish('signals.tip.served', { bus.publish('signals.tip.served', {
userId: req.userId!, userId: req.userId!,
tipId: shadowTip?.id ?? 'none', tipId: shadowTip?.id ?? 'none',
@@ -249,7 +312,7 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
// done 2 10 min → +0.6 (good: user engaged, acted in same session) // done 2 10 min → +0.6 (good: user engaged, acted in same session)
// done > 10 min → +0.3 (eventually done; tip may have helped, unclear) // done > 10 min → +0.3 (eventually done; tip may have helped, unclear)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
function inferReward(action: string, dwellMs: number | null): number { export function inferReward(action: string, dwellMs: number | null): number {
if (action === 'dismiss') return -1.0; if (action === 'dismiss') return -1.0;
if (action === 'snooze') return 0.1; if (action === 'snooze') return 0.1;
if (action === 'helpful') return 0.5; if (action === 'helpful') return 0.5;
@@ -269,7 +332,7 @@ async function sendRewardWithRetry(
userId: string, userId: string,
tipId: string, tipId: string,
reward: number, reward: number,
features: TaskFeatures, features: TipCandidate['features'],
): Promise<void> { ): Promise<void> {
const body = JSON.stringify({ const body = JSON.stringify({
user_id: userId, user_id: userId,
@@ -347,7 +410,7 @@ router.post('/tip/:id/feedback', requireAuth, async (req: AuthenticatedRequest,
createdAt: now.toISOString(), createdAt: now.toISOString(),
}); });
const task = taskCache.get(req.userId!)?.tasks.find((t) => t.id === tipId); const task: TipCandidate | undefined = taskCache.get(req.userId!)?.tasks.find((t) => t.id === tipId);
taskCache.delete(req.userId!); taskCache.delete(req.userId!);

View File

@@ -32,6 +32,7 @@ export function makeTestDb() {
refresh_token TEXT, refresh_token TEXT,
expires_at TEXT, expires_at TEXT,
connected_at TEXT NOT NULL, connected_at TEXT NOT NULL,
token_status TEXT NOT NULL DEFAULT 'active',
UNIQUE(user_id, provider) UNIQUE(user_id, provider)
); );
@@ -88,7 +89,10 @@ export function makeTestDb() {
features_json TEXT, features_json TEXT,
candidate_count INTEGER, candidate_count INTEGER,
latency_ms INTEGER, latency_ms INTEGER,
served_at TEXT NOT NULL served_at TEXT NOT NULL,
prompt_version TEXT,
llm_model TEXT,
tip_kind TEXT
); );
CREATE TABLE IF NOT EXISTS saved_queries ( CREATE TABLE IF NOT EXISTS saved_queries (

View File

@@ -4,6 +4,13 @@ export default defineConfig({
test: { test: {
globals: true, globals: true,
environment: 'node', environment: 'node',
env: {
SESSION_SECRET: 'test-secret',
GOOGLE_CLIENT_ID: 'test-google-id',
GOOGLE_CLIENT_SECRET: 'test-google-secret',
TODOIST_CLIENT_ID: 'test-todoist-id',
TODOIST_CLIENT_SECRET: 'test-todoist-secret',
},
coverage: { coverage: {
provider: 'v8', provider: 'v8',
reporter: ['text', 'lcov'], reporter: ['text', 'lcov'],