Files
oO/ml/serving/tests/test_generate.py
alvis 4267e6ac68 feat(ml/serving): inject profile features + sort tasks in tip prompt (#79)
- prompts.py: sort tasks overdue-first → priority desc → age desc before
  rendering into the LLM prompt (same ordering as ml/features/context.py)
- prompts.py: render User profile summary line (completion_rate, dismiss_rate,
  preferred_hour) when profile_features are present
- main.py: add profile_features field to PromptContext; plumb from
  GenerateRequest into the prompt builder via model_copy
- logging_config.py: drop add_logger_name processor (incompatible with
  PrintLoggerFactory — caused test ordering failures)
- test_generate.py: 6 new tests covering sort order, profile rendering,
  partial fields, empty profile, and end-to-end plumbing through /generate

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-27 13:46:16 +00:00

355 lines
13 KiB
Python

"""
Tests for POST /generate — LiteLLM gateway.
LiteLLM is mocked; no real network calls.
"""
import json
import pytest
import httpx
from unittest.mock import AsyncMock, patch
from httpx import AsyncClient, ASGITransport, Response
from main import app, PromptContext
from prompts import PROMPTS, get_prompt
_build_user_v1 = PROMPTS["v1"].build_user
def _litellm_response(candidates: list[dict]) -> Response:
import httpx
body = {
"model": "tip-generator",
"choices": [{"message": {"content": json.dumps(candidates)}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
return Response(200, json=body, request=req)
@pytest.mark.anyio
async def test_generate_returns_candidates():
fake_items = [
{"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."},
{"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."},
]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 2})
assert resp.status_code == 200
data = resp.json()
assert len(data["candidates"]) == 2
assert data["candidates"][0]["id"] == "tip-1"
assert data["model"] == "tip-generator"
@pytest.mark.anyio
async def test_generate_strips_markdown_fence():
fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}]
fenced = "```json\n" + json.dumps(fake_items) + "\n```"
body = {
"model": "tip-generator",
"choices": [{"message": {"content": fenced}}],
"usage": {},
}
req = httpx.Request("POST", "http://litellm/chat/completions")
mock_resp = Response(200, json=body, request=req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-a"
@pytest.mark.anyio
async def test_generate_503_on_unreachable():
import httpx as _httpx
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused"))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 503
def test_build_prompt_includes_tasks():
ctx = PromptContext(
tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}],
hour_of_day=9,
day_of_week=2,
)
prompt = _build_user_v1(ctx, n=3)
assert "Write report" in prompt
assert "09:00" in prompt
assert "Generate 3 tips" in prompt
def test_build_prompt_truncates_at_five():
tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)]
ctx = PromptContext(tasks=tasks, hour_of_day=12)
prompt = _build_user_v1(ctx, n=2)
assert "Task 4" in prompt
assert "Task 5" not in prompt
def test_build_prompt_extra_fields():
ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"})
prompt = _build_user_v1(ctx, n=1)
assert "mood: focused" in prompt
assert "energy: high" in prompt
def test_build_prompt_empty_tasks_no_task_line():
ctx = PromptContext(tasks=[], hour_of_day=10)
prompt = _build_user_v1(ctx, n=2)
assert "Tasks:" not in prompt
assert "Generate 2 tips" in prompt
def test_build_prompt_tasks_sorted_overdue_first():
tasks = [
{"content": "Low priority", "priority": 1, "is_overdue": False, "task_age_days": 0},
{"content": "Overdue task", "priority": 2, "is_overdue": True, "task_age_days": 3},
]
ctx = PromptContext(tasks=tasks, hour_of_day=9)
prompt = _build_user_v1(ctx, n=2)
assert prompt.index("Overdue task") < prompt.index("Low priority")
def test_build_prompt_includes_profile_features():
ctx = PromptContext(
tasks=[],
hour_of_day=14,
profile_features={"completion_rate_30d": 0.75, "dismiss_rate_30d": 0.1, "preferred_hour": 9},
)
prompt = _build_user_v1(ctx, n=1)
assert "User profile:" in prompt
assert "completion_rate=75%" in prompt
assert "dismiss_rate=10%" in prompt
assert "preferred_hour=09:00" in prompt
def test_build_prompt_no_profile_line_when_empty():
ctx = PromptContext(tasks=[], hour_of_day=10, profile_features={})
prompt = _build_user_v1(ctx, n=1)
assert "User profile:" not in prompt
def test_build_prompt_profile_partial_fields():
ctx = PromptContext(
tasks=[],
hour_of_day=10,
profile_features={"completion_rate_30d": 0.5},
)
prompt = _build_user_v1(ctx, n=1)
assert "completion_rate=50%" in prompt
assert "dismiss_rate" not in prompt
@pytest.mark.anyio
async def test_generate_retry_succeeds_on_second_attempt():
"""First response is invalid JSON; second is valid. Should return 200."""
valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}]
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
bad_resp = Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "this is not json"}}],
"usage": {},
}, request=bad_req)
good_resp = _litellm_response(valid_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=[bad_resp, good_resp])
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1", "n": 1})
assert resp.status_code == 200
assert resp.json()["candidates"][0]["id"] == "tip-ok"
assert instance.post.call_count == 2
# Retry message should include the correction suffix
second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"]
assert any("not valid JSON" in m["content"] for m in second_call_messages)
@pytest.mark.anyio
async def test_generate_502_after_all_retries_exhausted():
"""All attempts return invalid JSON → 502."""
bad_req = httpx.Request("POST", "http://litellm/chat/completions")
def _bad_resp():
return Response(200, json={
"model": "tip-generator",
"choices": [{"message": {"content": "not json at all"}}],
"usage": {},
}, request=bad_req)
from main import _MAX_GENERATE_RETRIES
responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)]
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=responses)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "retries" in resp.json()["detail"]
@pytest.mark.anyio
async def test_generate_502_on_upstream_http_error():
"""LiteLLM returns 500 → HTTPStatusError → 502."""
err_req = httpx.Request("POST", "http://litellm/chat/completions")
err_resp = Response(500, text="internal error", request=err_req)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
"500", request=err_req, response=err_resp
))
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={"user_id": "u1"})
assert resp.status_code == 502
assert "LiteLLM error" in resp.json()["detail"]
def test_parse_llm_json_bare_fence():
from main import _parse_llm_json
raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```"
items = _parse_llm_json(raw)
assert items[0]["id"] == "x"
def test_parse_llm_json_no_fence():
from main import _parse_llm_json
raw = '[{"id":"plain","content":"no fence"}]'
items = _parse_llm_json(raw)
assert items[0]["id"] == "plain"
def test_parse_llm_json_raises_on_invalid():
from main import _parse_llm_json
with pytest.raises((ValueError, Exception)):
_parse_llm_json("this is not json")
# ── Prompt registry / selection (#84) ──────────────────────────────────────
def test_prompt_registry_contains_expected_versions():
assert set(PROMPTS.keys()) >= {"v1", "v2-mentor", "v3-few-shot"}
# v2-mentor must differ from v1 in tone — easiest assertion: different system prompt.
assert PROMPTS["v1"].system != PROMPTS["v2-mentor"].system
# v3-few-shot must include curated example content in its system prompt.
assert "Examples" in PROMPTS["v3-few-shot"].system
def test_get_prompt_unknown_raises_keyerror():
with pytest.raises(KeyError):
get_prompt("does-not-exist")
def test_get_prompt_default_when_none():
p = get_prompt(None)
assert p.version == "v1" # current DEFAULT_PROMPT_VERSION
@pytest.mark.anyio
async def test_generate_echoes_selected_prompt_version():
"""Server should report back which prompt_version it actually used."""
fake_items = [{"id": "tip-1", "content": "x", "rationale": "y"}]
mock_resp = _litellm_response(fake_items)
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(return_value=mock_resp)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post(
"/generate",
json={"user_id": "u1", "n": 1, "prompt_version": "v2-mentor"},
)
assert resp.status_code == 200
assert resp.json()["prompt_version"] == "v2-mentor"
@pytest.mark.anyio
async def test_generate_passes_profile_features_to_prompt():
"""profile_features from GenerateRequest should appear in the user message sent to LiteLLM."""
fake_items = [{"id": "tip-1", "content": "x", "rationale": "y"}]
mock_resp = _litellm_response(fake_items)
captured_payload: list[dict] = []
async def _capture(url, *, json, headers):
captured_payload.append(json)
return mock_resp
with patch("main.httpx.AsyncClient") as MockClient:
instance = AsyncMock()
instance.post = AsyncMock(side_effect=_capture)
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
MockClient.return_value = instance
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/generate", json={
"user_id": "u1",
"n": 1,
"profile_features": {"completion_rate_30d": 0.8, "preferred_hour": 10},
})
assert resp.status_code == 200
user_msg = captured_payload[0]["messages"][1]["content"]
assert "User profile:" in user_msg
assert "completion_rate=80%" in user_msg
assert "preferred_hour=10:00" in user_msg
@pytest.mark.anyio
async def test_generate_422_on_unknown_prompt_version():
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post(
"/generate",
json={"user_id": "u1", "n": 1, "prompt_version": "nonsense"},
)
assert resp.status_code == 422
assert "Unknown prompt_version" in resp.json()["detail"]