diff --git a/agent/dispatcher.py b/agent/dispatcher.py new file mode 100644 index 0000000..df4b2e3 --- /dev/null +++ b/agent/dispatcher.py @@ -0,0 +1,149 @@ +"""백그라운드 작업 디스패처. + +TaskQueue를 폴링하여 에이전트를 실행한다. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Any + +from agent.task_queue import PersistentTaskQueue + +logger = logging.getLogger(__name__) + + +class Dispatcher: + """백그라운드 작업 소비자.""" + + def __init__( + self, + task_queue: PersistentTaskQueue, + poll_interval: float = 2.0, + ): + self._queue = task_queue + self._poll_interval = poll_interval + self._running = False + self._task: asyncio.Task | None = None + + async def start(self) -> None: + """백그라운드 폴링 루프를 시작한다.""" + self._running = True + self._task = asyncio.create_task(self._poll_loop()) + logger.info("Dispatcher started (poll_interval=%.1fs)", self._poll_interval) + + async def stop(self) -> None: + """폴링 루프를 중지한다.""" + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.info("Dispatcher stopped") + + async def _poll_loop(self) -> None: + """주기적으로 큐를 폴링한다.""" + while self._running: + try: + await self._poll_once() + except Exception: + logger.exception("Dispatcher poll error") + await asyncio.sleep(self._poll_interval) + + async def _poll_once(self) -> None: + """큐에서 작업을 하나 꺼내 처리한다.""" + task = await self._queue.dequeue() + if not task: + return + + logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"]) + + try: + result = await self._run_agent_for_task(task) + await self._queue.mark_completed(task["id"], result=result) + logger.info("Task %s completed successfully", task["id"]) + except Exception as e: + logger.exception("Task %s failed", task["id"]) + await self._queue.mark_failed(task["id"], error=str(e)) + await self._notify_failure(task, str(e)) + + async def _run_agent_for_task(self, task: dict) -> dict[str, Any]: + """작업에 대해 에이전트를 실행한다.""" + from agent.server import get_agent + + payload = task["payload"] + thread_id = task["thread_id"] + + config = { + "configurable": { + "thread_id": thread_id, + "__is_for_execution__": True, + "repo": { + "owner": payload.get("repo_owner", os.environ.get("DEFAULT_REPO_OWNER", "quant")), + "name": payload.get("repo_name", os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")), + }, + }, + "metadata": {}, + } + + await self._notify_start(task) + + agent = await get_agent(config) + + issue_number = payload.get("issue_number", 0) + message = payload.get("message", "") + title = payload.get("title", "") + + if issue_number: + input_text = f"이슈 #{issue_number}: {title}\n\n{message}" + else: + input_text = message + + result = await agent.ainvoke( + {"messages": [{"role": "human", "content": input_text}]}, + config=config, + ) + + return {"status": "completed", "messages_count": len(result.get("messages", []))} + + async def _notify_start(self, task: dict) -> None: + """작업 시작 알림을 전송한다.""" + payload = task["payload"] + issue_number = payload.get("issue_number", 0) + source = task["source"] + + if source == "gitea" and issue_number: + try: + from agent.tools.gitea_comment import gitea_comment + + await asyncio.to_thread( + gitea_comment, + message=f"작업을 시작합니다: {payload.get('title', '')}", + issue_number=issue_number, + ) + except Exception: + logger.debug("Failed to post start comment to Gitea") + + if source == "discord": + try: + from agent.tools.discord_reply import discord_reply + + await asyncio.to_thread( + discord_reply, + message=f"작업을 시작합니다: {payload.get('message', '')[:100]}", + ) + except Exception: + logger.debug("Failed to send start message to Discord") + + async def _notify_failure(self, task: dict, error: str) -> None: + """작업 실패 알림을 전송한다.""" + try: + from agent.tools.discord_reply import discord_reply + + await asyncio.to_thread(discord_reply, message=f"작업 실패: {error[:200]}") + except Exception: + logger.debug("Failed to send failure notification") diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py new file mode 100644 index 0000000..efb67c3 --- /dev/null +++ b/tests/test_dispatcher.py @@ -0,0 +1,87 @@ +"""Tests for the agent dispatcher.""" + +import os +import tempfile + +import pytest +from unittest.mock import AsyncMock + +from agent.dispatcher import Dispatcher +from agent.task_queue import PersistentTaskQueue + + +@pytest.fixture +async def task_queue(): + """Create a temporary task queue for testing.""" + fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + queue = PersistentTaskQueue(db_path=db_path) + await queue.initialize() + yield queue + await queue.close() + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_dispatcher_processes_task(task_queue): + """Dispatcher가 큐에서 작업을 꺼내 처리한다.""" + await task_queue.enqueue( + "thread-1", + "gitea", + { + "issue_number": 42, + "repo_owner": "quant", + "repo_name": "galaxis-po", + "message": "Fix the bug", + }, + ) + + mock_run_agent = AsyncMock(return_value={"pr_url": "http://..."}) + + dispatcher = Dispatcher(task_queue=task_queue) + dispatcher._run_agent_for_task = mock_run_agent + + await dispatcher._poll_once() + + mock_run_agent.assert_called_once() + pending = await task_queue.get_pending() + assert len(pending) == 0 + + +@pytest.mark.asyncio +async def test_dispatcher_skips_when_empty(task_queue): + """큐가 비어있으면 아무 작업도 하지 않는다.""" + mock_run_agent = AsyncMock() + + dispatcher = Dispatcher(task_queue=task_queue) + dispatcher._run_agent_for_task = mock_run_agent + + await dispatcher._poll_once() + mock_run_agent.assert_not_called() + + +@pytest.mark.asyncio +async def test_dispatcher_handles_failure(task_queue): + """에이전트 실행 실패 시 작업을 failed로 표시한다.""" + await task_queue.enqueue( + "thread-1", + "gitea", + { + "issue_number": 42, + "repo_owner": "quant", + "repo_name": "galaxis-po", + "message": "Fix", + }, + ) + + mock_run_agent = AsyncMock(side_effect=Exception("Agent crashed")) + + dispatcher = Dispatcher(task_queue=task_queue) + dispatcher._run_agent_for_task = mock_run_agent + + await dispatcher._poll_once() + + # 실패 후 다음 작업 dequeue 가능해야 함 + await task_queue.enqueue("thread-2", "gitea", {"message": "Next"}) + task = await task_queue.dequeue() + assert task is not None