""" Tests for POST /generate — LiteLLM gateway. LiteLLM is mocked; no real network calls. """ import json import pytest import httpx from unittest.mock import AsyncMock, patch from httpx import AsyncClient, ASGITransport, Response from main import app, _build_prompt, PromptContext def _litellm_response(candidates: list[dict]) -> Response: import httpx body = { "model": "tip-generator", "choices": [{"message": {"content": json.dumps(candidates)}}], "usage": {"prompt_tokens": 10, "completion_tokens": 20}, } req = httpx.Request("POST", "http://litellm/chat/completions") return Response(200, json=body, request=req) @pytest.mark.anyio async def test_generate_returns_candidates(): fake_items = [ {"id": "tip-1", "content": "Do the overdue task now.", "rationale": "It's been waiting."}, {"id": "tip-2", "content": "Take a 5-minute break.", "rationale": "You've been working long."}, ] mock_resp = _litellm_response(fake_items) with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(return_value=mock_resp) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1", "n": 2}) assert resp.status_code == 200 data = resp.json() assert len(data["candidates"]) == 2 assert data["candidates"][0]["id"] == "tip-1" assert data["model"] == "tip-generator" @pytest.mark.anyio async def test_generate_strips_markdown_fence(): fake_items = [{"id": "tip-a", "content": "Focus.", "rationale": "Now."}] fenced = "```json\n" + json.dumps(fake_items) + "\n```" body = { "model": "tip-generator", "choices": [{"message": {"content": fenced}}], "usage": {}, } req = httpx.Request("POST", "http://litellm/chat/completions") mock_resp = Response(200, json=body, request=req) with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(return_value=mock_resp) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1"}) assert resp.status_code == 200 assert resp.json()["candidates"][0]["id"] == "tip-a" @pytest.mark.anyio async def test_generate_503_on_unreachable(): import httpx as _httpx with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(side_effect=_httpx.ConnectError("refused")) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1"}) assert resp.status_code == 503 def test_build_prompt_includes_tasks(): ctx = PromptContext( tasks=[{"content": "Write report", "priority": 4, "is_overdue": True, "due_date": "2026-04-15"}], hour_of_day=9, day_of_week=2, ) prompt = _build_prompt(ctx, n=3) assert "Write report" in prompt assert "09:00" in prompt assert "Generate 3 tips" in prompt def test_build_prompt_truncates_at_five(): tasks = [{"content": f"Task {i}", "priority": 1, "is_overdue": False, "due_date": None} for i in range(8)] ctx = PromptContext(tasks=tasks, hour_of_day=12) prompt = _build_prompt(ctx, n=2) assert "Task 4" in prompt assert "Task 5" not in prompt def test_build_prompt_extra_fields(): ctx = PromptContext(tasks=[], hour_of_day=8, extra={"mood": "focused", "energy": "high"}) prompt = _build_prompt(ctx, n=1) assert "mood: focused" in prompt assert "energy: high" in prompt def test_build_prompt_empty_tasks_no_task_line(): ctx = PromptContext(tasks=[], hour_of_day=10) prompt = _build_prompt(ctx, n=2) assert "Tasks:" not in prompt assert "Generate 2 tips" in prompt @pytest.mark.anyio async def test_generate_retry_succeeds_on_second_attempt(): """First response is invalid JSON; second is valid. Should return 200.""" valid_items = [{"id": "tip-ok", "content": "Retry worked.", "rationale": "Second try."}] bad_req = httpx.Request("POST", "http://litellm/chat/completions") bad_resp = Response(200, json={ "model": "tip-generator", "choices": [{"message": {"content": "this is not json"}}], "usage": {}, }, request=bad_req) good_resp = _litellm_response(valid_items) with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(side_effect=[bad_resp, good_resp]) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1", "n": 1}) assert resp.status_code == 200 assert resp.json()["candidates"][0]["id"] == "tip-ok" assert instance.post.call_count == 2 # Retry message should include the correction suffix second_call_messages = instance.post.call_args_list[1][1]["json"]["messages"] assert any("not valid JSON" in m["content"] for m in second_call_messages) @pytest.mark.anyio async def test_generate_502_after_all_retries_exhausted(): """All attempts return invalid JSON → 502.""" bad_req = httpx.Request("POST", "http://litellm/chat/completions") def _bad_resp(): return Response(200, json={ "model": "tip-generator", "choices": [{"message": {"content": "not json at all"}}], "usage": {}, }, request=bad_req) from main import _MAX_GENERATE_RETRIES responses = [_bad_resp() for _ in range(1 + _MAX_GENERATE_RETRIES)] with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(side_effect=responses) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1"}) assert resp.status_code == 502 assert "retries" in resp.json()["detail"] @pytest.mark.anyio async def test_generate_502_on_upstream_http_error(): """LiteLLM returns 500 → HTTPStatusError → 502.""" err_req = httpx.Request("POST", "http://litellm/chat/completions") err_resp = Response(500, text="internal error", request=err_req) with patch("main.httpx.AsyncClient") as MockClient: instance = AsyncMock() instance.post = AsyncMock(side_effect=httpx.HTTPStatusError( "500", request=err_req, response=err_resp )) instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) MockClient.return_value = instance async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: resp = await client.post("/generate", json={"user_id": "u1"}) assert resp.status_code == 502 assert "LiteLLM error" in resp.json()["detail"] def test_parse_llm_json_bare_fence(): from main import _parse_llm_json raw = "```\n[{\"id\":\"x\",\"content\":\"hi\"}]\n```" items = _parse_llm_json(raw) assert items[0]["id"] == "x" def test_parse_llm_json_no_fence(): from main import _parse_llm_json raw = '[{"id":"plain","content":"no fence"}]' items = _parse_llm_json(raw) assert items[0]["id"] == "plain" def test_parse_llm_json_raises_on_invalid(): from main import _parse_llm_json with pytest.raises((ValueError, Exception)): _parse_llm_json("this is not json")