test(schema): unit tests for schemas.py and nats_consumer._handle (#54)
17 tests covering: pydantic model validation (all payload types, optional fields, invalid enum values, missing required fields), _handle write path for task_synced, validation errors surfaced through _make_handler causing nak instead of ack. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
169
ml/serving/tests/test_schemas_and_consumer.py
Normal file
169
ml/serving/tests/test_schemas_and_consumer.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Tests for schemas.py and nats_consumer._handle.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from pydantic import ValidationError
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from schemas import (
|
||||
TaskSyncedPayload,
|
||||
TipServedPayload,
|
||||
TipFeedbackPayload,
|
||||
TipRewardFailedPayload,
|
||||
IntegrationTokenExpiredPayload,
|
||||
)
|
||||
from nats_consumer import _handle, _sync_meta_path
|
||||
|
||||
|
||||
# ── Schema validation ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestTaskSyncedPayload:
|
||||
def test_valid(self):
|
||||
p = TaskSyncedPayload.model_validate(
|
||||
{"userId": "u1", "source": "todoist", "count": 5, "syncedAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.userId == "u1"
|
||||
assert p.count == 5
|
||||
|
||||
def test_missing_field_raises(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TaskSyncedPayload.model_validate({"userId": "u1", "source": "todoist"})
|
||||
|
||||
def test_wrong_type_raises(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TaskSyncedPayload.model_validate(
|
||||
{"userId": "u1", "source": "todoist", "count": "not-an-int", "syncedAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
|
||||
|
||||
class TestTipFeedbackPayload:
|
||||
def test_valid_without_dwell(self):
|
||||
p = TipFeedbackPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "action": "done", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.dwellMs is None
|
||||
|
||||
def test_valid_with_dwell(self):
|
||||
p = TipFeedbackPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "action": "helpful", "reward": 0.5,
|
||||
"dwellMs": 3200, "createdAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.dwellMs == 3200
|
||||
|
||||
def test_invalid_action_raises(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TipFeedbackPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "action": "like", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
|
||||
def test_all_valid_actions(self):
|
||||
for action in ("done", "dismiss", "snooze", "helpful", "not_helpful"):
|
||||
p = TipFeedbackPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "action": action, "reward": 0.0, "createdAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.action == action
|
||||
|
||||
|
||||
class TestOtherPayloads:
|
||||
def test_tip_served(self):
|
||||
p = TipServedPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "policy": "egreedy-v2", "servedAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.policy == "egreedy-v2"
|
||||
|
||||
def test_tip_reward_failed(self):
|
||||
p = TipRewardFailedPayload.model_validate(
|
||||
{"userId": "u1", "tipId": "t1", "reward": 1.0, "attempts": 3,
|
||||
"error": "timeout", "failedAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.attempts == 3
|
||||
|
||||
def test_integration_token_expired(self):
|
||||
p = IntegrationTokenExpiredPayload.model_validate(
|
||||
{"userId": "u1", "provider": "todoist", "detectedAt": "2026-04-25T10:00:00Z"}
|
||||
)
|
||||
assert p.provider == "todoist"
|
||||
|
||||
|
||||
# ── _handle behaviour ─────────────────────────────────────────────────────────
|
||||
|
||||
TASK_SYNCED = {
|
||||
"userId": "user-abc",
|
||||
"source": "todoist",
|
||||
"count": 7,
|
||||
"syncedAt": "2026-04-25T10:00:00Z",
|
||||
}
|
||||
|
||||
TIP_FEEDBACK = {
|
||||
"userId": "user-abc",
|
||||
"tipId": "tip-xyz",
|
||||
"action": "done",
|
||||
"reward": 1.0,
|
||||
"dwellMs": 4200,
|
||||
"createdAt": "2026-04-25T10:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
class TestHandle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_synced_writes_meta_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
state_dir = Path(tmp)
|
||||
await _handle("signals.task.synced", TASK_SYNCED, state_dir)
|
||||
meta_path = _sync_meta_path(state_dir, "user-abc")
|
||||
assert meta_path.exists()
|
||||
data = json.loads(meta_path.read_text())
|
||||
assert data["task_count"] == 7
|
||||
assert data["last_sync_ts"] == "2026-04-25T10:00:00Z"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_synced_bad_payload_raises(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with pytest.raises(ValidationError):
|
||||
await _handle("signals.task.synced", {"userId": "u1"}, Path(tmp))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tip_feedback_valid_does_not_raise(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# should log and return cleanly
|
||||
await _handle("signals.tip.feedback", TIP_FEEDBACK, Path(tmp))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tip_feedback_bad_action_raises(self):
|
||||
bad = {**TIP_FEEDBACK, "action": "unknown"}
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with pytest.raises(ValidationError):
|
||||
await _handle("signals.tip.feedback", bad, Path(tmp))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unhandled_subject_is_ignored(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# should not raise for unknown subjects
|
||||
await _handle("signals.something.new", {"any": "data"}, Path(tmp))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_handler_acks_on_success(self):
|
||||
from nats_consumer import _make_handler
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
handler = _make_handler("signals", Path(tmp))
|
||||
msg = AsyncMock()
|
||||
msg.subject = "signals.task.synced"
|
||||
msg.data = json.dumps(TASK_SYNCED).encode()
|
||||
await handler(msg)
|
||||
msg.ack.assert_awaited_once()
|
||||
msg.nak.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_handler_naks_on_validation_error(self):
|
||||
from nats_consumer import _make_handler
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
handler = _make_handler("signals", Path(tmp))
|
||||
msg = AsyncMock()
|
||||
msg.subject = "signals.task.synced"
|
||||
msg.data = json.dumps({"userId": "u1"}).encode() # missing fields
|
||||
await handler(msg)
|
||||
msg.nak.assert_awaited_once()
|
||||
msg.ack.assert_not_awaited()
|
||||
Reference in New Issue
Block a user