fix(clustering): route embeddings through LiteLLM instead of Ollama directly
The old code called Ollama's /api/embeddings one task at a time, which caused silent fallback to project-based grouping when host.docker.internal:11434 was unreachable from the ml-serving container. - Switch to LiteLLM /embeddings (model alias "embedder") as primary path - Batch all task contents in one request instead of N serial calls - Fall back to Ollama /api/embed (updated to current API) when LITELLM_URL is absent - Update tests to mock _embed_batch instead of the removed _embed Fixes #123 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,15 @@
|
|||||||
"""Semantic task clustering via nomic-embed-text (issue #97).
|
"""Semantic task clustering via nomic-embed-text (issue #97).
|
||||||
|
|
||||||
Public API:
|
Public API:
|
||||||
cluster_tasks(tasks, ollama_url) -> list[Cluster]
|
cluster_tasks(tasks) -> list[Cluster]
|
||||||
|
|
||||||
Each task dict must have a "content" key. Tasks without content are placed in a
|
Each task dict must have a "content" key. Tasks without content are placed in a
|
||||||
fallback "other" bucket. If Ollama is unreachable, falls back to grouping by
|
fallback "other" bucket. If the embedding service is unreachable, falls back to
|
||||||
project_id so compute() always returns something useful.
|
grouping by project_id so compute() always returns something useful.
|
||||||
|
|
||||||
|
Embeddings are fetched via LiteLLM's OpenAI-compatible /embeddings endpoint
|
||||||
|
(model alias "embedder") when LITELLM_URL is set. Falls back to the Ollama
|
||||||
|
/api/embed endpoint when only OLLAMA_URL is available (local dev without LiteLLM).
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -22,7 +26,7 @@ log = logging.getLogger(__name__)
|
|||||||
_SIM_THRESHOLD = 0.72
|
_SIM_THRESHOLD = 0.72
|
||||||
# Never produce more than this many clusters regardless of task count.
|
# Never produce more than this many clusters regardless of task count.
|
||||||
_MAX_CLUSTERS = 6
|
_MAX_CLUSTERS = 6
|
||||||
_EMBED_TIMEOUT = 10.0
|
_EMBED_TIMEOUT = 15.0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -39,20 +43,59 @@ class Cluster:
|
|||||||
return sum(1 for t in self.tasks if t.get("is_overdue"))
|
return sum(1 for t in self.tasks if t.get("is_overdue"))
|
||||||
|
|
||||||
|
|
||||||
def _embed(text: str, ollama_url: str) -> list[float] | None:
|
def _embed_via_litellm(texts: list[str], litellm_url: str) -> list[list[float]] | None:
|
||||||
|
"""Batch embed via LiteLLM OpenAI-compatible /embeddings endpoint."""
|
||||||
try:
|
try:
|
||||||
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
|
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
|
||||||
r = c.post(
|
r = c.post(
|
||||||
f"{ollama_url}/api/embeddings",
|
f"{litellm_url}/embeddings",
|
||||||
json={"model": "nomic-embed-text", "prompt": text, "keep_alive": 0},
|
json={"model": "embedder", "input": texts},
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return r.json().get("embedding")
|
data = r.json().get("data", [])
|
||||||
|
ordered = sorted(data, key=lambda x: x["index"])
|
||||||
|
return [item["embedding"] for item in ordered]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
log.debug("embed_failed text=%r error=%s", text[:40], exc)
|
log.debug("litellm_embed_failed error=%s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_via_ollama(texts: list[str], ollama_url: str) -> list[list[float]] | None:
|
||||||
|
"""Batch embed via Ollama /api/embed endpoint."""
|
||||||
|
try:
|
||||||
|
results = []
|
||||||
|
with httpx.Client(trust_env=False, timeout=_EMBED_TIMEOUT) as c:
|
||||||
|
for text in texts:
|
||||||
|
r = c.post(
|
||||||
|
f"{ollama_url}/api/embed",
|
||||||
|
json={"model": "nomic-embed-text", "input": text},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
body = r.json()
|
||||||
|
# /api/embed returns {"embeddings": [[...]]}
|
||||||
|
embeddings = body.get("embeddings")
|
||||||
|
if not embeddings:
|
||||||
|
return None
|
||||||
|
results.append(embeddings[0])
|
||||||
|
return results
|
||||||
|
except Exception as exc:
|
||||||
|
log.debug("ollama_embed_failed error=%s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_batch(texts: list[str]) -> list[list[float]] | None:
|
||||||
|
"""Embed a list of texts, preferring LiteLLM over direct Ollama."""
|
||||||
|
litellm_url = os.getenv("LITELLM_URL")
|
||||||
|
if litellm_url:
|
||||||
|
vecs = _embed_via_litellm(texts, litellm_url)
|
||||||
|
if vecs is not None:
|
||||||
|
return vecs
|
||||||
|
log.info("cluster: litellm embed failed, trying ollama fallback")
|
||||||
|
|
||||||
|
ollama_url = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
||||||
|
return _embed_via_ollama(texts, ollama_url)
|
||||||
|
|
||||||
|
|
||||||
def _cosine(a: list[float], b: list[float]) -> float:
|
def _cosine(a: list[float], b: list[float]) -> float:
|
||||||
dot = sum(x * y for x, y in zip(a, b))
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
na = math.sqrt(sum(x * x for x in a))
|
na = math.sqrt(sum(x * x for x in a))
|
||||||
@@ -109,18 +152,16 @@ def _fallback_by_project(tasks: list[dict]) -> list[Cluster]:
|
|||||||
|
|
||||||
def cluster_tasks(
|
def cluster_tasks(
|
||||||
tasks: list[dict],
|
tasks: list[dict],
|
||||||
ollama_url: str | None = None,
|
ollama_url: str | None = None, # kept for test compatibility; env vars take precedence
|
||||||
) -> list[Cluster]:
|
) -> list[Cluster]:
|
||||||
"""Cluster tasks by semantic similarity.
|
"""Cluster tasks by semantic similarity.
|
||||||
|
|
||||||
Returns a non-empty list of Cluster objects. Falls back to project-based
|
Returns a non-empty list of Cluster objects. Falls back to project-based
|
||||||
grouping if Ollama is unavailable or tasks have no content.
|
grouping if the embedding service is unavailable or tasks have no content.
|
||||||
"""
|
"""
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
url = ollama_url or os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
|
||||||
|
|
||||||
# Separate tasks with usable content from those without.
|
# Separate tasks with usable content from those without.
|
||||||
with_content = [(t, t.get("content", "").strip()) for t in tasks]
|
with_content = [(t, t.get("content", "").strip()) for t in tasks]
|
||||||
embeddable = [(t, c) for t, c in with_content if c]
|
embeddable = [(t, c) for t, c in with_content if c]
|
||||||
@@ -129,20 +170,17 @@ def cluster_tasks(
|
|||||||
if not embeddable:
|
if not embeddable:
|
||||||
return _fallback_by_project(tasks)
|
return _fallback_by_project(tasks)
|
||||||
|
|
||||||
# Fetch embeddings (best-effort; None means Ollama unavailable).
|
# Batch-embed all task contents in one call.
|
||||||
embedded: list[tuple[dict, list[float]]] = []
|
task_objs = [t for t, _ in embeddable]
|
||||||
failed = False
|
contents = [c for _, c in embeddable]
|
||||||
for task, content in embeddable:
|
vecs = _embed_batch(contents)
|
||||||
vec = _embed(content, url)
|
|
||||||
if vec is None:
|
|
||||||
failed = True
|
|
||||||
break
|
|
||||||
embedded.append((task, vec))
|
|
||||||
|
|
||||||
if failed or not embedded:
|
if vecs is None or len(vecs) != len(contents):
|
||||||
log.info("cluster_tasks: ollama unavailable, falling back to project grouping")
|
log.info("cluster_tasks: embedding unavailable, falling back to project grouping")
|
||||||
return _fallback_by_project(tasks)
|
return _fallback_by_project(tasks)
|
||||||
|
|
||||||
|
embedded = list(zip(task_objs, vecs))
|
||||||
|
|
||||||
clusters = _greedy_cluster(embedded)
|
clusters = _greedy_cluster(embedded)
|
||||||
|
|
||||||
# Tasks without content get their own bucket if any.
|
# Tasks without content get their own bucket if any.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine
|
from ml.agents.clustering import cluster_tasks, Cluster, _greedy_cluster, _cosine, _embed_batch
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||||
@@ -89,8 +89,8 @@ class TestClusterTasks:
|
|||||||
result = cluster_tasks([])
|
result = cluster_tasks([])
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
def test_fallback_when_ollama_unavailable(self):
|
def test_fallback_when_embed_unavailable(self):
|
||||||
with patch("ml.agents.clustering._embed", return_value=None):
|
with patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||||
tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")]
|
tasks = [_task("A", "p1"), _task("B", "p2"), _task("C", "p1")]
|
||||||
clusters = cluster_tasks(tasks)
|
clusters = cluster_tasks(tasks)
|
||||||
assert len(clusters) == 2
|
assert len(clusters) == 2
|
||||||
@@ -98,7 +98,7 @@ class TestClusterTasks:
|
|||||||
assert "p1" in labels and "p2" in labels
|
assert "p1" in labels and "p2" in labels
|
||||||
|
|
||||||
def test_fallback_groups_by_project(self):
|
def test_fallback_groups_by_project(self):
|
||||||
with patch("ml.agents.clustering._embed", return_value=None):
|
with patch("ml.agents.clustering._embed_batch", return_value=None):
|
||||||
tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2
|
tasks = [_task("A", "work")] * 3 + [_task("B", "home")] * 2
|
||||||
clusters = cluster_tasks(tasks)
|
clusters = cluster_tasks(tasks)
|
||||||
by_label = {c.label: c.task_count for c in clusters}
|
by_label = {c.label: c.task_count for c in clusters}
|
||||||
@@ -107,7 +107,7 @@ class TestClusterTasks:
|
|||||||
|
|
||||||
def test_tasks_without_content_go_to_other(self):
|
def test_tasks_without_content_go_to_other(self):
|
||||||
v = [1.0, 0.0]
|
v = [1.0, 0.0]
|
||||||
with patch("ml.agents.clustering._embed", return_value=v):
|
with patch("ml.agents.clustering._embed_batch", return_value=[v]):
|
||||||
tasks = [_task("Has content"), {"is_overdue": False}]
|
tasks = [_task("Has content"), {"is_overdue": False}]
|
||||||
clusters = cluster_tasks(tasks)
|
clusters = cluster_tasks(tasks)
|
||||||
labels = {c.label for c in clusters}
|
labels = {c.label for c in clusters}
|
||||||
@@ -116,8 +116,8 @@ class TestClusterTasks:
|
|||||||
def test_semantic_clustering_groups_similar(self):
|
def test_semantic_clustering_groups_similar(self):
|
||||||
v_work = [1.0, 0.0, 0.0]
|
v_work = [1.0, 0.0, 0.0]
|
||||||
v_home = [0.0, 1.0, 0.0]
|
v_home = [0.0, 1.0, 0.0]
|
||||||
side_effects = [v_work, v_work, v_home, v_home]
|
batch_result = [v_work, v_work, v_home, v_home]
|
||||||
with patch("ml.agents.clustering._embed", side_effect=side_effects):
|
with patch("ml.agents.clustering._embed_batch", return_value=batch_result):
|
||||||
tasks = [
|
tasks = [
|
||||||
_task("Write report"),
|
_task("Write report"),
|
||||||
_task("Review PR"),
|
_task("Review PR"),
|
||||||
|
|||||||
Reference in New Issue
Block a user