""" Tests for schemas.py and nats_consumer._handle. """ import json import pytest import tempfile from pathlib import Path from pydantic import ValidationError from unittest.mock import AsyncMock from schemas import ( TaskSyncedPayload, TipServedPayload, TipFeedbackPayload, TipRewardFailedPayload, IntegrationTokenExpiredPayload, ) from nats_consumer import _handle, _sync_meta_path # ── Schema validation ───────────────────────────────────────────────────────── class TestTaskSyncedPayload: def test_valid(self): p = TaskSyncedPayload.model_validate( {"userId": "u1", "source": "todoist", "count": 5, "syncedAt": "2026-04-25T10:00:00Z"} ) assert p.userId == "u1" assert p.count == 5 def test_missing_field_raises(self): with pytest.raises(ValidationError): TaskSyncedPayload.model_validate({"userId": "u1", "source": "todoist"}) def test_wrong_type_raises(self): with pytest.raises(ValidationError): TaskSyncedPayload.model_validate( {"userId": "u1", "source": "todoist", "count": "not-an-int", "syncedAt": "2026-04-25T10:00:00Z"} ) class TestTipFeedbackPayload: def test_valid_without_dwell(self): p = TipFeedbackPayload.model_validate( {"userId": "u1", "tipId": "t1", "action": "done", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"} ) assert p.dwellMs is None def test_valid_with_dwell(self): p = TipFeedbackPayload.model_validate( {"userId": "u1", "tipId": "t1", "action": "helpful", "reward": 0.5, "dwellMs": 3200, "createdAt": "2026-04-25T10:00:00Z"} ) assert p.dwellMs == 3200 def test_invalid_action_raises(self): with pytest.raises(ValidationError): TipFeedbackPayload.model_validate( {"userId": "u1", "tipId": "t1", "action": "like", "reward": 1.0, "createdAt": "2026-04-25T10:00:00Z"} ) def test_all_valid_actions(self): for action in ("done", "dismiss", "snooze", "helpful", "not_helpful"): p = TipFeedbackPayload.model_validate( {"userId": "u1", "tipId": "t1", "action": action, "reward": 0.0, "createdAt": "2026-04-25T10:00:00Z"} ) assert p.action == action class TestOtherPayloads: def test_tip_served(self): p = TipServedPayload.model_validate( {"userId": "u1", "tipId": "t1", "policy": "egreedy-v2", "servedAt": "2026-04-25T10:00:00Z"} ) assert p.policy == "egreedy-v2" def test_tip_reward_failed(self): p = TipRewardFailedPayload.model_validate( {"userId": "u1", "tipId": "t1", "reward": 1.0, "attempts": 3, "error": "timeout", "failedAt": "2026-04-25T10:00:00Z"} ) assert p.attempts == 3 def test_integration_token_expired(self): p = IntegrationTokenExpiredPayload.model_validate( {"userId": "u1", "provider": "todoist", "detectedAt": "2026-04-25T10:00:00Z"} ) assert p.provider == "todoist" # ── _handle behaviour ───────────────────────────────────────────────────────── TASK_SYNCED = { "userId": "user-abc", "source": "todoist", "count": 7, "syncedAt": "2026-04-25T10:00:00Z", } TIP_FEEDBACK = { "userId": "user-abc", "tipId": "tip-xyz", "action": "done", "reward": 1.0, "dwellMs": 4200, "createdAt": "2026-04-25T10:00:00Z", } class TestHandle: @pytest.mark.asyncio async def test_task_synced_writes_meta_file(self): with tempfile.TemporaryDirectory() as tmp: state_dir = Path(tmp) await _handle("signals.task.synced", TASK_SYNCED, state_dir) meta_path = _sync_meta_path(state_dir, "user-abc") assert meta_path.exists() data = json.loads(meta_path.read_text()) assert data["task_count"] == 7 assert data["last_sync_ts"] == "2026-04-25T10:00:00Z" @pytest.mark.asyncio async def test_task_synced_bad_payload_raises(self): with tempfile.TemporaryDirectory() as tmp: with pytest.raises(ValidationError): await _handle("signals.task.synced", {"userId": "u1"}, Path(tmp)) @pytest.mark.asyncio async def test_tip_feedback_valid_does_not_raise(self): with tempfile.TemporaryDirectory() as tmp: # should log and return cleanly await _handle("signals.tip.feedback", TIP_FEEDBACK, Path(tmp)) @pytest.mark.asyncio async def test_tip_feedback_bad_action_raises(self): bad = {**TIP_FEEDBACK, "action": "unknown"} with tempfile.TemporaryDirectory() as tmp: with pytest.raises(ValidationError): await _handle("signals.tip.feedback", bad, Path(tmp)) @pytest.mark.asyncio async def test_unhandled_subject_is_ignored(self): with tempfile.TemporaryDirectory() as tmp: # should not raise for unknown subjects await _handle("signals.something.new", {"any": "data"}, Path(tmp)) @pytest.mark.asyncio async def test_make_handler_acks_on_success(self): from nats_consumer import _make_handler with tempfile.TemporaryDirectory() as tmp: handler = _make_handler("signals", Path(tmp)) msg = AsyncMock() msg.subject = "signals.task.synced" msg.data = json.dumps(TASK_SYNCED).encode() await handler(msg) msg.ack.assert_awaited_once() msg.nak.assert_not_awaited() @pytest.mark.asyncio async def test_make_handler_naks_on_validation_error(self): from nats_consumer import _make_handler with tempfile.TemporaryDirectory() as tmp: handler = _make_handler("signals", Path(tmp)) msg = AsyncMock() msg.subject = "signals.task.synced" msg.data = json.dumps({"userId": "u1"}).encode() # missing fields await handler(msg) msg.nak.assert_awaited_once() msg.ack.assert_not_awaited()