Files
oO/ml/serving/tests/test_generate.py
alvis 430804e9a5 feat(ml): prompt registry + per-request variant selection
Replaces the hardcoded "v1" label with a real prompt registry:

  ml/serving/prompts.py       — keyed by version: v1 (baseline),
                                v2-mentor (calm/specific persona),
                                v3-few-shot (v1 persona + curated examples)
  ml/serving/main.py          — POST /generate accepts optional prompt_version,
                                422 on unknown, echoes the version actually used
                                back in the response
  services/api/src/config.ts  — TIP_PROMPT_VERSION: empty / single / comma-list
                                (uniform random per request)
  services/api/src/routes/recommender.ts
                              — pickPromptVersion() drives selection; the
                                response's prompt_version (not a stale TS
                                constant) is what lands in tip_scores so the
                                #92 reward-analytics dashboard shows real
                                per-variant reaction rates

Closes #84.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-24 15:44:04 +00:00

283 lines
10 KiB
Python

"""
Tests for POST /generate — LiteLLM gateway.
LiteLLM is mocked; no real network calls.
"""
import json
import pytest
import httpx
from unittest.mock import AsyncMock, patch
from httpx import AsyncClient, ASGITransport, Response
from main import app, PromptContext
from prompts import PROMPTS, get_prompt
_build_user_v1 = PROMPTS["v1"].build_user
def _litellm_response(candidates: list[dict]) -> Response:
import httpx
body = {
"model": "tip-generator",
"choices": [{"message": {"content": json.dumps(candidates)}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
return Response(200, json=body, request=req)
@pytest.mark.anyio
async def test_generate_returns_candidates():
fake_items = [
{"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."},
{"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."},
]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 2})
assert resp.status_code == 200
data = resp.json()
assert len(data["candidates"]) == 2
assert data["candidates"][0]["id"] == "tip-1"
assert data["model"] == "tip-generator"
@pytest.mark.anyio
async def test_generate_strips_markdown_fence():
fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}]
fenced = "```json\n" + json.dumps(fake_items) + "\n```"
body = {
"model": "tip-generator",
"choices": [{"message": {"content": fenced}}],
"usage": {},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
mock_resp = Response(200, json=body, request=req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-a"
@pytest.mark.anyio
async def test_generate_503_on_unreachable():
import httpx as _httpx
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused"))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 503
def test_build_prompt_includes_tasks():
ctx = PromptContext(
tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}],
hour_of_day=9,
day_of_week=2,
)
prompt = _build_user_v1(ctx, n=3)
assert "Write report" in prompt
assert "09:00" in prompt
assert "Generate 3 tips" in prompt
def test_build_prompt_truncates_at_five():
tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)]
ctx = PromptContext(tasks=tasks, hour_of_day=12)
prompt = _build_user_v1(ctx, n=2)
assert "Task 4" in prompt
assert "Task 5" not in prompt
def test_build_prompt_extra_fields():
ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"})
prompt = _build_user_v1(ctx, n=1)
assert "mood: focused" in prompt
assert "energy: high" in prompt
def test_build_prompt_empty_tasks_no_task_line():
ctx = PromptContext(tasks=[], hour_of_day=10)
prompt = _build_user_v1(ctx, n=2)
assert "Tasks:" not in prompt
assert "Generate 2 tips" in prompt
@pytest.mark.anyio
async def test_generate_retry_succeeds_on_second_attempt():
"""First response is invalid JSON; second is valid. Should return 200."""
valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}]
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
bad_resp = Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "this is not json"}}],
"usage": {},
}, request=bad_req)
good_resp = _litellm_response(valid_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=[bad_resp, good_resp])
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 1})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-ok"
assert instance.post.call_count == 2
# Retry message should include the correction suffix
second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"]
assert any("not valid JSON" in m["content"] for m in second_call_messages)
@pytest.mark.anyio
async def test_generate_502_after_all_retries_exhausted():
"""All attempts return invalid JSON → 502."""
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
def _bad_resp():
return Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "not json at all"}}],
"usage": {},
}, request=bad_req)
from main import _MAX_GENERATE_RETRIES
responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)]
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=responses)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "retries" in resp.json()["detail"]
@pytest.mark.anyio
async def test_generate_502_on_upstream_http_error():
"""LiteLLM returns 500 → HTTPStatusError → 502."""
err_req = httpx.Request("POST", "http://litellm/chat/completions")
err_resp = Response(500, text="internal error", request=err_req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
"500", request=err_req, response=err_resp
))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "LiteLLM error" in resp.json()["detail"]
def test_parse_llm_json_bare_fence():
from main import _parse_llm_json
raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```"
items = _parse_llm_json(raw)
assert items[0]["id"] == "x"
def test_parse_llm_json_no_fence():
from main import _parse_llm_json
raw = '[{"id":"plain","content":"no fence"}]'
items = _parse_llm_json(raw)
assert items[0]["id"] == "plain"
def test_parse_llm_json_raises_on_invalid():
from main import _parse_llm_json
with pytest.raises((ValueError, Exception)):
_parse_llm_json("this is not json")
# ── Prompt registry / selection (#84) ──────────────────────────────────────
def test_prompt_registry_contains_expected_versions():
assert set(PROMPTS.keys()) >= {"v1", "v2-mentor", "v3-few-shot"}
# v2-mentor must differ from v1 in tone — easiest assertion: different system prompt.
assert PROMPTS["v1"].system != PROMPTS["v2-mentor"].system
# v3-few-shot must include curated example content in its system prompt.
assert "Examples" in PROMPTS["v3-few-shot"].system
def test_get_prompt_unknown_raises_keyerror():
with pytest.raises(KeyError):
get_prompt("does-not-exist")
def test_get_prompt_default_when_none():
p = get_prompt(None)
assert p.version == "v1" # current DEFAULT_PROMPT_VERSION
@pytest.mark.anyio
async def test_generate_echoes_selected_prompt_version():
"""Server should report back which prompt_version it actually used."""
fake_items = [{"id": "tip-1", "content": "x", "rationale": "y"}]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post(
"/generate",
json={"user_id": "u1", "n": 1, "prompt_version": "v2-mentor"},
)
assert resp.status_code == 200
assert resp.json()["prompt_version"] == "v2-mentor"
@pytest.mark.anyio
async def test_generate_422_on_unknown_prompt_version():
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post(
"/generate",
json={"user_id": "u1", "n": 1, "prompt_version": "nonsense"},
)
assert resp.status_code == 422
assert "Unknown prompt_version" in resp.json()["detail"]