""" JetStream durable consumers for ml/serving. Streams: signals (subjects: signals.>) — durable: {prefix}-signals feedback (subjects: feedback.>) — durable: {prefix}-feedback Handled subjects: signals.task.synced → write per-user sync metadata to STATE_DIR signals.tip.feedback → log for observability (reward is applied via HTTP path) Config (env vars): NATS_URL — broker URL; empty = consumers disabled (default: "") NATS_DURABLE_PREFIX — prefix for durable consumer names (default: "feature-pipeline") NATS_MAX_DELIVER — max redelivery attempts before dropping (default: 5) """ from __future__ import annotations import json import os import time from pathlib import Path from typing import Optional import structlog from schemas import TaskSyncedPayload, TipFeedbackPayload log = structlog.get_logger(__name__) NATS_URL = os.getenv("NATS_URL", "") NATS_DURABLE_PREFIX = os.getenv("NATS_DURABLE_PREFIX", "feature-pipeline") NATS_MAX_DELIVER = int(os.getenv("NATS_MAX_DELIVER", "5")) # Exposed to /health consumer_health: dict[str, dict] = { "signals": {"last_msg_ts": None, "processed": 0, "errors": 0}, "feedback": {"last_msg_ts": None, "processed": 0, "errors": 0}, } _nc = None # nats.aio.Client _subs: list = [] # active JetStream subscriptions # ── Subject handlers ─────────────────────────────────────────────────────── def _sync_meta_path(state_dir: Path, user_id: str) -> Path: safe = "".join(c if c.isalnum() else "_" for c in user_id) return state_dir / f"{safe}_sync.json" async def _handle(subject: str, payload: dict, state_dir: Path) -> None: if subject == "signals.task.synced": msg = TaskSyncedPayload.model_validate(payload) p = _sync_meta_path(state_dir, msg.userId) p.write_text(json.dumps({ "last_sync_ts": msg.syncedAt, "task_count": msg.count, })) log.info("nats: task_synced", user_id=msg.userId, count=msg.count) elif subject == "signals.tip.feedback": msg = TipFeedbackPayload.model_validate(payload) log.info("nats: tip_feedback", user_id=msg.userId, tip_id=msg.tipId, action=msg.action, reward=msg.reward) else: log.debug("nats: unhandled subject", subject=subject) # ── Consumer factory ─────────────────────────────────────────────────────── def _make_handler(key: str, state_dir: Path): """Return an async push-consumer callback that acks on success, naks on error.""" async def handler(msg) -> None: consumer_health[key]["last_msg_ts"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) try: payload = json.loads(msg.data) await _handle(msg.subject, payload, state_dir) await msg.ack() consumer_health[key]["processed"] += 1 except Exception as exc: consumer_health[key]["errors"] += 1 log.warning("nats: processing error", key=key, subject=msg.subject, exc=str(exc)) await msg.nak() return handler # ── Lifecycle ────────────────────────────────────────────────────────────── async def start(state_dir: Path) -> None: """Connect to NATS and register durable push consumers. No-op if NATS_URL is unset.""" global _nc if not NATS_URL: log.info("nats: NATS_URL unset — JetStream consumers disabled") return try: import nats as nats_lib from nats.js.api import ConsumerConfig, AckPolicy _nc = await nats_lib.connect( NATS_URL, name="ml-serving", reconnect_time_wait=5, max_reconnect_attempts=-1, ) js = _nc.jetstream() log.info("nats: connected", url=NATS_URL) except Exception as exc: log.warning("nats: connection failed — consumers disabled", exc=str(exc)) _nc = None return config = ConsumerConfig( ack_policy=AckPolicy.EXPLICIT, max_deliver=NATS_MAX_DELIVER, ) for key, subject in [("signals", "signals.>"), ("feedback", "feedback.>")]: durable = f"{NATS_DURABLE_PREFIX}-{key}" try: sub = await js.subscribe( subject, durable=durable, cb=_make_handler(key, state_dir), config=config, ) _subs.append(sub) log.info("nats: subscribed", subject=subject, durable=durable) except Exception as exc: log.warning("nats: subscribe failed", key=key, exc=str(exc)) async def stop() -> None: """Drain subscriptions and close NATS connection.""" global _nc for sub in _subs: try: await sub.unsubscribe() except Exception: pass _subs.clear() if _nc: try: await _nc.drain() except Exception: pass _nc = None log.info("nats: disconnected")