diff --git a/infra/docker/Dockerfile.ml b/infra/docker/Dockerfile.ml index b3f7fb7..27701c1 100644 --- a/infra/docker/Dockerfile.ml +++ b/infra/docker/Dockerfile.ml @@ -2,5 +2,5 @@ FROM python:3.12-slim WORKDIR /app COPY ml/serving/requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -COPY ml/serving/main.py . +COPY ml/serving/*.py . CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/ml/serving/main.py b/ml/serving/main.py index 473f043..f03e5fc 100644 --- a/ml/serving/main.py +++ b/ml/serving/main.py @@ -28,6 +28,7 @@ import math import os import time from collections import deque +from contextlib import asynccontextmanager from pathlib import Path from typing import Optional, Deque @@ -36,9 +37,18 @@ import numpy as np from fastapi import FastAPI, HTTPException from pydantic import BaseModel +import nats_consumer from prompts import get_prompt -app = FastAPI(title="oO ML Serving", version="1.0.0") + +@asynccontextmanager +async def lifespan(app: FastAPI): + await nats_consumer.start(STATE_DIR) + yield + await nats_consumer.stop() + + +app = FastAPI(title="oO ML Serving", version="1.0.0", lifespan=lifespan) LITELLM_URL = os.getenv("LITELLM_URL", "http://localhost:4000") LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "sk-oo-dev") @@ -315,7 +325,13 @@ class GenerateResponse(BaseModel): @app.get("/health") def health(): - return {"ok": True} + return { + "ok": True, + "nats": { + "enabled": bool(nats_consumer.NATS_URL), + "consumers": nats_consumer.consumer_health, + }, + } _RETRY_SUFFIX = ( diff --git a/ml/serving/nats_consumer.py b/ml/serving/nats_consumer.py new file mode 100644 index 0000000..30b1fc1 --- /dev/null +++ b/ml/serving/nats_consumer.py @@ -0,0 +1,148 @@ +""" +JetStream durable consumers for ml/serving. + +Streams: + signals (subjects: signals.>) — durable: {prefix}-signals + feedback (subjects: feedback.>) — durable: {prefix}-feedback + +Handled subjects: + signals.task.synced → write per-user sync metadata to STATE_DIR + signals.tip.feedback → log for observability (reward is applied via HTTP path) + +Config (env vars): + NATS_URL — broker URL; empty = consumers disabled (default: "") + NATS_DURABLE_PREFIX — prefix for durable consumer names (default: "feature-pipeline") + NATS_MAX_DELIVER — max redelivery attempts before dropping (default: 5) +""" +from __future__ import annotations + +import json +import logging +import os +import time +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +NATS_URL = os.getenv("NATS_URL", "") +NATS_DURABLE_PREFIX = os.getenv("NATS_DURABLE_PREFIX", "feature-pipeline") +NATS_MAX_DELIVER = int(os.getenv("NATS_MAX_DELIVER", "5")) + +# Exposed to /health +consumer_health: dict[str, dict] = { + "signals": {"last_msg_ts": None, "processed": 0, "errors": 0}, + "feedback": {"last_msg_ts": None, "processed": 0, "errors": 0}, +} + +_nc = None # nats.aio.Client +_subs: list = [] # active JetStream subscriptions + + +# ── Subject handlers ─────────────────────────────────────────────────────── + +def _sync_meta_path(state_dir: Path, user_id: str) -> Path: + safe = "".join(c if c.isalnum() else "_" for c in user_id) + return state_dir / f"{safe}_sync.json" + + +async def _handle(subject: str, payload: dict, state_dir: Path) -> None: + if subject == "signals.task.synced": + user_id = payload.get("userId", "") + if user_id: + p = _sync_meta_path(state_dir, user_id) + p.write_text(json.dumps({ + "last_sync_ts": payload.get("syncedAt") or time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "task_count": payload.get("count", 0), + })) + logger.info("[nats] task_synced user=%s count=%s", user_id, payload.get("count")) + elif subject == "signals.tip.feedback": + logger.info( + "[nats] tip_feedback user=%s tip=%s action=%s reward=%s", + payload.get("userId"), payload.get("tipId"), + payload.get("action"), payload.get("reward"), + ) + else: + logger.debug("[nats] unhandled subject=%s", subject) + + +# ── Consumer factory ─────────────────────────────────────────────────────── + +def _make_handler(key: str, state_dir: Path): + """Return an async push-consumer callback that acks on success, naks on error.""" + async def handler(msg) -> None: + consumer_health[key]["last_msg_ts"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + try: + payload = json.loads(msg.data) + await _handle(msg.subject, payload, state_dir) + await msg.ack() + consumer_health[key]["processed"] += 1 + except Exception as exc: + consumer_health[key]["errors"] += 1 + logger.warning("[nats] processing error key=%s subject=%s: %s", key, msg.subject, exc) + await msg.nak() + return handler + + +# ── Lifecycle ────────────────────────────────────────────────────────────── + +async def start(state_dir: Path) -> None: + """Connect to NATS and register durable push consumers. No-op if NATS_URL is unset.""" + global _nc + if not NATS_URL: + logger.info("[nats] NATS_URL unset — JetStream consumers disabled") + return + + try: + import nats as nats_lib + from nats.js.api import ConsumerConfig, AckPolicy + + _nc = await nats_lib.connect( + NATS_URL, + name="ml-serving", + reconnect_time_wait=5, + max_reconnect_attempts=-1, + ) + js = _nc.jetstream() + logger.info("[nats] connected to %s", NATS_URL) + except Exception as exc: + logger.warning("[nats] connection failed: %s — consumers disabled", exc) + _nc = None + return + + config = ConsumerConfig( + ack_policy=AckPolicy.EXPLICIT, + max_deliver=NATS_MAX_DELIVER, + ) + + for key, subject in [("signals", "signals.>"), ("feedback", "feedback.>")]: + durable = f"{NATS_DURABLE_PREFIX}-{key}" + try: + sub = await js.subscribe( + subject, + durable=durable, + cb=_make_handler(key, state_dir), + config=config, + ) + _subs.append(sub) + logger.info("[nats] subscribed subject=%s durable=%s", subject, durable) + except Exception as exc: + logger.warning("[nats] subscribe failed key=%s: %s", key, exc) + + +async def stop() -> None: + """Drain subscriptions and close NATS connection.""" + global _nc + for sub in _subs: + try: + await sub.unsubscribe() + except Exception: + pass + _subs.clear() + if _nc: + try: + await _nc.drain() + except Exception: + pass + _nc = None + logger.info("[nats] disconnected") diff --git a/ml/serving/requirements.txt b/ml/serving/requirements.txt index 05f3cdd..af17243 100644 --- a/ml/serving/requirements.txt +++ b/ml/serving/requirements.txt @@ -4,3 +4,4 @@ pydantic==2.10.4 numpy>=1.26.0 httpx>=0.27.0 anthropic>=0.40.0 +nats-py>=2.9.0