From 1ca2351488aa6965985630ff64215b5ae5f539fd Mon Sep 17 00:00:00 2001 From: alvis Date: Tue, 12 May 2026 13:42:53 +0000 Subject: [PATCH] fix(clustering): route embeddings through LiteLLM instead of Ollama directly The old code called Ollama's /api/embeddings one task at a time, which caused silent fallback to project-based grouping when host.docker.internal:11434 was unreachable from the ml-serving container. - Switch to LiteLLM /embeddings (model alias "embedder") as primary path - Batch all task contents in one request instead of N serial calls - Fall back to Ollama /api/embed (updated to current API) when LITELLM_URL is absent - Update tests to mock _embed_batch instead of the removed _embed Fixes #123 Co-Authored-By: Claude Sonnet 4.6 --- ml/agents/clustering.py | 86 +++++++++++++++++++++--------- ml/agents/tests/test_clustering.py | 14 ++--- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/ml/agents/clustering.py b/ml/agents/clustering.py index b3a2a29..8d8bc62 100644 --- a/ml/agents/clustering.py +++ b/ml/agents/clustering.py @@ -1,11 +1,15 @@ """Semantic task clustering via nomic-embed-text (issue #97). Public API: - cluster_tasks(tasks, ollama_url) -> list[Cluster] + cluster_tasks(tasks) -> list[Cluster] Each task dict must have a "content" key. Tasks without content are placed in a -fallback "other" bucket. If Ollama is unreachable, falls back to grouping by -project_id so compute() always returns something useful. +fallback "other" bucket. If the embedding service is unreachable, falls back to +grouping by project_id so compute() always returns something useful. + +Embeddings are fetched via LiteLLM's OpenAI-compatible /embeddings endpoint +(model alias "embedder") when LITELLM_URL is set. Falls back to the Ollama +/api/embed endpoint when only OLLAMA_URL is available (local dev without LiteLLM). """ from __future__ import annotations @@ -22,7 +26,7 @@ log = logging.getLogger(__name__) _SIM_THRESHOLD = 0.72 # Never produce more than this many clusters regardless of task count. _MAX_CLUSTERS = 6 -_EMBED_TIMEOUT = 10.0 +_EMBED_TIMEOUT = 15.0 @dataclass @@ -39,20 +43,59 @@ class Cluster: return sum(1 for t in self.tasks if t.get("is_overdue")) -def _embed(text: str, ollama_url: str) -> list[float] | None: +def _embed_via_litellm(texts: list[str], litellm_url: str) -> list[list[float]] | None: + """Batch embed via LiteLLM OpenAI-compatible /embeddings endpoint.""" try: with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c: r = c.post( - f"{ollama_url}/api/embeddings", - json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0}, + f"{litellm_url}/embeddings", + json={"model": "embedder", "input": texts}, ) r.raise_for_status() - return r.json().get("embedding") + data = r.json().get("data", []) + ordered = sorted(data, key=lambda x: x["index"]) + return [item["embedding"] for item in ordered] except Exception as exc: - log.debug("embed_failed text=%r error=%s", text[:40], exc) + log.debug("litellm_embed_failed error=%s", exc) return None +def _embed_via_ollama(texts: list[str], ollama_url: str) -> list[list[float]] | None: + """Batch embed via Ollama /api/embed endpoint.""" + try: + results = [] + with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c: + for text in texts: + r = c.post( + f"{ollama_url}/api/embed", + json={"model": "nomic-embed-text", "input": text}, + ) + r.raise_for_status() + body = r.json() + # /api/embed returns {"embeddings": [[...]]} + embeddings = body.get("embeddings") + if not embeddings: + return None + results.append(embeddings[0]) + return results + except Exception as exc: + log.debug("ollama_embed_failed error=%s", exc) + return None + + +def _embed_batch(texts: list[str]) -> list[list[float]] | None: + """Embed a list of texts, preferring LiteLLM over direct Ollama.""" + litellm_url = os.getenv("LITELLM_URL") + if litellm_url: + vecs = _embed_via_litellm(texts, litellm_url) + if vecs is not None: + return vecs + log.info("cluster: litellm embed failed, trying ollama fallback") + + ollama_url = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") + return _embed_via_ollama(texts, ollama_url) + + def _cosine(a: list[float], b: list[float]) -> float: dot = sum(x * y for x, y in zip(a, b)) na = math.sqrt(sum(x * x for x in a)) @@ -109,18 +152,16 @@ def _fallback_by_project(tasks: list[dict]) -> list[Cluster]: def cluster_tasks( tasks: list[dict], - ollama_url: str | None = None, + ollama_url: str | None = None, # kept for test compatibility; env vars take precedence ) -> list[Cluster]: """Cluster tasks by semantic similarity. Returns a non-empty list of Cluster objects. Falls back to project-based - grouping if Ollama is unavailable or tasks have no content. + grouping if the embedding service is unavailable or tasks have no content. """ if not tasks: return [] - url = ollama_url or os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") - # Separate tasks with usable content from those without. with_content = [(t, t.get("content", "").strip()) for t in tasks] embeddable = [(t, c) for t, c in with_content if c] @@ -129,20 +170,17 @@ def cluster_tasks( if not embeddable: return _fallback_by_project(tasks) - # Fetch embeddings (best-effort; None means Ollama unavailable). - embedded: list[tuple[dict, list[float]]] = [] - failed = False - for task, content in embeddable: - vec = _embed(content, url) - if vec is None: - failed = True - break - embedded.append((task, vec)) + # Batch-embed all task contents in one call. + task_objs = [t for t, _ in embeddable] + contents = [c for _, c in embeddable] + vecs = _embed_batch(contents) - if failed or not embedded: - log.info("cluster_tasks: ollama unavailable, falling back to project grouping") + if vecs is None or len(vecs) != len(contents): + log.info("cluster_tasks: embedding unavailable, falling back to project grouping") return _fallback_by_project(tasks) + embedded = list(zip(task_objs, vecs)) + clusters = _greedy_cluster(embedded) # Tasks without content get their own bucket if any. diff --git a/ml/agents/tests/test_clustering.py b/ml/agents/tests/test_clustering.py index 85e2943..e33f0e7 100644 --- a/ml/agents/tests/test_clustering.py +++ b/ml/agents/tests/test_clustering.py @@ -9,7 +9,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) from unittest.mock import patch -from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine +from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine, _embed_batch # ── helpers ────────────────────────────────────────────────────────────────── @@ -89,8 +89,8 @@ class TestClusterTasks: result = cluster_tasks([]) assert result == [] - def test_fallback_when_ollama_unavailable(self): - with patch("ml.agents.clustering._embed", return_value=None): + def test_fallback_when_embed_unavailable(self): + with patch("ml.agents.clustering._embed_batch", return_value=None): tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")] clusters = cluster_tasks(tasks) assert len(clusters) == 2 @@ -98,7 +98,7 @@ class TestClusterTasks: assert "p1" in labels and "p2" in labels def test_fallback_groups_by_project(self): - with patch("ml.agents.clustering._embed", return_value=None): + with patch("ml.agents.clustering._embed_batch", return_value=None): tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2 clusters = cluster_tasks(tasks) by_label = {c.label: c.task_count for c in clusters} @@ -107,7 +107,7 @@ class TestClusterTasks: def test_tasks_without_content_go_to_other(self): v = [1.0, 0.0] - with patch("ml.agents.clustering._embed", return_value=v): + with patch("ml.agents.clustering._embed_batch", return_value=[v]): tasks = [_task("Has content"), {"is_overdue": False}] clusters = cluster_tasks(tasks) labels = {c.label for c in clusters} @@ -116,8 +116,8 @@ class TestClusterTasks: def test_semantic_clustering_groups_similar(self): v_work = [1.0, 0.0, 0.0] v_home = [0.0, 1.0, 0.0] - side_effects = [v_work, v_work, v_home, v_home] - with patch("ml.agents.clustering._embed", side_effect=side_effects): + batch_result = [v_work, v_work, v_home, v_home] + with patch("ml.agents.clustering._embed_batch", return_value=batch_result): tasks = [ _task("Write report"), _task("Review PR"),