feat(agents): per-agent inference — momentum, overdue-task, recent-patterns, focus-area (ADR-0014 step 7)
All four agents bumped to v1.1.0. momentum (#114): infers engagement_trend ('up'|'stable'|'down') by comparing done-rate in the last 7 days vs the prior 7 days. Agent surfaces the trend in its snippet ("trending up — build on the momentum"). overdue-task (#115): infers lateness_tolerance_days (0/1/2) from snooze rate. Agent now filters tasks against the tolerance so low-urgency users aren't nagged about tasks that are only hours overdue. recent-patterns (#116): infers window_days (7/14/30) from feedback event density — sparse users get a wider window so the snippet isn't always empty. focus-area (#113): no inferred params (project-level feedback linkage needed, tracked under #78). preferred_areas pref was declared but ignored; agent now honours it as a tiebreaker and mentions it in the snippet. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
14
CLAUDE.md
14
CLAUDE.md
@@ -106,11 +106,10 @@ Recent completions:
|
||||
- Model benchmarking for tip generation (#93, #95)
|
||||
- Admin UX refinements: feedback consolidation, settings placement (#100–102)
|
||||
- ADR-0012 — ε-greedy v2 (D=12) — 2026-04-26 (now superseded by ADR-0013)
|
||||
- ADR-0014 steps 1–6: unified Profile schema + backfill, manifest plumbing, `/api/profile` read-through, registry-driven eligibility filter, inference framework + time-of-day migration — 2026-05-05
|
||||
- ADR-0014 steps 1–7: unified Profile schema + backfill, manifest plumbing, `/api/profile` read-through, registry-driven eligibility filter, inference framework + per-agent inference (#112–#116) — 2026-05-05
|
||||
|
||||
Active work (M2):
|
||||
- ADR-0014 step 7 — per-agent inference: focus-area (#113), momentum (#114), overdue-task (#115), recent-patterns (#116)
|
||||
- ADR-0014 step 8 — drop `users.consentGiven` column
|
||||
- ADR-0014 step 8 — drop `users.consentGiven` column (one release after step 2)
|
||||
- Signal abstraction for multi-source support (#78)
|
||||
- Per-user feature freshness SLAs (#61, ADR-0011 phase B)
|
||||
|
||||
@@ -133,7 +132,14 @@ Lives in `ml/agents/inference/`. `run_inference(manifest, history)` evaluates al
|
||||
- `infer()` error → emit `cold_start_default` (never crashes)
|
||||
- Results written to `user_preferences` with `source='inferred'`; keys with `source='user'` are never overwritten
|
||||
|
||||
Time-of-day agent (`1.1.0`) is the proof agent (#112): infers `preferred_hour` (mode done-hour) and reads `quiet_start`/`quiet_end` from prefs.
|
||||
All five agents are at v1.1.0. Per-agent inferred params:
|
||||
| Agent | Inferred param | Logic |
|
||||
|-------|---------------|-------|
|
||||
| `time-of-day` | `preferred_hour` | Mode done-hour from feedback history |
|
||||
| `momentum` | `engagement_trend` | Done-rate last 7d vs prior 7d |
|
||||
| `overdue-task` | `lateness_tolerance_days` | Snooze rate → 0/1/2 days |
|
||||
| `recent-patterns` | `window_days` | Event density → 7/14/30 days |
|
||||
| `focus-area` | *(none yet)* | Needs project-level feedback linkage (#78) |
|
||||
|
||||
## What NOT to do
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import ClassVar
|
||||
|
||||
from .base import BaseAgent, AgentInput, AgentOutput
|
||||
from .manifest import AgentManifest
|
||||
|
||||
|
||||
MANIFEST = AgentManifest(
|
||||
id="focus-area",
|
||||
version="1.0.0",
|
||||
version="1.1.0", # bumped: preferred_areas pref is now honoured in compute (#113)
|
||||
description="Identifies the most congested project/area in the user's task list.",
|
||||
pref_schema={
|
||||
"type": "object",
|
||||
@@ -25,6 +27,9 @@ MANIFEST = AgentManifest(
|
||||
required_consents=["data:core", "data:todoist", "agent:focus-area"],
|
||||
output_contract={"type": "snippet", "format": "free_text"},
|
||||
ttl_sec=43_200,
|
||||
# No inferred_params: preferred_areas requires project-level feedback linkage
|
||||
# that isn't available in feedback_history alone. Revisit with #78 (signal
|
||||
# abstraction) once per-task reactions can be traced back to a project.
|
||||
)
|
||||
|
||||
|
||||
@@ -35,6 +40,7 @@ class FocusAreaAgent(BaseAgent):
|
||||
version: ClassVar[str] = MANIFEST.version
|
||||
|
||||
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||
preferred: list[str] = inp.agent_prefs.get("preferred_areas", [])
|
||||
by_project: dict[str, list[dict]] = defaultdict(list)
|
||||
for task in inp.tasks:
|
||||
project = task.get("project_id") or task.get("project") or "default"
|
||||
@@ -44,19 +50,27 @@ class FocusAreaAgent(BaseAgent):
|
||||
prompt = "No tasks available to identify a focus area."
|
||||
return self._make_output(inp, prompt, {"project_count": 0})
|
||||
|
||||
# Score each project: overdue tasks count double
|
||||
def score(tasks: list[dict]) -> float:
|
||||
return sum(2.0 if t.get("is_overdue") else 1.0 for t in tasks)
|
||||
def score(project: str, tasks: list[dict]) -> tuple[float, bool]:
|
||||
base = sum(2.0 if t.get("is_overdue") else 1.0 for t in tasks)
|
||||
# Boost preferred areas to break ties in their favour
|
||||
boosted = project in preferred or any(p in project for p in preferred)
|
||||
return (base + (0.5 if boosted else 0.0), boosted)
|
||||
|
||||
top_project, top_tasks = max(by_project.items(), key=lambda kv: score(kv[1]))
|
||||
top_project, top_tasks = max(
|
||||
by_project.items(),
|
||||
key=lambda kv: score(kv[0], kv[1]),
|
||||
)
|
||||
overdue_in_top = sum(1 for t in top_tasks if t.get("is_overdue"))
|
||||
label = "the default project" if top_project == "default" else f'"{top_project}"'
|
||||
n = len(top_tasks)
|
||||
boosted = top_project in preferred or any(p in top_project for p in preferred)
|
||||
|
||||
parts = [
|
||||
f"The user's most congested area is {label} "
|
||||
f"({n} task{'s' if n != 1 else ''}, {overdue_in_top} overdue)."
|
||||
]
|
||||
if boosted:
|
||||
parts.append("This area matches the user's stated focus preferences.")
|
||||
if overdue_in_top >= 3:
|
||||
parts.append("Consider surfacing an action from this area.")
|
||||
|
||||
@@ -66,5 +80,6 @@ class FocusAreaAgent(BaseAgent):
|
||||
"top_task_count": n,
|
||||
"top_overdue_count": overdue_in_top,
|
||||
"project_count": len(by_project),
|
||||
"preferred_areas": preferred,
|
||||
}
|
||||
return self._make_output(inp, prompt, snapshot)
|
||||
|
||||
@@ -1,12 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import ClassVar
|
||||
|
||||
from .base import BaseAgent, AgentInput, AgentOutput
|
||||
from .manifest import AgentManifest
|
||||
from .inference.history import UserHistory
|
||||
from .manifest import AgentManifest, InferredParam
|
||||
|
||||
|
||||
def _infer_engagement_trend(history: UserHistory) -> str:
|
||||
"""Compare done-rate in the most recent 7 days vs the 7 days before that."""
|
||||
events = sorted(history.events, key=lambda e: e.created_at)
|
||||
if not events:
|
||||
return "stable"
|
||||
|
||||
try:
|
||||
latest = datetime.fromisoformat(events[-1].created_at.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return "stable"
|
||||
|
||||
cutoff_recent = latest - timedelta(days=7)
|
||||
cutoff_older = latest - timedelta(days=14)
|
||||
|
||||
recent = [e for e in events if _parse_dt(e.created_at) >= cutoff_recent]
|
||||
older = [e for e in events if cutoff_older <= _parse_dt(e.created_at) < cutoff_recent]
|
||||
|
||||
if len(older) < 3:
|
||||
return "stable" # not enough baseline to compare
|
||||
|
||||
recent_rate = sum(1 for e in recent if e.action == "done") / max(len(recent), 1)
|
||||
older_rate = sum(1 for e in older if e.action == "done") / max(len(older), 1)
|
||||
|
||||
delta = recent_rate - older_rate
|
||||
if delta > 0.10:
|
||||
return "up"
|
||||
if delta < -0.10:
|
||||
return "down"
|
||||
return "stable"
|
||||
|
||||
|
||||
def _parse_dt(iso: str) -> datetime:
|
||||
try:
|
||||
dt = datetime.fromisoformat(iso.replace("Z", "+00:00"))
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
return datetime.min.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
MANIFEST = AgentManifest(
|
||||
id="momentum",
|
||||
version="1.0.0",
|
||||
version="1.1.0", # bumped: engagement_trend InferredParam added (#114)
|
||||
description="Characterises the user's recent engagement trend from profile features.",
|
||||
pref_schema={
|
||||
"type": "object",
|
||||
@@ -25,6 +70,15 @@ MANIFEST = AgentManifest(
|
||||
required_consents=["data:core", "agent:momentum"],
|
||||
output_contract={"type": "snippet", "format": "free_text"},
|
||||
ttl_sec=21_600,
|
||||
inferred_params=[
|
||||
InferredParam(
|
||||
key="engagement_trend",
|
||||
ttl_sec=21_600, # recompute every 6 hours alongside snippet
|
||||
cold_start_default="stable",
|
||||
min_history=10,
|
||||
infer=_infer_engagement_trend,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -38,6 +92,7 @@ class MomentumAgent(BaseAgent):
|
||||
completion = inp.profile.get("completion_rate_30d")
|
||||
dismiss = inp.profile.get("dismiss_rate_30d")
|
||||
volume = inp.profile.get("tip_volume_30d")
|
||||
trend: str = inp.agent_prefs.get("engagement_trend", "stable")
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
@@ -65,10 +120,16 @@ class MomentumAgent(BaseAgent):
|
||||
if volume is not None and int(volume) < 5:
|
||||
parts.append("Very few tips served so far — this is an early-stage user.")
|
||||
|
||||
if trend == "up":
|
||||
parts.append("Engagement is trending up compared to last week — build on the momentum.")
|
||||
elif trend == "down":
|
||||
parts.append("Engagement is trending down — a motivational or easy-win tip may help.")
|
||||
|
||||
prompt = " ".join(parts) if parts else "No engagement data available yet."
|
||||
snapshot = {
|
||||
"completion_rate_30d": completion,
|
||||
"dismiss_rate_30d": dismiss,
|
||||
"tip_volume_30d": volume,
|
||||
"engagement_trend": trend,
|
||||
}
|
||||
return self._make_output(inp, prompt, snapshot)
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from .base import BaseAgent, AgentInput, AgentOutput
|
||||
from .manifest import AgentManifest
|
||||
from .inference.history import UserHistory
|
||||
from .manifest import AgentManifest, InferredParam
|
||||
|
||||
|
||||
def _infer_lateness_tolerance(history: UserHistory) -> int:
|
||||
"""Estimate how many days past due a task needs to be before the user acts.
|
||||
|
||||
High snooze rate → user doesn't act immediately → raise tolerance so the
|
||||
agent doesn't nag them about tasks they'll handle in their own time.
|
||||
"""
|
||||
total = len(history.events)
|
||||
if total == 0:
|
||||
return 0
|
||||
snooze_rate = sum(1 for e in history.events if e.action == "snooze") / total
|
||||
if snooze_rate > 0.40:
|
||||
return 2
|
||||
if snooze_rate > 0.20:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
MANIFEST = AgentManifest(
|
||||
id="overdue-task",
|
||||
version="1.0.0",
|
||||
version="1.1.0", # bumped: lateness_tolerance_days InferredParam added (#115)
|
||||
description="Reports the user's overdue tasks by count and age.",
|
||||
pref_schema={
|
||||
"type": "object",
|
||||
@@ -25,6 +45,15 @@ MANIFEST = AgentManifest(
|
||||
output_contract={"type": "snippet", "format": "free_text"},
|
||||
ttl_sec=3600,
|
||||
silenced_in_contexts=["vacation"],
|
||||
inferred_params=[
|
||||
InferredParam(
|
||||
key="lateness_tolerance_days",
|
||||
ttl_sec=86_400, # recompute daily — snooze pattern shifts slowly
|
||||
cold_start_default=0,
|
||||
min_history=10,
|
||||
infer=_infer_lateness_tolerance,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +64,11 @@ class OverdueTaskAgent(BaseAgent):
|
||||
version: ClassVar[str] = MANIFEST.version
|
||||
|
||||
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||
overdue = [t for t in inp.tasks if t.get("is_overdue")]
|
||||
tolerance = max(0, int(inp.agent_prefs.get("lateness_tolerance_days", 0)))
|
||||
overdue = [
|
||||
t for t in inp.tasks
|
||||
if t.get("is_overdue") and t.get("task_age_days", 0) >= tolerance
|
||||
]
|
||||
top = sorted(overdue, key=lambda t: -t.get("task_age_days", 0))[:3]
|
||||
|
||||
if not overdue:
|
||||
@@ -59,6 +92,7 @@ class OverdueTaskAgent(BaseAgent):
|
||||
|
||||
snapshot = {
|
||||
"overdue_count": len(overdue),
|
||||
"lateness_tolerance_days": tolerance,
|
||||
"top_overdue": [
|
||||
{"content": t["content"], "task_age_days": t.get("task_age_days", 0)}
|
||||
for t in top
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from datetime import datetime, timezone
|
||||
from typing import ClassVar
|
||||
from .base import BaseAgent, AgentInput, AgentOutput
|
||||
from .manifest import AgentManifest
|
||||
|
||||
_SEVEN_DAYS_S = 7 * 86_400
|
||||
from .base import BaseAgent, AgentInput, AgentOutput
|
||||
from .inference.history import UserHistory
|
||||
from .manifest import AgentManifest, InferredParam
|
||||
|
||||
|
||||
def _infer_window_days(history: UserHistory) -> int:
|
||||
"""Infer the optimal lookback window from feedback event density.
|
||||
|
||||
More events per day → a shorter window captures the user's current state
|
||||
accurately. Sparse feedback → widen the window to gather signal.
|
||||
"""
|
||||
n = len(history.events)
|
||||
if n >= 14:
|
||||
return 7
|
||||
if n >= 7:
|
||||
return 14
|
||||
return 30
|
||||
|
||||
|
||||
MANIFEST = AgentManifest(
|
||||
id="recent-patterns",
|
||||
version="1.0.0",
|
||||
description="Surfaces the user's reaction pattern from the last 7 days of feedback.",
|
||||
version="1.1.0", # bumped: window_days InferredParam added (#116)
|
||||
description="Surfaces the user's reaction pattern from recent feedback.",
|
||||
pref_schema={
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
@@ -29,20 +44,32 @@ MANIFEST = AgentManifest(
|
||||
required_consents=["data:core", "agent:recent-patterns"],
|
||||
output_contract={"type": "snippet", "format": "free_text"},
|
||||
ttl_sec=86_400,
|
||||
inferred_params=[
|
||||
InferredParam(
|
||||
key="window_days",
|
||||
ttl_sec=86_400, # recompute daily alongside snippet
|
||||
cold_start_default=7,
|
||||
min_history=5,
|
||||
infer=_infer_window_days,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class RecentPatternsAgent(BaseAgent):
|
||||
"""Surfaces the user's reaction pattern from the last 7 days of feedback."""
|
||||
"""Surfaces the user's reaction pattern from recent feedback."""
|
||||
agent_id: ClassVar[str] = MANIFEST.id
|
||||
ttl_seconds: ClassVar[int] = MANIFEST.ttl_sec
|
||||
version: ClassVar[str] = MANIFEST.version
|
||||
|
||||
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||
window_days = max(1, int(inp.agent_prefs.get("window_days", 7)))
|
||||
window_s = window_days * 86_400
|
||||
now_ts = inp.now.timestamp()
|
||||
|
||||
recent = [
|
||||
f for f in inp.feedback_history
|
||||
if self._age_s(f.get("created_at", ""), now_ts) <= _SEVEN_DAYS_S
|
||||
if self._age_s(f.get("created_at", ""), now_ts) <= window_s
|
||||
]
|
||||
|
||||
counts: Counter[str] = Counter(f.get("action") for f in recent)
|
||||
@@ -50,13 +77,13 @@ class RecentPatternsAgent(BaseAgent):
|
||||
dwell_ms = inp.profile.get("mean_dwell_ms_30d")
|
||||
|
||||
if total == 0:
|
||||
prompt = "No tip reactions recorded in the last 7 days."
|
||||
prompt = f"No tip reactions recorded in the last {window_days} days."
|
||||
else:
|
||||
done = counts.get("done", 0)
|
||||
dismissed = counts.get("dismiss", 0)
|
||||
snoozed = counts.get("snooze", 0)
|
||||
parts = [
|
||||
f"Last 7 days: {total} tip reaction{'s' if total != 1 else ''} — "
|
||||
f"Last {window_days} days: {total} tip reaction{'s' if total != 1 else ''} — "
|
||||
f"{done} completed, {dismissed} dismissed, {snoozed} snoozed."
|
||||
]
|
||||
if dwell_ms is not None:
|
||||
@@ -74,6 +101,7 @@ class RecentPatternsAgent(BaseAgent):
|
||||
prompt = " ".join(parts)
|
||||
|
||||
snapshot = {
|
||||
"window_days": window_days,
|
||||
"recent_total": total,
|
||||
"action_counts": dict(counts),
|
||||
"mean_dwell_ms_30d": dwell_ms,
|
||||
|
||||
@@ -243,7 +243,7 @@ class TestFocusAreaAgent:
|
||||
|
||||
def test_snapshot_keys(self):
|
||||
out = self.agent.compute(_inp(tasks=[_task("T1", project_id="A")]))
|
||||
assert {"top_project", "top_task_count", "top_overdue_count", "project_count"} == set(out.signals_snapshot)
|
||||
assert {"top_project", "top_task_count", "top_overdue_count", "project_count", "preferred_areas"} == set(out.signals_snapshot)
|
||||
|
||||
|
||||
# ── Registry ─────────────────────────────────────────────────────────────────
|
||||
|
||||
213
ml/agents/tests/test_per_agent_inference.py
Normal file
213
ml/agents/tests/test_per_agent_inference.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Per-agent inference tests: momentum (#114), overdue-task (#115), recent-patterns (#116),
|
||||
and focus-area (#113) preferred_areas wiring."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import pytest
|
||||
|
||||
from ml.agents.inference.history import FeedbackEvent, UserHistory
|
||||
from ml.agents.inference.framework import run_inference
|
||||
from ml.agents.momentum import MomentumAgent, MANIFEST as MOMENTUM_MANIFEST
|
||||
from ml.agents.overdue_task import OverdueTaskAgent, MANIFEST as OVERDUE_MANIFEST
|
||||
from ml.agents.recent_patterns import RecentPatternsAgent, MANIFEST as RECENT_MANIFEST
|
||||
from ml.agents.focus_area import FocusAreaAgent
|
||||
from ml.agents.base import AgentInput
|
||||
|
||||
_NOW = datetime(2026, 5, 8, 14, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _inp(**kwargs) -> AgentInput:
|
||||
defaults = dict(user_id="u1", tasks=[], profile={}, now=_NOW, agent_prefs={})
|
||||
defaults.update(kwargs)
|
||||
return AgentInput(**defaults)
|
||||
|
||||
|
||||
def _event(action: str, days_ago: float = 1.0) -> FeedbackEvent:
|
||||
from datetime import timedelta
|
||||
ts = (_NOW - timedelta(days=days_ago)).isoformat()
|
||||
dwell = 60_000 if action == "done" else 500
|
||||
return FeedbackEvent(action=action, dwell_ms=dwell, created_at=ts)
|
||||
|
||||
|
||||
def _history(*events: FeedbackEvent) -> UserHistory:
|
||||
return UserHistory(user_id="u1", events=list(events))
|
||||
|
||||
|
||||
# ── momentum: engagement_trend ───────────────────────────────────────────────
|
||||
|
||||
class TestMomentumInference:
|
||||
def test_cold_start_below_min_history(self):
|
||||
history = _history(*[_event("done", days_ago=i) for i in range(5)])
|
||||
result = run_inference(MOMENTUM_MANIFEST, history)
|
||||
assert result["engagement_trend"] == "stable" # cold_start_default
|
||||
|
||||
def test_trend_up_when_recent_done_rate_higher(self):
|
||||
# 8 done in last 7 days, 1 done in prior 7 days → trending up
|
||||
recent = [_event("done", days_ago=i) for i in range(1, 9)]
|
||||
older = [_event("dismiss", days_ago=i) for i in range(8, 15)]
|
||||
older[0] = _event("done", days_ago=8) # one done in older window
|
||||
history = _history(*recent, *older)
|
||||
result = run_inference(MOMENTUM_MANIFEST, history)
|
||||
assert result["engagement_trend"] == "up"
|
||||
|
||||
def test_trend_down_when_recent_done_rate_lower(self):
|
||||
recent = [_event("dismiss", days_ago=i) for i in range(1, 8)]
|
||||
older = [_event("done", days_ago=i) for i in range(8, 15)]
|
||||
history = _history(*recent, *older)
|
||||
result = run_inference(MOMENTUM_MANIFEST, history)
|
||||
assert result["engagement_trend"] == "down"
|
||||
|
||||
def test_trend_stable_when_similar(self):
|
||||
events = [_event("done" if i % 2 == 0 else "dismiss", days_ago=i) for i in range(1, 15)]
|
||||
history = _history(*events)
|
||||
result = run_inference(MOMENTUM_MANIFEST, history)
|
||||
assert result["engagement_trend"] == "stable"
|
||||
|
||||
def test_agent_uses_trend_in_snippet(self):
|
||||
out = MomentumAgent().compute(_inp(agent_prefs={"engagement_trend": "up"}))
|
||||
assert "trending up" in out.prompt_text
|
||||
|
||||
def test_agent_uses_down_trend_in_snippet(self):
|
||||
out = MomentumAgent().compute(_inp(agent_prefs={"engagement_trend": "down"}))
|
||||
assert "trending down" in out.prompt_text
|
||||
|
||||
def test_snapshot_includes_trend(self):
|
||||
out = MomentumAgent().compute(_inp(agent_prefs={"engagement_trend": "stable"}))
|
||||
assert "engagement_trend" in out.signals_snapshot
|
||||
|
||||
def test_version_bumped(self):
|
||||
assert MOMENTUM_MANIFEST.version == "1.1.0"
|
||||
|
||||
|
||||
# ── overdue-task: lateness_tolerance_days ────────────────────────────────────
|
||||
|
||||
class TestOverdueTaskInference:
|
||||
def test_cold_start_returns_zero(self):
|
||||
history = _history(*[_event("done") for _ in range(5)])
|
||||
result = run_inference(OVERDUE_MANIFEST, history)
|
||||
assert result["lateness_tolerance_days"] == 0
|
||||
|
||||
def test_high_snooze_rate_returns_two(self):
|
||||
events = [_event("snooze")] * 8 + [_event("done")] * 2
|
||||
history = _history(*events)
|
||||
result = run_inference(OVERDUE_MANIFEST, history)
|
||||
assert result["lateness_tolerance_days"] == 2
|
||||
|
||||
def test_moderate_snooze_returns_one(self):
|
||||
events = [_event("snooze")] * 3 + [_event("done")] * 7
|
||||
history = _history(*events)
|
||||
result = run_inference(OVERDUE_MANIFEST, history)
|
||||
assert result["lateness_tolerance_days"] == 1
|
||||
|
||||
def test_low_snooze_returns_zero(self):
|
||||
events = [_event("done")] * 9 + [_event("snooze")] * 1
|
||||
history = _history(*events)
|
||||
result = run_inference(OVERDUE_MANIFEST, history)
|
||||
assert result["lateness_tolerance_days"] == 0
|
||||
|
||||
def test_tolerance_filters_tasks(self):
|
||||
tasks = [
|
||||
{"content": "Fresh overdue", "is_overdue": True, "task_age_days": 0.5},
|
||||
{"content": "Old overdue", "is_overdue": True, "task_age_days": 3.0},
|
||||
]
|
||||
# tolerance=2 → only the 3-day task should count
|
||||
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs={"lateness_tolerance_days": 2}))
|
||||
assert "1 overdue task" in out.prompt_text
|
||||
assert "Old overdue" in out.prompt_text
|
||||
|
||||
def test_snapshot_includes_tolerance(self):
|
||||
tasks = [{"content": "T", "is_overdue": True, "task_age_days": 1.0}]
|
||||
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs={"lateness_tolerance_days": 0}))
|
||||
assert "lateness_tolerance_days" in out.signals_snapshot
|
||||
|
||||
def test_version_bumped(self):
|
||||
assert OVERDUE_MANIFEST.version == "1.1.0"
|
||||
|
||||
|
||||
# ── recent-patterns: window_days ─────────────────────────────────────────────
|
||||
|
||||
class TestRecentPatternsInference:
|
||||
def test_cold_start_default_7(self):
|
||||
history = _history(*[_event("done") for _ in range(3)]) # below min_history=5
|
||||
result = run_inference(RECENT_MANIFEST, history)
|
||||
assert result["window_days"] == 7 # cold_start_default
|
||||
|
||||
def test_sparse_history_widens_window(self):
|
||||
history = _history(*[_event("done") for _ in range(5)]) # 5 events, n < 7 → 30 days
|
||||
result = run_inference(RECENT_MANIFEST, history)
|
||||
assert result["window_days"] == 30
|
||||
|
||||
def test_moderate_history_14_days(self):
|
||||
history = _history(*[_event("done") for _ in range(10)]) # 7 ≤ n < 14 → 14 days
|
||||
result = run_inference(RECENT_MANIFEST, history)
|
||||
assert result["window_days"] == 14
|
||||
|
||||
def test_dense_history_stays_7(self):
|
||||
history = _history(*[_event("done") for _ in range(20)]) # 20+ → 7 days
|
||||
result = run_inference(RECENT_MANIFEST, history)
|
||||
assert result["window_days"] == 7
|
||||
|
||||
def test_agent_uses_window_days_pref(self):
|
||||
from datetime import timedelta
|
||||
# 5 feedback events, all within 14 days but older than 7 days
|
||||
feedback = [
|
||||
{"action": "done", "dwell_ms": 60000,
|
||||
"created_at": (_NOW - timedelta(days=10)).isoformat()}
|
||||
] * 5
|
||||
# With window_days=7 → 0 events seen; with window_days=14 → 5 events
|
||||
out_narrow = RecentPatternsAgent().compute(
|
||||
_inp(feedback_history=feedback, agent_prefs={"window_days": 7})
|
||||
)
|
||||
out_wide = RecentPatternsAgent().compute(
|
||||
_inp(feedback_history=feedback, agent_prefs={"window_days": 14})
|
||||
)
|
||||
assert "No tip reactions" in out_narrow.prompt_text
|
||||
assert "5 tip reactions" in out_wide.prompt_text
|
||||
|
||||
def test_snapshot_includes_window_days(self):
|
||||
out = RecentPatternsAgent().compute(_inp(agent_prefs={"window_days": 14}))
|
||||
assert out.signals_snapshot["window_days"] == 14
|
||||
|
||||
def test_version_bumped(self):
|
||||
assert RECENT_MANIFEST.version == "1.1.0"
|
||||
|
||||
|
||||
# ── focus-area: preferred_areas wiring ───────────────────────────────────────
|
||||
|
||||
class TestFocusAreaPreferredAreas:
|
||||
agent = FocusAreaAgent()
|
||||
|
||||
def _task(self, content: str, project_id: str, is_overdue: bool = False) -> dict:
|
||||
return {"id": "t1", "content": content, "is_overdue": is_overdue,
|
||||
"task_age_days": 2.0, "priority": 1, "project_id": project_id}
|
||||
|
||||
def test_preferred_area_wins_tie(self):
|
||||
tasks = [
|
||||
self._task("Work thing", "work"),
|
||||
self._task("Home thing", "home"),
|
||||
]
|
||||
out = self.agent.compute(_inp(tasks=tasks, agent_prefs={"preferred_areas": ["work"]}))
|
||||
assert "work" in out.prompt_text
|
||||
assert "matches the user's stated focus preferences" in out.prompt_text
|
||||
|
||||
def test_no_preferred_areas_uses_congestion_score(self):
|
||||
tasks = [
|
||||
self._task("W1", "work"),
|
||||
self._task("H1", "home"),
|
||||
self._task("H2", "home"),
|
||||
]
|
||||
out = self.agent.compute(_inp(tasks=tasks))
|
||||
# home has more tasks → wins without any preference
|
||||
assert "home" in out.prompt_text
|
||||
|
||||
def test_snapshot_includes_preferred_areas(self):
|
||||
tasks = [self._task("T", "work")]
|
||||
out = self.agent.compute(_inp(tasks=tasks, agent_prefs={"preferred_areas": ["work"]}))
|
||||
assert out.signals_snapshot["preferred_areas"] == ["work"]
|
||||
|
||||
def test_version_bumped(self):
|
||||
from ml.agents.focus_area import MANIFEST as FA_MANIFEST
|
||||
assert FA_MANIFEST.version == "1.1.0"
|
||||
@@ -36,10 +36,10 @@ async def test_infer_time_of_day_enough_history():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_infer_agent_with_no_inferred_params():
|
||||
"""Agents with no inferred_params return an empty dict."""
|
||||
"""Agents with no inferred_params return an empty dict (focus-area has none)."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post("/agents/overdue-task/infer", json={"user_id": "u1", "feedback_history": []})
|
||||
resp = await client.post("/agents/focus-area/infer", json={"user_id": "u1", "feedback_history": []})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["inferred_prefs"] == {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user