feat(ml): JetStream durable consumers in ml/serving (#98)

Adds a NATS JetStream consumer to ml/serving so the feature pipeline
can react to events without the API triggering every read.

- nats_consumer.py: durable push consumers for signals.> and feedback.>
  streams; acks on success, naks for redeliver, up to NATS_MAX_DELIVER
  attempts; per-consumer health state (last_msg_ts, processed, errors)
- main.py: FastAPI lifespan wires start/stop; /health exposes nats state
- requirements.txt: adds nats-py>=2.9.0
- Dockerfile.ml: copy all *.py from ml/serving (was missing prompts.py)

Handled subjects:
  signals.task.synced   → writes per-user sync metadata to STATE_DIR
  signals.tip.feedback  → logged for observability (reward via HTTP path)

Config: NATS_URL (empty = disabled), NATS_DURABLE_PREFIX, NATS_MAX_DELIVER

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-25 10:19:47 +00:00
parent 2d7cf217a9
commit 4652e4b582
4 changed files with 168 additions and 3 deletions

View File

@@ -2,5 +2,5 @@ FROM python:3.12-slim
WORKDIR /app WORKDIR /app
COPY ml/serving/requirements.txt . COPY ml/serving/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
COPY ml/serving/main.py . COPY ml/serving/*.py .
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -28,6 +28,7 @@ import math
import os import os
import time import time
from collections import deque from collections import deque
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Optional, Deque from typing import Optional, Deque
@@ -36,9 +37,18 @@ import numpy as np
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import nats_consumer
from prompts import get_prompt from prompts import get_prompt
app = FastAPI(title="oO ML Serving", version="1.0.0")
@asynccontextmanager
async def lifespan(app: FastAPI):
await nats_consumer.start(STATE_DIR)
yield
await nats_consumer.stop()
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000") LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
@@ -315,7 +325,13 @@ class GenerateResponse(BaseModel):
@app.get("/health") @app.get("/health")
def health(): def health():
return {"ok": True} return {
"ok": True,
"nats": {
"enabled": bool(nats_consumer.NATS_URL),
"consumers": nats_consumer.consumer_health,
},
}
_RETRY_SUFFIX = ( _RETRY_SUFFIX = (

148
ml/serving/nats_consumer.py Normal file
View File

@@ -0,0 +1,148 @@
"""
JetStream durable consumers for ml/serving.
Streams:
signals (subjects: signals.>) — durable: {prefix}-signals
feedback (subjects: feedback.>) — durable: {prefix}-feedback
Handled subjects:
signals.task.synced → write per-user sync metadata to STATE_DIR
signals.tip.feedback → log for observability (reward is applied via HTTP path)
Config (env vars):
NATS_URL — broker URL; empty = consumers disabled (default: "")
NATS_DURABLE_PREFIX — prefix for durable consumer names (default: "feature-pipeline")
NATS_MAX_DELIVER — max redelivery attempts before dropping (default: 5)
"""
from __future__ import annotations
import json
import logging
import os
import time
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
NATS_URL = os.getenv("NATS_URL", "")
NATS_DURABLE_PREFIX = os.getenv("NATS_DURABLE_PREFIX", "feature-pipeline")
NATS_MAX_DELIVER = int(os.getenv("NATS_MAX_DELIVER", "5"))
# Exposed to /health
consumer_health: dict[str, dict] = {
"signals": {"last_msg_ts": None, "processed": 0, "errors": 0},
"feedback": {"last_msg_ts": None, "processed": 0, "errors": 0},
}
_nc = None # nats.aio.Client
_subs: list = [] # active JetStream subscriptions
# ── Subject handlers ───────────────────────────────────────────────────────
def _sync_meta_path(state_dir: Path, user_id: str) -> Path:
safe = "".join(c if c.isalnum() else "_" for c in user_id)
return state_dir / f"{safe}_sync.json"
async def _handle(subject: str, payload: dict, state_dir: Path) -> None:
if subject == "signals.task.synced":
user_id = payload.get("userId", "")
if user_id:
p = _sync_meta_path(state_dir, user_id)
p.write_text(json.dumps({
"last_sync_ts": payload.get("syncedAt") or time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"task_count": payload.get("count", 0),
}))
logger.info("[nats] task_synced user=%s count=%s", user_id, payload.get("count"))
elif subject == "signals.tip.feedback":
logger.info(
"[nats] tip_feedback user=%s tip=%s action=%s reward=%s",
payload.get("userId"), payload.get("tipId"),
payload.get("action"), payload.get("reward"),
)
else:
logger.debug("[nats] unhandled subject=%s", subject)
# ── Consumer factory ───────────────────────────────────────────────────────
def _make_handler(key: str, state_dir: Path):
"""Return an async push-consumer callback that acks on success, naks on error."""
async def handler(msg) -> None:
consumer_health[key]["last_msg_ts"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
try:
payload = json.loads(msg.data)
await _handle(msg.subject, payload, state_dir)
await msg.ack()
consumer_health[key]["processed"] += 1
except Exception as exc:
consumer_health[key]["errors"] += 1
logger.warning("[nats] processing error key=%s subject=%s: %s", key, msg.subject, exc)
await msg.nak()
return handler
# ── Lifecycle ──────────────────────────────────────────────────────────────
async def start(state_dir: Path) -> None:
"""Connect to NATS and register durable push consumers. No-op if NATS_URL is unset."""
global _nc
if not NATS_URL:
logger.info("[nats] NATS_URL unset — JetStream consumers disabled")
return
try:
import nats as nats_lib
from nats.js.api import ConsumerConfig, AckPolicy
_nc = await nats_lib.connect(
NATS_URL,
name="ml-serving",
reconnect_time_wait=5,
max_reconnect_attempts=-1,
)
js = _nc.jetstream()
logger.info("[nats] connected to %s", NATS_URL)
except Exception as exc:
logger.warning("[nats] connection failed: %s — consumers disabled", exc)
_nc = None
return
config = ConsumerConfig(
ack_policy=AckPolicy.EXPLICIT,
max_deliver=NATS_MAX_DELIVER,
)
for key, subject in [("signals", "signals.>"), ("feedback", "feedback.>")]:
durable = f"{NATS_DURABLE_PREFIX}-{key}"
try:
sub = await js.subscribe(
subject,
durable=durable,
cb=_make_handler(key, state_dir),
config=config,
)
_subs.append(sub)
logger.info("[nats] subscribed subject=%s durable=%s", subject, durable)
except Exception as exc:
logger.warning("[nats] subscribe failed key=%s: %s", key, exc)
async def stop() -> None:
"""Drain subscriptions and close NATS connection."""
global _nc
for sub in _subs:
try:
await sub.unsubscribe()
except Exception:
pass
_subs.clear()
if _nc:
try:
await _nc.drain()
except Exception:
pass
_nc = None
logger.info("[nats] disconnected")

View File

@@ -4,3 +4,4 @@ pydantic==2.10.4
numpy>=1.26.0 numpy>=1.26.0
httpx>=0.27.0 httpx>=0.27.0
anthropic>=0.40.0 anthropic>=0.40.0
nats-py>=2.9.0