- Add cost_guard and task_history optional parameters to Dispatcher.__init__ - Check daily cost limit before dequeuing tasks - Record usage with CostGuard after successful task completion - Record task history (success and failure) with TaskHistory - Maintain backward compatibility (cost_guard=None, task_history=None) - Add tests for cost recording and daily limit blocking
217 lines
7.7 KiB
Python
217 lines
7.7 KiB
Python
"""백그라운드 작업 디스패처.
|
|
|
|
TaskQueue를 폴링하여 에이전트를 실행한다.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from agent.task_queue import PersistentTaskQueue
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Dispatcher:
|
|
"""백그라운드 작업 소비자."""
|
|
|
|
def __init__(
|
|
self,
|
|
task_queue: PersistentTaskQueue,
|
|
poll_interval: float = 2.0,
|
|
cost_guard: "CostGuard | None" = None,
|
|
task_history: "TaskHistory | None" = None,
|
|
):
|
|
self._queue = task_queue
|
|
self._poll_interval = poll_interval
|
|
self._cost_guard = cost_guard
|
|
self._task_history = task_history
|
|
self._running = False
|
|
self._task: asyncio.Task | None = None
|
|
|
|
async def start(self) -> None:
|
|
"""백그라운드 폴링 루프를 시작한다."""
|
|
self._running = True
|
|
self._task = asyncio.create_task(self._poll_loop())
|
|
logger.info("Dispatcher started (poll_interval=%.1fs)", self._poll_interval)
|
|
|
|
async def stop(self) -> None:
|
|
"""폴링 루프를 중지한다."""
|
|
self._running = False
|
|
if self._task:
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Dispatcher stopped")
|
|
|
|
async def _poll_loop(self) -> None:
|
|
"""주기적으로 큐를 폴링한다."""
|
|
while self._running:
|
|
try:
|
|
await self._poll_once()
|
|
except Exception:
|
|
logger.exception("Dispatcher poll error")
|
|
await asyncio.sleep(self._poll_interval)
|
|
|
|
async def _poll_once(self) -> None:
|
|
"""큐에서 작업을 하나 꺼내 처리한다."""
|
|
# Check daily limit before dequeuing
|
|
if self._cost_guard:
|
|
if not await self._cost_guard.check_daily_limit():
|
|
logger.warning("Daily cost limit exceeded, skipping task dequeue")
|
|
return
|
|
|
|
task = await self._queue.dequeue()
|
|
if not task:
|
|
return
|
|
|
|
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
try:
|
|
result = await self._run_agent_for_task(task)
|
|
await self._queue.mark_completed(task["id"], result=result)
|
|
logger.info("Task %s completed successfully", task["id"])
|
|
|
|
# Record cost and history after successful completion
|
|
tokens_input = result.get("tokens_input", 0)
|
|
tokens_output = result.get("tokens_output", 0)
|
|
|
|
if self._cost_guard:
|
|
await self._cost_guard.record_usage(
|
|
task["id"],
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
)
|
|
|
|
if self._task_history:
|
|
end_time = datetime.now(timezone.utc)
|
|
duration_seconds = (end_time - start_time).total_seconds()
|
|
cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0
|
|
payload = task["payload"]
|
|
|
|
await self._task_history.record(
|
|
task_id=task["id"],
|
|
thread_id=task["thread_id"],
|
|
issue_number=payload.get("issue_number", 0),
|
|
repo_name=payload.get("repo_name", ""),
|
|
source=task["source"],
|
|
status="completed",
|
|
created_at=task["created_at"],
|
|
completed_at=end_time.isoformat(),
|
|
duration_seconds=duration_seconds,
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
cost_usd=cost_usd,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("Task %s failed", task["id"])
|
|
await self._queue.mark_failed(task["id"], error=str(e))
|
|
await self._notify_failure(task, str(e))
|
|
|
|
# Record history after failure
|
|
if self._task_history:
|
|
end_time = datetime.now(timezone.utc)
|
|
duration_seconds = (end_time - start_time).total_seconds()
|
|
payload = task["payload"]
|
|
|
|
await self._task_history.record(
|
|
task_id=task["id"],
|
|
thread_id=task["thread_id"],
|
|
issue_number=payload.get("issue_number", 0),
|
|
repo_name=payload.get("repo_name", ""),
|
|
source=task["source"],
|
|
status="failed",
|
|
created_at=task["created_at"],
|
|
completed_at=end_time.isoformat(),
|
|
duration_seconds=duration_seconds,
|
|
tokens_input=0,
|
|
tokens_output=0,
|
|
cost_usd=0.0,
|
|
error_message=str(e),
|
|
)
|
|
|
|
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
|
"""작업에 대해 에이전트를 실행한다."""
|
|
from agent.server import get_agent
|
|
|
|
payload = task["payload"]
|
|
thread_id = task["thread_id"]
|
|
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": thread_id,
|
|
"__is_for_execution__": True,
|
|
"repo": {
|
|
"owner": payload.get("repo_owner", os.environ.get("DEFAULT_REPO_OWNER", "quant")),
|
|
"name": payload.get("repo_name", os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")),
|
|
},
|
|
},
|
|
"metadata": {},
|
|
}
|
|
|
|
await self._notify_start(task)
|
|
|
|
agent = await get_agent(config)
|
|
|
|
issue_number = payload.get("issue_number", 0)
|
|
message = payload.get("message", "")
|
|
title = payload.get("title", "")
|
|
|
|
if issue_number:
|
|
input_text = f"이슈 #{issue_number}: {title}\n\n{message}"
|
|
else:
|
|
input_text = message
|
|
|
|
result = await agent.ainvoke(
|
|
{"messages": [{"role": "human", "content": input_text}]},
|
|
config=config,
|
|
)
|
|
|
|
return {"status": "completed", "messages_count": len(result.get("messages", []))}
|
|
|
|
async def _notify_start(self, task: dict) -> None:
|
|
"""작업 시작 알림을 전송한다."""
|
|
payload = task["payload"]
|
|
issue_number = payload.get("issue_number", 0)
|
|
source = task["source"]
|
|
|
|
if source == "gitea" and issue_number:
|
|
try:
|
|
from agent.tools.gitea_comment import gitea_comment
|
|
|
|
await asyncio.to_thread(
|
|
gitea_comment,
|
|
message=f"작업을 시작합니다: {payload.get('title', '')}",
|
|
issue_number=issue_number,
|
|
)
|
|
except Exception:
|
|
logger.debug("Failed to post start comment to Gitea")
|
|
|
|
if source == "discord":
|
|
try:
|
|
from agent.tools.discord_reply import discord_reply
|
|
|
|
await asyncio.to_thread(
|
|
discord_reply,
|
|
message=f"작업을 시작합니다: {payload.get('message', '')[:100]}",
|
|
)
|
|
except Exception:
|
|
logger.debug("Failed to send start message to Discord")
|
|
|
|
async def _notify_failure(self, task: dict, error: str) -> None:
|
|
"""작업 실패 알림을 전송한다."""
|
|
try:
|
|
from agent.tools.discord_reply import discord_reply
|
|
|
|
await asyncio.to_thread(discord_reply, message=f"작업 실패: {error[:200]}")
|
|
except Exception:
|
|
logger.debug("Failed to send failure notification")
|