- Add proto schemas in packages/shared-types/events/ (oo.events.v1): envelope.proto, signals.proto, integration.proto - buf.yaml with STANDARD lint + FILE breaking-change rules - .gitea/workflows/buf-check.yaml: lint + breaking check on every PR touching events/ (needs a Gitea Actions runner to execute) - scripts/buf-check.sh: local equivalent of the CI check - NormalizedEvent TS envelope gains eventId, schemaVersion, producer to align with the proto Envelope message - ml/serving/schemas.py: pydantic models mirroring the v1 proto types - nats_consumer.py: validate payloads via pydantic instead of raw .get() A field-rename PR will now fail buf breaking with exit code 100 and show the offending messages. To make a breaking change: keep the old field reserved, add the new one, bump schema_version to v2. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
150 lines
5.2 KiB
Python
150 lines
5.2 KiB
Python
"""
|
|
JetStream durable consumers for ml/serving.
|
|
|
|
Streams:
|
|
signals (subjects: signals.>) — durable: {prefix}-signals
|
|
feedback (subjects: feedback.>) — durable: {prefix}-feedback
|
|
|
|
Handled subjects:
|
|
signals.task.synced → write per-user sync metadata to STATE_DIR
|
|
signals.tip.feedback → log for observability (reward is applied via HTTP path)
|
|
|
|
Config (env vars):
|
|
NATS_URL — broker URL; empty = consumers disabled (default: "")
|
|
NATS_DURABLE_PREFIX — prefix for durable consumer names (default: "feature-pipeline")
|
|
NATS_MAX_DELIVER — max redelivery attempts before dropping (default: 5)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from schemas import TaskSyncedPayload, TipFeedbackPayload
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NATS_URL = os.getenv("NATS_URL", "")
|
|
NATS_DURABLE_PREFIX = os.getenv("NATS_DURABLE_PREFIX", "feature-pipeline")
|
|
NATS_MAX_DELIVER = int(os.getenv("NATS_MAX_DELIVER", "5"))
|
|
|
|
# Exposed to /health
|
|
consumer_health: dict[str, dict] = {
|
|
"signals": {"last_msg_ts": None, "processed": 0, "errors": 0},
|
|
"feedback": {"last_msg_ts": None, "processed": 0, "errors": 0},
|
|
}
|
|
|
|
_nc = None # nats.aio.Client
|
|
_subs: list = [] # active JetStream subscriptions
|
|
|
|
|
|
# ── Subject handlers ───────────────────────────────────────────────────────
|
|
|
|
def _sync_meta_path(state_dir: Path, user_id: str) -> Path:
|
|
safe = "".join(c if c.isalnum() else "_" for c in user_id)
|
|
return state_dir / f"{safe}_sync.json"
|
|
|
|
|
|
async def _handle(subject: str, payload: dict, state_dir: Path) -> None:
|
|
if subject == "signals.task.synced":
|
|
msg = TaskSyncedPayload.model_validate(payload)
|
|
p = _sync_meta_path(state_dir, msg.userId)
|
|
p.write_text(json.dumps({
|
|
"last_sync_ts": msg.syncedAt,
|
|
"task_count": msg.count,
|
|
}))
|
|
logger.info("[nats] task_synced user=%s count=%s", msg.userId, msg.count)
|
|
elif subject == "signals.tip.feedback":
|
|
msg = TipFeedbackPayload.model_validate(payload)
|
|
logger.info(
|
|
"[nats] tip_feedback user=%s tip=%s action=%s reward=%s",
|
|
msg.userId, msg.tipId, msg.action, msg.reward,
|
|
)
|
|
else:
|
|
logger.debug("[nats] unhandled subject=%s", subject)
|
|
|
|
|
|
# ── Consumer factory ───────────────────────────────────────────────────────
|
|
|
|
def _make_handler(key: str, state_dir: Path):
|
|
"""Return an async push-consumer callback that acks on success, naks on error."""
|
|
async def handler(msg) -> None:
|
|
consumer_health[key]["last_msg_ts"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
try:
|
|
payload = json.loads(msg.data)
|
|
await _handle(msg.subject, payload, state_dir)
|
|
await msg.ack()
|
|
consumer_health[key]["processed"] += 1
|
|
except Exception as exc:
|
|
consumer_health[key]["errors"] += 1
|
|
logger.warning("[nats] processing error key=%s subject=%s: %s", key, msg.subject, exc)
|
|
await msg.nak()
|
|
return handler
|
|
|
|
|
|
# ── Lifecycle ──────────────────────────────────────────────────────────────
|
|
|
|
async def start(state_dir: Path) -> None:
|
|
"""Connect to NATS and register durable push consumers. No-op if NATS_URL is unset."""
|
|
global _nc
|
|
if not NATS_URL:
|
|
logger.info("[nats] NATS_URL unset — JetStream consumers disabled")
|
|
return
|
|
|
|
try:
|
|
import nats as nats_lib
|
|
from nats.js.api import ConsumerConfig, AckPolicy
|
|
|
|
_nc = await nats_lib.connect(
|
|
NATS_URL,
|
|
name="ml-serving",
|
|
reconnect_time_wait=5,
|
|
max_reconnect_attempts=-1,
|
|
)
|
|
js = _nc.jetstream()
|
|
logger.info("[nats] connected to %s", NATS_URL)
|
|
except Exception as exc:
|
|
logger.warning("[nats] connection failed: %s — consumers disabled", exc)
|
|
_nc = None
|
|
return
|
|
|
|
config = ConsumerConfig(
|
|
ack_policy=AckPolicy.EXPLICIT,
|
|
max_deliver=NATS_MAX_DELIVER,
|
|
)
|
|
|
|
for key, subject in [("signals", "signals.>"), ("feedback", "feedback.>")]:
|
|
durable = f"{NATS_DURABLE_PREFIX}-{key}"
|
|
try:
|
|
sub = await js.subscribe(
|
|
subject,
|
|
durable=durable,
|
|
cb=_make_handler(key, state_dir),
|
|
config=config,
|
|
)
|
|
_subs.append(sub)
|
|
logger.info("[nats] subscribed subject=%s durable=%s", subject, durable)
|
|
except Exception as exc:
|
|
logger.warning("[nats] subscribe failed key=%s: %s", key, exc)
|
|
|
|
|
|
async def stop() -> None:
|
|
"""Drain subscriptions and close NATS connection."""
|
|
global _nc
|
|
for sub in _subs:
|
|
try:
|
|
await sub.unsubscribe()
|
|
except Exception:
|
|
pass
|
|
_subs.clear()
|
|
if _nc:
|
|
try:
|
|
await _nc.drain()
|
|
except Exception:
|
|
pass
|
|
_nc = None
|
|
logger.info("[nats] disconnected")
|