feat(clustering): LLM-enrichment before embedding (port from taskpile #129)
Ported from taskpile experiments/clustering_eval (prompt v1, qwen2.5:1.5b). The experiment showed ARI 0.22→0.77 and AUROC 0.76→0.91 on synthetic tasks when embedding LLM-expanded descriptions instead of raw titles. - Expand each task title via LiteLLM tip-generator before embedding - Prefix with "clustering: " (nomic-embed-text task instruction prefix) - Cache expansions in-memory by content hash within a compute cycle - Falls back to raw title if enrichment fails; no change to fallback behaviour Fixes #129 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Semantic task clustering via nomic-embed-text (issue #97).
|
||||
"""Semantic task clustering via nomic-embed-text (issue #97, #129).
|
||||
|
||||
Public API:
|
||||
cluster_tasks(tasks) -> list[Cluster]
|
||||
@@ -7,12 +7,18 @@ Each task dict must have a "content" key. Tasks without content are placed in a
|
||||
fallback "other" bucket. If the embedding service is unreachable, falls back to
|
||||
grouping by project_id so compute() always returns something useful.
|
||||
|
||||
Embeddings are fetched via LiteLLM's OpenAI-compatible /embeddings endpoint
|
||||
(model alias "embedder") when LITELLM_URL is set. Falls back to the Ollama
|
||||
/api/embed endpoint when only OLLAMA_URL is available (local dev without LiteLLM).
|
||||
Pipeline (ported from taskpile experiments/clustering_eval, prompt v1):
|
||||
1. Expand each raw title via LiteLLM `tip-generator` (qwen2.5:1.5b) into a
|
||||
3-sentence description. Cached in-memory by content hash within a compute
|
||||
cycle so duplicate titles cost one LLM call.
|
||||
2. Prefix the expanded text with "clustering: " (nomic-embed-text task prefix).
|
||||
3. Batch-embed via LiteLLM `embedder` (nomic-embed-text).
|
||||
Falls back to embedding raw titles when LLM expansion fails, and to
|
||||
project-based grouping when embeddings are unavailable.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -27,6 +33,16 @@ _SIM_THRESHOLD = 0.72
|
||||
# Never produce more than this many clusters regardless of task count.
|
||||
_MAX_CLUSTERS = 6
|
||||
_EMBED_TIMEOUT = 15.0
|
||||
_ENRICH_TIMEOUT = 30.0
|
||||
|
||||
_ENRICH_PROMPT_V1 = (
|
||||
"You are helping categorize a personal task. "
|
||||
"Write exactly 3 sentences in English describing what the task likely involves, "
|
||||
"what context or skills it needs, and why it might matter. "
|
||||
"Be concise and specific. Do not use bullet points or numbering.\n"
|
||||
"Task: {title}\n"
|
||||
"Description:"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,6 +59,60 @@ class Cluster:
|
||||
return sum(1 for t in self.tasks if t.get("is_overdue"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM enrichment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _content_hash(text: str) -> str:
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _enrich_title(title: str, litellm_url: str) -> str | None:
|
||||
"""Expand a terse task title into a 3-sentence description via LiteLLM."""
|
||||
try:
|
||||
with httpx.Client(trust_env=False, timeout=_ENRICH_TIMEOUT) as c:
|
||||
r = c.post(
|
||||
f"{litellm_url}/chat/completions",
|
||||
json={
|
||||
"model": "tip-generator",
|
||||
"messages": [{"role": "user", "content": _ENRICH_PROMPT_V1.format(title=title)}],
|
||||
"max_tokens": 120,
|
||||
"temperature": 0.3,
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"].strip()
|
||||
except Exception as exc:
|
||||
log.debug("enrich_failed title=%r error=%s", title[:40], exc)
|
||||
return None
|
||||
|
||||
|
||||
def _enrich_batch(titles: list[str]) -> list[str]:
|
||||
"""Return enriched descriptions for each title; falls back to raw title on failure.
|
||||
|
||||
Results are cached in-memory by content hash so duplicate titles within
|
||||
a single compute() call cost only one LLM round-trip.
|
||||
"""
|
||||
litellm_url = os.getenv("LITELLM_URL")
|
||||
if not litellm_url:
|
||||
log.debug("enrich_batch: no LITELLM_URL, skipping enrichment")
|
||||
return titles
|
||||
|
||||
cache: dict[str, str] = {}
|
||||
results = []
|
||||
for title in titles:
|
||||
h = _content_hash(title)
|
||||
if h not in cache:
|
||||
desc = _enrich_title(title, litellm_url)
|
||||
cache[h] = desc if desc else title
|
||||
results.append(cache[h])
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _embed_via_litellm(texts: list[str], litellm_url: str) -> list[list[float]] | None:
|
||||
"""Batch embed via LiteLLM OpenAI-compatible /embeddings endpoint."""
|
||||
try:
|
||||
@@ -96,6 +166,10 @@ def _embed_batch(texts: list[str]) -> list[list[float]] | None:
|
||||
return _embed_via_ollama(texts, ollama_url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clustering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
@@ -170,20 +244,23 @@ def cluster_tasks(
|
||||
if not embeddable:
|
||||
return _fallback_by_project(tasks)
|
||||
|
||||
# Batch-embed all task contents in one call.
|
||||
task_objs = [t for t, _ in embeddable]
|
||||
contents = [c for _, c in embeddable]
|
||||
vecs = _embed_batch(contents)
|
||||
raw_titles = [c for _, c in embeddable]
|
||||
|
||||
if vecs is None or len(vecs) != len(contents):
|
||||
# Step 1: LLM-enrich titles → richer semantic signal before embedding.
|
||||
descriptions = _enrich_batch(raw_titles)
|
||||
|
||||
# Step 2: Prefix with nomic-embed-text task prefix, then batch-embed.
|
||||
prefixed = [f"clustering: {d}" for d in descriptions]
|
||||
vecs = _embed_batch(prefixed)
|
||||
|
||||
if vecs is None or len(vecs) != len(prefixed):
|
||||
log.info("cluster_tasks: embedding unavailable, falling back to project grouping")
|
||||
return _fallback_by_project(tasks)
|
||||
|
||||
embedded = list(zip(task_objs, vecs))
|
||||
|
||||
clusters = _greedy_cluster(embedded)
|
||||
|
||||
# Tasks without content get their own bucket if any.
|
||||
if no_content:
|
||||
clusters.append(Cluster(label="Other tasks", tasks=no_content))
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Unit tests for ml.agents.clustering (issue #97).
|
||||
"""Unit tests for ml.agents.clustering (issue #97, #129).
|
||||
|
||||
Embedding calls are mocked so tests run without Ollama.
|
||||
LLM and embedding calls are mocked so tests run without Ollama or LiteLLM.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -9,7 +9,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine, _embed_batch
|
||||
from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine, _embed_batch, _enrich_batch
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
@@ -82,15 +82,51 @@ class TestGreedyClustering:
|
||||
assert clusters[0].label == "Write report"
|
||||
|
||||
|
||||
# ── enrichment ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestEnrichBatch:
|
||||
def test_falls_back_to_raw_when_no_litellm_url(self, monkeypatch):
|
||||
monkeypatch.delenv("LITELLM_URL", raising=False)
|
||||
result = _enrich_batch(["Buy milk", "Fix bug"])
|
||||
assert result == ["Buy milk", "Fix bug"]
|
||||
|
||||
def test_uses_description_when_litellm_available(self, monkeypatch):
|
||||
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
|
||||
with patch("ml.agents.clustering._enrich_title", return_value="Expanded description."):
|
||||
result = _enrich_batch(["Buy milk"])
|
||||
assert result == ["Expanded description."]
|
||||
|
||||
def test_falls_back_to_raw_title_on_enrich_failure(self, monkeypatch):
|
||||
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
|
||||
with patch("ml.agents.clustering._enrich_title", return_value=None):
|
||||
result = _enrich_batch(["Buy milk"])
|
||||
assert result == ["Buy milk"]
|
||||
|
||||
def test_deduplicates_identical_titles(self, monkeypatch):
|
||||
monkeypatch.setenv("LITELLM_URL", "http://fake-litellm")
|
||||
call_count = {"n": 0}
|
||||
def fake_enrich(title, url):
|
||||
call_count["n"] += 1
|
||||
return f"desc:{title}"
|
||||
with patch("ml.agents.clustering._enrich_title", side_effect=fake_enrich):
|
||||
result = _enrich_batch(["Buy milk", "Buy milk", "Fix bug"])
|
||||
assert call_count["n"] == 2 # only 2 unique titles
|
||||
assert result == ["desc:Buy milk", "desc:Buy milk", "desc:Fix bug"]
|
||||
|
||||
|
||||
# ── cluster_tasks integration ─────────────────────────────────────────────────
|
||||
|
||||
class TestClusterTasks:
|
||||
def _no_enrich(self, titles):
|
||||
return titles # pass-through; enrichment tested separately
|
||||
|
||||
def test_empty_tasks(self):
|
||||
result = cluster_tasks([])
|
||||
assert result == []
|
||||
|
||||
def test_fallback_when_embed_unavailable(self):
|
||||
with patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
|
||||
patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||
tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")]
|
||||
clusters = cluster_tasks(tasks)
|
||||
assert len(clusters) == 2
|
||||
@@ -98,7 +134,8 @@ class TestClusterTasks:
|
||||
assert "p1" in labels and "p2" in labels
|
||||
|
||||
def test_fallback_groups_by_project(self):
|
||||
with patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
|
||||
patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||
tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2
|
||||
clusters = cluster_tasks(tasks)
|
||||
by_label = {c.label: c.task_count for c in clusters}
|
||||
@@ -107,7 +144,8 @@ class TestClusterTasks:
|
||||
|
||||
def test_tasks_without_content_go_to_other(self):
|
||||
v = [1.0, 0.0]
|
||||
with patch("ml.agents.clustering._embed_batch", return_value=[v]):
|
||||
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
|
||||
patch("ml.agents.clustering._embed_batch", return_value=[v]):
|
||||
tasks = [_task("Has content"), {"is_overdue": False}]
|
||||
clusters = cluster_tasks(tasks)
|
||||
labels = {c.label for c in clusters}
|
||||
@@ -117,7 +155,8 @@ class TestClusterTasks:
|
||||
v_work = [1.0, 0.0, 0.0]
|
||||
v_home = [0.0, 1.0, 0.0]
|
||||
batch_result = [v_work, v_work, v_home, v_home]
|
||||
with patch("ml.agents.clustering._embed_batch", return_value=batch_result):
|
||||
with patch("ml.agents.clustering._enrich_batch", side_effect=self._no_enrich), \
|
||||
patch("ml.agents.clustering._embed_batch", return_value=batch_result):
|
||||
tasks = [
|
||||
_task("Write report"),
|
||||
_task("Review PR"),
|
||||
@@ -133,3 +172,15 @@ class TestClusterTasks:
|
||||
{"project_id": "p2", "is_overdue": False}]
|
||||
clusters = cluster_tasks(tasks)
|
||||
assert len(clusters) == 2
|
||||
|
||||
def test_enrich_called_before_embed(self):
|
||||
"""Verify enrichment output (not raw title) is what gets embedded."""
|
||||
v = [1.0, 0.0]
|
||||
captured = {}
|
||||
def fake_embed(texts):
|
||||
captured["texts"] = texts
|
||||
return [v] * len(texts)
|
||||
with patch("ml.agents.clustering._enrich_batch", return_value=["Expanded desc."]), \
|
||||
patch("ml.agents.clustering._embed_batch", side_effect=fake_embed):
|
||||
cluster_tasks([_task("Buy milk")])
|
||||
assert captured["texts"] == ["clustering: Expanded desc."]
|
||||
|
||||
Reference in New Issue
Block a user