import pytest import os import tempfile from unittest.mock import AsyncMock, patch from agent.task_queue import PersistentTaskQueue from agent.cost_guard import CostGuard from agent.task_history import TaskHistory from agent.dispatcher import Dispatcher @pytest.fixture async def resources(): paths = [] for _ in range(3): fd, p = tempfile.mkstemp(suffix=".db") os.close(fd) paths.append(p) queue = PersistentTaskQueue(db_path=paths[0]) await queue.initialize() guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0) await guard.initialize() history = TaskHistory(db_path=paths[2]) await history.initialize() yield queue, guard, history await queue.close() await guard.close() await history.close() for p in paths: os.unlink(p) @pytest.mark.asyncio async def test_dispatcher_records_cost(resources): queue, guard, history = resources await queue.enqueue("thread-1", "gitea", { "issue_number": 42, "repo_owner": "quant", "repo_name": "galaxis-po", "message": "Fix", }) mock_run = AsyncMock(return_value={ "status": "completed", "tokens_input": 5000, "tokens_output": 2000, }) dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history) dispatcher._run_agent_for_task = mock_run await dispatcher._poll_once() daily = await guard.get_daily_cost() assert daily > 0 records = await history.get_recent() assert len(records) == 1 assert records[0]["status"] == "completed" @pytest.mark.asyncio async def test_dispatcher_blocks_when_daily_limit_exceeded(resources): queue, guard, history = resources for i in range(5): await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000) await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"}) mock_run = AsyncMock() dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history) dispatcher._run_agent_for_task = mock_run await dispatcher._poll_once() mock_run.assert_not_called() pending = await queue.get_pending() assert len(pending) == 1