feat: integrate CostGuard and TaskHistory into Dispatcher

- Add cost_guard and task_history optional parameters to Dispatcher.__init__
- Check daily cost limit before dequeuing tasks
- Record usage with CostGuard after successful task completion
- Record task history (success and failure) with TaskHistory
- Maintain backward compatibility (cost_guard=None, task_history=None)
- Add tests for cost recording and daily limit blocking
This commit is contained in:
머니페니 2026-03-20 18:44:22 +09:00
parent e82dfe18f9
commit c0cb4b7499
2 changed files with 147 additions and 0 deletions

View File

@ -8,6 +8,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
from datetime import datetime, timezone
from typing import Any from typing import Any
from agent.task_queue import PersistentTaskQueue from agent.task_queue import PersistentTaskQueue
@ -22,9 +23,13 @@ class Dispatcher:
self, self,
task_queue: PersistentTaskQueue, task_queue: PersistentTaskQueue,
poll_interval: float = 2.0, poll_interval: float = 2.0,
cost_guard: "CostGuard | None" = None,
task_history: "TaskHistory | None" = None,
): ):
self._queue = task_queue self._queue = task_queue
self._poll_interval = poll_interval self._poll_interval = poll_interval
self._cost_guard = cost_guard
self._task_history = task_history
self._running = False self._running = False
self._task: asyncio.Task | None = None self._task: asyncio.Task | None = None
@ -56,21 +61,83 @@ class Dispatcher:
async def _poll_once(self) -> None: async def _poll_once(self) -> None:
"""큐에서 작업을 하나 꺼내 처리한다.""" """큐에서 작업을 하나 꺼내 처리한다."""
# Check daily limit before dequeuing
if self._cost_guard:
if not await self._cost_guard.check_daily_limit():
logger.warning("Daily cost limit exceeded, skipping task dequeue")
return
task = await self._queue.dequeue() task = await self._queue.dequeue()
if not task: if not task:
return return
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"]) logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
start_time = datetime.now(timezone.utc)
try: try:
result = await self._run_agent_for_task(task) result = await self._run_agent_for_task(task)
await self._queue.mark_completed(task["id"], result=result) await self._queue.mark_completed(task["id"], result=result)
logger.info("Task %s completed successfully", task["id"]) logger.info("Task %s completed successfully", task["id"])
# Record cost and history after successful completion
tokens_input = result.get("tokens_input", 0)
tokens_output = result.get("tokens_output", 0)
if self._cost_guard:
await self._cost_guard.record_usage(
task["id"],
tokens_input=tokens_input,
tokens_output=tokens_output,
)
if self._task_history:
end_time = datetime.now(timezone.utc)
duration_seconds = (end_time - start_time).total_seconds()
cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0
payload = task["payload"]
await self._task_history.record(
task_id=task["id"],
thread_id=task["thread_id"],
issue_number=payload.get("issue_number", 0),
repo_name=payload.get("repo_name", ""),
source=task["source"],
status="completed",
created_at=task["created_at"],
completed_at=end_time.isoformat(),
duration_seconds=duration_seconds,
tokens_input=tokens_input,
tokens_output=tokens_output,
cost_usd=cost_usd,
)
except Exception as e: except Exception as e:
logger.exception("Task %s failed", task["id"]) logger.exception("Task %s failed", task["id"])
await self._queue.mark_failed(task["id"], error=str(e)) await self._queue.mark_failed(task["id"], error=str(e))
await self._notify_failure(task, str(e)) await self._notify_failure(task, str(e))
# Record history after failure
if self._task_history:
end_time = datetime.now(timezone.utc)
duration_seconds = (end_time - start_time).total_seconds()
payload = task["payload"]
await self._task_history.record(
task_id=task["id"],
thread_id=task["thread_id"],
issue_number=payload.get("issue_number", 0),
repo_name=payload.get("repo_name", ""),
source=task["source"],
status="failed",
created_at=task["created_at"],
completed_at=end_time.isoformat(),
duration_seconds=duration_seconds,
tokens_input=0,
tokens_output=0,
cost_usd=0.0,
error_message=str(e),
)
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]: async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
"""작업에 대해 에이전트를 실행한다.""" """작업에 대해 에이전트를 실행한다."""
from agent.server import get_agent from agent.server import get_agent

View File

@ -0,0 +1,80 @@
import pytest
import os
import tempfile
from unittest.mock import AsyncMock, patch
from agent.task_queue import PersistentTaskQueue
from agent.cost_guard import CostGuard
from agent.task_history import TaskHistory
from agent.dispatcher import Dispatcher
@pytest.fixture
async def resources():
paths = []
for _ in range(3):
fd, p = tempfile.mkstemp(suffix=".db")
os.close(fd)
paths.append(p)
queue = PersistentTaskQueue(db_path=paths[0])
await queue.initialize()
guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0)
await guard.initialize()
history = TaskHistory(db_path=paths[2])
await history.initialize()
yield queue, guard, history
await queue.close()
await guard.close()
await history.close()
for p in paths:
os.unlink(p)
@pytest.mark.asyncio
async def test_dispatcher_records_cost(resources):
queue, guard, history = resources
await queue.enqueue("thread-1", "gitea", {
"issue_number": 42, "repo_owner": "quant",
"repo_name": "galaxis-po", "message": "Fix",
})
mock_run = AsyncMock(return_value={
"status": "completed", "tokens_input": 5000, "tokens_output": 2000,
})
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
dispatcher._run_agent_for_task = mock_run
await dispatcher._poll_once()
daily = await guard.get_daily_cost()
assert daily > 0
records = await history.get_recent()
assert len(records) == 1
assert records[0]["status"] == "completed"
@pytest.mark.asyncio
async def test_dispatcher_blocks_when_daily_limit_exceeded(resources):
queue, guard, history = resources
for i in range(5):
await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000)
await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"})
mock_run = AsyncMock()
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
dispatcher._run_agent_for_task = mock_run
await dispatcher._poll_once()
mock_run.assert_not_called()
pending = await queue.get_pending()
assert len(pending) == 1