diff --git a/ml/serving/tests/test_schemas_and_consumer.py b/ml/serving/tests/test_schemas_and_consumer.py new file mode 100644 index 0000000..3a06415 --- /dev/null +++ b/ml/serving/tests/test_schemas_and_consumer.py @@ -0,0 +1,169 @@ +""" +Tests for schemas.py and nats_consumer._handle. +""" +import json +import pytest +import tempfile +from pathlib import Path +from pydantic import ValidationError +from unittest.mock import AsyncMock + +from schemas import ( + TaskSyncedPayload, + TipServedPayload, + TipFeedbackPayload, + TipRewardFailedPayload, + IntegrationTokenExpiredPayload, +) +from nats_consumer import _handle, _sync_meta_path + + +# ── Schema validation ───────────────────────────────────────────────────────── + +class TestTaskSyncedPayload: + def test_valid(self): + p = TaskSyncedPayload.model_validate( + {"userId": "u1", "source": "todoist", "count": 5, "syncedAt": "2026-04-25T10:00:00Z"} + ) + assert p.userId == "u1" + assert p.count == 5 + + def test_missing_field_raises(self): + with pytest.raises(ValidationError): + TaskSyncedPayload.model_validate({"userId": "u1", "source": "todoist"}) + + def test_wrong_type_raises(self): + with pytest.raises(ValidationError): + TaskSyncedPayload.model_validate( + {"userId": "u1", "source": "todoist", "count": "not-an-int", "syncedAt": "2026-04-25T10:00:00Z"} + ) + + +class TestTipFeedbackPayload: + def test_valid_without_dwell(self): + p = TipFeedbackPayload.model_validate( + {"userId": "u1", "tipId": "t1", "action": "done", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"} + ) + assert p.dwellMs is None + + def test_valid_with_dwell(self): + p = TipFeedbackPayload.model_validate( + {"userId": "u1", "tipId": "t1", "action": "helpful", "reward": 0.5, + "dwellMs": 3200, "createdAt": "2026-04-25T10:00:00Z"} + ) + assert p.dwellMs == 3200 + + def test_invalid_action_raises(self): + with pytest.raises(ValidationError): + TipFeedbackPayload.model_validate( + {"userId": "u1", "tipId": "t1", "action": "like", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"} + ) + + def test_all_valid_actions(self): + for action in ("done", "dismiss", "snooze", "helpful", "not_helpful"): + p = TipFeedbackPayload.model_validate( + {"userId": "u1", "tipId": "t1", "action": action, "reward": 0.0, "createdAt": "2026-04-25T10:00:00Z"} + ) + assert p.action == action + + +class TestOtherPayloads: + def test_tip_served(self): + p = TipServedPayload.model_validate( + {"userId": "u1", "tipId": "t1", "policy": "egreedy-v2", "servedAt": "2026-04-25T10:00:00Z"} + ) + assert p.policy == "egreedy-v2" + + def test_tip_reward_failed(self): + p = TipRewardFailedPayload.model_validate( + {"userId": "u1", "tipId": "t1", "reward": 1.0, "attempts": 3, + "error": "timeout", "failedAt": "2026-04-25T10:00:00Z"} + ) + assert p.attempts == 3 + + def test_integration_token_expired(self): + p = IntegrationTokenExpiredPayload.model_validate( + {"userId": "u1", "provider": "todoist", "detectedAt": "2026-04-25T10:00:00Z"} + ) + assert p.provider == "todoist" + + +# ── _handle behaviour ───────────────────────────────────────────────────────── + +TASK_SYNCED = { + "userId": "user-abc", + "source": "todoist", + "count": 7, + "syncedAt": "2026-04-25T10:00:00Z", +} + +TIP_FEEDBACK = { + "userId": "user-abc", + "tipId": "tip-xyz", + "action": "done", + "reward": 1.0, + "dwellMs": 4200, + "createdAt": "2026-04-25T10:00:00Z", +} + + +class TestHandle: + @pytest.mark.asyncio + async def test_task_synced_writes_meta_file(self): + with tempfile.TemporaryDirectory() as tmp: + state_dir = Path(tmp) + await _handle("signals.task.synced", TASK_SYNCED, state_dir) + meta_path = _sync_meta_path(state_dir, "user-abc") + assert meta_path.exists() + data = json.loads(meta_path.read_text()) + assert data["task_count"] == 7 + assert data["last_sync_ts"] == "2026-04-25T10:00:00Z" + + @pytest.mark.asyncio + async def test_task_synced_bad_payload_raises(self): + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(ValidationError): + await _handle("signals.task.synced", {"userId": "u1"}, Path(tmp)) + + @pytest.mark.asyncio + async def test_tip_feedback_valid_does_not_raise(self): + with tempfile.TemporaryDirectory() as tmp: + # should log and return cleanly + await _handle("signals.tip.feedback", TIP_FEEDBACK, Path(tmp)) + + @pytest.mark.asyncio + async def test_tip_feedback_bad_action_raises(self): + bad = {**TIP_FEEDBACK, "action": "unknown"} + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(ValidationError): + await _handle("signals.tip.feedback", bad, Path(tmp)) + + @pytest.mark.asyncio + async def test_unhandled_subject_is_ignored(self): + with tempfile.TemporaryDirectory() as tmp: + # should not raise for unknown subjects + await _handle("signals.something.new", {"any": "data"}, Path(tmp)) + + @pytest.mark.asyncio + async def test_make_handler_acks_on_success(self): + from nats_consumer import _make_handler + with tempfile.TemporaryDirectory() as tmp: + handler = _make_handler("signals", Path(tmp)) + msg = AsyncMock() + msg.subject = "signals.task.synced" + msg.data = json.dumps(TASK_SYNCED).encode() + await handler(msg) + msg.ack.assert_awaited_once() + msg.nak.assert_not_awaited() + + @pytest.mark.asyncio + async def test_make_handler_naks_on_validation_error(self): + from nats_consumer import _make_handler + with tempfile.TemporaryDirectory() as tmp: + handler = _make_handler("signals", Path(tmp)) + msg = AsyncMock() + msg.subject = "signals.task.synced" + msg.data = json.dumps({"userId": "u1"}).encode() # missing fields + await handler(msg) + msg.nak.assert_awaited_once() + msg.ack.assert_not_awaited()