From c0cb4b7499ff9a7e3b426b4d420b1860bc429c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=A8=B8=EB=8B=88=ED=8E=98=EB=8B=88?= Date: Fri, 20 Mar 2026 18:44:22 +0900 Subject: [PATCH] feat: integrate CostGuard and TaskHistory into Dispatcher - Add cost_guard and task_history optional parameters to Dispatcher.__init__ - Check daily cost limit before dequeuing tasks - Record usage with CostGuard after successful task completion - Record task history (success and failure) with TaskHistory - Maintain backward compatibility (cost_guard=None, task_history=None) - Add tests for cost recording and daily limit blocking --- agent/dispatcher.py | 67 +++++++++++++++++++++++++++++ tests/test_dispatcher_cost.py | 80 +++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) create mode 100644 tests/test_dispatcher_cost.py diff --git a/agent/dispatcher.py b/agent/dispatcher.py index df4b2e3..41a9f1e 100644 --- a/agent/dispatcher.py +++ b/agent/dispatcher.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio import logging import os +from datetime import datetime, timezone from typing import Any from agent.task_queue import PersistentTaskQueue @@ -22,9 +23,13 @@ class Dispatcher: self, task_queue: PersistentTaskQueue, poll_interval: float = 2.0, + cost_guard: "CostGuard | None" = None, + task_history: "TaskHistory | None" = None, ): self._queue = task_queue self._poll_interval = poll_interval + self._cost_guard = cost_guard + self._task_history = task_history self._running = False self._task: asyncio.Task | None = None @@ -56,21 +61,83 @@ class Dispatcher: async def _poll_once(self) -> None: """큐에서 작업을 하나 꺼내 처리한다.""" + # Check daily limit before dequeuing + if self._cost_guard: + if not await self._cost_guard.check_daily_limit(): + logger.warning("Daily cost limit exceeded, skipping task dequeue") + return + task = await self._queue.dequeue() if not task: return logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"]) + start_time = datetime.now(timezone.utc) try: result = await self._run_agent_for_task(task) await self._queue.mark_completed(task["id"], result=result) logger.info("Task %s completed successfully", task["id"]) + + # Record cost and history after successful completion + tokens_input = result.get("tokens_input", 0) + tokens_output = result.get("tokens_output", 0) + + if self._cost_guard: + await self._cost_guard.record_usage( + task["id"], + tokens_input=tokens_input, + tokens_output=tokens_output, + ) + + if self._task_history: + end_time = datetime.now(timezone.utc) + duration_seconds = (end_time - start_time).total_seconds() + cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0 + payload = task["payload"] + + await self._task_history.record( + task_id=task["id"], + thread_id=task["thread_id"], + issue_number=payload.get("issue_number", 0), + repo_name=payload.get("repo_name", ""), + source=task["source"], + status="completed", + created_at=task["created_at"], + completed_at=end_time.isoformat(), + duration_seconds=duration_seconds, + tokens_input=tokens_input, + tokens_output=tokens_output, + cost_usd=cost_usd, + ) + except Exception as e: logger.exception("Task %s failed", task["id"]) await self._queue.mark_failed(task["id"], error=str(e)) await self._notify_failure(task, str(e)) + # Record history after failure + if self._task_history: + end_time = datetime.now(timezone.utc) + duration_seconds = (end_time - start_time).total_seconds() + payload = task["payload"] + + await self._task_history.record( + task_id=task["id"], + thread_id=task["thread_id"], + issue_number=payload.get("issue_number", 0), + repo_name=payload.get("repo_name", ""), + source=task["source"], + status="failed", + created_at=task["created_at"], + completed_at=end_time.isoformat(), + duration_seconds=duration_seconds, + tokens_input=0, + tokens_output=0, + cost_usd=0.0, + error_message=str(e), + ) + async def _run_agent_for_task(self, task: dict) -> dict[str, Any]: """작업에 대해 에이전트를 실행한다.""" from agent.server import get_agent diff --git a/tests/test_dispatcher_cost.py b/tests/test_dispatcher_cost.py new file mode 100644 index 0000000..5b8bec6 --- /dev/null +++ b/tests/test_dispatcher_cost.py @@ -0,0 +1,80 @@ +import pytest +import os +import tempfile +from unittest.mock import AsyncMock, patch + +from agent.task_queue import PersistentTaskQueue +from agent.cost_guard import CostGuard +from agent.task_history import TaskHistory +from agent.dispatcher import Dispatcher + + +@pytest.fixture +async def resources(): + paths = [] + for _ in range(3): + fd, p = tempfile.mkstemp(suffix=".db") + os.close(fd) + paths.append(p) + + queue = PersistentTaskQueue(db_path=paths[0]) + await queue.initialize() + guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0) + await guard.initialize() + history = TaskHistory(db_path=paths[2]) + await history.initialize() + + yield queue, guard, history + + await queue.close() + await guard.close() + await history.close() + for p in paths: + os.unlink(p) + + +@pytest.mark.asyncio +async def test_dispatcher_records_cost(resources): + queue, guard, history = resources + + await queue.enqueue("thread-1", "gitea", { + "issue_number": 42, "repo_owner": "quant", + "repo_name": "galaxis-po", "message": "Fix", + }) + + mock_run = AsyncMock(return_value={ + "status": "completed", "tokens_input": 5000, "tokens_output": 2000, + }) + + dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history) + dispatcher._run_agent_for_task = mock_run + + await dispatcher._poll_once() + + daily = await guard.get_daily_cost() + assert daily > 0 + + records = await history.get_recent() + assert len(records) == 1 + assert records[0]["status"] == "completed" + + +@pytest.mark.asyncio +async def test_dispatcher_blocks_when_daily_limit_exceeded(resources): + queue, guard, history = resources + + for i in range(5): + await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000) + + await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"}) + + mock_run = AsyncMock() + dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history) + dispatcher._run_agent_for_task = mock_run + + await dispatcher._poll_once() + + mock_run.assert_not_called() + + pending = await queue.get_pending() + assert len(pending) == 1