"""Unit tests for vram_manager.py — VRAMManager flush/poll/prewarm logic.""" import asyncio import pytest from unittest.mock import AsyncMock, MagicMock, patch from vram_manager import VRAMManager BASE_URL = "http://localhost:11434" def _make_manager() -> VRAMManager: return VRAMManager(base_url=BASE_URL) def _mock_client(get_response=None, post_response=None): """Return a context-manager mock for httpx.AsyncClient.""" client = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=False) if get_response is not None: client.get = AsyncMock(return_value=get_response) if post_response is not None: client.post = AsyncMock(return_value=post_response) return client # ── _flush ───────────────────────────────────────────────────────────────────── class TestFlush: async def test_sends_keep_alive_zero(self): client = _mock_client(post_response=MagicMock()) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() await mgr._flush("qwen3:4b") client.post.assert_awaited_once() _, kwargs = client.post.await_args body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1] assert body["model"] == "qwen3:4b" assert body["keep_alive"] == 0 async def test_posts_to_correct_endpoint(self): client = _mock_client(post_response=MagicMock()) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() await mgr._flush("qwen3:8b") url = client.post.call_args[0][0] assert url == f"{BASE_URL}/api/generate" async def test_ignores_exceptions_silently(self): client = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=False) client.post = AsyncMock(side_effect=Exception("connection refused")) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() # Should not raise await mgr._flush("qwen3:4b") # ── _prewarm ─────────────────────────────────────────────────────────────────── class TestPrewarm: async def test_sends_keep_alive_300(self): client = _mock_client(post_response=MagicMock()) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() await mgr._prewarm("qwen3:4b") _, kwargs = client.post.await_args body = kwargs.get("json") or client.post.call_args[1].get("json") or client.post.call_args[0][1] assert body["keep_alive"] == 300 assert body["model"] == "qwen3:4b" async def test_ignores_exceptions_silently(self): client = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=False) client.post = AsyncMock(side_effect=Exception("timeout")) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() await mgr._prewarm("qwen3:4b") # ── _poll_evicted ────────────────────────────────────────────────────────────── class TestPollEvicted: async def test_returns_true_when_models_absent(self): resp = MagicMock() resp.json.return_value = {"models": [{"name": "some_other_model"}]} client = _mock_client(get_response=resp) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() result = await mgr._poll_evicted(["qwen3:4b", "qwen2.5:1.5b"], timeout=5) assert result is True async def test_returns_false_on_timeout_when_model_still_loaded(self): resp = MagicMock() resp.json.return_value = {"models": [{"name": "qwen3:4b"}]} client = _mock_client(get_response=resp) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.1) assert result is False async def test_returns_true_immediately_if_already_empty(self): resp = MagicMock() resp.json.return_value = {"models": []} client = _mock_client(get_response=resp) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() result = await mgr._poll_evicted(["qwen3:4b"], timeout=5) assert result is True async def test_handles_poll_error_and_continues(self): """If /api/ps errors, polling continues until timeout.""" client = AsyncMock() client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=False) client.get = AsyncMock(side_effect=Exception("network error")) with patch("vram_manager.httpx.AsyncClient", return_value=client): mgr = _make_manager() result = await mgr._poll_evicted(["qwen3:4b"], timeout=0.2) assert result is False # ── enter_complex_mode / exit_complex_mode ───────────────────────────────────── class TestComplexMode: async def test_enter_complex_mode_returns_true_on_success(self): mgr = _make_manager() mgr._flush = AsyncMock() mgr._poll_evicted = AsyncMock(return_value=True) result = await mgr.enter_complex_mode() assert result is True async def test_enter_complex_mode_flushes_medium_models(self): mgr = _make_manager() mgr._flush = AsyncMock() mgr._poll_evicted = AsyncMock(return_value=True) await mgr.enter_complex_mode() flushed = {call.args[0] for call in mgr._flush.call_args_list} assert "qwen3:4b" in flushed assert "qwen2.5:1.5b" in flushed async def test_enter_complex_mode_returns_false_on_eviction_timeout(self): mgr = _make_manager() mgr._flush = AsyncMock() mgr._poll_evicted = AsyncMock(return_value=False) result = await mgr.enter_complex_mode() assert result is False async def test_exit_complex_mode_flushes_complex_and_prewarms_medium(self): mgr = _make_manager() mgr._flush = AsyncMock() mgr._prewarm = AsyncMock() await mgr.exit_complex_mode() # Must flush 8b flushed = {call.args[0] for call in mgr._flush.call_args_list} assert "qwen3:8b" in flushed # Must prewarm medium models prewarmed = {call.args[0] for call in mgr._prewarm.call_args_list} assert "qwen3:4b" in prewarmed assert "qwen2.5:1.5b" in prewarmed