- Add cost_guard and task_history optional parameters to Dispatcher.__init__ - Check daily cost limit before dequeuing tasks - Record usage with CostGuard after successful task completion - Record task history (success and failure) with TaskHistory - Maintain backward compatibility (cost_guard=None, task_history=None) - Add tests for cost recording and daily limit blocking
81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
import pytest
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from agent.task_queue import PersistentTaskQueue
|
|
from agent.cost_guard import CostGuard
|
|
from agent.task_history import TaskHistory
|
|
from agent.dispatcher import Dispatcher
|
|
|
|
|
|
@pytest.fixture
|
|
async def resources():
|
|
paths = []
|
|
for _ in range(3):
|
|
fd, p = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
paths.append(p)
|
|
|
|
queue = PersistentTaskQueue(db_path=paths[0])
|
|
await queue.initialize()
|
|
guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0)
|
|
await guard.initialize()
|
|
history = TaskHistory(db_path=paths[2])
|
|
await history.initialize()
|
|
|
|
yield queue, guard, history
|
|
|
|
await queue.close()
|
|
await guard.close()
|
|
await history.close()
|
|
for p in paths:
|
|
os.unlink(p)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatcher_records_cost(resources):
|
|
queue, guard, history = resources
|
|
|
|
await queue.enqueue("thread-1", "gitea", {
|
|
"issue_number": 42, "repo_owner": "quant",
|
|
"repo_name": "galaxis-po", "message": "Fix",
|
|
})
|
|
|
|
mock_run = AsyncMock(return_value={
|
|
"status": "completed", "tokens_input": 5000, "tokens_output": 2000,
|
|
})
|
|
|
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
|
dispatcher._run_agent_for_task = mock_run
|
|
|
|
await dispatcher._poll_once()
|
|
|
|
daily = await guard.get_daily_cost()
|
|
assert daily > 0
|
|
|
|
records = await history.get_recent()
|
|
assert len(records) == 1
|
|
assert records[0]["status"] == "completed"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatcher_blocks_when_daily_limit_exceeded(resources):
|
|
queue, guard, history = resources
|
|
|
|
for i in range(5):
|
|
await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000)
|
|
|
|
await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"})
|
|
|
|
mock_run = AsyncMock()
|
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
|
dispatcher._run_agent_for_task = mock_run
|
|
|
|
await dispatcher._poll_once()
|
|
|
|
mock_run.assert_not_called()
|
|
|
|
pending = await queue.get_pending()
|
|
assert len(pending) == 1
|