feat(ml): JetStream durable consumers in ml/serving (#98)
Adds a NATS JetStream consumer to ml/serving so the feature pipeline can react to events without the API triggering every read. - nats_consumer.py: durable push consumers for signals.> and feedback.> streams; acks on success, naks for redeliver, up to NATS_MAX_DELIVER attempts; per-consumer health state (last_msg_ts, processed, errors) - main.py: FastAPI lifespan wires start/stop; /health exposes nats state - requirements.txt: adds nats-py>=2.9.0 - Dockerfile.ml: copy all *.py from ml/serving (was missing prompts.py) Handled subjects: signals.task.synced → writes per-user sync metadata to STATE_DIR signals.tip.feedback → logged for observability (reward via HTTP path) Config: NATS_URL (empty = disabled), NATS_DURABLE_PREFIX, NATS_MAX_DELIVER Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -28,6 +28,7 @@ import math
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Deque
|
||||
|
||||
@@ -36,9 +37,18 @@ import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import nats_consumer
|
||||
from prompts import get_prompt
|
||||
|
||||
app = FastAPI(title="oO ML Serving", version="1.0.0")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await nats_consumer.start(STATE_DIR)
|
||||
yield
|
||||
await nats_consumer.stop()
|
||||
|
||||
|
||||
app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000")
|
||||
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev")
|
||||
@@ -315,7 +325,13 @@ class GenerateResponse(BaseModel):
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True}
|
||||
return {
|
||||
"ok": True,
|
||||
"nats": {
|
||||
"enabled": bool(nats_consumer.NATS_URL),
|
||||
"consumers": nats_consumer.consumer_health,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_RETRY_SUFFIX = (
|
||||
|
||||
148
ml/serving/nats_consumer.py
Normal file
148
ml/serving/nats_consumer.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
JetStream durable consumers for ml/serving.
|
||||
|
||||
Streams:
|
||||
signals (subjects: signals.>) — durable: {prefix}-signals
|
||||
feedback (subjects: feedback.>) — durable: {prefix}-feedback
|
||||
|
||||
Handled subjects:
|
||||
signals.task.synced → write per-user sync metadata to STATE_DIR
|
||||
signals.tip.feedback → log for observability (reward is applied via HTTP path)
|
||||
|
||||
Config (env vars):
|
||||
NATS_URL — broker URL; empty = consumers disabled (default: "")
|
||||
NATS_DURABLE_PREFIX — prefix for durable consumer names (default: "feature-pipeline")
|
||||
NATS_MAX_DELIVER — max redelivery attempts before dropping (default: 5)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NATS_URL = os.getenv("NATS_URL", "")
|
||||
NATS_DURABLE_PREFIX = os.getenv("NATS_DURABLE_PREFIX", "feature-pipeline")
|
||||
NATS_MAX_DELIVER = int(os.getenv("NATS_MAX_DELIVER", "5"))
|
||||
|
||||
# Exposed to /health
|
||||
consumer_health: dict[str, dict] = {
|
||||
"signals": {"last_msg_ts": None, "processed": 0, "errors": 0},
|
||||
"feedback": {"last_msg_ts": None, "processed": 0, "errors": 0},
|
||||
}
|
||||
|
||||
_nc = None # nats.aio.Client
|
||||
_subs: list = [] # active JetStream subscriptions
|
||||
|
||||
|
||||
# ── Subject handlers ───────────────────────────────────────────────────────
|
||||
|
||||
def _sync_meta_path(state_dir: Path, user_id: str) -> Path:
|
||||
safe = "".join(c if c.isalnum() else "_" for c in user_id)
|
||||
return state_dir / f"{safe}_sync.json"
|
||||
|
||||
|
||||
async def _handle(subject: str, payload: dict, state_dir: Path) -> None:
|
||||
if subject == "signals.task.synced":
|
||||
user_id = payload.get("userId", "")
|
||||
if user_id:
|
||||
p = _sync_meta_path(state_dir, user_id)
|
||||
p.write_text(json.dumps({
|
||||
"last_sync_ts": payload.get("syncedAt") or time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"task_count": payload.get("count", 0),
|
||||
}))
|
||||
logger.info("[nats] task_synced user=%s count=%s", user_id, payload.get("count"))
|
||||
elif subject == "signals.tip.feedback":
|
||||
logger.info(
|
||||
"[nats] tip_feedback user=%s tip=%s action=%s reward=%s",
|
||||
payload.get("userId"), payload.get("tipId"),
|
||||
payload.get("action"), payload.get("reward"),
|
||||
)
|
||||
else:
|
||||
logger.debug("[nats] unhandled subject=%s", subject)
|
||||
|
||||
|
||||
# ── Consumer factory ───────────────────────────────────────────────────────
|
||||
|
||||
def _make_handler(key: str, state_dir: Path):
|
||||
"""Return an async push-consumer callback that acks on success, naks on error."""
|
||||
async def handler(msg) -> None:
|
||||
consumer_health[key]["last_msg_ts"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
try:
|
||||
payload = json.loads(msg.data)
|
||||
await _handle(msg.subject, payload, state_dir)
|
||||
await msg.ack()
|
||||
consumer_health[key]["processed"] += 1
|
||||
except Exception as exc:
|
||||
consumer_health[key]["errors"] += 1
|
||||
logger.warning("[nats] processing error key=%s subject=%s: %s", key, msg.subject, exc)
|
||||
await msg.nak()
|
||||
return handler
|
||||
|
||||
|
||||
# ── Lifecycle ──────────────────────────────────────────────────────────────
|
||||
|
||||
async def start(state_dir: Path) -> None:
|
||||
"""Connect to NATS and register durable push consumers. No-op if NATS_URL is unset."""
|
||||
global _nc
|
||||
if not NATS_URL:
|
||||
logger.info("[nats] NATS_URL unset — JetStream consumers disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
import nats as nats_lib
|
||||
from nats.js.api import ConsumerConfig, AckPolicy
|
||||
|
||||
_nc = await nats_lib.connect(
|
||||
NATS_URL,
|
||||
name="ml-serving",
|
||||
reconnect_time_wait=5,
|
||||
max_reconnect_attempts=-1,
|
||||
)
|
||||
js = _nc.jetstream()
|
||||
logger.info("[nats] connected to %s", NATS_URL)
|
||||
except Exception as exc:
|
||||
logger.warning("[nats] connection failed: %s — consumers disabled", exc)
|
||||
_nc = None
|
||||
return
|
||||
|
||||
config = ConsumerConfig(
|
||||
ack_policy=AckPolicy.EXPLICIT,
|
||||
max_deliver=NATS_MAX_DELIVER,
|
||||
)
|
||||
|
||||
for key, subject in [("signals", "signals.>"), ("feedback", "feedback.>")]:
|
||||
durable = f"{NATS_DURABLE_PREFIX}-{key}"
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
durable=durable,
|
||||
cb=_make_handler(key, state_dir),
|
||||
config=config,
|
||||
)
|
||||
_subs.append(sub)
|
||||
logger.info("[nats] subscribed subject=%s durable=%s", subject, durable)
|
||||
except Exception as exc:
|
||||
logger.warning("[nats] subscribe failed key=%s: %s", key, exc)
|
||||
|
||||
|
||||
async def stop() -> None:
|
||||
"""Drain subscriptions and close NATS connection."""
|
||||
global _nc
|
||||
for sub in _subs:
|
||||
try:
|
||||
await sub.unsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
_subs.clear()
|
||||
if _nc:
|
||||
try:
|
||||
await _nc.drain()
|
||||
except Exception:
|
||||
pass
|
||||
_nc = None
|
||||
logger.info("[nats] disconnected")
|
||||
@@ -4,3 +4,4 @@ pydantic==2.10.4
|
||||
numpy>=1.26.0
|
||||
httpx>=0.27.0
|
||||
anthropic>=0.40.0
|
||||
nats-py>=2.9.0
|
||||
|
||||
Reference in New Issue
Block a user