feat(ml): multi-agent context framework + v4 orchestrator prompt
Adds ml/agents/ — five specialised sub-agents (overdue_task, momentum, time_of_day, recent_patterns, focus_area) each producing a prompt snippet from user signals. A registry wires them up; the orchestrator prompt in ml/serving/prompts.py synthesises their outputs into one tip via LiteLLM. Also wires /api/agents route in the API and updates the Dockerfile to copy the full ml/ tree with PYTHONPATH=/app so agent imports resolve correctly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
WORKDIR /app
|
WORKDIR /app/ml/serving
|
||||||
COPY ml/serving/requirements.txt .
|
COPY ml/serving/requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
COPY ml/serving/*.py .
|
COPY ml/ /app/ml/
|
||||||
|
# PYTHONPATH=/app lets 'import ml.agents.*' resolve from /app/ml/agents/
|
||||||
|
ENV PYTHONPATH=/app
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|||||||
0
ml/__init__.py
Normal file
0
ml/__init__.py
Normal file
4
ml/agents/__init__.py
Normal file
4
ml/agents/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
from .registry import get_agent, all_agents
|
||||||
|
|
||||||
|
__all__ = ["BaseAgent", "AgentInput", "AgentOutput", "get_agent", "all_agents"]
|
||||||
53
ml/agents/base.py
Normal file
53
ml/agents/base.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Base class and shared data structures for all recommendation sub-agents."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentInput:
|
||||||
|
"""Everything an agent may need to produce its prompt snippet."""
|
||||||
|
user_id: str
|
||||||
|
tasks: list[dict] # task signal dicts (content, priority, is_overdue, …)
|
||||||
|
profile: dict[str, float | None] # profile feature values keyed by feature name
|
||||||
|
feedback_history: list[dict] = field(default_factory=list) # [{action, dwell_ms, created_at}, …]
|
||||||
|
now: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentOutput:
|
||||||
|
"""Result produced by an agent; persisted to agent_outputs table."""
|
||||||
|
user_id: str
|
||||||
|
agent_id: str
|
||||||
|
prompt_text: str # snippet passed to the orchestrator
|
||||||
|
signals_snapshot: dict # inputs consumed (for explainability / debugging)
|
||||||
|
computed_at: str # ISO 8601
|
||||||
|
expires_at: str # ISO 8601
|
||||||
|
agent_version: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
agent_id: ClassVar[str]
|
||||||
|
ttl_seconds: ClassVar[int]
|
||||||
|
version: ClassVar[str]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
"""Analyse inp and return a prompt snippet describing what was found."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def _make_output(self, inp: AgentInput, prompt_text: str, snapshot: dict) -> AgentOutput:
|
||||||
|
computed_at = inp.now.astimezone(timezone.utc).isoformat()
|
||||||
|
expires_at = (inp.now.astimezone(timezone.utc) + timedelta(seconds=self.ttl_seconds)).isoformat()
|
||||||
|
return AgentOutput(
|
||||||
|
user_id=inp.user_id,
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
signals_snapshot=snapshot,
|
||||||
|
computed_at=computed_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
agent_version=self.version,
|
||||||
|
)
|
||||||
46
ml/agents/focus_area.py
Normal file
46
ml/agents/focus_area.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import ClassVar
|
||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
|
||||||
|
|
||||||
|
class FocusAreaAgent(BaseAgent):
|
||||||
|
"""Identifies the most congested project/area in the user's task list."""
|
||||||
|
agent_id: ClassVar[str] = "focus-area"
|
||||||
|
ttl_seconds: ClassVar[int] = 43_200 # 12h
|
||||||
|
version: ClassVar[str] = "1.0.0"
|
||||||
|
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
by_project: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
for task in inp.tasks:
|
||||||
|
project = task.get("project_id") or task.get("project") or "default"
|
||||||
|
by_project[project].append(task)
|
||||||
|
|
||||||
|
if not by_project:
|
||||||
|
prompt = "No tasks available to identify a focus area."
|
||||||
|
return self._make_output(inp, prompt, {"project_count": 0})
|
||||||
|
|
||||||
|
# Score each project: overdue tasks count double
|
||||||
|
def score(tasks: list[dict]) -> float:
|
||||||
|
return sum(2.0 if t.get("is_overdue") else 1.0 for t in tasks)
|
||||||
|
|
||||||
|
top_project, top_tasks = max(by_project.items(), key=lambda kv: score(kv[1]))
|
||||||
|
overdue_in_top = sum(1 for t in top_tasks if t.get("is_overdue"))
|
||||||
|
label = "the default project" if top_project == "default" else f'"{top_project}"'
|
||||||
|
n = len(top_tasks)
|
||||||
|
|
||||||
|
parts = [
|
||||||
|
f"The user's most congested area is {label} "
|
||||||
|
f"({n} task{'s' if n != 1 else ''}, {overdue_in_top} overdue)."
|
||||||
|
]
|
||||||
|
if overdue_in_top >= 3:
|
||||||
|
parts.append("Consider surfacing an action from this area.")
|
||||||
|
|
||||||
|
prompt = " ".join(parts)
|
||||||
|
snapshot = {
|
||||||
|
"top_project": top_project,
|
||||||
|
"top_task_count": n,
|
||||||
|
"top_overdue_count": overdue_in_top,
|
||||||
|
"project_count": len(by_project),
|
||||||
|
}
|
||||||
|
return self._make_output(inp, prompt, snapshot)
|
||||||
49
ml/agents/momentum.py
Normal file
49
ml/agents/momentum.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import ClassVar
|
||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumAgent(BaseAgent):
|
||||||
|
"""Characterises the user's recent engagement trend from profile features."""
|
||||||
|
agent_id: ClassVar[str] = "momentum"
|
||||||
|
ttl_seconds: ClassVar[int] = 21600 # 6h
|
||||||
|
version: ClassVar[str] = "1.0.0"
|
||||||
|
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
completion = inp.profile.get("completion_rate_30d")
|
||||||
|
dismiss = inp.profile.get("dismiss_rate_30d")
|
||||||
|
volume = inp.profile.get("tip_volume_30d")
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
if completion is not None:
|
||||||
|
pct = round(completion * 100)
|
||||||
|
if pct >= 50:
|
||||||
|
parts.append(f"The user completes {pct}% of tips (strong engagement).")
|
||||||
|
elif pct >= 25:
|
||||||
|
parts.append(f"The user completes {pct}% of tips (moderate engagement).")
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
f"The user completes {pct}% of tips "
|
||||||
|
f"(low engagement — prefer simple, immediately actionable tips)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append("No completion-rate data yet (new user).")
|
||||||
|
|
||||||
|
if dismiss is not None:
|
||||||
|
dpct = round(dismiss * 100)
|
||||||
|
if dpct >= 40:
|
||||||
|
parts.append(f"Dismiss rate is high ({dpct}%) — avoid repetitive or irrelevant tips.")
|
||||||
|
elif dpct <= 10:
|
||||||
|
parts.append(f"Dismiss rate is low ({dpct}%).")
|
||||||
|
|
||||||
|
if volume is not None and int(volume) < 5:
|
||||||
|
parts.append("Very few tips served so far — this is an early-stage user.")
|
||||||
|
|
||||||
|
prompt = " ".join(parts) if parts else "No engagement data available yet."
|
||||||
|
snapshot = {
|
||||||
|
"completion_rate_30d": completion,
|
||||||
|
"dismiss_rate_30d": dismiss,
|
||||||
|
"tip_volume_30d": volume,
|
||||||
|
}
|
||||||
|
return self._make_output(inp, prompt, snapshot)
|
||||||
42
ml/agents/overdue_task.py
Normal file
42
ml/agents/overdue_task.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import ClassVar
|
||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
|
||||||
|
|
||||||
|
class OverdueTaskAgent(BaseAgent):
|
||||||
|
"""Reports the user's overdue tasks by count and age."""
|
||||||
|
agent_id: ClassVar[str] = "overdue-task"
|
||||||
|
ttl_seconds: ClassVar[int] = 3600 # 1h — overdue status changes infrequently
|
||||||
|
version: ClassVar[str] = "1.0.0"
|
||||||
|
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
overdue = [t for t in inp.tasks if t.get("is_overdue")]
|
||||||
|
top = sorted(overdue, key=lambda t: -t.get("task_age_days", 0))[:3]
|
||||||
|
|
||||||
|
if not overdue:
|
||||||
|
prompt = "The user has no overdue tasks at this time."
|
||||||
|
elif len(overdue) == 1:
|
||||||
|
t = top[0]
|
||||||
|
age = round(t.get("task_age_days", 0))
|
||||||
|
prompt = (
|
||||||
|
f'The user has 1 overdue task: "{t["content"]}" '
|
||||||
|
f"({age} day{'s' if age != 1 else ''} overdue)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items = ", ".join(
|
||||||
|
f'"{t["content"]}" ({round(t.get("task_age_days", 0))}d)'
|
||||||
|
for t in top
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"The user has {len(overdue)} overdue tasks. "
|
||||||
|
f"Top {len(top)}: {items}."
|
||||||
|
)
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
"overdue_count": len(overdue),
|
||||||
|
"top_overdue": [
|
||||||
|
{"content": t["content"], "task_age_days": t.get("task_age_days", 0)}
|
||||||
|
for t in top
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return self._make_output(inp, prompt, snapshot)
|
||||||
68
ml/agents/recent_patterns.py
Normal file
68
ml/agents/recent_patterns.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from collections import Counter
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import ClassVar
|
||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
|
||||||
|
_SEVEN_DAYS_S = 7 * 86_400
|
||||||
|
|
||||||
|
|
||||||
|
class RecentPatternsAgent(BaseAgent):
|
||||||
|
"""Surfaces the user's reaction pattern from the last 7 days of feedback."""
|
||||||
|
agent_id: ClassVar[str] = "recent-patterns"
|
||||||
|
ttl_seconds: ClassVar[int] = 86_400 # 24h
|
||||||
|
version: ClassVar[str] = "1.0.0"
|
||||||
|
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
now_ts = inp.now.timestamp()
|
||||||
|
recent = [
|
||||||
|
f for f in inp.feedback_history
|
||||||
|
if self._age_s(f.get("created_at", ""), now_ts) <= _SEVEN_DAYS_S
|
||||||
|
]
|
||||||
|
|
||||||
|
counts: Counter[str] = Counter(f.get("action") for f in recent)
|
||||||
|
total = len(recent)
|
||||||
|
dwell_ms = inp.profile.get("mean_dwell_ms_30d")
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
prompt = "No tip reactions recorded in the last 7 days."
|
||||||
|
else:
|
||||||
|
done = counts.get("done", 0)
|
||||||
|
dismissed = counts.get("dismiss", 0)
|
||||||
|
snoozed = counts.get("snooze", 0)
|
||||||
|
parts = [
|
||||||
|
f"Last 7 days: {total} tip reaction{'s' if total != 1 else ''} — "
|
||||||
|
f"{done} completed, {dismissed} dismissed, {snoozed} snoozed."
|
||||||
|
]
|
||||||
|
if dwell_ms is not None:
|
||||||
|
dwell_s = round(dwell_ms / 1000)
|
||||||
|
if dwell_s < 15:
|
||||||
|
parts.append(
|
||||||
|
"Average dwell is very short — user may be acting on auto-pilot; vary tip content."
|
||||||
|
)
|
||||||
|
elif dwell_s < 60:
|
||||||
|
parts.append(f"Average dwell {dwell_s}s — tips are being read.")
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
f"Average dwell {dwell_s}s — user deliberates; prefer tips that reward reflection."
|
||||||
|
)
|
||||||
|
prompt = " ".join(parts)
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
"recent_total": total,
|
||||||
|
"action_counts": dict(counts),
|
||||||
|
"mean_dwell_ms_30d": dwell_ms,
|
||||||
|
}
|
||||||
|
return self._make_output(inp, prompt, snapshot)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _age_s(iso: str, now_ts: float) -> float:
|
||||||
|
if not iso:
|
||||||
|
return float("inf")
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(iso.replace("Z", "+00:00"))
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
return now_ts - dt.timestamp()
|
||||||
|
except Exception:
|
||||||
|
return float("inf")
|
||||||
28
ml/agents/registry.py
Normal file
28
ml/agents/registry.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from .base import BaseAgent
|
||||||
|
from .overdue_task import OverdueTaskAgent
|
||||||
|
from .momentum import MomentumAgent
|
||||||
|
from .time_of_day import TimeOfDayAgent
|
||||||
|
from .recent_patterns import RecentPatternsAgent
|
||||||
|
from .focus_area import FocusAreaAgent
|
||||||
|
|
||||||
|
_AGENTS: dict[str, BaseAgent] = {
|
||||||
|
a.agent_id: a
|
||||||
|
for a in [
|
||||||
|
OverdueTaskAgent(),
|
||||||
|
MomentumAgent(),
|
||||||
|
TimeOfDayAgent(),
|
||||||
|
RecentPatternsAgent(),
|
||||||
|
FocusAreaAgent(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent(agent_id: str) -> BaseAgent:
|
||||||
|
if agent_id not in _AGENTS:
|
||||||
|
raise KeyError(f"Unknown agent: {agent_id!r}. Known: {sorted(_AGENTS)}")
|
||||||
|
return _AGENTS[agent_id]
|
||||||
|
|
||||||
|
|
||||||
|
def all_agents() -> list[BaseAgent]:
|
||||||
|
return list(_AGENTS.values())
|
||||||
0
ml/agents/tests/__init__.py
Normal file
0
ml/agents/tests/__init__.py
Normal file
275
ml/agents/tests/test_agents.py
Normal file
275
ml/agents/tests/test_agents.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
"""Unit tests for all sub-agents and the registry."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ml.agents.base import AgentInput, AgentOutput
|
||||||
|
from ml.agents.overdue_task import OverdueTaskAgent
|
||||||
|
from ml.agents.momentum import MomentumAgent
|
||||||
|
from ml.agents.time_of_day import TimeOfDayAgent
|
||||||
|
from ml.agents.recent_patterns import RecentPatternsAgent
|
||||||
|
from ml.agents.focus_area import FocusAreaAgent
|
||||||
|
from ml.agents.registry import get_agent, all_agents
|
||||||
|
|
||||||
|
_NOW = datetime(2026, 5, 1, 9, 0, 0, tzinfo=timezone.utc) # Thursday 09:00 UTC
|
||||||
|
|
||||||
|
|
||||||
|
def _inp(**kwargs) -> AgentInput:
|
||||||
|
defaults = dict(
|
||||||
|
user_id="u1",
|
||||||
|
tasks=[],
|
||||||
|
profile={},
|
||||||
|
feedback_history=[],
|
||||||
|
now=_NOW,
|
||||||
|
)
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return AgentInput(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _task(content="Do thing", is_overdue=False, task_age_days=0.0, priority=1, project_id=None):
|
||||||
|
t = {"id": "t1", "content": content, "is_overdue": is_overdue,
|
||||||
|
"task_age_days": task_age_days, "priority": priority}
|
||||||
|
if project_id:
|
||||||
|
t["project_id"] = project_id
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _check_output(out: AgentOutput, agent) -> None:
|
||||||
|
assert isinstance(out, AgentOutput)
|
||||||
|
assert out.user_id == "u1"
|
||||||
|
assert out.agent_id == agent.agent_id
|
||||||
|
assert out.prompt_text
|
||||||
|
assert out.computed_at
|
||||||
|
assert out.expires_at > out.computed_at
|
||||||
|
assert out.agent_version == agent.version
|
||||||
|
|
||||||
|
|
||||||
|
# ── OverdueTaskAgent ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestOverdueTaskAgent:
|
||||||
|
agent = OverdueTaskAgent()
|
||||||
|
|
||||||
|
def test_no_overdue(self):
|
||||||
|
out = self.agent.compute(_inp(tasks=[_task("Read book")]))
|
||||||
|
_check_output(out, self.agent)
|
||||||
|
assert "no overdue" in out.prompt_text.lower()
|
||||||
|
assert out.signals_snapshot["overdue_count"] == 0
|
||||||
|
|
||||||
|
def test_single_overdue(self):
|
||||||
|
out = self.agent.compute(_inp(tasks=[_task("Call dentist", is_overdue=True, task_age_days=3)]))
|
||||||
|
_check_output(out, self.agent)
|
||||||
|
assert "1 overdue" in out.prompt_text
|
||||||
|
assert "Call dentist" in out.prompt_text
|
||||||
|
assert "3 day" in out.prompt_text
|
||||||
|
|
||||||
|
def test_multiple_overdue_top3(self):
|
||||||
|
tasks = [
|
||||||
|
_task(f"Task {i}", is_overdue=True, task_age_days=float(i))
|
||||||
|
for i in range(1, 6)
|
||||||
|
]
|
||||||
|
out = self.agent.compute(_inp(tasks=tasks))
|
||||||
|
_check_output(out, self.agent)
|
||||||
|
assert "5 overdue" in out.prompt_text
|
||||||
|
assert out.signals_snapshot["overdue_count"] == 5
|
||||||
|
assert len(out.signals_snapshot["top_overdue"]) == 3
|
||||||
|
# Top 3 should be highest age: 5, 4, 3
|
||||||
|
ages = [t["task_age_days"] for t in out.signals_snapshot["top_overdue"]]
|
||||||
|
assert ages == sorted(ages, reverse=True)
|
||||||
|
|
||||||
|
def test_ttl_respected(self):
|
||||||
|
out = self.agent.compute(_inp())
|
||||||
|
assert out.expires_at > out.computed_at
|
||||||
|
|
||||||
|
|
||||||
|
# ── MomentumAgent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestMomentumAgent:
|
||||||
|
agent = MomentumAgent()
|
||||||
|
|
||||||
|
def test_no_profile(self):
|
||||||
|
out = self.agent.compute(_inp(profile={}))
|
||||||
|
_check_output(out, self.agent)
|
||||||
|
assert "new user" in out.prompt_text.lower() or "no " in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_strong_engagement(self):
|
||||||
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.65, "dismiss_rate_30d": 0.05}))
|
||||||
|
assert "strong engagement" in out.prompt_text
|
||||||
|
|
||||||
|
def test_low_completion_warns(self):
|
||||||
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.1}))
|
||||||
|
assert "low engagement" in out.prompt_text
|
||||||
|
assert "actionable" in out.prompt_text
|
||||||
|
|
||||||
|
def test_high_dismiss_warns(self):
|
||||||
|
out = self.agent.compute(_inp(profile={"completion_rate_30d": 0.3, "dismiss_rate_30d": 0.5}))
|
||||||
|
assert "dismiss rate is high" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_early_stage_user(self):
|
||||||
|
out = self.agent.compute(_inp(profile={"tip_volume_30d": 2.0}))
|
||||||
|
assert "early-stage" in out.prompt_text
|
||||||
|
|
||||||
|
|
||||||
|
# ── TimeOfDayAgent ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestTimeOfDayAgent:
|
||||||
|
agent = TimeOfDayAgent()
|
||||||
|
|
||||||
|
def test_morning_label(self):
|
||||||
|
inp = _inp(now=datetime(2026, 5, 1, 8, 0, tzinfo=timezone.utc)) # Friday
|
||||||
|
out = self.agent.compute(inp)
|
||||||
|
assert "morning" in out.prompt_text
|
||||||
|
assert "08:00" in out.prompt_text
|
||||||
|
|
||||||
|
def test_weekend_note(self):
|
||||||
|
inp = _inp(now=datetime(2026, 5, 2, 10, 0, tzinfo=timezone.utc)) # Saturday
|
||||||
|
out = self.agent.compute(inp)
|
||||||
|
assert "weekend" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_peak_hour_exact(self):
|
||||||
|
inp = _inp(
|
||||||
|
now=datetime(2026, 5, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
profile={"preferred_hour": 10.0},
|
||||||
|
)
|
||||||
|
out = self.agent.compute(inp)
|
||||||
|
assert "peak productivity hour" in out.prompt_text
|
||||||
|
|
||||||
|
def test_approaching_peak(self):
|
||||||
|
inp = _inp(
|
||||||
|
now=datetime(2026, 5, 1, 9, 0, tzinfo=timezone.utc),
|
||||||
|
profile={"preferred_hour": 10.0},
|
||||||
|
)
|
||||||
|
out = self.agent.compute(inp)
|
||||||
|
assert "approaching" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_no_preferred_hour(self):
|
||||||
|
out = self.agent.compute(_inp())
|
||||||
|
assert "no preferred-hour" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_snapshot_keys(self):
|
||||||
|
out = self.agent.compute(_inp())
|
||||||
|
assert {"hour", "day_of_week", "preferred_hour"} == set(out.signals_snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
# ── RecentPatternsAgent ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRecentPatternsAgent:
|
||||||
|
agent = RecentPatternsAgent()
|
||||||
|
|
||||||
|
def test_no_feedback(self):
|
||||||
|
out = self.agent.compute(_inp())
|
||||||
|
assert "no tip reactions" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_recent_feedback_summary(self):
|
||||||
|
now_iso = _NOW.isoformat()
|
||||||
|
feedback = [
|
||||||
|
{"action": "done", "dwell_ms": 30000, "created_at": now_iso},
|
||||||
|
{"action": "done", "dwell_ms": 45000, "created_at": now_iso},
|
||||||
|
{"action": "dismiss", "dwell_ms": 2000, "created_at": now_iso},
|
||||||
|
]
|
||||||
|
out = self.agent.compute(_inp(feedback_history=feedback))
|
||||||
|
assert "3 tip reactions" in out.prompt_text
|
||||||
|
assert "2 completed" in out.prompt_text
|
||||||
|
assert "1 dismissed" in out.prompt_text
|
||||||
|
|
||||||
|
def test_old_feedback_excluded(self):
|
||||||
|
# 10 days ago — should be excluded from 7-day window
|
||||||
|
old_iso = "2026-04-21T09:00:00+00:00"
|
||||||
|
feedback = [{"action": "done", "dwell_ms": 5000, "created_at": old_iso}]
|
||||||
|
out = self.agent.compute(_inp(feedback_history=feedback))
|
||||||
|
assert "no tip reactions" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_short_dwell_note(self):
|
||||||
|
now_iso = _NOW.isoformat()
|
||||||
|
feedback = [{"action": "done", "dwell_ms": 5000, "created_at": now_iso}]
|
||||||
|
out = self.agent.compute(_inp(
|
||||||
|
feedback_history=feedback,
|
||||||
|
profile={"mean_dwell_ms_30d": 5000.0},
|
||||||
|
))
|
||||||
|
assert "auto-pilot" in out.prompt_text.lower() or "short" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_long_dwell_note(self):
|
||||||
|
now_iso = _NOW.isoformat()
|
||||||
|
feedback = [{"action": "done", "dwell_ms": 90000, "created_at": now_iso}]
|
||||||
|
out = self.agent.compute(_inp(
|
||||||
|
feedback_history=feedback,
|
||||||
|
profile={"mean_dwell_ms_30d": 90000.0},
|
||||||
|
))
|
||||||
|
assert "deliberate" in out.prompt_text.lower() or "reflection" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── FocusAreaAgent ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestFocusAreaAgent:
|
||||||
|
agent = FocusAreaAgent()
|
||||||
|
|
||||||
|
def test_no_tasks(self):
|
||||||
|
out = self.agent.compute(_inp())
|
||||||
|
assert "no tasks" in out.prompt_text.lower()
|
||||||
|
|
||||||
|
def test_single_project(self):
|
||||||
|
tasks = [_task(f"T{i}", project_id="Work") for i in range(3)]
|
||||||
|
out = self.agent.compute(_inp(tasks=tasks))
|
||||||
|
assert '"Work"' in out.prompt_text
|
||||||
|
assert "3 tasks" in out.prompt_text
|
||||||
|
|
||||||
|
def test_most_congested_wins(self):
|
||||||
|
tasks = (
|
||||||
|
[_task(f"W{i}", project_id="Work") for i in range(5)]
|
||||||
|
+ [_task(f"H{i}", project_id="Home") for i in range(2)]
|
||||||
|
)
|
||||||
|
out = self.agent.compute(_inp(tasks=tasks))
|
||||||
|
assert '"Work"' in out.prompt_text
|
||||||
|
|
||||||
|
def test_overdue_weighting(self):
|
||||||
|
# Home has 2 tasks (1 overdue), Work has 3 non-overdue tasks
|
||||||
|
# Home score = 2+1 = 3; Work score = 3 — Home should win due to overdue weight
|
||||||
|
tasks = (
|
||||||
|
[_task("Home1", project_id="Home", is_overdue=True),
|
||||||
|
_task("Home2", project_id="Home")]
|
||||||
|
+ [_task(f"W{i}", project_id="Work") for i in range(3)]
|
||||||
|
)
|
||||||
|
out = self.agent.compute(_inp(tasks=tasks))
|
||||||
|
assert '"Work"' not in out.prompt_text or '"Home"' in out.prompt_text
|
||||||
|
|
||||||
|
def test_default_project_fallback(self):
|
||||||
|
out = self.agent.compute(_inp(tasks=[_task("No project task")]))
|
||||||
|
assert "default project" in out.prompt_text
|
||||||
|
|
||||||
|
def test_snapshot_keys(self):
|
||||||
|
out = self.agent.compute(_inp(tasks=[_task("T1", project_id="A")]))
|
||||||
|
assert {"top_project", "top_task_count", "top_overdue_count", "project_count"} == set(out.signals_snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registry ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRegistry:
|
||||||
|
def test_all_agents_present(self):
|
||||||
|
agents = all_agents()
|
||||||
|
ids = {a.agent_id for a in agents}
|
||||||
|
assert ids == {"overdue-task", "momentum", "time-of-day", "recent-patterns", "focus-area"}
|
||||||
|
|
||||||
|
def test_get_agent(self):
|
||||||
|
a = get_agent("momentum")
|
||||||
|
assert a.agent_id == "momentum"
|
||||||
|
|
||||||
|
def test_get_unknown_raises(self):
|
||||||
|
with pytest.raises(KeyError, match="Unknown agent"):
|
||||||
|
get_agent("nonexistent")
|
||||||
|
|
||||||
|
def test_all_agents_compute(self):
|
||||||
|
inp = _inp(
|
||||||
|
tasks=[_task("Buy milk", is_overdue=True, task_age_days=2, project_id="Personal")],
|
||||||
|
profile={"completion_rate_30d": 0.4, "tip_volume_30d": 10.0, "preferred_hour": 9.0},
|
||||||
|
feedback_history=[
|
||||||
|
{"action": "done", "dwell_ms": 25000, "created_at": _NOW.isoformat()}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for agent in all_agents():
|
||||||
|
out = agent.compute(inp)
|
||||||
|
_check_output(out, agent)
|
||||||
50
ml/agents/time_of_day.py
Normal file
50
ml/agents/time_of_day.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import ClassVar
|
||||||
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
|
|
||||||
|
_DOW_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
|
||||||
|
|
||||||
|
|
||||||
|
class TimeOfDayAgent(BaseAgent):
|
||||||
|
"""Frames the current moment relative to the user's productive peak."""
|
||||||
|
agent_id: ClassVar[str] = "time-of-day"
|
||||||
|
ttl_seconds: ClassVar[int] = 900 # 15m — must stay current-hour accurate
|
||||||
|
version: ClassVar[str] = "1.0.0"
|
||||||
|
|
||||||
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
|
hour = inp.now.hour
|
||||||
|
dow = inp.now.weekday() # 0=Monday … 6=Sunday
|
||||||
|
preferred = inp.profile.get("preferred_hour")
|
||||||
|
is_weekend = dow >= 5
|
||||||
|
|
||||||
|
parts = [f"It is {hour:02d}:00 on {_DOW_NAMES[dow]} ({self._label(hour)})."]
|
||||||
|
|
||||||
|
if is_weekend:
|
||||||
|
parts.append("Weekend context — prefer personal or reflective tips over work tasks.")
|
||||||
|
|
||||||
|
if preferred is not None:
|
||||||
|
ph = int(preferred)
|
||||||
|
delta = min(abs(hour - ph), 24 - abs(hour - ph)) # circular distance
|
||||||
|
if delta == 0:
|
||||||
|
parts.append(
|
||||||
|
f"This is the user's peak productivity hour ({ph:02d}:00) — "
|
||||||
|
f"a high-impact tip is appropriate."
|
||||||
|
)
|
||||||
|
elif delta <= 2:
|
||||||
|
parts.append(f"Approaching the user's peak productivity window ({ph:02d}:00).")
|
||||||
|
else:
|
||||||
|
parts.append("No preferred-hour data yet.")
|
||||||
|
|
||||||
|
prompt = " ".join(parts)
|
||||||
|
snapshot = {"hour": hour, "day_of_week": dow, "preferred_hour": preferred}
|
||||||
|
return self._make_output(inp, prompt, snapshot)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _label(hour: int) -> str:
|
||||||
|
if 5 <= hour < 12:
|
||||||
|
return "morning"
|
||||||
|
if 12 <= hour < 17:
|
||||||
|
return "afternoon"
|
||||||
|
if 17 <= hour < 21:
|
||||||
|
return "evening"
|
||||||
|
return "night"
|
||||||
@@ -108,6 +108,52 @@ PROMPTS: dict[str, Prompt] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4-orchestrator ────────────────────────────────────────────────────────
|
||||||
|
# Not a Prompt entry — takes pre-computed agent snippets, not a _Ctx.
|
||||||
|
|
||||||
|
_SYS_V4_ORCHESTRATOR = (
|
||||||
|
"You are a personal advisor generating a single, perfectly-timed tip. "
|
||||||
|
"Multiple specialized agents have analyzed the user's current context and provided "
|
||||||
|
"their insights below. Synthesize their combined perspective to generate exactly ONE "
|
||||||
|
"tip that is specific, actionable, and relevant right now. "
|
||||||
|
"Respond ONLY with a JSON object with keys: "
|
||||||
|
'"id" (short slug), "content" (the tip, ≤2 sentences), '
|
||||||
|
'"rationale" (why now, ≤1 sentence). '
|
||||||
|
"No markdown, no prose outside the JSON object."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_orchestrator_messages(
|
||||||
|
agent_outputs: list[dict],
|
||||||
|
tasks: list[dict],
|
||||||
|
hour_of_day: int,
|
||||||
|
day_of_week: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build the [system, user] message list for the orchestrator LLM call.
|
||||||
|
|
||||||
|
agent_outputs: list of {agent_id, prompt_text} dicts.
|
||||||
|
Falls back to raw task summary when agent_outputs is empty.
|
||||||
|
"""
|
||||||
|
lines = [f"Current time: {hour_of_day:02d}:00, day_of_week={day_of_week}", ""]
|
||||||
|
if agent_outputs:
|
||||||
|
lines.append("Context from analysis agents:")
|
||||||
|
for s in agent_outputs:
|
||||||
|
lines.append(f"[{s['agent_id']}] {s['prompt_text']}")
|
||||||
|
else:
|
||||||
|
overdue = [t for t in tasks if t.get("is_overdue")]
|
||||||
|
lines.append(
|
||||||
|
f"No pre-computed agent context available. "
|
||||||
|
f"Tasks: {len(tasks)} total, {len(overdue)} overdue."
|
||||||
|
)
|
||||||
|
for t in tasks[:3]:
|
||||||
|
lines.append(f" - {t.get('content', '?')}")
|
||||||
|
lines.append("\nGenerate one tip as a JSON object.")
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": _SYS_V4_ORCHESTRATOR},
|
||||||
|
{"role": "user", "content": "\n".join(lines)},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def default_version() -> str:
|
def default_version() -> str:
|
||||||
return os.getenv("DEFAULT_PROMPT_VERSION", "v1")
|
return os.getenv("DEFAULT_PROMPT_VERSION", "v1")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import { userRouter } from './routes/user.js';
|
|||||||
import { pushRouter } from './routes/push.js';
|
import { pushRouter } from './routes/push.js';
|
||||||
import { adminRouter, adminInternalRouter } from './routes/admin.js';
|
import { adminRouter, adminInternalRouter } from './routes/admin.js';
|
||||||
import benchRouter from './routes/bench.js';
|
import benchRouter from './routes/bench.js';
|
||||||
|
import agentOutputsRouter from './routes/agent-outputs.js';
|
||||||
import { mkdir } from 'fs/promises';
|
import { mkdir } from 'fs/promises';
|
||||||
import { dirname } from 'path';
|
import { dirname } from 'path';
|
||||||
import { requireAuth } from './middleware/session.js';
|
import { requireAuth } from './middleware/session.js';
|
||||||
@@ -68,6 +69,7 @@ app.use('/api/push', pushRouter);
|
|||||||
app.use('/api/admin', adminRouter);
|
app.use('/api/admin', adminRouter);
|
||||||
app.use('/api/admin', adminInternalRouter);
|
app.use('/api/admin', adminInternalRouter);
|
||||||
app.use('/api/bench', requireAuth as any, requireAdmin as any, benchRouter);
|
app.use('/api/bench', requireAuth as any, requireAdmin as any, benchRouter);
|
||||||
|
app.use('/api/agents', agentOutputsRouter);
|
||||||
|
|
||||||
app.use('/api/ml', requireAuth as any, requireAdmin as any, async (req: Request, res: Response) => {
|
app.use('/api/ml', requireAuth as any, requireAdmin as any, async (req: Request, res: Response) => {
|
||||||
const mlUrl = config.ML_SERVING_URL;
|
const mlUrl = config.ML_SERVING_URL;
|
||||||
|
|||||||
Reference in New Issue
Block a user