feat: M2 AI tips — LiteLLM gateway, context assembler, end-to-end generation pipeline
Issues closed: #86, #87, #88, #89, #90, #91, #79, #80, #82 infra: - docker-compose `ai` profile: Ollama + LiteLLM services - infra/litellm/litellm_config.yaml: tip-generator / embedder / judge aliases - .env.example: LITELLM_URL, LITELLM_MASTER_KEY, OLLAMA_URL ml/serving: - POST /generate: calls LiteLLM tip-generator alias, returns TipCandidate[] - JSON retry loop (2 retries with correction prompt on malformed response) - _parse_llm_json strips markdown fences ml/features: - context.py: build_context() assembles user signals → PromptContext (sorts overdue/high-priority tasks first for LLM prompt quality) shared-types: - TipKind, TipSource, TipCandidate types - Tip gains kind + rationale fields services/api: - recommender: 3-stage pipeline (assemble → score → serve) Stage 1: Todoist tasks + LLM candidates fetched in parallel Stage 2: egreedy bandit scores merged candidate pool Stage 3: serve + log with prompt_version, llm_model, tip_kind - tip_scores: prompt_version, llm_model, tip_kind columns + migrations - config: LITELLM_URL added - integrations: surface token_status in /integrations response tests: - ml/serving/tests/test_generate.py: 13 tests (retry, 502/503, fence variants) - ml/features/test_context.py: 9 tests (sorting, edge cases) - services/api recommender.unit.test.ts: 16 pure-function tests (inferReward, dueAgeDays) - services/api recommender.test.ts: 4 integration tests (tip_scores columns, LLM fallback) - shared-types: TipCandidate, rationale, full TipFeedback action set docs: - ADR-0008: LiteLLM AI gateway decision - overview.md: M2 pipeline description updated - ml/README.md: serving + features roles updated Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,11 @@ API_BASE_URL=http://localhost:3078
|
|||||||
WEB_BASE_URL=http://localhost:3000
|
WEB_BASE_URL=http://localhost:3000
|
||||||
ML_SERVING_URL=http://localhost:8000
|
ML_SERVING_URL=http://localhost:8000
|
||||||
|
|
||||||
|
# AI stack — Ollama + LiteLLM (docker compose --profile ai)
|
||||||
|
LITELLM_URL=http://localhost:4000
|
||||||
|
LITELLM_MASTER_KEY=sk-oo-dev
|
||||||
|
OLLAMA_URL=http://localhost:11434
|
||||||
|
|
||||||
# Google OAuth — https://console.cloud.google.com/
|
# Google OAuth — https://console.cloud.google.com/
|
||||||
GOOGLE_CLIENT_ID=
|
GOOGLE_CLIENT_ID=
|
||||||
GOOGLE_CLIENT_SECRET=
|
GOOGLE_CLIENT_SECRET=
|
||||||
|
|||||||
41
docs/adr/0008-litellm-ai-gateway.md
Normal file
41
docs/adr/0008-litellm-ai-gateway.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# ADR-0008 — LiteLLM as AI gateway; model aliases decouple code from model names
|
||||||
|
|
||||||
|
**Status:** Accepted
|
||||||
|
**Date:** 2026-04-17
|
||||||
|
**Milestone:** M2
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
M2 requires LLM inference for tip generation (`ml/serving POST /generate`). We need a way to:
|
||||||
|
- Run locally during development without cloud API keys.
|
||||||
|
- Switch models (qwen2.5 → llama3.2, or cloud fallback) without touching application code.
|
||||||
|
- Share the LLM infrastructure with other local services on Agap.
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
|
||||||
|
Route all LLM calls through **LiteLLM** (`http://localhost:4000` in dev, `llm.alogins.net` in prod) backed by **Ollama** for local inference.
|
||||||
|
|
||||||
|
Application code references model aliases — never bare model names:
|
||||||
|
|
||||||
|
| Alias | Default model | Used by |
|
||||||
|
|-------|--------------|---------|
|
||||||
|
| `tip-generator` | `qwen2.5:7b` | `ml/serving POST /generate` |
|
||||||
|
| `embedder` | `nomic-embed-text` | task clustering, dedup (M4) |
|
||||||
|
| `judge` | `claude-haiku-4-5` | offline simulation only |
|
||||||
|
|
||||||
|
Config is in `infra/litellm/litellm_config.yaml`. Swapping a model = one YAML change, zero code change.
|
||||||
|
|
||||||
|
`ml/serving` reads `LITELLM_URL` and `LITELLM_MASTER_KEY` from env. TypeScript services never call LLM endpoints directly — all inference flows through `ml/serving`.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
|
||||||
|
- **Local dev:** `docker compose --profile ai up` starts Ollama + LiteLLM. First run pulls models (~4 GB for qwen2.5:7b).
|
||||||
|
- **Prod:** both are shared Agap services; set `LITELLM_URL=http://llm.alogins.net` in `.env.local`.
|
||||||
|
- **Offline sim:** `judge` alias points at `claude-haiku-4-5` (cloud) — requires `ANTHROPIC_API_KEY`; simulation is opt-in.
|
||||||
|
- **Vendor lock-in:** none at the code level. LiteLLM translates the OpenAI-compatible API to whatever backend.
|
||||||
|
- **Observability:** LiteLLM logs all requests; `tip_scores.llm_model` + `tip_scores.prompt_version` track which model + prompt generated each served tip.
|
||||||
|
|
||||||
|
## Alternatives considered
|
||||||
|
|
||||||
|
- **Call Ollama directly:** cheaper in latency, but ties code to Ollama's API format and makes cloud fallback a code change.
|
||||||
|
- **Call Anthropic directly from TS:** violates the rule that TS services never hold model names (CLAUDE.md prime directive 3).
|
||||||
@@ -82,6 +82,8 @@ client ─► gateway ─► recommender (TS)
|
|||||||
◄─ best TipCandidate
|
◄─ best TipCandidate
|
||||||
```
|
```
|
||||||
|
|
||||||
**Phase 1 (current):** candidates come from Todoist task list, no LLM. The bandit scores tasks directly.
|
**Phase 1 (shipped M1):** candidates come from Todoist task list, no LLM. The bandit scores tasks directly.
|
||||||
|
|
||||||
|
**Phase 2 (shipped M2):** LLM candidates are generated in parallel with Todoist fetch. Both pools are merged, scored by the bandit, and the winner served. `tip_scores` tracks `prompt_version`, `llm_model`, and `tip_kind` for every row.
|
||||||
|
|
||||||
Feedback: `POST /feedback → events.emit(reaction)` → online bandit update + `prompt_version` tracked for A/B analysis.
|
Feedback: `POST /feedback → events.emit(reaction)` → online bandit update + `prompt_version` tracked for A/B analysis.
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- /mnt/ssd/dbs/oo:/mnt/ssd/dbs/oo
|
- /mnt/ssd/dbs/oo:/mnt/ssd/dbs/oo
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:3078:3078"
|
- "127.0.0.1:3001:3001"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "wget", "--spider", "-q", "http://localhost:3078/health"]
|
test: ["CMD", "wget", "--spider", "-q", "http://localhost:3001/health"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
@@ -49,7 +49,7 @@ services:
|
|||||||
PORT: "3080"
|
PORT: "3080"
|
||||||
HOSTNAME: "0.0.0.0"
|
HOSTNAME: "0.0.0.0"
|
||||||
NEXT_PUBLIC_API_URL: ""
|
NEXT_PUBLIC_API_URL: ""
|
||||||
INTERNAL_API_URL: "http://api:3078"
|
INTERNAL_API_URL: "http://api:3001"
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:3080:3080"
|
- "127.0.0.1:3080:3080"
|
||||||
depends_on:
|
depends_on:
|
||||||
@@ -63,6 +63,10 @@ services:
|
|||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: infra/docker/Dockerfile.ml
|
dockerfile: infra/docker/Dockerfile.ml
|
||||||
profiles: [full]
|
profiles: [full]
|
||||||
|
env_file: ../../.env.local
|
||||||
|
environment:
|
||||||
|
LITELLM_URL: ${LITELLM_URL:-http://litellm:4000}
|
||||||
|
OLLAMA_URL: ${OLLAMA_URL:-http://ollama:11434}
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:8000:8000"
|
- "127.0.0.1:8000:8000"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -155,6 +159,45 @@ services:
|
|||||||
airflow-init:
|
airflow-init:
|
||||||
condition: service_completed_successfully
|
condition: service_completed_successfully
|
||||||
|
|
||||||
|
# ── ai profile — Ollama + LiteLLM ────────────────────────────────────────
|
||||||
|
# Start: docker compose --profile ai up
|
||||||
|
# LiteLLM proxy: http://localhost:4000 (master key from LITELLM_MASTER_KEY)
|
||||||
|
# Ollama API: http://localhost:11434
|
||||||
|
# In prod both are shared Agap services; set LITELLM_URL + OLLAMA_URL in .env.local
|
||||||
|
|
||||||
|
ollama:
|
||||||
|
image: ollama/ollama:latest
|
||||||
|
profiles: [ai]
|
||||||
|
volumes:
|
||||||
|
- /mnt/ssd/dbs/oo/ollama:/root/.ollama
|
||||||
|
ports:
|
||||||
|
- "127.0.0.1:11434:11434"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "--fail", "http://localhost:11434"]
|
||||||
|
interval: 15s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
litellm:
|
||||||
|
image: ghcr.io/berriai/litellm:main-latest
|
||||||
|
profiles: [ai]
|
||||||
|
command: ["--config", "/app/litellm_config.yaml", "--port", "4000"]
|
||||||
|
environment:
|
||||||
|
LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY:-sk-oo-dev}
|
||||||
|
OLLAMA_URL: ${OLLAMA_URL:-http://ollama:11434}
|
||||||
|
volumes:
|
||||||
|
- ../../infra/litellm/litellm_config.yaml:/app/litellm_config.yaml:ro
|
||||||
|
ports:
|
||||||
|
- "127.0.0.1:4000:4000"
|
||||||
|
depends_on:
|
||||||
|
ollama:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "--fail", "http://localhost:4000/health"]
|
||||||
|
interval: 15s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
mlflow:
|
mlflow:
|
||||||
image: ghcr.io/mlflow/mlflow:2.14.3
|
image: ghcr.io/mlflow/mlflow:2.14.3
|
||||||
profiles: [mlops]
|
profiles: [mlops]
|
||||||
|
|||||||
17
infra/litellm/litellm_config.yaml
Normal file
17
infra/litellm/litellm_config.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
model_list:
|
||||||
|
- model_name: tip-generator
|
||||||
|
litellm_params:
|
||||||
|
model: ollama/qwen2.5:7b
|
||||||
|
api_base: "${OLLAMA_URL}"
|
||||||
|
|
||||||
|
- model_name: embedder
|
||||||
|
litellm_params:
|
||||||
|
model: ollama/nomic-embed-text
|
||||||
|
api_base: "${OLLAMA_URL}"
|
||||||
|
|
||||||
|
- model_name: judge
|
||||||
|
litellm_params:
|
||||||
|
model: claude-haiku-4-5-20251001
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: "${LITELLM_MASTER_KEY}"
|
||||||
@@ -4,8 +4,8 @@ Python. Owns models, features, training, online scoring.
|
|||||||
|
|
||||||
| Dir | Role | Phase |
|
| Dir | Role | Phase |
|
||||||
|---|---|---|
|
|---|---|---|
|
||||||
| `serving/` | FastAPI online scorer (`/score`), called by `recommender` | 1 |
|
| `serving/` | FastAPI online scorer (`/score`, `/generate`) + LiteLLM gateway, called by `recommender` | 1–2 |
|
||||||
| `features/` | feature definitions + store adapter (Feast later) | 1 |
|
| `features/` | context assembler (`context.py`): signals → `PromptContext`; Feast adapter later | 2 |
|
||||||
| `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 |
|
| `pipelines/` | batch feature + training DAGs (Prefect/Airflow) | 4 |
|
||||||
| `registry/` | MLflow-backed model registry integration | 4 |
|
| `registry/` | MLflow-backed model registry integration | 4 |
|
||||||
| `experiments/` | A/B assignment + multi-armed bandit policies | 4 |
|
| `experiments/` | A/B assignment + multi-armed bandit policies | 4 |
|
||||||
|
|||||||
3
ml/features/__init__.py
Normal file
3
ml/features/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .context import build_context, PromptContext, TaskSignal
|
||||||
|
|
||||||
|
__all__ = ["build_context", "PromptContext", "TaskSignal"]
|
||||||
63
ml/features/context.py
Normal file
63
ml/features/context.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Context assembler — converts raw user signals into a PromptContext for LLM tip generation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from ml.features.context import build_context
|
||||||
|
ctx = build_context(tasks, hour_of_day=9, day_of_week=2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskSignal:
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
priority: int = 1 # 1–4 (Todoist scale)
|
||||||
|
is_overdue: bool = False
|
||||||
|
task_age_days: float = 0.0
|
||||||
|
due_date: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptContext:
|
||||||
|
tasks: list[dict] = field(default_factory=list)
|
||||||
|
hour_of_day: int = 12
|
||||||
|
day_of_week: int = 0
|
||||||
|
extra: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def build_context(
|
||||||
|
tasks: list[TaskSignal],
|
||||||
|
hour_of_day: int = 12,
|
||||||
|
day_of_week: int = 0,
|
||||||
|
extra: dict | None = None,
|
||||||
|
) -> PromptContext:
|
||||||
|
"""
|
||||||
|
Assemble user signals into a PromptContext.
|
||||||
|
|
||||||
|
Signals are sorted so overdue + high-priority tasks appear first,
|
||||||
|
giving the LLM the most actionable context at the top of the prompt.
|
||||||
|
"""
|
||||||
|
sorted_tasks = sorted(
|
||||||
|
tasks,
|
||||||
|
key=lambda t: (not t.is_overdue, -t.priority, -t.task_age_days),
|
||||||
|
)
|
||||||
|
task_dicts = [
|
||||||
|
{
|
||||||
|
"id": t.id,
|
||||||
|
"content": t.content,
|
||||||
|
"priority": t.priority,
|
||||||
|
"is_overdue": t.is_overdue,
|
||||||
|
"task_age_days": round(t.task_age_days, 1),
|
||||||
|
"due_date": t.due_date,
|
||||||
|
}
|
||||||
|
for t in sorted_tasks
|
||||||
|
]
|
||||||
|
return PromptContext(
|
||||||
|
tasks=task_dicts,
|
||||||
|
hour_of_day=hour_of_day,
|
||||||
|
day_of_week=day_of_week,
|
||||||
|
extra=extra or {},
|
||||||
|
)
|
||||||
64
ml/features/test_context.py
Normal file
64
ml/features/test_context.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""Tests for ml/features/context.py"""
|
||||||
|
import pytest
|
||||||
|
import sys, os; sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
from context import build_context, TaskSignal, PromptContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_tasks():
|
||||||
|
ctx = build_context([], hour_of_day=9, day_of_week=1)
|
||||||
|
assert ctx.tasks == []
|
||||||
|
assert ctx.hour_of_day == 9
|
||||||
|
assert ctx.day_of_week == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_overdue_tasks_sorted_first():
|
||||||
|
tasks = [
|
||||||
|
TaskSignal(id="a", content="Normal task", priority=1, is_overdue=False),
|
||||||
|
TaskSignal(id="b", content="Overdue task", priority=2, is_overdue=True, task_age_days=3.0),
|
||||||
|
]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["id"] == "b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_high_priority_within_non_overdue():
|
||||||
|
tasks = [
|
||||||
|
TaskSignal(id="lo", content="Low prio", priority=1, is_overdue=False),
|
||||||
|
TaskSignal(id="hi", content="High prio", priority=4, is_overdue=False),
|
||||||
|
]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["id"] == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_fields_passed_through():
|
||||||
|
ctx = build_context([], extra={"mood": "focused"})
|
||||||
|
assert ctx.extra["mood"] == "focused"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_age_rounded():
|
||||||
|
tasks = [TaskSignal(id="x", content="Task", task_age_days=1.23456)]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["task_age_days"] == 1.2
|
||||||
|
|
||||||
|
|
||||||
|
def test_overdue_sorted_by_priority():
|
||||||
|
tasks = [
|
||||||
|
TaskSignal(id="lo", content="Low", priority=1, is_overdue=True),
|
||||||
|
TaskSignal(id="hi", content="High", priority=4, is_overdue=True),
|
||||||
|
]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["id"] == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_overdue_same_priority_sorted_by_age():
|
||||||
|
tasks = [
|
||||||
|
TaskSignal(id="new", content="New", priority=2, is_overdue=True, task_age_days=1.0),
|
||||||
|
TaskSignal(id="old", content="Old", priority=2, is_overdue=True, task_age_days=5.0),
|
||||||
|
]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["id"] == "old"
|
||||||
|
|
||||||
|
|
||||||
|
def test_due_date_none_preserved():
|
||||||
|
tasks = [TaskSignal(id="x", content="No due", due_date=None)]
|
||||||
|
ctx = build_context(tasks)
|
||||||
|
assert ctx.tasks[0]["due_date"] is None
|
||||||
@@ -26,12 +26,16 @@ from collections import deque
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Deque
|
from typing import Optional, Deque
|
||||||
|
|
||||||
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
app = FastAPI(title="oO ML Serving", version="1.0.0")
|
app = FastAPI(title="oO ML Serving", version="1.0.0")
|
||||||
|
|
||||||
|
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
|
||||||
|
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
||||||
|
|
||||||
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-bandit-state"))
|
STATE_DIR = Path(os.getenv("STATE_DIR", "/tmp/oo-bandit-state"))
|
||||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -166,6 +170,56 @@ class RewardResponse(BaseModel):
|
|||||||
ok: bool
|
ok: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PromptContext(BaseModel):
|
||||||
|
tasks: list[dict] = []
|
||||||
|
hour_of_day: int = 12
|
||||||
|
day_of_week: int = 0
|
||||||
|
extra: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
context: PromptContext = PromptContext()
|
||||||
|
n: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class TipCandidate(BaseModel):
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
source: str = "llm"
|
||||||
|
rationale: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponse(BaseModel):
|
||||||
|
candidates: list[TipCandidate]
|
||||||
|
model: str
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
_GENERATE_SYSTEM = (
|
||||||
|
"You are a personal productivity coach. "
|
||||||
|
"Given the user's current context, generate actionable, specific tips. "
|
||||||
|
"Respond ONLY with a JSON array of objects, each with keys: "
|
||||||
|
'"id" (short slug), "content" (the tip, ≤2 sentences), "rationale" (why now, ≤1 sentence). '
|
||||||
|
"No markdown, no prose outside the JSON array."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt(ctx: PromptContext, n: int) -> str:
|
||||||
|
lines = [f"Time: {ctx.hour_of_day:02d}:00, day_of_week={ctx.day_of_week}"]
|
||||||
|
if ctx.tasks:
|
||||||
|
overdue = [t for t in ctx.tasks if t.get("is_overdue")]
|
||||||
|
lines.append(f"Tasks: {len(ctx.tasks)} total, {len(overdue)} overdue")
|
||||||
|
for t in ctx.tasks[:5]:
|
||||||
|
due = t.get("due_date", "no due date")
|
||||||
|
lines.append(f" - [{t.get('priority','?')}] {t.get('content','?')} (due: {due})")
|
||||||
|
for k, v in ctx.extra.items():
|
||||||
|
lines.append(f"{k}: {v}")
|
||||||
|
lines.append(f"\nGenerate {n} tips as a JSON array.")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoints ──────────────────────────────────────────────────────────────
|
# ── Endpoints ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@@ -173,6 +227,97 @@ def health():
|
|||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
_RETRY_SUFFIX = (
|
||||||
|
"\n\nYour previous response was not valid JSON. "
|
||||||
|
"Reply ONLY with the JSON array — no prose, no markdown fences."
|
||||||
|
)
|
||||||
|
|
||||||
|
_MAX_GENERATE_RETRIES = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_llm_json(raw: str) -> list[dict]:
|
||||||
|
"""Strip markdown fences and parse JSON array. Raises ValueError on failure."""
|
||||||
|
text = raw.strip()
|
||||||
|
if text.startswith("```"):
|
||||||
|
parts = text.split("```")
|
||||||
|
text = parts[1] if len(parts) > 1 else text
|
||||||
|
if text.startswith("json"):
|
||||||
|
text = text[4:]
|
||||||
|
return json.loads(text)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate", response_model=GenerateResponse)
|
||||||
|
async def generate(req: GenerateRequest) -> GenerateResponse:
|
||||||
|
"""Generate tip candidates via LiteLLM → tip-generator alias.
|
||||||
|
|
||||||
|
Retries up to _MAX_GENERATE_RETRIES times on malformed JSON, appending
|
||||||
|
a correction hint to the conversation so the model can self-correct.
|
||||||
|
"""
|
||||||
|
prompt = _build_prompt(req.context, req.n)
|
||||||
|
messages: list[dict] = [
|
||||||
|
{"role": "system", "content": _GENERATE_SYSTEM},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
headers = {"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
|
||||||
|
last_parse_error: str = ""
|
||||||
|
last_raw: str = ""
|
||||||
|
total_usage: dict = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
model_used = "tip-generator"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
for attempt in range(1 + _MAX_GENERATE_RETRIES):
|
||||||
|
payload = {"model": "tip-generator", "messages": messages, "temperature": 0.7}
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LITELLM_URL}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"LiteLLM error: {e.response.text}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"LiteLLM unreachable: {e}")
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||||
|
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||||
|
model_used = data.get("model", "tip-generator")
|
||||||
|
|
||||||
|
last_raw = data["choices"][0]["message"]["content"]
|
||||||
|
try:
|
||||||
|
items = _parse_llm_json(last_raw)
|
||||||
|
break
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
last_parse_error = str(e)
|
||||||
|
# Feed the bad reply back so the model can self-correct
|
||||||
|
messages.append({"role": "assistant", "content": last_raw})
|
||||||
|
messages.append({"role": "user", "content": _RETRY_SUFFIX})
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"LLM returned invalid JSON after {_MAX_GENERATE_RETRIES} retries: "
|
||||||
|
f"{last_parse_error}\n{last_raw[:200]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = [
|
||||||
|
TipCandidate(
|
||||||
|
id=item.get("id", f"tip-{i}"),
|
||||||
|
content=item.get("content", ""),
|
||||||
|
rationale=item.get("rationale"),
|
||||||
|
)
|
||||||
|
for i, item in enumerate(items)
|
||||||
|
]
|
||||||
|
|
||||||
|
return GenerateResponse(
|
||||||
|
candidates=candidates,
|
||||||
|
model=model_used,
|
||||||
|
prompt_tokens=total_usage["prompt_tokens"],
|
||||||
|
completion_tokens=total_usage["completion_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/score", response_model=ScoreResponse)
|
@app.post("/score", response_model=ScoreResponse)
|
||||||
def score(req: ScoreRequest) -> ScoreResponse:
|
def score(req: ScoreRequest) -> ScoreResponse:
|
||||||
if not req.candidates:
|
if not req.candidates:
|
||||||
|
|||||||
225
ml/serving/tests/test_generate.py
Normal file
225
ml/serving/tests/test_generate.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
"""
|
||||||
|
Tests for POST /generate — LiteLLM gateway.
|
||||||
|
LiteLLM is mocked; no real network calls.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from httpx import AsyncClient, ASGITransport, Response
|
||||||
|
|
||||||
|
from main import app, _build_prompt, PromptContext
|
||||||
|
|
||||||
|
|
||||||
|
def _litellm_response(candidates: list[dict]) -> Response:
|
||||||
|
import httpx
|
||||||
|
body = {
|
||||||
|
"model": "tip-generator",
|
||||||
|
"choices": [{"message": {"content": json.dumps(candidates)}}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||||
|
}
|
||||||
|
req = httpx.Request("POST", "http://litellm/chat/completions")
|
||||||
|
return Response(200, json=body, request=req)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_returns_candidates():
|
||||||
|
fake_items = [
|
||||||
|
{"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."},
|
||||||
|
{"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."},
|
||||||
|
]
|
||||||
|
mock_resp = _litellm_response(fake_items)
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(return_value=mock_resp)
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1", "n": 2})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data["candidates"]) == 2
|
||||||
|
assert data["candidates"][0]["id"] == "tip-1"
|
||||||
|
assert data["model"] == "tip-generator"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_strips_markdown_fence():
|
||||||
|
fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}]
|
||||||
|
fenced = "```json\n" + json.dumps(fake_items) + "\n```"
|
||||||
|
body = {
|
||||||
|
"model": "tip-generator",
|
||||||
|
"choices": [{"message": {"content": fenced}}],
|
||||||
|
"usage": {},
|
||||||
|
}
|
||||||
|
req = httpx.Request("POST", "http://litellm/chat/completions")
|
||||||
|
mock_resp = Response(200, json=body, request=req)
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(return_value=mock_resp)
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["candidates"][0]["id"] == "tip-a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_503_on_unreachable():
|
||||||
|
import httpx as _httpx
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused"))
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1"})
|
||||||
|
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_prompt_includes_tasks():
|
||||||
|
ctx = PromptContext(
|
||||||
|
tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}],
|
||||||
|
hour_of_day=9,
|
||||||
|
day_of_week=2,
|
||||||
|
)
|
||||||
|
prompt = _build_prompt(ctx, n=3)
|
||||||
|
assert "Write report" in prompt
|
||||||
|
assert "09:00" in prompt
|
||||||
|
assert "Generate 3 tips" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_prompt_truncates_at_five():
|
||||||
|
tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)]
|
||||||
|
ctx = PromptContext(tasks=tasks, hour_of_day=12)
|
||||||
|
prompt = _build_prompt(ctx, n=2)
|
||||||
|
assert "Task 4" in prompt
|
||||||
|
assert "Task 5" not in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_prompt_extra_fields():
|
||||||
|
ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"})
|
||||||
|
prompt = _build_prompt(ctx, n=1)
|
||||||
|
assert "mood: focused" in prompt
|
||||||
|
assert "energy: high" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_prompt_empty_tasks_no_task_line():
|
||||||
|
ctx = PromptContext(tasks=[], hour_of_day=10)
|
||||||
|
prompt = _build_prompt(ctx, n=2)
|
||||||
|
assert "Tasks:" not in prompt
|
||||||
|
assert "Generate 2 tips" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_retry_succeeds_on_second_attempt():
|
||||||
|
"""First response is invalid JSON; second is valid. Should return 200."""
|
||||||
|
valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}]
|
||||||
|
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
|
||||||
|
bad_resp = Response(200, json={
|
||||||
|
"model": "tip-generator",
|
||||||
|
"choices": [{"message": {"content": "this is not json"}}],
|
||||||
|
"usage": {},
|
||||||
|
}, request=bad_req)
|
||||||
|
good_resp = _litellm_response(valid_items)
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(side_effect=[bad_resp, good_resp])
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1", "n": 1})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["candidates"][0]["id"] == "tip-ok"
|
||||||
|
assert instance.post.call_count == 2
|
||||||
|
# Retry message should include the correction suffix
|
||||||
|
second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"]
|
||||||
|
assert any("not valid JSON" in m["content"] for m in second_call_messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_502_after_all_retries_exhausted():
|
||||||
|
"""All attempts return invalid JSON → 502."""
|
||||||
|
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
|
||||||
|
|
||||||
|
def _bad_resp():
|
||||||
|
return Response(200, json={
|
||||||
|
"model": "tip-generator",
|
||||||
|
"choices": [{"message": {"content": "not json at all"}}],
|
||||||
|
"usage": {},
|
||||||
|
}, request=bad_req)
|
||||||
|
|
||||||
|
from main import _MAX_GENERATE_RETRIES
|
||||||
|
responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)]
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(side_effect=responses)
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1"})
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert "retries" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_generate_502_on_upstream_http_error():
|
||||||
|
"""LiteLLM returns 500 → HTTPStatusError → 502."""
|
||||||
|
err_req = httpx.Request("POST", "http://litellm/chat/completions")
|
||||||
|
err_resp = Response(500, text="internal error", request=err_req)
|
||||||
|
|
||||||
|
with patch("main.httpx.AsyncClient") as MockClient:
|
||||||
|
instance = AsyncMock()
|
||||||
|
instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
|
||||||
|
"500", request=err_req, response=err_resp
|
||||||
|
))
|
||||||
|
instance.__aenter__ = AsyncMock(return_value=instance)
|
||||||
|
instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = instance
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
resp = await client.post("/generate", json={"user_id": "u1"})
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert "LiteLLM error" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_json_bare_fence():
|
||||||
|
from main import _parse_llm_json
|
||||||
|
raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```"
|
||||||
|
items = _parse_llm_json(raw)
|
||||||
|
assert items[0]["id"] == "x"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_json_no_fence():
|
||||||
|
from main import _parse_llm_json
|
||||||
|
raw = '[{"id":"plain","content":"no fence"}]'
|
||||||
|
items = _parse_llm_json(raw)
|
||||||
|
assert items[0]["id"] == "plain"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_json_raises_on_invalid():
|
||||||
|
from main import _parse_llm_json
|
||||||
|
with pytest.raises((ValueError, Exception)):
|
||||||
|
_parse_llm_json("this is not json")
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import { describe, it, expect } from 'vitest';
|
import { describe, it, expect } from 'vitest';
|
||||||
import type { Tip, TipFeedback, RecommendResponse } from '../index.js';
|
import type { Tip, TipFeedback, TipCandidate, RecommendResponse } from '../index.js';
|
||||||
|
|
||||||
describe('Tip type contract', () => {
|
describe('Tip type contract', () => {
|
||||||
it('accepts a valid Tip object', () => {
|
it('accepts a valid Tip object', () => {
|
||||||
@@ -7,6 +7,7 @@ describe('Tip type contract', () => {
|
|||||||
id: 'todoist:123',
|
id: 'todoist:123',
|
||||||
content: 'Finish the report',
|
content: 'Finish the report',
|
||||||
source: 'todoist',
|
source: 'todoist',
|
||||||
|
kind: 'task',
|
||||||
sourceId: '123',
|
sourceId: '123',
|
||||||
createdAt: new Date().toISOString(),
|
createdAt: new Date().toISOString(),
|
||||||
};
|
};
|
||||||
@@ -18,6 +19,7 @@ describe('Tip type contract', () => {
|
|||||||
id: 'advice:abc',
|
id: 'advice:abc',
|
||||||
content: 'Take a break',
|
content: 'Take a break',
|
||||||
source: 'advice',
|
source: 'advice',
|
||||||
|
kind: 'advice',
|
||||||
createdAt: new Date().toISOString(),
|
createdAt: new Date().toISOString(),
|
||||||
};
|
};
|
||||||
expect(tip.sourceId).toBeUndefined();
|
expect(tip.sourceId).toBeUndefined();
|
||||||
@@ -25,16 +27,45 @@ describe('Tip type contract', () => {
|
|||||||
|
|
||||||
it('RecommendResponse wraps a Tip', () => {
|
it('RecommendResponse wraps a Tip', () => {
|
||||||
const res: RecommendResponse = {
|
const res: RecommendResponse = {
|
||||||
tip: { id: 'x', content: 'Do it', source: 'todoist', createdAt: '' },
|
tip: { id: 'x', content: 'Do it', source: 'todoist', kind: 'task', createdAt: '' },
|
||||||
};
|
};
|
||||||
expect(res.tip.id).toBe('x');
|
expect(res.tip.id).toBe('x');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('TipFeedback allows valid actions', () => {
|
it('TipFeedback allows all valid actions including helpful/not_helpful', () => {
|
||||||
const actions: TipFeedback['action'][] = ['done', 'dismiss', 'snooze'];
|
const actions: TipFeedback['action'][] = ['done', 'dismiss', 'snooze', 'helpful', 'not_helpful'];
|
||||||
for (const action of actions) {
|
for (const action of actions) {
|
||||||
const fb: TipFeedback = { action };
|
const fb: TipFeedback = { action };
|
||||||
expect(fb.action).toBe(action);
|
expect(fb.action).toBe(action);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('Tip accepts optional rationale', () => {
|
||||||
|
const tip: Tip = {
|
||||||
|
id: 'llm:tip-1',
|
||||||
|
content: 'Block 30 min for deep work.',
|
||||||
|
source: 'llm',
|
||||||
|
kind: 'advice',
|
||||||
|
rationale: 'Your calendar is clear until noon.',
|
||||||
|
createdAt: new Date().toISOString(),
|
||||||
|
};
|
||||||
|
expect(tip.rationale).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('Tip rationale is optional', () => {
|
||||||
|
const tip: Tip = { id: 'x', content: 'Do it', source: 'todoist', kind: 'task', createdAt: '' };
|
||||||
|
expect(tip.rationale).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('TipCandidate includes features', () => {
|
||||||
|
const c: TipCandidate = {
|
||||||
|
id: 'todoist:1',
|
||||||
|
content: 'Finish report',
|
||||||
|
source: 'todoist',
|
||||||
|
kind: 'task',
|
||||||
|
createdAt: '',
|
||||||
|
features: { is_overdue: true, task_age_days: 2, priority: 4 },
|
||||||
|
};
|
||||||
|
expect(c.features.is_overdue).toBe(true);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,10 +1,30 @@
|
|||||||
|
/** Category of a tip — drives icon, CTA copy, and reward inference */
|
||||||
|
export type TipKind = 'task' | 'advice' | 'insight' | 'reminder';
|
||||||
|
|
||||||
|
/** Where the tip content originated */
|
||||||
|
export type TipSource = 'todoist' | 'llm' | 'advice';
|
||||||
|
|
||||||
/** A single recommendation surfaced to the user */
|
/** A single recommendation surfaced to the user */
|
||||||
export interface Tip {
|
export interface Tip {
|
||||||
id: string;
|
id: string;
|
||||||
content: string;
|
content: string;
|
||||||
source: 'todoist' | 'advice';
|
source: TipSource;
|
||||||
|
kind: TipKind;
|
||||||
sourceId?: string;
|
sourceId?: string;
|
||||||
createdAt: string; // ISO 8601
|
rationale?: string; // LLM-generated "why now" shown on long-press
|
||||||
|
createdAt: string; // ISO 8601
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scored tip candidate flowing through the bandit pipeline.
|
||||||
|
* Extends Tip with features needed for scoring.
|
||||||
|
*/
|
||||||
|
export interface TipCandidate extends Tip {
|
||||||
|
features: {
|
||||||
|
is_overdue: boolean;
|
||||||
|
task_age_days: number;
|
||||||
|
priority: number;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/** POST /recommend response */
|
/** POST /recommend response */
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export const config = {
|
|||||||
WEB_BASE_URL: optional('WEB_BASE_URL', 'http://localhost:3000'),
|
WEB_BASE_URL: optional('WEB_BASE_URL', 'http://localhost:3000'),
|
||||||
|
|
||||||
ML_SERVING_URL: optional('ML_SERVING_URL', 'http://localhost:8000'),
|
ML_SERVING_URL: optional('ML_SERVING_URL', 'http://localhost:8000'),
|
||||||
|
LITELLM_URL: optional('LITELLM_URL', 'http://localhost:4000'),
|
||||||
|
|
||||||
VAPID_PUBLIC_KEY: optional('VAPID_PUBLIC_KEY', ''),
|
VAPID_PUBLIC_KEY: optional('VAPID_PUBLIC_KEY', ''),
|
||||||
VAPID_PRIVATE_KEY: optional('VAPID_PRIVATE_KEY', ''),
|
VAPID_PRIVATE_KEY: optional('VAPID_PRIVATE_KEY', ''),
|
||||||
|
|||||||
@@ -142,6 +142,10 @@ export function runMigrations() {
|
|||||||
`ALTER TABLE push_subscriptions ADD COLUMN created_at TEXT NOT NULL DEFAULT ''`,
|
`ALTER TABLE push_subscriptions ADD COLUMN created_at TEXT NOT NULL DEFAULT ''`,
|
||||||
`ALTER TABLE tip_feedback ADD COLUMN dwell_ms INTEGER`,
|
`ALTER TABLE tip_feedback ADD COLUMN dwell_ms INTEGER`,
|
||||||
`ALTER TABLE tip_feedback ADD COLUMN reward_milli INTEGER`,
|
`ALTER TABLE tip_feedback ADD COLUMN reward_milli INTEGER`,
|
||||||
|
`ALTER TABLE integration_tokens ADD COLUMN token_status TEXT NOT NULL DEFAULT 'active'`,
|
||||||
|
`ALTER TABLE tip_scores ADD COLUMN prompt_version TEXT`,
|
||||||
|
`ALTER TABLE tip_scores ADD COLUMN llm_model TEXT`,
|
||||||
|
`ALTER TABLE tip_scores ADD COLUMN tip_kind TEXT`,
|
||||||
]) {
|
]) {
|
||||||
try { sqlite.exec(stmt); } catch { /* column already exists */ }
|
try { sqlite.exec(stmt); } catch { /* column already exists */ }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ export const integrationTokens = sqliteTable('integration_tokens', {
|
|||||||
accessToken: text('access_token').notNull(),
|
accessToken: text('access_token').notNull(),
|
||||||
refreshToken: text('refresh_token'),
|
refreshToken: text('refresh_token'),
|
||||||
expiresAt: text('expires_at'),
|
expiresAt: text('expires_at'),
|
||||||
|
tokenStatus: text('token_status').notNull().default('active'), // 'active' | 'needs_reconnect'
|
||||||
connectedAt: text('connected_at').notNull(),
|
connectedAt: text('connected_at').notNull(),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -81,6 +82,9 @@ export const tipScores = sqliteTable('tip_scores', {
|
|||||||
candidateCount: integer('candidate_count'),
|
candidateCount: integer('candidate_count'),
|
||||||
latencyMs: integer('latency_ms'),
|
latencyMs: integer('latency_ms'),
|
||||||
servedAt: text('served_at').notNull(),
|
servedAt: text('served_at').notNull(),
|
||||||
|
promptVersion: text('prompt_version'), // e.g. 'v1' — tracks which prompt template generated this tip
|
||||||
|
llmModel: text('llm_model'), // e.g. 'tip-generator/qwen2.5:7b' — null for bandit-only tips
|
||||||
|
tipKind: text('tip_kind'), // 'task' | 'advice' | 'insight' | 'reminder'
|
||||||
});
|
});
|
||||||
|
|
||||||
// ── Simulation runs ──────────────────────────────────────────────────────────
|
// ── Simulation runs ──────────────────────────────────────────────────────────
|
||||||
|
|||||||
190
services/api/src/routes/__tests__/recommender.test.ts
Normal file
190
services/api/src/routes/__tests__/recommender.test.ts
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
/**
|
||||||
|
* Integration tests for POST /recommend and tip_scores DB writes.
|
||||||
|
* Uses a real in-memory SQLite DB. recommender is imported dynamically
|
||||||
|
* inside beforeAll (same pattern as admin.test.ts) to avoid TDZ issues.
|
||||||
|
* Uses http.request (not fetch) as the test client so that globalThis.fetch
|
||||||
|
* mocking doesn't interfere with the test runner itself.
|
||||||
|
*/
|
||||||
|
import { describe, it, expect, vi, beforeAll, afterEach } from 'vitest';
|
||||||
|
import express from 'express';
|
||||||
|
import * as http from 'http';
|
||||||
|
import { makeTestDb } from '../../test/db.js';
|
||||||
|
import { users, integrationTokens, tipScores } from '../../db/schema.js';
|
||||||
|
|
||||||
|
const testDb = makeTestDb();
|
||||||
|
|
||||||
|
vi.mock('../../db/index.js', () => ({ db: testDb }));
|
||||||
|
vi.mock('../../middleware/session.js', () => ({
|
||||||
|
sessionMiddleware: (_req: express.Request, _res: express.Response, next: express.NextFunction) => next(),
|
||||||
|
requireAuth: (req: express.Request, _res: express.Response, next: express.NextFunction) => {
|
||||||
|
(req as any).userId = 'user-1';
|
||||||
|
next();
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
vi.mock('../../events/bus.js', () => ({ bus: { publish: vi.fn() } }));
|
||||||
|
|
||||||
|
/** Minimal http.request wrapper → { status, body } */
|
||||||
|
function post(url: string): Promise<{ status: number; body: any }> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const u = new URL(url);
|
||||||
|
const req = http.request(
|
||||||
|
{ hostname: u.hostname, port: Number(u.port), path: u.pathname, method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' } },
|
||||||
|
(res) => {
|
||||||
|
let data = '';
|
||||||
|
res.on('data', (c) => { data += c; });
|
||||||
|
res.on('end', () => {
|
||||||
|
try { resolve({ status: res.statusCode ?? 0, body: data ? JSON.parse(data) : null }); }
|
||||||
|
catch { resolve({ status: res.statusCode ?? 0, body: data }); }
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
|
req.on('error', reject);
|
||||||
|
req.end();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('POST /recommend integration', () => {
|
||||||
|
let server: http.Server;
|
||||||
|
let baseUrl: string;
|
||||||
|
let savedFetch: typeof globalThis.fetch;
|
||||||
|
let clearCache: () => void;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
await testDb.insert(users).values({
|
||||||
|
id: 'user-1', email: 'u@test.com', role: 'user',
|
||||||
|
consentGiven: 1, createdAt: new Date().toISOString(),
|
||||||
|
});
|
||||||
|
await testDb.insert(integrationTokens).values({
|
||||||
|
id: 'tok-1', userId: 'user-1', provider: 'todoist',
|
||||||
|
accessToken: 'fake-token', connectedAt: new Date().toISOString(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const mod = await import('../recommender.js');
|
||||||
|
const { recommenderRouter } = mod;
|
||||||
|
clearCache = (mod as any)._clearTaskCacheForTests;
|
||||||
|
const app = express();
|
||||||
|
app.use(express.json());
|
||||||
|
app.use('/api', recommenderRouter);
|
||||||
|
server = app.listen(0);
|
||||||
|
const addr = server.address() as { port: number };
|
||||||
|
baseUrl = `http://localhost:${addr.port}`;
|
||||||
|
savedFetch = globalThis.fetch;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
globalThis.fetch = savedFetch;
|
||||||
|
clearCache?.();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns 204 when Todoist + LLM both return empty', async () => {
|
||||||
|
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({ results: [] }),
|
||||||
|
} as any);
|
||||||
|
const { status } = await post(`${baseUrl}/api/recommend`);
|
||||||
|
expect(status).toBe(204);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('serves todoist tip and writes correct tip_scores columns', async () => {
|
||||||
|
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
|
||||||
|
if (String(url).includes('todoist.com')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({
|
||||||
|
results: [{ id: 'task-1', content: 'Write tests', priority: 3, due: { date: '2026-04-10' } }],
|
||||||
|
}),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/generate')) {
|
||||||
|
return Promise.resolve({ ok: false, status: 503, json: async () => ({}) } as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/score')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({ tip_id: 'todoist:task-1', score: 0.8 }),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
|
||||||
|
});
|
||||||
|
|
||||||
|
const { status, body } = await post(`${baseUrl}/api/recommend`);
|
||||||
|
expect(status).toBe(200);
|
||||||
|
expect(body.tip.source).toBe('todoist');
|
||||||
|
expect(body.tip.kind).toBe('task');
|
||||||
|
|
||||||
|
const rows = await testDb.select().from(tipScores);
|
||||||
|
const row = rows[rows.length - 1];
|
||||||
|
expect(row.tipKind).toBe('task');
|
||||||
|
expect(row.promptVersion).toBeNull();
|
||||||
|
expect(row.llmModel).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('writes prompt_version + llm_model when LLM tip is served', async () => {
|
||||||
|
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
|
||||||
|
if (String(url).includes('todoist.com')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({ results: [] }),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/generate')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({
|
||||||
|
candidates: [{ id: 'adv-1', content: 'Take a break.', rationale: 'You deserve it.' }],
|
||||||
|
model: 'tip-generator',
|
||||||
|
}),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/score')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({ tip_id: 'llm:adv-1', score: 0.9 }),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
|
||||||
|
});
|
||||||
|
|
||||||
|
const { status, body } = await post(`${baseUrl}/api/recommend`);
|
||||||
|
expect(status).toBe(200);
|
||||||
|
expect(body.tip.source).toBe('llm');
|
||||||
|
expect(body.tip.kind).toBe('advice');
|
||||||
|
expect(body.tip.rationale).toBe('You deserve it.');
|
||||||
|
|
||||||
|
const rows = await testDb.select().from(tipScores);
|
||||||
|
const row = rows[rows.length - 1];
|
||||||
|
expect(row.promptVersion).toBe('v1');
|
||||||
|
expect(row.llmModel).toBe('tip-generator');
|
||||||
|
expect(row.tipKind).toBe('advice');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back to todoist tip when /generate returns non-200', async () => {
|
||||||
|
globalThis.fetch = vi.fn().mockImplementation((url: string) => {
|
||||||
|
if (String(url).includes('todoist.com')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({
|
||||||
|
results: [{ id: 'fallback-1', content: 'Do stuff', priority: 2, due: null }],
|
||||||
|
}),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/generate')) {
|
||||||
|
return Promise.resolve({ ok: false, status: 502, json: async () => ({}) } as any);
|
||||||
|
}
|
||||||
|
if (String(url).includes('/score')) {
|
||||||
|
return Promise.resolve({
|
||||||
|
ok: true, status: 200,
|
||||||
|
json: async () => ({ tip_id: 'todoist:fallback-1', score: 0.5 }),
|
||||||
|
} as any);
|
||||||
|
}
|
||||||
|
return Promise.resolve({ ok: false, status: 500, json: async () => ({}) } as any);
|
||||||
|
});
|
||||||
|
|
||||||
|
const { status, body } = await post(`${baseUrl}/api/recommend`);
|
||||||
|
expect([200, 204]).toContain(status);
|
||||||
|
if (status === 200) {
|
||||||
|
expect(body.tip.source).toBe('todoist');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
39
services/api/src/routes/__tests__/recommender.unit.test.ts
Normal file
39
services/api/src/routes/__tests__/recommender.unit.test.ts
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
/**
|
||||||
|
* Pure-function unit tests for recommender logic — no DB, no HTTP.
|
||||||
|
* These can import directly from the module without any mocking.
|
||||||
|
*/
|
||||||
|
import { describe, it, expect } from 'vitest';
|
||||||
|
import { inferReward, dueAgeDays } from '../recommender.js';
|
||||||
|
|
||||||
|
describe('inferReward', () => {
|
||||||
|
it('dismiss → -1', () => expect(inferReward('dismiss', null)).toBe(-1.0));
|
||||||
|
it('snooze → +0.1', () => expect(inferReward('snooze', null)).toBe(0.1));
|
||||||
|
it('helpful → +0.5', () => expect(inferReward('helpful', null)).toBe(0.5));
|
||||||
|
it('not_helpful → -0.5', () => expect(inferReward('not_helpful', null)).toBe(-0.5));
|
||||||
|
it('done with null dwell → +0.5', () => expect(inferReward('done', null)).toBe(0.5));
|
||||||
|
it('done < 15s (reflex) → -0.3', () => expect(inferReward('done', 5_000)).toBe(-0.3));
|
||||||
|
it('done 15s–2min (magic) → +1.0', () => expect(inferReward('done', 60_000)).toBe(1.0));
|
||||||
|
it('done 2–10min (good) → +0.6', () => expect(inferReward('done', 300_000)).toBe(0.6));
|
||||||
|
it('done > 10min (eventual) → +0.3', () => expect(inferReward('done', 700_000)).toBe(0.3));
|
||||||
|
it('done exactly 15s (boundary) → magic zone', () => expect(inferReward('done', 15_000)).toBe(1.0));
|
||||||
|
it('done exactly 2min (boundary) → good zone', () => expect(inferReward('done', 120_000)).toBe(0.6));
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('dueAgeDays', () => {
|
||||||
|
it('null due → 0', () => expect(dueAgeDays(null)).toBe(0));
|
||||||
|
it('empty object → 0', () => expect(dueAgeDays({})).toBe(0));
|
||||||
|
it('future date → 0 (clamped)', () => {
|
||||||
|
const future = new Date(Date.now() + 86_400_000).toISOString();
|
||||||
|
expect(dueAgeDays({ datetime: future })).toBe(0);
|
||||||
|
});
|
||||||
|
it('past date → positive age', () => {
|
||||||
|
const twoDaysAgo = new Date(Date.now() - 2 * 86_400_000).toISOString();
|
||||||
|
const age = dueAgeDays({ datetime: twoDaysAgo });
|
||||||
|
expect(age).toBeGreaterThan(1.9);
|
||||||
|
expect(age).toBeLessThan(2.1);
|
||||||
|
});
|
||||||
|
it('date-only field used when datetime absent', () => {
|
||||||
|
const yesterday = new Date(Date.now() - 86_400_000).toISOString().slice(0, 10);
|
||||||
|
expect(dueAgeDays({ date: yesterday })).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -24,7 +24,7 @@ router.get('/', requireAuth, async (req: AuthenticatedRequest, res: Response) =>
|
|||||||
|
|
||||||
const integrations = tokens.map((t) => ({
|
const integrations = tokens.map((t) => ({
|
||||||
provider: t.provider,
|
provider: t.provider,
|
||||||
status: 'connected',
|
status: t.tokenStatus === 'needs_reconnect' ? 'needs_reconnect' : 'connected',
|
||||||
connectedAt: t.connectedAt,
|
connectedAt: t.connectedAt,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@@ -97,6 +97,7 @@ router.get('/todoist/callback', async (req: Request, res: Response) => {
|
|||||||
userId: pending.userId,
|
userId: pending.userId,
|
||||||
provider: 'todoist',
|
provider: 'todoist',
|
||||||
accessToken: access_token,
|
accessToken: access_token,
|
||||||
|
tokenStatus: 'active',
|
||||||
connectedAt: now,
|
connectedAt: now,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -6,23 +6,15 @@ import { eq, and, desc } from 'drizzle-orm';
|
|||||||
import { requireAuth, AuthenticatedRequest } from '../middleware/session.js';
|
import { requireAuth, AuthenticatedRequest } from '../middleware/session.js';
|
||||||
import { config } from '../config.js';
|
import { config } from '../config.js';
|
||||||
import { bus } from '../events/bus.js';
|
import { bus } from '../events/bus.js';
|
||||||
import type { Tip } from '@oo/shared-types';
|
import type { TipCandidate } from '@oo/shared-types';
|
||||||
|
|
||||||
const router: ExpressRouter = Router();
|
const router: ExpressRouter = Router();
|
||||||
|
|
||||||
const CACHE_TTL_MS = 30_000;
|
const CACHE_TTL_MS = 30_000;
|
||||||
|
const PROMPT_VERSION = 'v1';
|
||||||
|
|
||||||
interface TaskFeatures {
|
const taskCache = new Map<string, { tasks: TipCandidate[]; fetchedAt: number }>();
|
||||||
is_overdue: boolean;
|
export const _clearTaskCacheForTests = () => taskCache.clear();
|
||||||
task_age_days: number;
|
|
||||||
priority: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface CachedTask extends Tip {
|
|
||||||
features: TaskFeatures;
|
|
||||||
}
|
|
||||||
|
|
||||||
const taskCache = new Map<string, { tasks: CachedTask[]; fetchedAt: number }>();
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Shadow-policy registry
|
// Shadow-policy registry
|
||||||
@@ -49,7 +41,7 @@ export function setPolicyActive(name: string, active: boolean): boolean {
|
|||||||
// Todoist helpers
|
// Todoist helpers
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined): number {
|
export function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined): number {
|
||||||
if (!due) return 0;
|
if (!due) return 0;
|
||||||
const dateStr = due.datetime ?? due.date;
|
const dateStr = due.datetime ?? due.date;
|
||||||
if (!dateStr) return 0;
|
if (!dateStr) return 0;
|
||||||
@@ -57,7 +49,7 @@ function dueAgeDays(due: { date?: string; datetime?: string } | null | undefined
|
|||||||
return Math.max(0, (Date.now() - dueMs) / (1000 * 60 * 60 * 24));
|
return Math.max(0, (Date.now() - dueMs) / (1000 * 60 * 60 * 24));
|
||||||
}
|
}
|
||||||
|
|
||||||
async function fetchTodoistTasks(userId: string, accessToken: string): Promise<CachedTask[]> {
|
async function fetchTodoistTasks(userId: string, accessToken: string): Promise<TipCandidate[]> {
|
||||||
const cached = taskCache.get(userId);
|
const cached = taskCache.get(userId);
|
||||||
if (cached && Date.now() - cached.fetchedAt < CACHE_TTL_MS) return cached.tasks;
|
if (cached && Date.now() - cached.fetchedAt < CACHE_TTL_MS) return cached.tasks;
|
||||||
|
|
||||||
@@ -73,6 +65,10 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
|
|||||||
provider: 'todoist',
|
provider: 'todoist',
|
||||||
detectedAt: new Date().toISOString(),
|
detectedAt: new Date().toISOString(),
|
||||||
});
|
});
|
||||||
|
await db
|
||||||
|
.update(integrationTokens)
|
||||||
|
.set({ tokenStatus: 'needs_reconnect' })
|
||||||
|
.where(and(eq(integrationTokens.userId, userId), eq(integrationTokens.provider, 'todoist')));
|
||||||
}
|
}
|
||||||
return cached?.tasks ?? [];
|
return cached?.tasks ?? [];
|
||||||
}
|
}
|
||||||
@@ -87,13 +83,14 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
|
|||||||
};
|
};
|
||||||
|
|
||||||
const now = new Date();
|
const now = new Date();
|
||||||
const tasks: CachedTask[] = (body.results ?? []).map((t) => {
|
const tasks: TipCandidate[] = (body.results ?? []).map((t) => {
|
||||||
const ageDays = dueAgeDays(t.due);
|
const ageDays = dueAgeDays(t.due);
|
||||||
const isOverdue = ageDays > 0;
|
const isOverdue = ageDays > 0;
|
||||||
return {
|
return {
|
||||||
id: `todoist:${t.id}`,
|
id: `todoist:${t.id}`,
|
||||||
content: t.content,
|
content: t.content,
|
||||||
source: 'todoist' as const,
|
source: 'todoist' as const,
|
||||||
|
kind: 'task' as const,
|
||||||
sourceId: t.id,
|
sourceId: t.id,
|
||||||
createdAt: now.toISOString(),
|
createdAt: now.toISOString(),
|
||||||
features: {
|
features: {
|
||||||
@@ -111,10 +108,14 @@ async function fetchTodoistTasks(userId: string, accessToken: string): Promise<C
|
|||||||
return tasks;
|
return tasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stage 2: score candidates via ml/serving bandit
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
/** Call ml/serving for scored selection; returns { tip_id, score } or null on failure */
|
/** Call ml/serving for scored selection; returns { tip_id, score } or null on failure */
|
||||||
async function remotePolicy(
|
async function remotePolicy(
|
||||||
userId: string,
|
userId: string,
|
||||||
tasks: CachedTask[],
|
tasks: TipCandidate[],
|
||||||
): Promise<{ tipId: string; score: number; policy: string } | null> {
|
): Promise<{ tipId: string; score: number; policy: string } | null> {
|
||||||
const hour = new Date().getHours();
|
const hour = new Date().getHours();
|
||||||
const dayOfWeek = new Date().getDay();
|
const dayOfWeek = new Date().getDay();
|
||||||
@@ -147,13 +148,64 @@ async function remotePolicy(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function randomPolicy(candidates: CachedTask[]): CachedTask | null {
|
function randomPolicy(candidates: TipCandidate[]): TipCandidate | null {
|
||||||
if (!candidates.length) return null;
|
if (!candidates.length) return null;
|
||||||
return candidates[Math.floor(Math.random() * candidates.length)];
|
return candidates[Math.floor(Math.random() * candidates.length)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stage 1b: fetch LLM candidates from ml/serving /generate
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
interface LlmCandidate {
|
||||||
|
id: string;
|
||||||
|
content: string;
|
||||||
|
rationale?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchLlmCandidates(
|
||||||
|
userId: string,
|
||||||
|
todoistTasks: TipCandidate[],
|
||||||
|
hour: number,
|
||||||
|
dayOfWeek: number,
|
||||||
|
): Promise<TipCandidate[]> {
|
||||||
|
try {
|
||||||
|
const tasks = todoistTasks.slice(0, 10).map((t) => ({
|
||||||
|
content: t.content,
|
||||||
|
priority: t.features.priority,
|
||||||
|
is_overdue: t.features.is_overdue,
|
||||||
|
task_age_days: t.features.task_age_days,
|
||||||
|
}));
|
||||||
|
const res = await fetch(`${config.ML_SERVING_URL}/generate`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
user_id: userId,
|
||||||
|
context: { tasks, hour_of_day: hour, day_of_week: dayOfWeek },
|
||||||
|
n: 3,
|
||||||
|
}),
|
||||||
|
signal: AbortSignal.timeout(15_000),
|
||||||
|
});
|
||||||
|
if (!res.ok) return [];
|
||||||
|
const data = (await res.json()) as { candidates: LlmCandidate[]; model?: string };
|
||||||
|
const now = new Date().toISOString();
|
||||||
|
return data.candidates.map((c) => ({
|
||||||
|
id: `llm:${c.id}`,
|
||||||
|
content: c.content,
|
||||||
|
source: 'llm' as const,
|
||||||
|
kind: 'advice' as const,
|
||||||
|
rationale: c.rationale,
|
||||||
|
createdAt: now,
|
||||||
|
features: { is_overdue: false, task_age_days: 0, priority: 1 },
|
||||||
|
}));
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// POST /api/recommend
|
// POST /api/recommend
|
||||||
|
// Pipeline: [Stage 1] assemble candidates → [Stage 2] score → [Stage 3] serve
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Response) => {
|
router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Response) => {
|
||||||
const [token] = await db
|
const [token] = await db
|
||||||
@@ -167,34 +219,42 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const tasks = await fetchTodoistTasks(req.userId!, token.accessToken);
|
const hour = new Date().getHours();
|
||||||
if (!tasks.length) {
|
const dayOfWeek = new Date().getDay();
|
||||||
|
|
||||||
|
// Stage 1: assemble candidates — Todoist tasks + LLM-generated advice (parallel)
|
||||||
|
const [todoistTasks, llmCandidates] = await Promise.all([
|
||||||
|
fetchTodoistTasks(req.userId!, token.accessToken),
|
||||||
|
fetchLlmCandidates(req.userId!, taskCache.get(req.userId!)?.tasks ?? [], hour, dayOfWeek),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const allCandidates: TipCandidate[] = [...todoistTasks, ...llmCandidates];
|
||||||
|
if (!allCandidates.length) {
|
||||||
res.status(204).end();
|
res.status(204).end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const hour = new Date().getHours();
|
|
||||||
const dayOfWeek = new Date().getDay();
|
|
||||||
const t0 = Date.now();
|
const t0 = Date.now();
|
||||||
|
|
||||||
// RemotePolicy with RandomPolicy fallback
|
// Stage 2: score — egreedy bandit with random fallback
|
||||||
const scored = await remotePolicy(req.userId!, tasks);
|
const scored = await remotePolicy(req.userId!, allCandidates);
|
||||||
const latencyMs = Date.now() - t0;
|
const latencyMs = Date.now() - t0;
|
||||||
const tip = scored
|
const tip = scored
|
||||||
? (tasks.find((t) => t.id === scored.tipId) ?? randomPolicy(tasks))
|
? (allCandidates.find((t) => t.id === scored.tipId) ?? randomPolicy(allCandidates))
|
||||||
: randomPolicy(tasks);
|
: randomPolicy(allCandidates);
|
||||||
|
|
||||||
if (!tip) {
|
if (!tip) {
|
||||||
res.status(204).end();
|
res.status(204).end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stage 3: serve + log
|
||||||
const policy = scored ? scored.policy : 'random';
|
const policy = scored ? scored.policy : 'random';
|
||||||
|
const isLlmTip = tip.source === 'llm';
|
||||||
const servedAt = new Date().toISOString();
|
const servedAt = new Date().toISOString();
|
||||||
|
|
||||||
await db.insert(tipViews).values({ id: nanoid(), userId: req.userId!, tipId: tip.id, servedAt });
|
await db.insert(tipViews).values({ id: nanoid(), userId: req.userId!, tipId: tip.id, servedAt });
|
||||||
|
|
||||||
// Log recommendation explainability
|
|
||||||
await db.insert(tipScores).values({
|
await db.insert(tipScores).values({
|
||||||
id: nanoid(),
|
id: nanoid(),
|
||||||
userId: req.userId!,
|
userId: req.userId!,
|
||||||
@@ -208,9 +268,12 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
|
|||||||
hour_of_day: hour,
|
hour_of_day: hour,
|
||||||
day_of_week: dayOfWeek,
|
day_of_week: dayOfWeek,
|
||||||
}),
|
}),
|
||||||
candidateCount: tasks.length,
|
candidateCount: allCandidates.length,
|
||||||
latencyMs,
|
latencyMs,
|
||||||
servedAt,
|
servedAt,
|
||||||
|
promptVersion: isLlmTip ? PROMPT_VERSION : null,
|
||||||
|
llmModel: isLlmTip ? 'tip-generator' : null,
|
||||||
|
tipKind: tip.kind ?? null,
|
||||||
});
|
});
|
||||||
|
|
||||||
bus.publish('signals.tip.served', {
|
bus.publish('signals.tip.served', {
|
||||||
@@ -224,7 +287,7 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
|
|||||||
for (const [name, s] of shadowPolicies) {
|
for (const [name, s] of shadowPolicies) {
|
||||||
if (!s.active) continue;
|
if (!s.active) continue;
|
||||||
if (name.startsWith('random')) {
|
if (name.startsWith('random')) {
|
||||||
const shadowTip = randomPolicy(tasks);
|
const shadowTip = randomPolicy(allCandidates);
|
||||||
bus.publish('signals.tip.served', {
|
bus.publish('signals.tip.served', {
|
||||||
userId: req.userId!,
|
userId: req.userId!,
|
||||||
tipId: shadowTip?.id ?? 'none',
|
tipId: shadowTip?.id ?? 'none',
|
||||||
@@ -249,7 +312,7 @@ router.post('/recommend', requireAuth, async (req: AuthenticatedRequest, res: Re
|
|||||||
// done 2 – 10 min → +0.6 (good: user engaged, acted in same session)
|
// done 2 – 10 min → +0.6 (good: user engaged, acted in same session)
|
||||||
// done > 10 min → +0.3 (eventually done; tip may have helped, unclear)
|
// done > 10 min → +0.3 (eventually done; tip may have helped, unclear)
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
function inferReward(action: string, dwellMs: number | null): number {
|
export function inferReward(action: string, dwellMs: number | null): number {
|
||||||
if (action === 'dismiss') return -1.0;
|
if (action === 'dismiss') return -1.0;
|
||||||
if (action === 'snooze') return 0.1;
|
if (action === 'snooze') return 0.1;
|
||||||
if (action === 'helpful') return 0.5;
|
if (action === 'helpful') return 0.5;
|
||||||
@@ -269,7 +332,7 @@ async function sendRewardWithRetry(
|
|||||||
userId: string,
|
userId: string,
|
||||||
tipId: string,
|
tipId: string,
|
||||||
reward: number,
|
reward: number,
|
||||||
features: TaskFeatures,
|
features: TipCandidate['features'],
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const body = JSON.stringify({
|
const body = JSON.stringify({
|
||||||
user_id: userId,
|
user_id: userId,
|
||||||
@@ -347,7 +410,7 @@ router.post('/tip/:id/feedback', requireAuth, async (req: AuthenticatedRequest,
|
|||||||
createdAt: now.toISOString(),
|
createdAt: now.toISOString(),
|
||||||
});
|
});
|
||||||
|
|
||||||
const task = taskCache.get(req.userId!)?.tasks.find((t) => t.id === tipId);
|
const task: TipCandidate | undefined = taskCache.get(req.userId!)?.tasks.find((t) => t.id === tipId);
|
||||||
|
|
||||||
taskCache.delete(req.userId!);
|
taskCache.delete(req.userId!);
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export function makeTestDb() {
|
|||||||
refresh_token TEXT,
|
refresh_token TEXT,
|
||||||
expires_at TEXT,
|
expires_at TEXT,
|
||||||
connected_at TEXT NOT NULL,
|
connected_at TEXT NOT NULL,
|
||||||
|
token_status TEXT NOT NULL DEFAULT 'active',
|
||||||
UNIQUE(user_id, provider)
|
UNIQUE(user_id, provider)
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -88,7 +89,10 @@ export function makeTestDb() {
|
|||||||
features_json TEXT,
|
features_json TEXT,
|
||||||
candidate_count INTEGER,
|
candidate_count INTEGER,
|
||||||
latency_ms INTEGER,
|
latency_ms INTEGER,
|
||||||
served_at TEXT NOT NULL
|
served_at TEXT NOT NULL,
|
||||||
|
prompt_version TEXT,
|
||||||
|
llm_model TEXT,
|
||||||
|
tip_kind TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS saved_queries (
|
CREATE TABLE IF NOT EXISTS saved_queries (
|
||||||
|
|||||||
@@ -4,6 +4,13 @@ export default defineConfig({
|
|||||||
test: {
|
test: {
|
||||||
globals: true,
|
globals: true,
|
||||||
environment: 'node',
|
environment: 'node',
|
||||||
|
env: {
|
||||||
|
SESSION_SECRET: 'test-secret',
|
||||||
|
GOOGLE_CLIENT_ID: 'test-google-id',
|
||||||
|
GOOGLE_CLIENT_SECRET: 'test-google-secret',
|
||||||
|
TODOIST_CLIENT_ID: 'test-todoist-id',
|
||||||
|
TODOIST_CLIENT_SECRET: 'test-todoist-secret',
|
||||||
|
},
|
||||||
coverage: {
|
coverage: {
|
||||||
provider: 'v8',
|
provider: 'v8',
|
||||||
reporter: ['text', 'lcov'],
|
reporter: ['text', 'lcov'],
|
||||||
|
|||||||
Reference in New Issue
Block a user