From 4267e6ac68590eaecd01e4f547ef294186886b11 Mon Sep 17 00:00:00 2001 From: alvis Date: Mon, 27 Apr 2026 13:46:16 +0000 Subject: [PATCH] feat(ml/serving): inject profile features + sort tasks in tip prompt (#79) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - prompts.py: sort tasks overdue-first → priority desc → age desc before rendering into the LLM prompt (same ordering as ml/features/context.py) - prompts.py: render User profile summary line (completion_rate, dismiss_rate, preferred_hour) when profile_features are present - main.py: add profile_features field to PromptContext; plumb from GenerateRequest into the prompt builder via model_copy - logging_config.py: drop add_logger_name processor (incompatible with PrintLoggerFactory — caused test ordering failures) - test_generate.py: 6 new tests covering sort order, profile rendering, partial fields, empty profile, and end-to-end plumbing through /generate Co-Authored-By: Claude Sonnet 4.6 --- ml/serving/logging_config.py | 1 - ml/serving/main.py | 4 +- ml/serving/prompts.py | 25 +++++++++-- ml/serving/tests/test_generate.py | 72 +++++++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 6 deletions(-) diff --git a/ml/serving/logging_config.py b/ml/serving/logging_config.py index b40ae40..52f338a 100644 --- a/ml/serving/logging_config.py +++ b/ml/serving/logging_config.py @@ -8,7 +8,6 @@ def configure() -> None: processors=[ structlog.contextvars.merge_contextvars, structlog.stdlib.add_log_level, - structlog.stdlib.add_logger_name, structlog.processors.TimeStamper(fmt="iso"), structlog.processors.StackInfoRenderer(), structlog.processors.JSONRenderer(), diff --git a/ml/serving/main.py b/ml/serving/main.py index 111eb2b..1a8bb26 100644 --- a/ml/serving/main.py +++ b/ml/serving/main.py @@ -322,6 +322,7 @@ class PromptContext(BaseModel): hour_of_day: int = 12 day_of_week: int = 0 extra: dict = {} + profile_features: Optional[dict] = None class GenerateRequest(BaseModel): @@ -392,7 +393,8 @@ async def generate(req: GenerateRequest) -> GenerateResponse: prompt_template = get_prompt(req.prompt_version) except KeyError as e: raise HTTPException(status_code=422, detail=f"Unknown prompt_version: {e.args[0]}") - user_msg = prompt_template.build_user(req.context, req.n) + ctx = req.context.model_copy(update={"profile_features": req.profile_features}) + user_msg = prompt_template.build_user(ctx, req.n) messages: list[dict] = [ {"role": "system", "content": prompt_template.system}, {"role": "user", "content": user_msg}, diff --git a/ml/serving/prompts.py b/ml/serving/prompts.py index 2eec45d..6a04a98 100644 --- a/ml/serving/prompts.py +++ b/ml/serving/prompts.py @@ -23,6 +23,7 @@ class _Ctx(Protocol): hour_of_day: int day_of_week: int extra: dict + profile_features: "dict | None" @dataclass(frozen=True) @@ -33,13 +34,29 @@ class Prompt: def _base_user_lines(ctx: "_Ctx") -> list[str]: + # Overdue tasks first, then high-priority, then oldest — most actionable context at top + tasks = sorted( + ctx.tasks, + key=lambda t: (not t.get("is_overdue", False), -t.get("priority", 1), -t.get("task_age_days", 0.0)), + ) lines = [f"Time: {ctx.hour_of_day:02d}:00, day_of_week={ctx.day_of_week}"] - if ctx.tasks: - overdue = [t for t in ctx.tasks if t.get("is_overdue")] - lines.append(f"Tasks: {len(ctx.tasks)} total, {len(overdue)} overdue") - for t in ctx.tasks[:5]: + if tasks: + overdue = [t for t in tasks if t.get("is_overdue")] + lines.append(f"Tasks: {len(tasks)} total, {len(overdue)} overdue") + for t in tasks[:5]: due = t.get("due_date", "no due date") lines.append(f" - [{t.get('priority','?')}] {t.get('content','?')} (due: {due})") + p = getattr(ctx, "profile_features", None) or {} + if p: + parts: list[str] = [] + if (v := p.get("completion_rate_30d")) is not None: + parts.append(f"completion_rate={float(v):.0%}") + if (v := p.get("dismiss_rate_30d")) is not None: + parts.append(f"dismiss_rate={float(v):.0%}") + if (v := p.get("preferred_hour")) is not None: + parts.append(f"preferred_hour={int(v):02d}:00") + if parts: + lines.append(f"User profile: {', '.join(parts)}") for k, v in ctx.extra.items(): lines.append(f"{k}: {v}") return lines diff --git a/ml/serving/tests/test_generate.py b/ml/serving/tests/test_generate.py index 180875c..b25f079 100644 --- a/ml/serving/tests/test_generate.py +++ b/ml/serving/tests/test_generate.py @@ -127,6 +127,46 @@ def test_build_prompt_empty_tasks_no_task_line(): assert "Generate 2 tips" in prompt +def test_build_prompt_tasks_sorted_overdue_first(): + tasks = [ + {"content": "Low priority", "priority": 1, "is_overdue": False, "task_age_days": 0}, + {"content": "Overdue task", "priority": 2, "is_overdue": True, "task_age_days": 3}, + ] + ctx = PromptContext(tasks=tasks, hour_of_day=9) + prompt = _build_user_v1(ctx, n=2) + assert prompt.index("Overdue task") < prompt.index("Low priority") + + +def test_build_prompt_includes_profile_features(): + ctx = PromptContext( + tasks=[], + hour_of_day=14, + profile_features={"completion_rate_30d": 0.75, "dismiss_rate_30d": 0.1, "preferred_hour": 9}, + ) + prompt = _build_user_v1(ctx, n=1) + assert "User profile:" in prompt + assert "completion_rate=75%" in prompt + assert "dismiss_rate=10%" in prompt + assert "preferred_hour=09:00" in prompt + + +def test_build_prompt_no_profile_line_when_empty(): + ctx = PromptContext(tasks=[], hour_of_day=10, profile_features={}) + prompt = _build_user_v1(ctx, n=1) + assert "User profile:" not in prompt + + +def test_build_prompt_profile_partial_fields(): + ctx = PromptContext( + tasks=[], + hour_of_day=10, + profile_features={"completion_rate_30d": 0.5}, + ) + prompt = _build_user_v1(ctx, n=1) + assert "completion_rate=50%" in prompt + assert "dismiss_rate" not in prompt + + @pytest.mark.anyio async def test_generate_retry_succeeds_on_second_attempt(): """First response is invalid JSON; second is valid. Should return 200.""" @@ -271,6 +311,38 @@ async def test_generate_echoes_selected_prompt_version(): assert resp.json()["prompt_version"] == "v2-mentor" +@pytest.mark.anyio +async def test_generate_passes_profile_features_to_prompt(): + """profile_features from GenerateRequest should appear in the user message sent to LiteLLM.""" + fake_items = [{"id": "tip-1", "content": "x", "rationale": "y"}] + mock_resp = _litellm_response(fake_items) + captured_payload: list[dict] = [] + + async def _capture(url, *, json, headers): + captured_payload.append(json) + return mock_resp + + with patch("main.httpx.AsyncClient") as MockClient: + instance = AsyncMock() + instance.post = AsyncMock(side_effect=_capture) + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value = instance + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + resp = await client.post("/generate", json={ + "user_id": "u1", + "n": 1, + "profile_features": {"completion_rate_30d": 0.8, "preferred_hour": 10}, + }) + + assert resp.status_code == 200 + user_msg = captured_payload[0]["messages"][1]["content"] + assert "User profile:" in user_msg + assert "completion_rate=80%" in user_msg + assert "preferred_hour=10:00" in user_msg + + @pytest.mark.anyio async def test_generate_422_on_unknown_prompt_version(): async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: