feat(agents): p50-lateness tolerance + per-project realness for overdue-task (#115)
Replaces snooze-rate heuristic with p50 of actual task lateness (completedAt − dueAt).
Adds project_realness inference: projects with chronic lateness get realness < 1 and
the agent softens its snippet language from "overdue" to "past target date".
- TaskCompletion added to UserHistory with lateness_days computed property
- _infer_lateness_tolerance: p50 of task_completions, clipped at 0, float
- _infer_project_realness: per-project median lateness normalised by global median
- Both InferredParams use 7d TTL; cold_start = 0.0 / {}
- AgentInferRequest accepts task_completions; endpoint wires them through
- 12 new tests covering punctual/chronic/mixed users and language softening
- Agent bumped to v1.2.0
Closes #115
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,6 @@ Each agent's manifest declares InferredParams; this package owns the
|
|||||||
scheduling contract, history data model, and write path to user_preferences.
|
scheduling contract, history data model, and write path to user_preferences.
|
||||||
"""
|
"""
|
||||||
from .framework import run_inference
|
from .framework import run_inference
|
||||||
from .history import FeedbackEvent, UserHistory
|
from .history import FeedbackEvent, TaskCompletion, UserHistory
|
||||||
|
|
||||||
__all__ = ["run_inference", "FeedbackEvent", "UserHistory"]
|
__all__ = ["run_inference", "FeedbackEvent", "TaskCompletion", "UserHistory"]
|
||||||
|
|||||||
@@ -23,7 +23,27 @@ class FeedbackEvent:
|
|||||||
return dt.hour
|
return dt.hour
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskCompletion:
|
||||||
|
"""A completed task that had a due date — used for lateness inference."""
|
||||||
|
project_id: str | None
|
||||||
|
completed_at: str # ISO 8601
|
||||||
|
due_at: str # ISO 8601
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lateness_days(self) -> float:
|
||||||
|
"""Days between due_at and completed_at. Negative = completed early."""
|
||||||
|
try:
|
||||||
|
def _parse(s: str) -> datetime:
|
||||||
|
dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
|
||||||
|
return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
|
||||||
|
return (_parse(self.completed_at) - _parse(self.due_at)).total_seconds() / 86_400
|
||||||
|
except ValueError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserHistory:
|
class UserHistory:
|
||||||
user_id: str
|
user_id: str
|
||||||
events: list[FeedbackEvent] = field(default_factory=list)
|
events: list[FeedbackEvent] = field(default_factory=list)
|
||||||
|
task_completions: list[TaskCompletion] = field(default_factory=list)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import statistics
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
from .base import BaseAgent, AgentInput, AgentOutput
|
from .base import BaseAgent, AgentInput, AgentOutput
|
||||||
@@ -7,36 +8,64 @@ from .inference.history import UserHistory
|
|||||||
from .manifest import AgentManifest, InferredParam
|
from .manifest import AgentManifest, InferredParam
|
||||||
|
|
||||||
|
|
||||||
def _infer_lateness_tolerance(history: UserHistory) -> int:
|
def _infer_lateness_tolerance(history: UserHistory) -> float:
|
||||||
"""Estimate how many days past due a task needs to be before the user acts.
|
"""p50 lateness (days) across completed tasks that had a due date, clipped at 0.
|
||||||
|
|
||||||
High snooze rate → user doesn't act immediately → raise tolerance so the
|
Negative lateness (finished early) pulls the percentile down; we clip at 0
|
||||||
agent doesn't nag them about tasks they'll handle in their own time.
|
so punctual users always get tolerance=0, never a negative offset.
|
||||||
"""
|
"""
|
||||||
total = len(history.events)
|
lateness = [c.lateness_days for c in history.task_completions]
|
||||||
if total == 0:
|
if not lateness:
|
||||||
return 0
|
return 0.0
|
||||||
snooze_rate = sum(1 for e in history.events if e.action == "snooze") / total
|
return max(0.0, statistics.median(lateness))
|
||||||
if snooze_rate > 0.40:
|
|
||||||
return 2
|
|
||||||
if snooze_rate > 0.20:
|
def _infer_project_realness(history: UserHistory) -> dict[str, float]:
|
||||||
return 1
|
"""Per-project realness: 1 − (median project lateness / global median lateness).
|
||||||
return 0
|
|
||||||
|
Projects whose tasks are consistently completed on time get realness ≈ 1.
|
||||||
|
Aspirational projects (chronic lateness) get realness closer to 0.
|
||||||
|
"""
|
||||||
|
completions = [c for c in history.task_completions if c.project_id]
|
||||||
|
if not completions:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
global_median = statistics.median(c.lateness_days for c in completions)
|
||||||
|
if global_median <= 0:
|
||||||
|
# Everyone finishes early — no project is less real than another.
|
||||||
|
return {pid: 1.0 for pid in {c.project_id for c in completions}} # type: ignore[misc]
|
||||||
|
|
||||||
|
by_project: dict[str, list[float]] = {}
|
||||||
|
for c in completions:
|
||||||
|
by_project.setdefault(c.project_id, []).append(c.lateness_days) # type: ignore[index]
|
||||||
|
|
||||||
|
result: dict[str, float] = {}
|
||||||
|
for pid, days in by_project.items():
|
||||||
|
project_median = statistics.median(days)
|
||||||
|
realness = 1.0 - (project_median / global_median)
|
||||||
|
result[pid] = round(max(0.0, min(1.0, realness)), 3)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
MANIFEST = AgentManifest(
|
MANIFEST = AgentManifest(
|
||||||
id="overdue-task",
|
id="overdue-task",
|
||||||
version="1.1.0", # bumped: lateness_tolerance_days InferredParam added (#115)
|
version="1.2.0", # #115: p50-lateness tolerance + per-project realness
|
||||||
description="Reports the user's overdue tasks by count and age.",
|
description="Reports the user's overdue tasks by count and age.",
|
||||||
pref_schema={
|
pref_schema={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
"properties": {
|
"properties": {
|
||||||
"lateness_tolerance_days": {
|
"lateness_tolerance_days": {
|
||||||
"type": "integer",
|
"type": "number",
|
||||||
"minimum": 0,
|
"minimum": 0,
|
||||||
"default": 0,
|
"default": 0,
|
||||||
"description": "Days past due before a task is considered overdue. 0 = the moment it's late.",
|
"description": "Days past due before a task is flagged. p50 of historical lateness.",
|
||||||
|
},
|
||||||
|
"project_realness": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {"type": "number", "minimum": 0, "maximum": 1},
|
||||||
|
"default": {},
|
||||||
|
"description": "Per-project realness score [0,1]. Low = aspirational due dates.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -48,15 +77,40 @@ MANIFEST = AgentManifest(
|
|||||||
inferred_params=[
|
inferred_params=[
|
||||||
InferredParam(
|
InferredParam(
|
||||||
key="lateness_tolerance_days",
|
key="lateness_tolerance_days",
|
||||||
ttl_sec=86_400, # recompute daily — snooze pattern shifts slowly
|
ttl_sec=7 * 86_400, # recompute weekly — lateness habits shift slowly
|
||||||
cold_start_default=0,
|
cold_start_default=0.0,
|
||||||
min_history=10,
|
min_history=10,
|
||||||
infer=_infer_lateness_tolerance,
|
infer=_infer_lateness_tolerance,
|
||||||
),
|
),
|
||||||
|
InferredParam(
|
||||||
|
key="project_realness",
|
||||||
|
ttl_sec=7 * 86_400,
|
||||||
|
cold_start_default={},
|
||||||
|
min_history=10,
|
||||||
|
infer=_infer_project_realness,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _realness(project_id: str | None, project_realness: dict[str, float]) -> float:
|
||||||
|
"""Return realness for a project, defaulting to 1.0 (treat as real)."""
|
||||||
|
if not project_id or not project_realness:
|
||||||
|
return 1.0
|
||||||
|
return project_realness.get(project_id, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_task(task: dict, project_realness: dict[str, float]) -> str:
|
||||||
|
content = task["content"]
|
||||||
|
age = round(task.get("task_age_days", 0))
|
||||||
|
pid = task.get("project_id")
|
||||||
|
r = _realness(pid, project_realness)
|
||||||
|
unit = "day" if age == 1 else "days"
|
||||||
|
if r < 0.4:
|
||||||
|
return f'"{content}" ({age} {unit} past target date)'
|
||||||
|
return f'"{content}" ({age} {unit} overdue)'
|
||||||
|
|
||||||
|
|
||||||
class OverdueTaskAgent(BaseAgent):
|
class OverdueTaskAgent(BaseAgent):
|
||||||
"""Reports the user's overdue tasks by count and age."""
|
"""Reports the user's overdue tasks by count and age."""
|
||||||
agent_id: ClassVar[str] = MANIFEST.id
|
agent_id: ClassVar[str] = MANIFEST.id
|
||||||
@@ -64,7 +118,9 @@ class OverdueTaskAgent(BaseAgent):
|
|||||||
version: ClassVar[str] = MANIFEST.version
|
version: ClassVar[str] = MANIFEST.version
|
||||||
|
|
||||||
def compute(self, inp: AgentInput) -> AgentOutput:
|
def compute(self, inp: AgentInput) -> AgentOutput:
|
||||||
tolerance = max(0, int(inp.agent_prefs.get("lateness_tolerance_days", 0)))
|
tolerance = max(0.0, float(inp.agent_prefs.get("lateness_tolerance_days", 0)))
|
||||||
|
project_realness: dict[str, float] = inp.agent_prefs.get("project_realness", {})
|
||||||
|
|
||||||
overdue = [
|
overdue = [
|
||||||
t for t in inp.tasks
|
t for t in inp.tasks
|
||||||
if t.get("is_overdue") and t.get("task_age_days", 0) >= tolerance
|
if t.get("is_overdue") and t.get("task_age_days", 0) >= tolerance
|
||||||
@@ -75,18 +131,21 @@ class OverdueTaskAgent(BaseAgent):
|
|||||||
prompt = "The user has no overdue tasks at this time."
|
prompt = "The user has no overdue tasks at this time."
|
||||||
elif len(overdue) == 1:
|
elif len(overdue) == 1:
|
||||||
t = top[0]
|
t = top[0]
|
||||||
age = round(t.get("task_age_days", 0))
|
r = _realness(t.get("project_id"), project_realness)
|
||||||
prompt = (
|
item = _format_task(t, project_realness)
|
||||||
f'The user has 1 overdue task: "{t["content"]}" '
|
if r < 0.4:
|
||||||
f"({age} day{'s' if age != 1 else ''} overdue)."
|
prompt = f"The user has 1 task past its target date: {item}."
|
||||||
)
|
else:
|
||||||
|
prompt = f"The user has 1 overdue task: {item}."
|
||||||
else:
|
else:
|
||||||
items = ", ".join(
|
items = ", ".join(_format_task(t, project_realness) for t in top)
|
||||||
f'"{t["content"]}" ({round(t.get("task_age_days", 0))}d)'
|
avg_realness = (
|
||||||
for t in top
|
sum(_realness(t.get("project_id"), project_realness) for t in overdue)
|
||||||
|
/ len(overdue)
|
||||||
)
|
)
|
||||||
|
label = "tasks past their target dates" if avg_realness < 0.4 else "overdue tasks"
|
||||||
prompt = (
|
prompt = (
|
||||||
f"The user has {len(overdue)} overdue tasks. "
|
f"The user has {len(overdue)} {label}. "
|
||||||
f"Top {len(top)}: {items}."
|
f"Top {len(top)}: {items}."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -94,7 +153,12 @@ class OverdueTaskAgent(BaseAgent):
|
|||||||
"overdue_count": len(overdue),
|
"overdue_count": len(overdue),
|
||||||
"lateness_tolerance_days": tolerance,
|
"lateness_tolerance_days": tolerance,
|
||||||
"top_overdue": [
|
"top_overdue": [
|
||||||
{"content": t["content"], "task_age_days": t.get("task_age_days", 0)}
|
{
|
||||||
|
"content": t["content"],
|
||||||
|
"task_age_days": t.get("task_age_days", 0),
|
||||||
|
"project_id": t.get("project_id"),
|
||||||
|
"realness": _realness(t.get("project_id"), project_realness),
|
||||||
|
}
|
||||||
for t in top
|
for t in top
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ml.agents.inference.history import FeedbackEvent, UserHistory
|
from ml.agents.inference.history import FeedbackEvent, TaskCompletion, UserHistory
|
||||||
from ml.agents.inference.framework import run_inference
|
from ml.agents.inference.framework import run_inference
|
||||||
from ml.agents.momentum import MomentumAgent, MANIFEST as MOMENTUM_MANIFEST
|
from ml.agents.momentum import MomentumAgent, MANIFEST as MOMENTUM_MANIFEST
|
||||||
from ml.agents.overdue_task import OverdueTaskAgent, MANIFEST as OVERDUE_MANIFEST
|
from ml.agents.overdue_task import OverdueTaskAgent, MANIFEST as OVERDUE_MANIFEST
|
||||||
@@ -32,8 +32,20 @@ def _event(action: str, days_ago: float = 1.0) -> FeedbackEvent:
|
|||||||
return FeedbackEvent(action=action, dwell_ms=dwell, created_at=ts)
|
return FeedbackEvent(action=action, dwell_ms=dwell, created_at=ts)
|
||||||
|
|
||||||
|
|
||||||
def _history(*events: FeedbackEvent) -> UserHistory:
|
def _history(*events: FeedbackEvent, completions: list[TaskCompletion] | None = None) -> UserHistory:
|
||||||
return UserHistory(user_id="u1", events=list(events))
|
return UserHistory(user_id="u1", events=list(events), task_completions=completions or [])
|
||||||
|
|
||||||
|
|
||||||
|
def _completion(project_id: str | None, lateness_days: float) -> TaskCompletion:
|
||||||
|
"""Build a TaskCompletion where completed_at is lateness_days after due_at."""
|
||||||
|
from datetime import timedelta
|
||||||
|
due = _NOW - timedelta(days=30)
|
||||||
|
completed = due + timedelta(days=lateness_days)
|
||||||
|
return TaskCompletion(
|
||||||
|
project_id=project_id,
|
||||||
|
completed_at=completed.isoformat(),
|
||||||
|
due_at=due.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── momentum: engagement_trend ───────────────────────────────────────────────
|
# ── momentum: engagement_trend ───────────────────────────────────────────────
|
||||||
@@ -82,49 +94,94 @@ class TestMomentumInference:
|
|||||||
assert MOMENTUM_MANIFEST.version == "1.1.0"
|
assert MOMENTUM_MANIFEST.version == "1.1.0"
|
||||||
|
|
||||||
|
|
||||||
# ── overdue-task: lateness_tolerance_days ────────────────────────────────────
|
# ── overdue-task: lateness_tolerance_days + project_realness (#115) ──────────
|
||||||
|
|
||||||
class TestOverdueTaskInference:
|
class TestOverdueTaskInference:
|
||||||
def test_cold_start_returns_zero(self):
|
# -- lateness_tolerance_days inference --
|
||||||
history = _history(*[_event("done") for _ in range(5)])
|
|
||||||
result = run_inference(OVERDUE_MANIFEST, history)
|
|
||||||
assert result["lateness_tolerance_days"] == 0
|
|
||||||
|
|
||||||
def test_high_snooze_rate_returns_two(self):
|
def test_cold_start_returns_zero_when_few_completions(self):
|
||||||
events = [_event("snooze")] * 8 + [_event("done")] * 2
|
# Below min_history=10 task completions → cold start
|
||||||
history = _history(*events)
|
cs = [_completion("p1", 2.0) for _ in range(5)]
|
||||||
|
history = _history(*[_event("done")] * 5, completions=cs)
|
||||||
result = run_inference(OVERDUE_MANIFEST, history)
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
assert result["lateness_tolerance_days"] == 2
|
assert result["lateness_tolerance_days"] == 0.0
|
||||||
|
|
||||||
def test_moderate_snooze_returns_one(self):
|
def test_punctual_user_zero_tolerance(self):
|
||||||
events = [_event("snooze")] * 3 + [_event("done")] * 7
|
# User always finishes early or on time (negative lateness) → tolerance 0
|
||||||
history = _history(*events)
|
cs = [_completion("p1", -1.0) for _ in range(12)]
|
||||||
|
history = _history(*[_event("done")] * 12, completions=cs)
|
||||||
result = run_inference(OVERDUE_MANIFEST, history)
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
assert result["lateness_tolerance_days"] == 1
|
assert result["lateness_tolerance_days"] == 0.0
|
||||||
|
|
||||||
def test_low_snooze_returns_zero(self):
|
def test_chronic_late_user_positive_tolerance(self):
|
||||||
events = [_event("done")] * 9 + [_event("snooze")] * 1
|
# User consistently finishes 5 days late → p50 = 5
|
||||||
history = _history(*events)
|
cs = [_completion("p1", 5.0) for _ in range(12)]
|
||||||
|
history = _history(*[_event("done")] * 12, completions=cs)
|
||||||
result = run_inference(OVERDUE_MANIFEST, history)
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
assert result["lateness_tolerance_days"] == 0
|
assert result["lateness_tolerance_days"] == pytest.approx(5.0)
|
||||||
|
|
||||||
|
def test_mixed_lateness_uses_median(self):
|
||||||
|
# 6 tasks at +1d, 6 tasks at +3d → median = 2
|
||||||
|
cs = [_completion("p1", 1.0)] * 6 + [_completion("p1", 3.0)] * 6
|
||||||
|
history = _history(*[_event("done")] * 12, completions=cs)
|
||||||
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
|
assert result["lateness_tolerance_days"] == pytest.approx(2.0)
|
||||||
|
|
||||||
|
# -- project_realness inference --
|
||||||
|
|
||||||
|
def test_project_realness_cold_start_empty(self):
|
||||||
|
cs = [_completion("p1", 1.0) for _ in range(5)] # below min_history
|
||||||
|
history = _history(*[_event("done")] * 5, completions=cs)
|
||||||
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
|
assert result["project_realness"] == {}
|
||||||
|
|
||||||
|
def test_project_realness_punctual_project_scores_high(self):
|
||||||
|
# p1 always on time (0d late), p2 always 10d late → p1 should be realness ≈ 1
|
||||||
|
cs = [_completion("p1", 0.0)] * 6 + [_completion("p2", 10.0)] * 6
|
||||||
|
history = _history(*[_event("done")] * 12, completions=cs)
|
||||||
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
|
assert result["project_realness"]["p1"] > result["project_realness"]["p2"]
|
||||||
|
|
||||||
|
def test_project_realness_values_clipped_01(self):
|
||||||
|
cs = [_completion("p1", 0.0)] * 6 + [_completion("p2", 100.0)] * 6
|
||||||
|
history = _history(*[_event("done")] * 12, completions=cs)
|
||||||
|
result = run_inference(OVERDUE_MANIFEST, history)
|
||||||
|
for v in result["project_realness"].values():
|
||||||
|
assert 0.0 <= v <= 1.0
|
||||||
|
|
||||||
|
# -- compute() reads inferred prefs --
|
||||||
|
|
||||||
def test_tolerance_filters_tasks(self):
|
def test_tolerance_filters_tasks(self):
|
||||||
tasks = [
|
tasks = [
|
||||||
{"content": "Fresh overdue", "is_overdue": True, "task_age_days": 0.5},
|
{"content": "Fresh overdue", "is_overdue": True, "task_age_days": 0.5},
|
||||||
{"content": "Old overdue", "is_overdue": True, "task_age_days": 3.0},
|
{"content": "Old overdue", "is_overdue": True, "task_age_days": 3.0},
|
||||||
]
|
]
|
||||||
# tolerance=2 → only the 3-day task should count
|
|
||||||
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs={"lateness_tolerance_days": 2}))
|
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs={"lateness_tolerance_days": 2}))
|
||||||
assert "1 overdue task" in out.prompt_text
|
assert "1 overdue task" in out.prompt_text
|
||||||
assert "Old overdue" in out.prompt_text
|
assert "Old overdue" in out.prompt_text
|
||||||
|
|
||||||
def test_snapshot_includes_tolerance(self):
|
def test_low_realness_softens_language(self):
|
||||||
tasks = [{"content": "T", "is_overdue": True, "task_age_days": 1.0}]
|
tasks = [{"content": "Wishlist", "is_overdue": True, "task_age_days": 3.0,
|
||||||
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs={"lateness_tolerance_days": 0}))
|
"project_id": "aspirational"}]
|
||||||
assert "lateness_tolerance_days" in out.signals_snapshot
|
prefs = {"lateness_tolerance_days": 0, "project_realness": {"aspirational": 0.2}}
|
||||||
|
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs=prefs))
|
||||||
|
assert "target date" in out.prompt_text
|
||||||
|
|
||||||
|
def test_high_realness_uses_overdue_language(self):
|
||||||
|
tasks = [{"content": "Critical", "is_overdue": True, "task_age_days": 3.0,
|
||||||
|
"project_id": "work"}]
|
||||||
|
prefs = {"lateness_tolerance_days": 0, "project_realness": {"work": 0.9}}
|
||||||
|
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs=prefs))
|
||||||
|
assert "overdue" in out.prompt_text
|
||||||
|
|
||||||
|
def test_snapshot_includes_realness(self):
|
||||||
|
tasks = [{"content": "T", "is_overdue": True, "task_age_days": 1.0, "project_id": "p1"}]
|
||||||
|
prefs = {"lateness_tolerance_days": 0, "project_realness": {"p1": 0.8}}
|
||||||
|
out = OverdueTaskAgent().compute(_inp(tasks=tasks, agent_prefs=prefs))
|
||||||
|
assert "realness" in out.signals_snapshot["top_overdue"][0]
|
||||||
|
|
||||||
def test_version_bumped(self):
|
def test_version_bumped(self):
|
||||||
assert OVERDUE_MANIFEST.version == "1.1.0"
|
assert OVERDUE_MANIFEST.version == "1.2.0"
|
||||||
|
|
||||||
|
|
||||||
# ── recent-patterns: window_days ─────────────────────────────────────────────
|
# ── recent-patterns: window_days ─────────────────────────────────────────────
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ if _repo_root not in sys.path:
|
|||||||
|
|
||||||
from ml.agents.base import AgentInput # noqa: E402
|
from ml.agents.base import AgentInput # noqa: E402
|
||||||
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
|
from ml.agents.registry import get_agent, all_agents, all_manifests, get_manifest # noqa: E402
|
||||||
from ml.agents.inference import run_inference, FeedbackEvent, UserHistory # noqa: E402
|
from ml.agents.inference import run_inference, FeedbackEvent, TaskCompletion, UserHistory # noqa: E402
|
||||||
|
|
||||||
logging_config.configure()
|
logging_config.configure()
|
||||||
|
|
||||||
@@ -141,7 +141,8 @@ class AgentComputeResponse(BaseModel):
|
|||||||
|
|
||||||
class AgentInferRequest(BaseModel):
|
class AgentInferRequest(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
|
feedback_history: list[dict] = [] # [{action, dwell_ms, created_at}, …]
|
||||||
|
task_completions: list[dict] = [] # [{project_id, completed_at, due_at}, …]
|
||||||
|
|
||||||
|
|
||||||
class AgentInferResponse(BaseModel):
|
class AgentInferResponse(BaseModel):
|
||||||
@@ -284,7 +285,15 @@ async def infer_agent(agent_id: str, req: AgentInferRequest) -> AgentInferRespon
|
|||||||
)
|
)
|
||||||
for e in req.feedback_history
|
for e in req.feedback_history
|
||||||
]
|
]
|
||||||
history = UserHistory(user_id=req.user_id, events=events)
|
completions = [
|
||||||
|
TaskCompletion(
|
||||||
|
project_id=c.get("project_id"),
|
||||||
|
completed_at=c.get("completed_at", ""),
|
||||||
|
due_at=c.get("due_at", ""),
|
||||||
|
)
|
||||||
|
for c in req.task_completions
|
||||||
|
]
|
||||||
|
history = UserHistory(user_id=req.user_id, events=events, task_completions=completions)
|
||||||
|
|
||||||
t0 = __import__("time").monotonic()
|
t0 = __import__("time").monotonic()
|
||||||
inferred = run_inference(manifest, history)
|
inferred = run_inference(manifest, history)
|
||||||
|
|||||||
Reference in New Issue
Block a user