galaxis-agent/tests/test_dispatcher_cost.py
머니페니 c0cb4b7499 feat: integrate CostGuard and TaskHistory into Dispatcher
- Add cost_guard and task_history optional parameters to Dispatcher.__init__
- Check daily cost limit before dequeuing tasks
- Record usage with CostGuard after successful task completion
- Record task history (success and failure) with TaskHistory
- Maintain backward compatibility (cost_guard=None, task_history=None)
- Add tests for cost recording and daily limit blocking
2026-03-20 18:44:22 +09:00

81 lines
2.2 KiB
Python

import pytest
import os
import tempfile
from unittest.mock import AsyncMock, patch
from agent.task_queue import PersistentTaskQueue
from agent.cost_guard import CostGuard
from agent.task_history import TaskHistory
from agent.dispatcher import Dispatcher
@pytest.fixture
async def resources():
paths = []
for _ in range(3):
fd, p = tempfile.mkstemp(suffix=".db")
os.close(fd)
paths.append(p)
queue = PersistentTaskQueue(db_path=paths[0])
await queue.initialize()
guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0)
await guard.initialize()
history = TaskHistory(db_path=paths[2])
await history.initialize()
yield queue, guard, history
await queue.close()
await guard.close()
await history.close()
for p in paths:
os.unlink(p)
@pytest.mark.asyncio
async def test_dispatcher_records_cost(resources):
queue, guard, history = resources
await queue.enqueue("thread-1", "gitea", {
"issue_number": 42, "repo_owner": "quant",
"repo_name": "galaxis-po", "message": "Fix",
})
mock_run = AsyncMock(return_value={
"status": "completed", "tokens_input": 5000, "tokens_output": 2000,
})
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
dispatcher._run_agent_for_task = mock_run
await dispatcher._poll_once()
daily = await guard.get_daily_cost()
assert daily > 0
records = await history.get_recent()
assert len(records) == 1
assert records[0]["status"] == "completed"
@pytest.mark.asyncio
async def test_dispatcher_blocks_when_daily_limit_exceeded(resources):
queue, guard, history = resources
for i in range(5):
await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000)
await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"})
mock_run = AsyncMock()
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
dispatcher._run_agent_for_task = mock_run
await dispatcher._poll_once()
mock_run.assert_not_called()
pending = await queue.get_pending()
assert len(pending) == 1