Each unique task title is now enriched by LiteLLM once and cached in the DB. Subsequent agent compute cycles (every 12h) fetch the cache before calling ml-serving; only new titles hit the tip-generator. - DB: task_enrichments(content_hash PK, description, model, created_at) - TS: fetchEnrichmentCache / persistEnrichments helpers in agent-outputs.ts; enrichment_cache passed in compute request, new_enrichments persisted from response - Python: AgentComputeRequest.enrichment_cache / AgentComputeResponse.new_enrichments; AgentInput.enrichment_cache; _enrich_batch returns (descriptions, new_entries); cluster_tasks returns (clusters, new_enrichments) - FocusAreaAgent stashes new_enrichments in signals_snapshot under _new_enrichments; compute_agent endpoint pops it before storing the snapshot Closes part of #129 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
280 lines
11 KiB
Python
280 lines
11 KiB
Python
"""Unit tests for all sub-agents and the registry."""
|
|
from __future__ import annotations
|
|
|
|
import sys, os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
|
|
from datetime import datetime, timezone
|
|
import pytest
|
|
|
|
from ml.agents.base import AgentInput, AgentOutput
|
|
from ml.agents.overdue_task import OverdueTaskAgent
|
|
from ml.agents.momentum import MomentumAgent
|
|
from ml.agents.time_of_day import TimeOfDayAgent
|
|
from ml.agents.recent_patterns import RecentPatternsAgent
|
|
from ml.agents.focus_area import FocusAreaAgent
|
|
from ml.agents.registry import get_agent, all_agents
|
|
|
|
_NOW = datetime(2026, 5, 1, 9, 0, 0, tzinfo=timezone.utc) # Thursday 09:00 UTC
|
|
|
|
|
|
def _inp(**kwargs) -> AgentInput:
|
|
defaults = dict(
|
|
user_id="u1",
|
|
tasks=[],
|
|
profile={},
|
|
feedback_history=[],
|
|
now=_NOW,
|
|
)
|
|
defaults.update(kwargs)
|
|
return AgentInput(**defaults)
|
|
|
|
|
|
def _task(content="Do thing", is_overdue=False, task_age_days=0.0, priority=1, project_id=None):
|
|
t = {"id": "t1", "content": content, "is_overdue": is_overdue,
|
|
"task_age_days": task_age_days, "priority": priority}
|
|
if project_id:
|
|
t["project_id"] = project_id
|
|
return t
|
|
|
|
|
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
def _check_output(out: AgentOutput, agent) -> None:
|
|
assert isinstance(out, AgentOutput)
|
|
assert out.user_id == "u1"
|
|
assert out.agent_id == agent.agent_id
|
|
assert out.prompt_text
|
|
assert out.computed_at
|
|
assert out.expires_at > out.computed_at
|
|
assert out.agent_version == agent.version
|
|
|
|
|
|
# ── OverdueTaskAgent ──────────────────────────────────────────────────────────
|
|
|
|
class TestOverdueTaskAgent:
|
|
agent = OverdueTaskAgent()
|
|
|
|
def test_no_overdue(self):
|
|
out = self.agent.compute(_inp(tasks=[_task("Read book")]))
|
|
_check_output(out, self.agent)
|
|
assert "no overdue" in out.prompt_text.lower()
|
|
assert out.signals_snapshot["overdue_count"] == 0
|
|
|
|
def test_single_overdue(self):
|
|
out = self.agent.compute(_inp(tasks=[_task("Call dentist", is_overdue=True, task_age_days=3)]))
|
|
_check_output(out, self.agent)
|
|
assert "1 overdue" in out.prompt_text
|
|
assert "Call dentist" in out.prompt_text
|
|
assert "3 day" in out.prompt_text
|
|
|
|
def test_multiple_overdue_top3(self):
|
|
tasks = [
|
|
_task(f"Task {i}", is_overdue=True, task_age_days=float(i))
|
|
for i in range(1, 6)
|
|
]
|
|
out = self.agent.compute(_inp(tasks=tasks))
|
|
_check_output(out, self.agent)
|
|
assert "5 overdue" in out.prompt_text
|
|
assert out.signals_snapshot["overdue_count"] == 5
|
|
assert len(out.signals_snapshot["top_overdue"]) == 3
|
|
# Top 3 should be highest age: 5, 4, 3
|
|
ages = [t["task_age_days"] for t in out.signals_snapshot["top_overdue"]]
|
|
assert ages == sorted(ages, reverse=True)
|
|
|
|
def test_ttl_respected(self):
|
|
out = self.agent.compute(_inp())
|
|
assert out.expires_at > out.computed_at
|
|
|
|
|
|
# ── MomentumAgent ─────────────────────────────────────────────────────────────
|
|
|
|
class TestMomentumAgent:
|
|
agent = MomentumAgent()
|
|
|
|
def test_no_profile(self):
|
|
out = self.agent.compute(_inp(profile={}))
|
|
_check_output(out, self.agent)
|
|
assert "new user" in out.prompt_text.lower() or "no " in out.prompt_text.lower()
|
|
|
|
def test_strong_engagement(self):
|
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.65, "dismiss_rate_30d": 0.05}))
|
|
assert "strong engagement" in out.prompt_text
|
|
|
|
def test_low_completion_warns(self):
|
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.1}))
|
|
assert "low engagement" in out.prompt_text
|
|
assert "actionable" in out.prompt_text
|
|
|
|
def test_high_dismiss_warns(self):
|
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.3, "dismiss_rate_30d": 0.5}))
|
|
assert "dismiss rate is high" in out.prompt_text.lower()
|
|
|
|
def test_early_stage_user(self):
|
|
out = self.agent.compute(_inp(profile={"tip_volume_30d": 2.0}))
|
|
assert "early-stage" in out.prompt_text
|
|
|
|
|
|
# ── TimeOfDayAgent ────────────────────────────────────────────────────────────
|
|
|
|
class TestTimeOfDayAgent:
|
|
agent = TimeOfDayAgent()
|
|
|
|
def test_morning_label(self):
|
|
inp = _inp(now=datetime(2026, 5, 1, 8, 0, tzinfo=timezone.utc)) # Friday
|
|
out = self.agent.compute(inp)
|
|
assert "morning" in out.prompt_text
|
|
assert "08:00" in out.prompt_text
|
|
|
|
def test_weekend_note(self):
|
|
inp = _inp(now=datetime(2026, 5, 2, 10, 0, tzinfo=timezone.utc)) # Saturday
|
|
out = self.agent.compute(inp)
|
|
assert "weekend" in out.prompt_text.lower()
|
|
|
|
def test_peak_hour_exact(self):
|
|
inp = _inp(
|
|
now=datetime(2026, 5, 1, 10, 0, tzinfo=timezone.utc),
|
|
profile={"preferred_hour": 10.0},
|
|
)
|
|
out = self.agent.compute(inp)
|
|
assert "peak productivity hour" in out.prompt_text
|
|
|
|
def test_approaching_peak(self):
|
|
inp = _inp(
|
|
now=datetime(2026, 5, 1, 9, 0, tzinfo=timezone.utc),
|
|
profile={"preferred_hour": 10.0},
|
|
)
|
|
out = self.agent.compute(inp)
|
|
assert "approaching" in out.prompt_text.lower()
|
|
|
|
def test_no_preferred_hour(self):
|
|
out = self.agent.compute(_inp())
|
|
assert "no preferred-hour" in out.prompt_text.lower()
|
|
|
|
def test_snapshot_keys(self):
|
|
out = self.agent.compute(_inp())
|
|
assert {"hour", "day_of_week", "preferred_hour", "quiet_start", "quiet_end",
|
|
"peak_hours", "in_quiet", "in_peak", "tz"} == set(out.signals_snapshot)
|
|
|
|
|
|
# ── RecentPatternsAgent ───────────────────────────────────────────────────────
|
|
|
|
class TestRecentPatternsAgent:
|
|
agent = RecentPatternsAgent()
|
|
|
|
def test_no_feedback(self):
|
|
out = self.agent.compute(_inp())
|
|
assert "no tip reactions" in out.prompt_text.lower()
|
|
|
|
def test_recent_feedback_summary(self):
|
|
now_iso = _NOW.isoformat()
|
|
feedback = [
|
|
{"action": "done", "dwell_ms": 30000, "created_at": now_iso},
|
|
{"action": "done", "dwell_ms": 45000, "created_at": now_iso},
|
|
{"action": "dismiss", "dwell_ms": 2000, "created_at": now_iso},
|
|
]
|
|
out = self.agent.compute(_inp(feedback_history=feedback))
|
|
assert "3 tip reactions" in out.prompt_text
|
|
assert "2 completed" in out.prompt_text
|
|
assert "1 dismissed" in out.prompt_text
|
|
|
|
def test_old_feedback_excluded(self):
|
|
# 10 days ago — should be excluded from 7-day window
|
|
old_iso = "2026-04-21T09:00:00+00:00"
|
|
feedback = [{"action": "done", "dwell_ms": 5000, "created_at": old_iso}]
|
|
out = self.agent.compute(_inp(feedback_history=feedback))
|
|
assert "no tip reactions" in out.prompt_text.lower()
|
|
|
|
def test_short_dwell_note(self):
|
|
now_iso = _NOW.isoformat()
|
|
feedback = [{"action": "done", "dwell_ms": 5000, "created_at": now_iso}]
|
|
out = self.agent.compute(_inp(
|
|
feedback_history=feedback,
|
|
profile={"mean_dwell_ms_30d": 5000.0},
|
|
))
|
|
assert "auto-pilot" in out.prompt_text.lower() or "short" in out.prompt_text.lower()
|
|
|
|
def test_long_dwell_note(self):
|
|
now_iso = _NOW.isoformat()
|
|
feedback = [{"action": "done", "dwell_ms": 90000, "created_at": now_iso}]
|
|
out = self.agent.compute(_inp(
|
|
feedback_history=feedback,
|
|
profile={"mean_dwell_ms_30d": 90000.0},
|
|
))
|
|
assert "deliberate" in out.prompt_text.lower() or "reflection" in out.prompt_text.lower()
|
|
|
|
|
|
# ── FocusAreaAgent ────────────────────────────────────────────────────────────
|
|
|
|
class TestFocusAreaAgent:
|
|
agent = FocusAreaAgent()
|
|
|
|
def test_no_tasks(self):
|
|
out = self.agent.compute(_inp())
|
|
assert "no tasks" in out.prompt_text.lower()
|
|
|
|
def test_single_project(self):
|
|
tasks = [_task(f"T{i}", project_id="Work") for i in range(3)]
|
|
out = self.agent.compute(_inp(tasks=tasks))
|
|
assert '"Work"' in out.prompt_text
|
|
assert "3 tasks" in out.prompt_text
|
|
|
|
def test_most_congested_wins(self):
|
|
tasks = (
|
|
[_task(f"W{i}", project_id="Work") for i in range(5)]
|
|
+ [_task(f"H{i}", project_id="Home") for i in range(2)]
|
|
)
|
|
out = self.agent.compute(_inp(tasks=tasks))
|
|
assert '"Work"' in out.prompt_text
|
|
|
|
def test_overdue_weighting(self):
|
|
# Home has 2 tasks (1 overdue), Work has 3 non-overdue tasks
|
|
# Home score = 2+1 = 3; Work score = 3 — Home should win due to overdue weight
|
|
tasks = (
|
|
[_task("Home1", project_id="Home", is_overdue=True),
|
|
_task("Home2", project_id="Home")]
|
|
+ [_task(f"W{i}", project_id="Work") for i in range(3)]
|
|
)
|
|
out = self.agent.compute(_inp(tasks=tasks))
|
|
assert '"Work"' not in out.prompt_text or '"Home"' in out.prompt_text
|
|
|
|
def test_default_project_fallback(self):
|
|
out = self.agent.compute(_inp(tasks=[_task("No project task")]))
|
|
# Tasks without project_id fall back to a "Tasks" bucket
|
|
assert "Tasks" in out.prompt_text
|
|
|
|
def test_snapshot_keys(self):
|
|
out = self.agent.compute(_inp(tasks=[_task("T1", project_id="A")]))
|
|
public_keys = {k for k in out.signals_snapshot if not k.startswith("_")}
|
|
assert {"top_cluster_label", "top_task_count", "top_overdue_count", "cluster_count",
|
|
"strategy", "preferred_areas"} == public_keys
|
|
|
|
|
|
# ── Registry ─────────────────────────────────────────────────────────────────
|
|
|
|
class TestRegistry:
|
|
def test_all_agents_present(self):
|
|
agents = all_agents()
|
|
ids = {a.agent_id for a in agents}
|
|
assert ids == {"overdue-task", "momentum", "time-of-day", "recent-patterns", "focus-area", "health-vitals"}
|
|
|
|
def test_get_agent(self):
|
|
a = get_agent("momentum")
|
|
assert a.agent_id == "momentum"
|
|
|
|
def test_get_unknown_raises(self):
|
|
with pytest.raises(KeyError, match="Unknown agent"):
|
|
get_agent("nonexistent")
|
|
|
|
def test_all_agents_compute(self):
|
|
inp = _inp(
|
|
tasks=[_task("Buy milk", is_overdue=True, task_age_days=2, project_id="Personal")],
|
|
profile={"completion_rate_30d": 0.4, "tip_volume_30d": 10.0, "preferred_hour": 9.0},
|
|
feedback_history=[
|
|
{"action": "done", "dwell_ms": 25000, "created_at": _NOW.isoformat()}
|
|
],
|
|
)
|
|
for agent in all_agents():
|
|
out = agent.compute(inp)
|
|
_check_output(out, agent)
|