feat: add Dispatcher for background task processing
This commit is contained in:
parent
5a471907fa
commit
da9caca791
149
agent/dispatcher.py
Normal file
149
agent/dispatcher.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""백그라운드 작업 디스패처.
|
||||
|
||||
TaskQueue를 폴링하여 에이전트를 실행한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""백그라운드 작업 소비자."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_queue: PersistentTaskQueue,
|
||||
poll_interval: float = 2.0,
|
||||
):
|
||||
self._queue = task_queue
|
||||
self._poll_interval = poll_interval
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""백그라운드 폴링 루프를 시작한다."""
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Dispatcher started (poll_interval=%.1fs)", self._poll_interval)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""폴링 루프를 중지한다."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Dispatcher stopped")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""주기적으로 큐를 폴링한다."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._poll_once()
|
||||
except Exception:
|
||||
logger.exception("Dispatcher poll error")
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
async def _poll_once(self) -> None:
|
||||
"""큐에서 작업을 하나 꺼내 처리한다."""
|
||||
task = await self._queue.dequeue()
|
||||
if not task:
|
||||
return
|
||||
|
||||
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
||||
|
||||
try:
|
||||
result = await self._run_agent_for_task(task)
|
||||
await self._queue.mark_completed(task["id"], result=result)
|
||||
logger.info("Task %s completed successfully", task["id"])
|
||||
except Exception as e:
|
||||
logger.exception("Task %s failed", task["id"])
|
||||
await self._queue.mark_failed(task["id"], error=str(e))
|
||||
await self._notify_failure(task, str(e))
|
||||
|
||||
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
||||
"""작업에 대해 에이전트를 실행한다."""
|
||||
from agent.server import get_agent
|
||||
|
||||
payload = task["payload"]
|
||||
thread_id = task["thread_id"]
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"__is_for_execution__": True,
|
||||
"repo": {
|
||||
"owner": payload.get("repo_owner", os.environ.get("DEFAULT_REPO_OWNER", "quant")),
|
||||
"name": payload.get("repo_name", os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")),
|
||||
},
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
await self._notify_start(task)
|
||||
|
||||
agent = await get_agent(config)
|
||||
|
||||
issue_number = payload.get("issue_number", 0)
|
||||
message = payload.get("message", "")
|
||||
title = payload.get("title", "")
|
||||
|
||||
if issue_number:
|
||||
input_text = f"이슈 #{issue_number}: {title}\n\n{message}"
|
||||
else:
|
||||
input_text = message
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [{"role": "human", "content": input_text}]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
return {"status": "completed", "messages_count": len(result.get("messages", []))}
|
||||
|
||||
async def _notify_start(self, task: dict) -> None:
|
||||
"""작업 시작 알림을 전송한다."""
|
||||
payload = task["payload"]
|
||||
issue_number = payload.get("issue_number", 0)
|
||||
source = task["source"]
|
||||
|
||||
if source == "gitea" and issue_number:
|
||||
try:
|
||||
from agent.tools.gitea_comment import gitea_comment
|
||||
|
||||
await asyncio.to_thread(
|
||||
gitea_comment,
|
||||
message=f"작업을 시작합니다: {payload.get('title', '')}",
|
||||
issue_number=issue_number,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to post start comment to Gitea")
|
||||
|
||||
if source == "discord":
|
||||
try:
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
|
||||
await asyncio.to_thread(
|
||||
discord_reply,
|
||||
message=f"작업을 시작합니다: {payload.get('message', '')[:100]}",
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to send start message to Discord")
|
||||
|
||||
async def _notify_failure(self, task: dict, error: str) -> None:
|
||||
"""작업 실패 알림을 전송한다."""
|
||||
try:
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
|
||||
await asyncio.to_thread(discord_reply, message=f"작업 실패: {error[:200]}")
|
||||
except Exception:
|
||||
logger.debug("Failed to send failure notification")
|
||||
87
tests/test_dispatcher.py
Normal file
87
tests/test_dispatcher.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Tests for the agent dispatcher."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent.dispatcher import Dispatcher
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_queue():
|
||||
"""Create a temporary task queue for testing."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
queue = PersistentTaskQueue(db_path=db_path)
|
||||
await queue.initialize()
|
||||
yield queue
|
||||
await queue.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_processes_task(task_queue):
|
||||
"""Dispatcher가 큐에서 작업을 꺼내 처리한다."""
|
||||
await task_queue.enqueue(
|
||||
"thread-1",
|
||||
"gitea",
|
||||
{
|
||||
"issue_number": 42,
|
||||
"repo_owner": "quant",
|
||||
"repo_name": "galaxis-po",
|
||||
"message": "Fix the bug",
|
||||
},
|
||||
)
|
||||
|
||||
mock_run_agent = AsyncMock(return_value={"pr_url": "http://..."})
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
|
||||
mock_run_agent.assert_called_once()
|
||||
pending = await task_queue.get_pending()
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_skips_when_empty(task_queue):
|
||||
"""큐가 비어있으면 아무 작업도 하지 않는다."""
|
||||
mock_run_agent = AsyncMock()
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
mock_run_agent.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_handles_failure(task_queue):
|
||||
"""에이전트 실행 실패 시 작업을 failed로 표시한다."""
|
||||
await task_queue.enqueue(
|
||||
"thread-1",
|
||||
"gitea",
|
||||
{
|
||||
"issue_number": 42,
|
||||
"repo_owner": "quant",
|
||||
"repo_name": "galaxis-po",
|
||||
"message": "Fix",
|
||||
},
|
||||
)
|
||||
|
||||
mock_run_agent = AsyncMock(side_effect=Exception("Agent crashed"))
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
|
||||
# 실패 후 다음 작업 dequeue 가능해야 함
|
||||
await task_queue.enqueue("thread-2", "gitea", {"message": "Next"})
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
Loading…
x
Reference in New Issue
Block a user