"""Unit tests for all sub-agents and the registry.""" from __future__ import annotations import sys, os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) from datetime import datetime, timezone import pytest from ml.agents.base import AgentInput, AgentOutput from ml.agents.overdue_task import OverdueTaskAgent from ml.agents.momentum import MomentumAgent from ml.agents.time_of_day import TimeOfDayAgent from ml.agents.recent_patterns import RecentPatternsAgent from ml.agents.focus_area import FocusAreaAgent from ml.agents.registry import get_agent, all_agents _NOW = datetime(2026, 5, 1, 9, 0, 0, tzinfo=timezone.utc) # Thursday 09:00 UTC def _inp(**kwargs) -> AgentInput: defaults = dict( user_id="u1", tasks=[], profile={}, feedback_history=[], now=_NOW, ) defaults.update(kwargs) return AgentInput(**defaults) def _task(content="Do thing", is_overdue=False, task_age_days=0.0, priority=1, project_id=None): t = {"id": "t1", "content": content, "is_overdue": is_overdue, "task_age_days": task_age_days, "priority": priority} if project_id: t["project_id"] = project_id return t # ── helpers ────────────────────────────────────────────────────────────────── def _check_output(out: AgentOutput, agent) -> None: assert isinstance(out, AgentOutput) assert out.user_id == "u1" assert out.agent_id == agent.agent_id assert out.prompt_text assert out.computed_at assert out.expires_at > out.computed_at assert out.agent_version == agent.version # ── OverdueTaskAgent ────────────────────────────────────────────────────────── class TestOverdueTaskAgent: agent = OverdueTaskAgent() def test_no_overdue(self): out = self.agent.compute(_inp(tasks=[_task("Read book")])) _check_output(out, self.agent) assert "no overdue" in out.prompt_text.lower() assert out.signals_snapshot["overdue_count"] == 0 def test_single_overdue(self): out = self.agent.compute(_inp(tasks=[_task("Call dentist", is_overdue=True, task_age_days=3)])) _check_output(out, self.agent) assert "1 overdue" in out.prompt_text assert "Call dentist" in out.prompt_text assert "3 day" in out.prompt_text def test_multiple_overdue_top3(self): tasks = [ _task(f"Task {i}", is_overdue=True, task_age_days=float(i)) for i in range(1, 6) ] out = self.agent.compute(_inp(tasks=tasks)) _check_output(out, self.agent) assert "5 overdue" in out.prompt_text assert out.signals_snapshot["overdue_count"] == 5 assert len(out.signals_snapshot["top_overdue"]) == 3 # Top 3 should be highest age: 5, 4, 3 ages = [t["task_age_days"] for t in out.signals_snapshot["top_overdue"]] assert ages == sorted(ages, reverse=True) def test_ttl_respected(self): out = self.agent.compute(_inp()) assert out.expires_at > out.computed_at # ── MomentumAgent ───────────────────────────────────────────────────────────── class TestMomentumAgent: agent = MomentumAgent() def test_no_profile(self): out = self.agent.compute(_inp(profile={})) _check_output(out, self.agent) assert "new user" in out.prompt_text.lower() or "no " in out.prompt_text.lower() def test_strong_engagement(self): out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.65, "dismiss_rate_30d": 0.05})) assert "strong engagement" in out.prompt_text def test_low_completion_warns(self): out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.1})) assert "low engagement" in out.prompt_text assert "actionable" in out.prompt_text def test_high_dismiss_warns(self): out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.3, "dismiss_rate_30d": 0.5})) assert "dismiss rate is high" in out.prompt_text.lower() def test_early_stage_user(self): out = self.agent.compute(_inp(profile={"tip_volume_30d": 2.0})) assert "early-stage" in out.prompt_text # ── TimeOfDayAgent ──────────────────────────────────────────────────────────── class TestTimeOfDayAgent: agent = TimeOfDayAgent() def test_morning_label(self): inp = _inp(now=datetime(2026, 5, 1, 8, 0, tzinfo=timezone.utc)) # Friday out = self.agent.compute(inp) assert "morning" in out.prompt_text assert "08:00" in out.prompt_text def test_weekend_note(self): inp = _inp(now=datetime(2026, 5, 2, 10, 0, tzinfo=timezone.utc)) # Saturday out = self.agent.compute(inp) assert "weekend" in out.prompt_text.lower() def test_peak_hour_exact(self): inp = _inp( now=datetime(2026, 5, 1, 10, 0, tzinfo=timezone.utc), profile={"preferred_hour": 10.0}, ) out = self.agent.compute(inp) assert "peak productivity hour" in out.prompt_text def test_approaching_peak(self): inp = _inp( now=datetime(2026, 5, 1, 9, 0, tzinfo=timezone.utc), profile={"preferred_hour": 10.0}, ) out = self.agent.compute(inp) assert "approaching" in out.prompt_text.lower() def test_no_preferred_hour(self): out = self.agent.compute(_inp()) assert "no preferred-hour" in out.prompt_text.lower() def test_snapshot_keys(self): out = self.agent.compute(_inp()) assert {"hour", "day_of_week", "preferred_hour", "quiet_start", "quiet_end", "peak_hours", "in_quiet", "in_peak", "tz"} == set(out.signals_snapshot) # ── RecentPatternsAgent ─────────────────────────────────────────────────────── class TestRecentPatternsAgent: agent = RecentPatternsAgent() def test_no_feedback(self): out = self.agent.compute(_inp()) assert "no tip reactions" in out.prompt_text.lower() def test_recent_feedback_summary(self): now_iso = _NOW.isoformat() feedback = [ {"action": "done", "dwell_ms": 30000, "created_at": now_iso}, {"action": "done", "dwell_ms": 45000, "created_at": now_iso}, {"action": "dismiss", "dwell_ms": 2000, "created_at": now_iso}, ] out = self.agent.compute(_inp(feedback_history=feedback)) assert "3 tip reactions" in out.prompt_text assert "2 completed" in out.prompt_text assert "1 dismissed" in out.prompt_text def test_old_feedback_excluded(self): # 10 days ago — should be excluded from 7-day window old_iso = "2026-04-21T09:00:00+00:00" feedback = [{"action": "done", "dwell_ms": 5000, "created_at": old_iso}] out = self.agent.compute(_inp(feedback_history=feedback)) assert "no tip reactions" in out.prompt_text.lower() def test_short_dwell_note(self): now_iso = _NOW.isoformat() feedback = [{"action": "done", "dwell_ms": 5000, "created_at": now_iso}] out = self.agent.compute(_inp( feedback_history=feedback, profile={"mean_dwell_ms_30d": 5000.0}, )) assert "auto-pilot" in out.prompt_text.lower() or "short" in out.prompt_text.lower() def test_long_dwell_note(self): now_iso = _NOW.isoformat() feedback = [{"action": "done", "dwell_ms": 90000, "created_at": now_iso}] out = self.agent.compute(_inp( feedback_history=feedback, profile={"mean_dwell_ms_30d": 90000.0}, )) assert "deliberate" in out.prompt_text.lower() or "reflection" in out.prompt_text.lower() # ── FocusAreaAgent ──────────────────────────────────────────────────────────── class TestFocusAreaAgent: agent = FocusAreaAgent() def test_no_tasks(self): out = self.agent.compute(_inp()) assert "no tasks" in out.prompt_text.lower() def test_single_project(self): tasks = [_task(f"T{i}", project_id="Work") for i in range(3)] out = self.agent.compute(_inp(tasks=tasks)) assert '"Work"' in out.prompt_text assert "3 tasks" in out.prompt_text def test_most_congested_wins(self): tasks = ( [_task(f"W{i}", project_id="Work") for i in range(5)] + [_task(f"H{i}", project_id="Home") for i in range(2)] ) out = self.agent.compute(_inp(tasks=tasks)) assert '"Work"' in out.prompt_text def test_overdue_weighting(self): # Home has 2 tasks (1 overdue), Work has 3 non-overdue tasks # Home score = 2+1 = 3; Work score = 3 — Home should win due to overdue weight tasks = ( [_task("Home1", project_id="Home", is_overdue=True), _task("Home2", project_id="Home")] + [_task(f"W{i}", project_id="Work") for i in range(3)] ) out = self.agent.compute(_inp(tasks=tasks)) assert '"Work"' not in out.prompt_text or '"Home"' in out.prompt_text def test_default_project_fallback(self): out = self.agent.compute(_inp(tasks=[_task("No project task")])) assert "default project" in out.prompt_text def test_snapshot_keys(self): out = self.agent.compute(_inp(tasks=[_task("T1", project_id="A")])) assert {"top_project", "top_task_count", "top_overdue_count", "project_count", "preferred_areas"} == set(out.signals_snapshot) # ── Registry ───────────────────────────────────────────────────────────────── class TestRegistry: def test_all_agents_present(self): agents = all_agents() ids = {a.agent_id for a in agents} assert ids == {"overdue-task", "momentum", "time-of-day", "recent-patterns", "focus-area"} def test_get_agent(self): a = get_agent("momentum") assert a.agent_id == "momentum" def test_get_unknown_raises(self): with pytest.raises(KeyError, match="Unknown agent"): get_agent("nonexistent") def test_all_agents_compute(self): inp = _inp( tasks=[_task("Buy milk", is_overdue=True, task_age_days=2, project_id="Personal")], profile={"completion_rate_30d": 0.4, "tip_volume_30d": 10.0, "preferred_hour": 9.0}, feedback_history=[ {"action": "done", "dwell_ms": 25000, "created_at": _NOW.isoformat()} ], ) for agent in all_agents(): out = agent.compute(inp) _check_output(out, agent)