feat: integrate CostGuard and TaskHistory into Dispatcher
- Add cost_guard and task_history optional parameters to Dispatcher.__init__ - Check daily cost limit before dequeuing tasks - Record usage with CostGuard after successful task completion - Record task history (success and failure) with TaskHistory - Maintain backward compatibility (cost_guard=None, task_history=None) - Add tests for cost recording and daily limit blocking
This commit is contained in:
parent
e82dfe18f9
commit
c0cb4b7499
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agent.task_queue import PersistentTaskQueue
|
from agent.task_queue import PersistentTaskQueue
|
||||||
@ -22,9 +23,13 @@ class Dispatcher:
|
|||||||
self,
|
self,
|
||||||
task_queue: PersistentTaskQueue,
|
task_queue: PersistentTaskQueue,
|
||||||
poll_interval: float = 2.0,
|
poll_interval: float = 2.0,
|
||||||
|
cost_guard: "CostGuard | None" = None,
|
||||||
|
task_history: "TaskHistory | None" = None,
|
||||||
):
|
):
|
||||||
self._queue = task_queue
|
self._queue = task_queue
|
||||||
self._poll_interval = poll_interval
|
self._poll_interval = poll_interval
|
||||||
|
self._cost_guard = cost_guard
|
||||||
|
self._task_history = task_history
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@ -56,21 +61,83 @@ class Dispatcher:
|
|||||||
|
|
||||||
async def _poll_once(self) -> None:
|
async def _poll_once(self) -> None:
|
||||||
"""큐에서 작업을 하나 꺼내 처리한다."""
|
"""큐에서 작업을 하나 꺼내 처리한다."""
|
||||||
|
# Check daily limit before dequeuing
|
||||||
|
if self._cost_guard:
|
||||||
|
if not await self._cost_guard.check_daily_limit():
|
||||||
|
logger.warning("Daily cost limit exceeded, skipping task dequeue")
|
||||||
|
return
|
||||||
|
|
||||||
task = await self._queue.dequeue()
|
task = await self._queue.dequeue()
|
||||||
if not task:
|
if not task:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
||||||
|
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
try:
|
try:
|
||||||
result = await self._run_agent_for_task(task)
|
result = await self._run_agent_for_task(task)
|
||||||
await self._queue.mark_completed(task["id"], result=result)
|
await self._queue.mark_completed(task["id"], result=result)
|
||||||
logger.info("Task %s completed successfully", task["id"])
|
logger.info("Task %s completed successfully", task["id"])
|
||||||
|
|
||||||
|
# Record cost and history after successful completion
|
||||||
|
tokens_input = result.get("tokens_input", 0)
|
||||||
|
tokens_output = result.get("tokens_output", 0)
|
||||||
|
|
||||||
|
if self._cost_guard:
|
||||||
|
await self._cost_guard.record_usage(
|
||||||
|
task["id"],
|
||||||
|
tokens_input=tokens_input,
|
||||||
|
tokens_output=tokens_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._task_history:
|
||||||
|
end_time = datetime.now(timezone.utc)
|
||||||
|
duration_seconds = (end_time - start_time).total_seconds()
|
||||||
|
cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0
|
||||||
|
payload = task["payload"]
|
||||||
|
|
||||||
|
await self._task_history.record(
|
||||||
|
task_id=task["id"],
|
||||||
|
thread_id=task["thread_id"],
|
||||||
|
issue_number=payload.get("issue_number", 0),
|
||||||
|
repo_name=payload.get("repo_name", ""),
|
||||||
|
source=task["source"],
|
||||||
|
status="completed",
|
||||||
|
created_at=task["created_at"],
|
||||||
|
completed_at=end_time.isoformat(),
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_input=tokens_input,
|
||||||
|
tokens_output=tokens_output,
|
||||||
|
cost_usd=cost_usd,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Task %s failed", task["id"])
|
logger.exception("Task %s failed", task["id"])
|
||||||
await self._queue.mark_failed(task["id"], error=str(e))
|
await self._queue.mark_failed(task["id"], error=str(e))
|
||||||
await self._notify_failure(task, str(e))
|
await self._notify_failure(task, str(e))
|
||||||
|
|
||||||
|
# Record history after failure
|
||||||
|
if self._task_history:
|
||||||
|
end_time = datetime.now(timezone.utc)
|
||||||
|
duration_seconds = (end_time - start_time).total_seconds()
|
||||||
|
payload = task["payload"]
|
||||||
|
|
||||||
|
await self._task_history.record(
|
||||||
|
task_id=task["id"],
|
||||||
|
thread_id=task["thread_id"],
|
||||||
|
issue_number=payload.get("issue_number", 0),
|
||||||
|
repo_name=payload.get("repo_name", ""),
|
||||||
|
source=task["source"],
|
||||||
|
status="failed",
|
||||||
|
created_at=task["created_at"],
|
||||||
|
completed_at=end_time.isoformat(),
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_input=0,
|
||||||
|
tokens_output=0,
|
||||||
|
cost_usd=0.0,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
||||||
"""작업에 대해 에이전트를 실행한다."""
|
"""작업에 대해 에이전트를 실행한다."""
|
||||||
from agent.server import get_agent
|
from agent.server import get_agent
|
||||||
|
|||||||
80
tests/test_dispatcher_cost.py
Normal file
80
tests/test_dispatcher_cost.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from agent.task_queue import PersistentTaskQueue
|
||||||
|
from agent.cost_guard import CostGuard
|
||||||
|
from agent.task_history import TaskHistory
|
||||||
|
from agent.dispatcher import Dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def resources():
|
||||||
|
paths = []
|
||||||
|
for _ in range(3):
|
||||||
|
fd, p = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
paths.append(p)
|
||||||
|
|
||||||
|
queue = PersistentTaskQueue(db_path=paths[0])
|
||||||
|
await queue.initialize()
|
||||||
|
guard = CostGuard(db_path=paths[1], daily_limit=10.0, per_task_limit=3.0)
|
||||||
|
await guard.initialize()
|
||||||
|
history = TaskHistory(db_path=paths[2])
|
||||||
|
await history.initialize()
|
||||||
|
|
||||||
|
yield queue, guard, history
|
||||||
|
|
||||||
|
await queue.close()
|
||||||
|
await guard.close()
|
||||||
|
await history.close()
|
||||||
|
for p in paths:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatcher_records_cost(resources):
|
||||||
|
queue, guard, history = resources
|
||||||
|
|
||||||
|
await queue.enqueue("thread-1", "gitea", {
|
||||||
|
"issue_number": 42, "repo_owner": "quant",
|
||||||
|
"repo_name": "galaxis-po", "message": "Fix",
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_run = AsyncMock(return_value={
|
||||||
|
"status": "completed", "tokens_input": 5000, "tokens_output": 2000,
|
||||||
|
})
|
||||||
|
|
||||||
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
||||||
|
dispatcher._run_agent_for_task = mock_run
|
||||||
|
|
||||||
|
await dispatcher._poll_once()
|
||||||
|
|
||||||
|
daily = await guard.get_daily_cost()
|
||||||
|
assert daily > 0
|
||||||
|
|
||||||
|
records = await history.get_recent()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatcher_blocks_when_daily_limit_exceeded(resources):
|
||||||
|
queue, guard, history = resources
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
await guard.record_usage(f"prev-{i}", tokens_input=1_000_000, tokens_output=200_000)
|
||||||
|
|
||||||
|
await queue.enqueue("thread-1", "gitea", {"message": "Should be blocked"})
|
||||||
|
|
||||||
|
mock_run = AsyncMock()
|
||||||
|
dispatcher = Dispatcher(task_queue=queue, cost_guard=guard, task_history=history)
|
||||||
|
dispatcher._run_agent_for_task = mock_run
|
||||||
|
|
||||||
|
await dispatcher._poll_once()
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
pending = await queue.get_pending()
|
||||||
|
assert len(pending) == 1
|
||||||
Loading…
x
Reference in New Issue
Block a user