Files
oO/ml/agents/clustering.py
alvis 1ca2351488 fix(clustering): route embeddings through LiteLLM instead of Ollama directly
The old code called Ollama's /api/embeddings one task at a time, which caused
silent fallback to project-based grouping when host.docker.internal:11434 was
unreachable from the ml-serving container.

- Switch to LiteLLM /embeddings (model alias "embedder") as primary path
- Batch all task contents in one request instead of N serial calls
- Fall back to Ollama /api/embed (updated to current API) when LITELLM_URL is absent
- Update tests to mock _embed_batch instead of the removed _embed

Fixes #123

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 13:42:53 +00:00

191 lines
6.7 KiB
Python

"""Semantic task clustering via nomic-embed-text (issue #97).
Public API:
cluster_tasks(tasks) -> list[Cluster]
Each task dict must have a "content" key. Tasks without content are placed in a
fallback "other" bucket. If the embedding service is unreachable, falls back to
grouping by project_id so compute() always returns something useful.
Embeddings are fetched via LiteLLM's OpenAI-compatible /embeddings endpoint
(model alias "embedder") when LITELLM_URL is set. Falls back to the Ollama
/api/embed endpoint when only OLLAMA_URL is available (local dev without LiteLLM).
"""
from __future__ import annotations
import logging
import math
import os
from dataclasses import dataclass, field
import httpx
log = logging.getLogger(__name__)
# Cosine similarity threshold for merging tasks into the same cluster.
_SIM_THRESHOLD = 0.72
# Never produce more than this many clusters regardless of task count.
_MAX_CLUSTERS = 6
_EMBED_TIMEOUT = 15.0
@dataclass
class Cluster:
label: str # representative task content (shortest, most central)
tasks: list[dict] = field(default_factory=list)
@property
def task_count(self) -> int:
return len(self.tasks)
@property
def overdue_count(self) -> int:
return sum(1 for t in self.tasks if t.get("is_overdue"))
def _embed_via_litellm(texts: list[str], litellm_url: str) -> list[list[float]] | None:
"""Batch embed via LiteLLM OpenAI-compatible /embeddings endpoint."""
try:
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
r = c.post(
f"{litellm_url}/embeddings",
json={"model": "embedder", "input": texts},
)
r.raise_for_status()
data = r.json().get("data", [])
ordered = sorted(data, key=lambda x: x["index"])
return [item["embedding"] for item in ordered]
except Exception as exc:
log.debug("litellm_embed_failed error=%s", exc)
return None
def _embed_via_ollama(texts: list[str], ollama_url: str) -> list[list[float]] | None:
"""Batch embed via Ollama /api/embed endpoint."""
try:
results = []
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
for text in texts:
r = c.post(
f"{ollama_url}/api/embed",
json={"model": "nomic-embed-text", "input": text},
)
r.raise_for_status()
body = r.json()
# /api/embed returns {"embeddings": [[...]]}
embeddings = body.get("embeddings")
if not embeddings:
return None
results.append(embeddings[0])
return results
except Exception as exc:
log.debug("ollama_embed_failed error=%s", exc)
return None
def _embed_batch(texts: list[str]) -> list[list[float]] | None:
"""Embed a list of texts, preferring LiteLLM over direct Ollama."""
litellm_url = os.getenv("LITELLM_URL")
if litellm_url:
vecs = _embed_via_litellm(texts, litellm_url)
if vecs is not None:
return vecs
log.info("cluster: litellm embed failed, trying ollama fallback")
ollama_url = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
return _embed_via_ollama(texts, ollama_url)
def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(x * x for x in b))
if na == 0 or nb == 0:
return 0.0
return dot / (na * nb)
def _greedy_cluster(items: list[tuple[dict, list[float]]]) -> list[Cluster]:
"""Single-pass greedy clustering: each item joins the first existing cluster
whose centroid is above _SIM_THRESHOLD, else starts a new one."""
clusters: list[tuple[list[float], Cluster]] = [] # (centroid, cluster)
for task, vec in items:
best_idx = -1
best_sim = _SIM_THRESHOLD - 1e-9
for i, (centroid, _) in enumerate(clusters):
sim = _cosine(centroid, vec)
if sim > best_sim:
best_sim = sim
best_idx = i
if best_idx >= 0 and len(clusters) < _MAX_CLUSTERS:
centroid, cluster = clusters[best_idx]
cluster.tasks.append(task)
# Update centroid as running mean.
n = len(cluster.tasks)
new_centroid = [(c * (n - 1) + v) / n for c, v in zip(centroid, vec)]
clusters[best_idx] = (new_centroid, cluster)
elif len(clusters) < _MAX_CLUSTERS:
label = task.get("content", "Tasks")[:60]
cluster = Cluster(label=label, tasks=[task])
clusters.append((vec, cluster))
else:
# Overflow: append to closest cluster even below threshold.
best_i = max(range(len(clusters)), key=lambda i: _cosine(clusters[i][0], vec))
clusters[best_i][1].tasks.append(task)
return [c for _, c in clusters]
def _fallback_by_project(tasks: list[dict]) -> list[Cluster]:
"""Group by project_id when embeddings are unavailable."""
buckets: dict[str, Cluster] = {}
for task in tasks:
pid = task.get("project_id") or task.get("project") or "default"
if pid not in buckets:
label = pid if pid != "default" else "Tasks"
buckets[pid] = Cluster(label=label)
buckets[pid].tasks.append(task)
return list(buckets.values())
def cluster_tasks(
tasks: list[dict],
ollama_url: str | None = None, # kept for test compatibility; env vars take precedence
) -> list[Cluster]:
"""Cluster tasks by semantic similarity.
Returns a non-empty list of Cluster objects. Falls back to project-based
grouping if the embedding service is unavailable or tasks have no content.
"""
if not tasks:
return []
# Separate tasks with usable content from those without.
with_content = [(t, t.get("content", "").strip()) for t in tasks]
embeddable = [(t, c) for t, c in with_content if c]
no_content = [t for t, c in with_content if not c]
if not embeddable:
return _fallback_by_project(tasks)
# Batch-embed all task contents in one call.
task_objs = [t for t, _ in embeddable]
contents = [c for _, c in embeddable]
vecs = _embed_batch(contents)
if vecs is None or len(vecs) != len(contents):
log.info("cluster_tasks: embedding unavailable, falling back to project grouping")
return _fallback_by_project(tasks)
embedded = list(zip(task_objs, vecs))
clusters = _greedy_cluster(embedded)
# Tasks without content get their own bucket if any.
if no_content:
clusters.append(Cluster(label="Other tasks", tasks=no_content))
return clusters