"""백그라운드 작업 디스패처. TaskQueue를 폴링하여 에이전트를 실행한다. """ from __future__ import annotations import asyncio import logging import os from datetime import datetime, timezone from typing import Any from agent.task_queue import PersistentTaskQueue logger = logging.getLogger(__name__) class Dispatcher: """백그라운드 작업 소비자.""" def __init__( self, task_queue: PersistentTaskQueue, poll_interval: float = 2.0, cost_guard: "CostGuard | None" = None, task_history: "TaskHistory | None" = None, ): self._queue = task_queue self._poll_interval = poll_interval self._cost_guard = cost_guard self._task_history = task_history self._running = False self._task: asyncio.Task | None = None async def start(self) -> None: """백그라운드 폴링 루프를 시작한다.""" self._running = True self._task = asyncio.create_task(self._poll_loop()) logger.info("Dispatcher started (poll_interval=%.1fs)", self._poll_interval) async def stop(self) -> None: """폴링 루프를 중지한다.""" self._running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass logger.info("Dispatcher stopped") async def _poll_loop(self) -> None: """주기적으로 큐를 폴링한다.""" while self._running: try: await self._poll_once() except Exception: logger.exception("Dispatcher poll error") await asyncio.sleep(self._poll_interval) async def _poll_once(self) -> None: """큐에서 작업을 하나 꺼내 처리한다.""" # Check daily limit before dequeuing if self._cost_guard: if not await self._cost_guard.check_daily_limit(): logger.warning("Daily cost limit exceeded, skipping task dequeue") return task = await self._queue.dequeue() if not task: return logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"]) start_time = datetime.now(timezone.utc) try: result = await self._run_agent_for_task(task) await self._queue.mark_completed(task["id"], result=result) logger.info("Task %s completed successfully", task["id"]) # Record cost and history after successful completion tokens_input = result.get("tokens_input", 0) tokens_output = result.get("tokens_output", 0) if self._cost_guard: await self._cost_guard.record_usage( task["id"], tokens_input=tokens_input, tokens_output=tokens_output, ) if self._task_history: end_time = datetime.now(timezone.utc) duration_seconds = (end_time - start_time).total_seconds() cost_usd = self._cost_guard.calculate_cost(tokens_input, tokens_output) if self._cost_guard else 0.0 payload = task["payload"] await self._task_history.record( task_id=task["id"], thread_id=task["thread_id"], issue_number=payload.get("issue_number", 0), repo_name=payload.get("repo_name", ""), source=task["source"], status="completed", created_at=task["created_at"], completed_at=end_time.isoformat(), duration_seconds=duration_seconds, tokens_input=tokens_input, tokens_output=tokens_output, cost_usd=cost_usd, ) except Exception as e: logger.exception("Task %s failed", task["id"]) await self._queue.mark_failed(task["id"], error=str(e)) await self._notify_failure(task, str(e)) # Record history after failure if self._task_history: end_time = datetime.now(timezone.utc) duration_seconds = (end_time - start_time).total_seconds() payload = task["payload"] await self._task_history.record( task_id=task["id"], thread_id=task["thread_id"], issue_number=payload.get("issue_number", 0), repo_name=payload.get("repo_name", ""), source=task["source"], status="failed", created_at=task["created_at"], completed_at=end_time.isoformat(), duration_seconds=duration_seconds, tokens_input=0, tokens_output=0, cost_usd=0.0, error_message=str(e), ) async def _run_agent_for_task(self, task: dict) -> dict[str, Any]: """작업에 대해 에이전트를 실행한다.""" from agent.server import get_agent payload = task["payload"] thread_id = task["thread_id"] config = { "configurable": { "thread_id": thread_id, "__is_for_execution__": True, "repo": { "owner": payload.get("repo_owner", os.environ.get("DEFAULT_REPO_OWNER", "quant")), "name": payload.get("repo_name", os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")), }, }, "metadata": {}, } await self._notify_start(task) agent = await get_agent(config) issue_number = payload.get("issue_number", 0) message = payload.get("message", "") title = payload.get("title", "") if issue_number: input_text = f"이슈 #{issue_number}: {title}\n\n{message}" else: input_text = message result = await agent.ainvoke( {"messages": [{"role": "human", "content": input_text}]}, config=config, ) return {"status": "completed", "messages_count": len(result.get("messages", []))} async def _notify_start(self, task: dict) -> None: """작업 시작 알림을 전송한다.""" payload = task["payload"] issue_number = payload.get("issue_number", 0) source = task["source"] if source == "gitea" and issue_number: try: from agent.tools.gitea_comment import gitea_comment await asyncio.to_thread( gitea_comment, message=f"작업을 시작합니다: {payload.get('title', '')}", issue_number=issue_number, ) except Exception: logger.debug("Failed to post start comment to Gitea") if source == "discord": try: from agent.tools.discord_reply import discord_reply await asyncio.to_thread( discord_reply, message=f"작업을 시작합니다: {payload.get('message', '')[:100]}", ) except Exception: logger.debug("Failed to send start message to Discord") async def _notify_failure(self, task: dict, error: str) -> None: """작업 실패 알림을 전송한다.""" try: from agent.tools.discord_reply import discord_reply await asyncio.to_thread(discord_reply, message=f"작업 실패: {error[:200]}") except Exception: logger.debug("Failed to send failure notification")