feat(clustering): persistent enrichment cache in task_enrichments table

Each unique task title is now enriched by LiteLLM once and cached in the DB.
Subsequent agent compute cycles (every 12h) fetch the cache before calling
ml-serving; only new titles hit the tip-generator.

- DB: task_enrichments(content_hash PK, description, model, created_at)
- TS: fetchEnrichmentCache / persistEnrichments helpers in agent-outputs.ts;
  enrichment_cache passed in compute request, new_enrichments persisted from response
- Python: AgentComputeRequest.enrichment_cache / AgentComputeResponse.new_enrichments;
  AgentInput.enrichment_cache; _enrich_batch returns (descriptions, new_entries);
  cluster_tasks returns (clusters, new_enrichments)
- FocusAreaAgent stashes new_enrichments in signals_snapshot under _new_enrichments;
  compute_agent endpoint pops it before storing the snapshot

Closes part of #129

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-12 14:39:35 +00:00
parent 08d08ad7b0
commit 9ddeea6cac
9 changed files with 158 additions and 40 deletions

View File

@@ -20,6 +20,9 @@ class AgentInput:
# precedence over 'inferred' source; the caller resolves priority before
# passing this dict in.
agent_prefs: dict = field(default_factory=dict)
# Pre-fetched enrichment cache: {content_hash -> description}. Populated by
# the TS caller from the task_enrichments DB table to avoid redundant LLM calls.
enrichment_cache: dict = field(default_factory=dict)
@dataclass

View File

@@ -87,26 +87,41 @@ def _enrich_title(title: str, litellm_url: str) -> str | None:
return None
def _enrich_batch(titles: list[str]) -> list[str]:
"""Return enriched descriptions for each title; falls back to raw title on failure.
def _enrich_batch(
titles: list[str],
persistent_cache: dict[str, str] | None = None,
) -> tuple[list[str], dict[str, str]]:
"""Return (descriptions, new_entries) for each title.
Results are cached in-memory by content hash so duplicate titles within
a single compute() call cost only one LLM round-trip.
Checks persistent_cache (pre-fetched from DB) first, then falls back to
calling LiteLLM. new_entries contains only hashes generated this call —
the caller should persist these to the DB.
"""
litellm_url = os.getenv("LITELLM_URL")
if not litellm_url:
log.debug("enrich_batch: no LITELLM_URL, skipping enrichment")
return titles
return titles, {}
cache: dict[str, str] = {}
db_cache = persistent_cache or {}
session_cache: dict[str, str] = {} # dedup within this call
new_entries: dict[str, str] = {}
results = []
for title in titles:
h = _content_hash(title)
if h not in cache:
if h in db_cache:
results.append(db_cache[h])
elif h in session_cache:
results.append(session_cache[h])
else:
desc = _enrich_title(title, litellm_url)
cache[h] = desc if desc else title
results.append(cache[h])
return results
value = desc if desc else title
session_cache[h] = value
if desc: # only persist successful enrichments
new_entries[h] = desc
results.append(value)
return results, new_entries
# ---------------------------------------------------------------------------
@@ -227,14 +242,17 @@ def _fallback_by_project(tasks: list[dict]) -> list[Cluster]:
def cluster_tasks(
tasks: list[dict],
ollama_url: str | None = None, # kept for test compatibility; env vars take precedence
) -> list[Cluster]:
enrichment_cache: dict[str, str] | None = None,
) -> tuple[list[Cluster], dict[str, str]]:
"""Cluster tasks by semantic similarity.
Returns a non-empty list of Cluster objects. Falls back to project-based
grouping if the embedding service is unavailable or tasks have no content.
Returns (clusters, new_enrichments). new_enrichments contains LLM-generated
descriptions produced this call that were not in the persistent cache — the
caller should persist these. Falls back to project-based grouping if the
embedding service is unavailable or tasks have no content.
"""
if not tasks:
return []
return [], {}
# Separate tasks with usable content from those without.
with_content = [(t, t.get("content", "").strip()) for t in tasks]
@@ -242,13 +260,13 @@ def cluster_tasks(
no_content = [t for t, c in with_content if not c]
if not embeddable:
return _fallback_by_project(tasks)
return _fallback_by_project(tasks), {}
task_objs = [t for t, _ in embeddable]
raw_titles = [c for _, c in embeddable]
# Step 1: LLM-enrich titles → richer semantic signal before embedding.
descriptions = _enrich_batch(raw_titles)
descriptions, new_enrichments = _enrich_batch(raw_titles, persistent_cache=enrichment_cache)
# Step 2: Prefix with nomic-embed-text task prefix, then batch-embed.
prefixed = [f"clustering: {d}" for d in descriptions]
@@ -256,7 +274,7 @@ def cluster_tasks(
if vecs is None or len(vecs) != len(prefixed):
log.info("cluster_tasks: embedding unavailable, falling back to project grouping")
return _fallback_by_project(tasks)
return _fallback_by_project(tasks), new_enrichments
embedded = list(zip(task_objs, vecs))
clusters = _greedy_cluster(embedded)
@@ -264,4 +282,4 @@ def cluster_tasks(
if no_content:
clusters.append(Cluster(label="Other tasks", tasks=no_content))
return clusters
return clusters, new_enrichments

View File

@@ -35,7 +35,7 @@ MANIFEST = AgentManifest(
},
},
context_schema=["todoist.tasks"],
required_consents=["data:core", "data:todoist", "agent:focus-area"],
required_consents=["data:core", "data:todoist"],
output_contract={"type": "snippet", "format": "free_text"},
ttl_sec=43_200,
inferred_params=[
@@ -66,7 +66,7 @@ class FocusAreaAgent(BaseAgent):
{"cluster_count": 0, "strategy": "none"},
)
clusters = cluster_tasks(inp.tasks)
clusters, new_enrichments = cluster_tasks(inp.tasks, enrichment_cache=inp.enrichment_cache)
if not clusters:
return self._make_output(
@@ -109,5 +109,7 @@ class FocusAreaAgent(BaseAgent):
"cluster_count": len(clusters),
"strategy": strategy,
"preferred_areas": preferred,
# Consumed by compute_agent endpoint; stripped before storing the snapshot.
"_new_enrichments": new_enrichments,
}
return self._make_output(inp, " ".join(parts), snapshot)

View File

@@ -245,8 +245,9 @@ class TestFocusAreaAgent:
def test_snapshot_keys(self):
out = self.agent.compute(_inp(tasks=[_task("T1", project_id="A")]))
public_keys = {k for k in out.signals_snapshot if not k.startswith("_")}
assert {"top_cluster_label", "top_task_count", "top_overdue_count", "cluster_count",
"strategy", "preferred_areas"} == set(out.signals_snapshot)
"strategy", "preferred_areas"} == public_keys
# ── Registry ─────────────────────────────────────────────────────────────────

View File

@@ -87,20 +87,22 @@ class TestGreedyClustering:
class TestEnrichBatch:
def test_falls_back_to_raw_when_no_litellm_url(self, monkeypatch):
monkeypatch.delenv("LITELLM_URL", raising=False)
result = _enrich_batch(["Buy milk", "Fix bug"])
assert result == ["Buy milk", "Fix bug"]
result, new = _enrich_batch(["Buy milk", "Fix bug"])
assert result == ["Buy milk", "Fix bug"] and new == {}
def test_uses_description_when_litellm_available(self, monkeypatch):
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
with patch("ml.agents.clustering._enrich_title", return_value="Expanded description."):
result = _enrich_batch(["Buy milk"])
result, new = _enrich_batch(["Buy milk"])
assert result == ["Expanded description."]
assert len(new) == 1
def test_falls_back_to_raw_title_on_enrich_failure(self, monkeypatch):
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
with patch("ml.agents.clustering._enrich_title", return_value=None):
result = _enrich_batch(["Buy milk"])
result, new = _enrich_batch(["Buy milk"])
assert result == ["Buy milk"]
assert new == {} # failed enrichments are not persisted
def test_deduplicates_identical_titles(self, monkeypatch):
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
@@ -109,26 +111,40 @@ class TestEnrichBatch:
call_count["n"] += 1
return f"desc:{title}"
with patch("ml.agents.clustering._enrich_title", side_effect=fake_enrich):
result = _enrich_batch(["Buy milk", "Buy milk", "Fix bug"])
result, new = _enrich_batch(["Buy milk", "Buy milk", "Fix bug"])
assert call_count["n"] == 2 # only 2 unique titles
assert result == ["desc:Buy milk", "desc:Buy milk", "desc:Fix bug"]
def test_uses_persistent_cache(self, monkeypatch):
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
from ml.agents.clustering import _content_hash
h = _content_hash("Buy milk")
call_count = {"n": 0}
def fake_enrich(title, url):
call_count["n"] += 1
return "new desc"
with patch("ml.agents.clustering._enrich_title", side_effect=fake_enrich):
result, new = _enrich_batch(["Buy milk"], persistent_cache={h: "cached desc"})
assert call_count["n"] == 0 # cache hit, no LLM call
assert result == ["cached desc"]
assert new == {}
# ── cluster_tasks integration ─────────────────────────────────────────────────
class TestClusterTasks:
def _no_enrich(self, titles):
return titles # pass-through; enrichment tested separately
def _no_enrich(self, titles, persistent_cache=None):
return titles, {}
def test_empty_tasks(self):
result = cluster_tasks([])
assert result == []
clusters, new = cluster_tasks([])
assert clusters == [] and new == {}
def test_fallback_when_embed_unavailable(self):
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
patch("ml.agents.clustering._embed_batch", return_value=None):
tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")]
clusters = cluster_tasks(tasks)
clusters, _ = cluster_tasks(tasks)
assert len(clusters) == 2
labels = {c.label for c in clusters}
assert "p1" in labels and "p2" in labels
@@ -137,7 +153,7 @@ class TestClusterTasks:
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
patch("ml.agents.clustering._embed_batch", return_value=None):
tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2
clusters = cluster_tasks(tasks)
clusters, _ = cluster_tasks(tasks)
by_label = {c.label: c.task_count for c in clusters}
assert by_label["work"] == 3
assert by_label["home"] == 2
@@ -147,7 +163,7 @@ class TestClusterTasks:
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
patch("ml.agents.clustering._embed_batch", return_value=[v]):
tasks = [_task("Has content"), {"is_overdue": False}]
clusters = cluster_tasks(tasks)
clusters, _ = cluster_tasks(tasks)
labels = {c.label for c in clusters}
assert "Other tasks" in labels
@@ -163,15 +179,15 @@ class TestClusterTasks:
_task("Buy groceries"),
_task("Cook dinner"),
]
clusters = cluster_tasks(tasks)
clusters, _ = cluster_tasks(tasks)
assert len(clusters) == 2
assert all(c.task_count == 2 for c in clusters)
def test_all_tasks_no_content_fallback_by_project(self):
tasks = [{"project_id": "p1", "is_overdue": False},
{"project_id": "p2", "is_overdue": False}]
clusters = cluster_tasks(tasks)
assert len(clusters) == 2
clusters, new = cluster_tasks(tasks)
assert len(clusters) == 2 and new == {}
def test_enrich_called_before_embed(self):
"""Verify enrichment output (not raw title) is what gets embedded."""
@@ -180,7 +196,14 @@ class TestClusterTasks:
def fake_embed(texts):
captured["texts"] = texts
return [v] * len(texts)
with patch("ml.agents.clustering._enrich_batch", return_value=["Expanded desc."]), \
with patch("ml.agents.clustering._enrich_batch", return_value=(["Expanded desc."], {})), \
patch("ml.agents.clustering._embed_batch", side_effect=fake_embed):
cluster_tasks([_task("Buy milk")])
assert captured["texts"] == ["clustering: Expanded desc."]
def test_new_enrichments_returned(self):
v = [1.0, 0.0]
with patch("ml.agents.clustering._enrich_batch", return_value=(["desc"], {"abc123": "desc"})), \
patch("ml.agents.clustering._embed_batch", return_value=[v]):
_, new = cluster_tasks([_task("Buy milk")])
assert new == {"abc123": "desc"}