fix(clustering): route embeddings through LiteLLM instead of Ollama directly

The old code called Ollama's /api/embeddings one task at a time, which caused
silent fallback to project-based grouping when host.docker.internal:11434 was
unreachable from the ml-serving container.

- Switch to LiteLLM /embeddings (model alias "embedder") as primary path
- Batch all task contents in one request instead of N serial calls
- Fall back to Ollama /api/embed (updated to current API) when LITELLM_URL is absent
- Update tests to mock _embed_batch instead of the removed _embed

Fixes #123

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-12 13:42:53 +00:00
parent 4e9210fcef
commit 1ca2351488
2 changed files with 69 additions and 31 deletions

View File

@@ -1,11 +1,15 @@
"""Semantic task clustering via nomic-embed-text (issue #97). """Semantic task clustering via nomic-embed-text (issue #97).
Public API: Public API:
cluster_tasks(tasks, ollama_url) -> list[Cluster] cluster_tasks(tasks) -> list[Cluster]
Each task dict must have a "content" key. Tasks without content are placed in a Each task dict must have a "content" key. Tasks without content are placed in a
fallback "other" bucket. If Ollama is unreachable, falls back to grouping by fallback "other" bucket. If the embedding service is unreachable, falls back to
project_id so compute() always returns something useful. grouping by project_id so compute() always returns something useful.
Embeddings are fetched via LiteLLM's OpenAI-compatible /embeddings endpoint
(model alias "embedder") when LITELLM_URL is set. Falls back to the Ollama
/api/embed endpoint when only OLLAMA_URL is available (local dev without LiteLLM).
""" """
from __future__ import annotations from __future__ import annotations
@@ -22,7 +26,7 @@ log = logging.getLogger(__name__)
_SIM_THRESHOLD = 0.72 _SIM_THRESHOLD = 0.72
# Never produce more than this many clusters regardless of task count. # Never produce more than this many clusters regardless of task count.
_MAX_CLUSTERS = 6 _MAX_CLUSTERS = 6
_EMBED_TIMEOUT = 10.0 _EMBED_TIMEOUT = 15.0
@dataclass @dataclass
@@ -39,20 +43,59 @@ class Cluster:
return sum(1 for t in self.tasks if t.get("is_overdue")) return sum(1 for t in self.tasks if t.get("is_overdue"))
def _embed(text: str, ollama_url: str) -> list[float] | None: def _embed_via_litellm(texts: list[str], litellm_url: str) -> list[list[float]] | None:
"""Batch embed via LiteLLM OpenAI-compatible /embeddings endpoint."""
try: try:
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c: with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
r = c.post( r = c.post(
f"{ollama_url}/api/embeddings", f"{litellm_url}/embeddings",
json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0}, json={"model": "embedder", "input": texts},
) )
r.raise_for_status() r.raise_for_status()
return r.json().get("embedding") data = r.json().get("data", [])
ordered = sorted(data, key=lambda x: x["index"])
return [item["embedding"] for item in ordered]
except Exception as exc: except Exception as exc:
log.debug("embed_failed text=%r error=%s", text[:40], exc) log.debug("litellm_embed_failed error=%s", exc)
return None return None
def _embed_via_ollama(texts: list[str], ollama_url: str) -> list[list[float]] | None:
"""Batch embed via Ollama /api/embed endpoint."""
try:
results = []
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
for text in texts:
r = c.post(
f"{ollama_url}/api/embed",
json={"model": "nomic-embed-text", "input": text},
)
r.raise_for_status()
body = r.json()
# /api/embed returns {"embeddings": [[...]]}
embeddings = body.get("embeddings")
if not embeddings:
return None
results.append(embeddings[0])
return results
except Exception as exc:
log.debug("ollama_embed_failed error=%s", exc)
return None
def _embed_batch(texts: list[str]) -> list[list[float]] | None:
"""Embed a list of texts, preferring LiteLLM over direct Ollama."""
litellm_url = os.getenv("LITELLM_URL")
if litellm_url:
vecs = _embed_via_litellm(texts, litellm_url)
if vecs is not None:
return vecs
log.info("cluster: litellm embed failed, trying ollama fallback")
ollama_url = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
return _embed_via_ollama(texts, ollama_url)
def _cosine(a: list[float], b: list[float]) -> float: def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b)) dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a)) na = math.sqrt(sum(x * x for x in a))
@@ -109,18 +152,16 @@ def _fallback_by_project(tasks: list[dict]) -> list[Cluster]:
def cluster_tasks( def cluster_tasks(
tasks: list[dict], tasks: list[dict],
ollama_url: str | None = None, ollama_url: str | None = None, # kept for test compatibility; env vars take precedence
) -> list[Cluster]: ) -> list[Cluster]:
"""Cluster tasks by semantic similarity. """Cluster tasks by semantic similarity.
Returns a non-empty list of Cluster objects. Falls back to project-based Returns a non-empty list of Cluster objects. Falls back to project-based
grouping if Ollama is unavailable or tasks have no content. grouping if the embedding service is unavailable or tasks have no content.
""" """
if not tasks: if not tasks:
return [] return []
url = ollama_url or os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
# Separate tasks with usable content from those without. # Separate tasks with usable content from those without.
with_content = [(t, t.get("content", "").strip()) for t in tasks] with_content = [(t, t.get("content", "").strip()) for t in tasks]
embeddable = [(t, c) for t, c in with_content if c] embeddable = [(t, c) for t, c in with_content if c]
@@ -129,20 +170,17 @@ def cluster_tasks(
if not embeddable: if not embeddable:
return _fallback_by_project(tasks) return _fallback_by_project(tasks)
# Fetch embeddings (best-effort; None means Ollama unavailable). # Batch-embed all task contents in one call.
embedded: list[tuple[dict, list[float]]] = [] task_objs = [t for t, _ in embeddable]
failed = False contents = [c for _, c in embeddable]
for task, content in embeddable: vecs = _embed_batch(contents)
vec = _embed(content, url)
if vec is None:
failed = True
break
embedded.append((task, vec))
if failed or not embedded: if vecs is None or len(vecs) != len(contents):
log.info("cluster_tasks: ollama unavailable, falling back to project grouping") log.info("cluster_tasks: embedding unavailable, falling back to project grouping")
return _fallback_by_project(tasks) return _fallback_by_project(tasks)
embedded = list(zip(task_objs, vecs))
clusters = _greedy_cluster(embedded) clusters = _greedy_cluster(embedded)
# Tasks without content get their own bucket if any. # Tasks without content get their own bucket if any.

View File

@@ -9,7 +9,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
from unittest.mock import patch from unittest.mock import patch
from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine, _embed_batch
# ── helpers ────────────────────────────────────────────────────────────────── # ── helpers ──────────────────────────────────────────────────────────────────
@@ -89,8 +89,8 @@ class TestClusterTasks:
result = cluster_tasks([]) result = cluster_tasks([])
assert result == [] assert result == []
def test_fallback_when_ollama_unavailable(self): def test_fallback_when_embed_unavailable(self):
with patch("ml.agents.clustering._embed", return_value=None): with patch("ml.agents.clustering._embed_batch", return_value=None):
tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")] tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")]
clusters = cluster_tasks(tasks) clusters = cluster_tasks(tasks)
assert len(clusters) == 2 assert len(clusters) == 2
@@ -98,7 +98,7 @@ class TestClusterTasks:
assert "p1" in labels and "p2" in labels assert "p1" in labels and "p2" in labels
def test_fallback_groups_by_project(self): def test_fallback_groups_by_project(self):
with patch("ml.agents.clustering._embed", return_value=None): with patch("ml.agents.clustering._embed_batch", return_value=None):
tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2 tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2
clusters = cluster_tasks(tasks) clusters = cluster_tasks(tasks)
by_label = {c.label: c.task_count for c in clusters} by_label = {c.label: c.task_count for c in clusters}
@@ -107,7 +107,7 @@ class TestClusterTasks:
def test_tasks_without_content_go_to_other(self): def test_tasks_without_content_go_to_other(self):
v = [1.0, 0.0] v = [1.0, 0.0]
with patch("ml.agents.clustering._embed", return_value=v): with patch("ml.agents.clustering._embed_batch", return_value=[v]):
tasks = [_task("Has content"), {"is_overdue": False}] tasks = [_task("Has content"), {"is_overdue": False}]
clusters = cluster_tasks(tasks) clusters = cluster_tasks(tasks)
labels = {c.label for c in clusters} labels = {c.label for c in clusters}
@@ -116,8 +116,8 @@ class TestClusterTasks:
def test_semantic_clustering_groups_similar(self): def test_semantic_clustering_groups_similar(self):
v_work = [1.0, 0.0, 0.0] v_work = [1.0, 0.0, 0.0]
v_home = [0.0, 1.0, 0.0] v_home = [0.0, 1.0, 0.0]
side_effects = [v_work, v_work, v_home, v_home] batch_result = [v_work, v_work, v_home, v_home]
with patch("ml.agents.clustering._embed", side_effect=side_effects): with patch("ml.agents.clustering._embed_batch", return_value=batch_result):
tasks = [ tasks = [
_task("Write report"), _task("Write report"),
_task("Review PR"), _task("Review PR"),