Compare commits
8 Commits
60aeebf7a7
...
a58bbca9b7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a58bbca9b7 | ||
|
|
d35efae12e | ||
|
|
7e95aeb8ce | ||
|
|
da9caca791 | ||
|
|
5a471907fa | ||
|
|
8c274b4be2 | ||
|
|
9242badeff | ||
|
|
0136823462 |
149
agent/dispatcher.py
Normal file
149
agent/dispatcher.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""백그라운드 작업 디스패처.
|
||||
|
||||
TaskQueue를 폴링하여 에이전트를 실행한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""백그라운드 작업 소비자."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_queue: PersistentTaskQueue,
|
||||
poll_interval: float = 2.0,
|
||||
):
|
||||
self._queue = task_queue
|
||||
self._poll_interval = poll_interval
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""백그라운드 폴링 루프를 시작한다."""
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Dispatcher started (poll_interval=%.1fs)", self._poll_interval)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""폴링 루프를 중지한다."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Dispatcher stopped")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""주기적으로 큐를 폴링한다."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._poll_once()
|
||||
except Exception:
|
||||
logger.exception("Dispatcher poll error")
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
async def _poll_once(self) -> None:
|
||||
"""큐에서 작업을 하나 꺼내 처리한다."""
|
||||
task = await self._queue.dequeue()
|
||||
if not task:
|
||||
return
|
||||
|
||||
logger.info("Processing task %s (thread %s)", task["id"], task["thread_id"])
|
||||
|
||||
try:
|
||||
result = await self._run_agent_for_task(task)
|
||||
await self._queue.mark_completed(task["id"], result=result)
|
||||
logger.info("Task %s completed successfully", task["id"])
|
||||
except Exception as e:
|
||||
logger.exception("Task %s failed", task["id"])
|
||||
await self._queue.mark_failed(task["id"], error=str(e))
|
||||
await self._notify_failure(task, str(e))
|
||||
|
||||
async def _run_agent_for_task(self, task: dict) -> dict[str, Any]:
|
||||
"""작업에 대해 에이전트를 실행한다."""
|
||||
from agent.server import get_agent
|
||||
|
||||
payload = task["payload"]
|
||||
thread_id = task["thread_id"]
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"__is_for_execution__": True,
|
||||
"repo": {
|
||||
"owner": payload.get("repo_owner", os.environ.get("DEFAULT_REPO_OWNER", "quant")),
|
||||
"name": payload.get("repo_name", os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")),
|
||||
},
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
await self._notify_start(task)
|
||||
|
||||
agent = await get_agent(config)
|
||||
|
||||
issue_number = payload.get("issue_number", 0)
|
||||
message = payload.get("message", "")
|
||||
title = payload.get("title", "")
|
||||
|
||||
if issue_number:
|
||||
input_text = f"이슈 #{issue_number}: {title}\n\n{message}"
|
||||
else:
|
||||
input_text = message
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [{"role": "human", "content": input_text}]},
|
||||
config=config,
|
||||
)
|
||||
|
||||
return {"status": "completed", "messages_count": len(result.get("messages", []))}
|
||||
|
||||
async def _notify_start(self, task: dict) -> None:
|
||||
"""작업 시작 알림을 전송한다."""
|
||||
payload = task["payload"]
|
||||
issue_number = payload.get("issue_number", 0)
|
||||
source = task["source"]
|
||||
|
||||
if source == "gitea" and issue_number:
|
||||
try:
|
||||
from agent.tools.gitea_comment import gitea_comment
|
||||
|
||||
await asyncio.to_thread(
|
||||
gitea_comment,
|
||||
message=f"작업을 시작합니다: {payload.get('title', '')}",
|
||||
issue_number=issue_number,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to post start comment to Gitea")
|
||||
|
||||
if source == "discord":
|
||||
try:
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
|
||||
await asyncio.to_thread(
|
||||
discord_reply,
|
||||
message=f"작업을 시작합니다: {payload.get('message', '')[:100]}",
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to send start message to Discord")
|
||||
|
||||
async def _notify_failure(self, task: dict, error: str) -> None:
|
||||
"""작업 실패 알림을 전송한다."""
|
||||
try:
|
||||
from agent.tools.discord_reply import discord_reply
|
||||
|
||||
await asyncio.to_thread(discord_reply, message=f"작업 실패: {error[:200]}")
|
||||
except Exception:
|
||||
logger.debug("Failed to send failure notification")
|
||||
120
agent/integrations/discord_handler.py
Normal file
120
agent/integrations/discord_handler.py
Normal file
@ -0,0 +1,120 @@
|
||||
# agent/integrations/discord_handler.py
|
||||
"""Discord Bot Gateway 수신 핸들러.
|
||||
|
||||
discord.py를 사용하여 @agent 멘션을 수신하고 작업을 큐에 추가한다.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_discord_message(content: str, bot_user_id: int) -> dict:
|
||||
"""Discord 메시지를 파싱하여 작업 정보를 추출한다."""
|
||||
# 봇 멘션 제거 (<@123456>)
|
||||
cleaned = re.sub(rf"<@!?{bot_user_id}>", "", content).strip()
|
||||
|
||||
# 이슈 번호 추출
|
||||
issue_match = re.search(r"#(\d+)", cleaned)
|
||||
issue_number = int(issue_match.group(1)) if issue_match else 0
|
||||
|
||||
# 리포 이름 추출
|
||||
repo_name = os.environ.get("DEFAULT_REPO_NAME", "galaxis-po")
|
||||
repo_match = re.search(r"\b(galaxis-\w+)\b", cleaned)
|
||||
if repo_match:
|
||||
repo_name = repo_match.group(1)
|
||||
|
||||
# 순수 메시지
|
||||
message = re.sub(r"@agent\b", "", cleaned, flags=re.IGNORECASE)
|
||||
message = re.sub(rf"\b{re.escape(repo_name)}\b", "", message)
|
||||
message = re.sub(r"이슈\s*#\d+", "", message)
|
||||
message = re.sub(r"#\d+", "", message)
|
||||
message = message.strip()
|
||||
|
||||
return {
|
||||
"issue_number": issue_number,
|
||||
"repo_name": repo_name,
|
||||
"message": message or cleaned,
|
||||
}
|
||||
|
||||
|
||||
def generate_discord_thread_id(channel_id: int, message_id: int) -> str:
|
||||
"""Discord 메시지에서 결정론적 스레드 ID를 생성한다."""
|
||||
raw = hashlib.sha256(f"discord:{channel_id}:{message_id}".encode()).hexdigest()
|
||||
return f"{raw[:8]}-{raw[8:12]}-{raw[12:16]}-{raw[16:20]}-{raw[20:32]}"
|
||||
|
||||
|
||||
class DiscordHandler:
|
||||
"""Discord Bot Gateway 핸들러."""
|
||||
|
||||
def __init__(self):
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guild_messages = True
|
||||
self.bot = commands.Bot(command_prefix="!", intents=intents)
|
||||
self._setup_events()
|
||||
|
||||
def _setup_events(self):
|
||||
@self.bot.event
|
||||
async def on_ready():
|
||||
logger.info("Discord bot connected as %s", self.bot.user)
|
||||
|
||||
@self.bot.event
|
||||
async def on_message(message: discord.Message):
|
||||
if message.author == self.bot.user:
|
||||
return
|
||||
if not self.bot.user or not self.bot.user.mentioned_in(message):
|
||||
return
|
||||
await self._handle_mention(message)
|
||||
|
||||
async def _handle_mention(self, message: discord.Message):
|
||||
"""@agent 멘션을 처리한다."""
|
||||
parsed = parse_discord_message(
|
||||
message.content, self.bot.user.id if self.bot.user else 0
|
||||
)
|
||||
thread_id = generate_discord_thread_id(message.channel.id, message.id)
|
||||
repo_owner = os.environ.get("DEFAULT_REPO_OWNER", "quant")
|
||||
|
||||
from agent.task_queue import get_task_queue
|
||||
from agent.message_store import get_message_store
|
||||
|
||||
task_queue = await get_task_queue()
|
||||
|
||||
if parsed["issue_number"] and await task_queue.has_running_task(thread_id):
|
||||
store = await get_message_store()
|
||||
await store.push_message(thread_id, {
|
||||
"role": "human",
|
||||
"content": parsed["message"],
|
||||
})
|
||||
await message.reply("메시지를 대기열에 추가했습니다. 현재 작업이 완료되면 확인하겠습니다.")
|
||||
return
|
||||
|
||||
task_id = await task_queue.enqueue(
|
||||
thread_id=thread_id,
|
||||
source="discord",
|
||||
payload={
|
||||
"issue_number": parsed["issue_number"],
|
||||
"repo_owner": repo_owner,
|
||||
"repo_name": parsed["repo_name"],
|
||||
"message": parsed["message"],
|
||||
"channel_id": str(message.channel.id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
)
|
||||
await message.reply(f"작업을 대기열에 추가했습니다. (task: {task_id[:8]})")
|
||||
logger.info("Discord task enqueued: %s (thread %s)", task_id, thread_id)
|
||||
|
||||
async def start(self, token: str):
|
||||
"""Bot Gateway를 시작한다."""
|
||||
await self.bot.start(token)
|
||||
|
||||
async def close(self):
|
||||
"""Bot Gateway를 종료한다."""
|
||||
await self.bot.close()
|
||||
86
agent/message_store.py
Normal file
86
agent/message_store.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""SQLite 기반 메시지 스토어.
|
||||
|
||||
에이전트 작업 중 도착하는 follow-up 메시지를 저장한다.
|
||||
check_message_queue 미들웨어가 다음 모델 호출 시 주입한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS pending_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
thread_id TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
|
||||
_store: MessageStore | None = None
|
||||
|
||||
|
||||
class MessageStore:
|
||||
"""SQLite 기반 follow-up 메시지 스토어."""
|
||||
|
||||
def __init__(self, db_path: str = "/data/message_store.db"):
|
||||
self._db_path = db_path
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the database connection and create tables."""
|
||||
self._db = await aiosqlite.connect(self._db_path)
|
||||
await self._db.execute(_CREATE_TABLE)
|
||||
await self._db.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._db:
|
||||
await self._db.close()
|
||||
|
||||
async def push_message(self, thread_id: str, message: dict) -> None:
|
||||
"""스레드에 pending 메시지를 추가한다."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
|
||||
(thread_id, json.dumps(message), now),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.debug("Pushed message for thread %s", thread_id)
|
||||
|
||||
async def get_messages(self, thread_id: str) -> list[dict]:
|
||||
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
|
||||
cursor = await self._db.execute(
|
||||
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
|
||||
(thread_id,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [json.loads(row[0]) for row in rows]
|
||||
|
||||
async def consume_messages(self, thread_id: str) -> list[dict]:
|
||||
"""스레드의 pending 메시지를 조회하고 삭제한다."""
|
||||
messages = await self.get_messages(thread_id)
|
||||
await self._db.execute(
|
||||
"DELETE FROM pending_messages WHERE thread_id = ?",
|
||||
(thread_id,),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
|
||||
return messages
|
||||
|
||||
|
||||
async def get_message_store() -> MessageStore:
|
||||
"""MessageStore 싱글턴을 반환한다."""
|
||||
global _store
|
||||
if _store is None:
|
||||
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
|
||||
_store = MessageStore(db_path=db_path)
|
||||
await _store.initialize()
|
||||
return _store
|
||||
@ -1,6 +1,6 @@
|
||||
"""Before-model middleware that injects queued messages into state.
|
||||
|
||||
Checks the LangGraph store for pending messages (e.g. follow-up Linear
|
||||
Checks the MessageStore for pending messages (e.g. follow-up Linear
|
||||
comments that arrived while the agent was busy) and injects them as new
|
||||
human messages before the next model call.
|
||||
"""
|
||||
@ -12,9 +12,10 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain.agents.middleware import AgentState, before_model
|
||||
from langgraph.config import get_config, get_store
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ..message_store import get_message_store
|
||||
from ..utils.multimodal import fetch_image_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911
|
||||
return None
|
||||
|
||||
try:
|
||||
store = get_store()
|
||||
store = await get_message_store()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("Could not get store from context: %s", e)
|
||||
logger.warning("Could not get message store: %s", e)
|
||||
return None
|
||||
|
||||
if store is None:
|
||||
return None
|
||||
|
||||
namespace = ("queue", thread_id)
|
||||
|
||||
try:
|
||||
queued_item = await store.aget(namespace, "pending_messages")
|
||||
queued_messages = await store.consume_messages(thread_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to get queued item: %s", e)
|
||||
logger.warning("Failed to consume messages: %s", e)
|
||||
return None
|
||||
|
||||
if queued_item is None:
|
||||
return None
|
||||
|
||||
queued_value = queued_item.value
|
||||
queued_messages = queued_value.get("messages", [])
|
||||
|
||||
# Delete early to prevent duplicate processing if middleware runs again
|
||||
await store.adelete(namespace, "pending_messages")
|
||||
|
||||
if not queued_messages:
|
||||
return None
|
||||
|
||||
|
||||
148
agent/task_queue.py
Normal file
148
agent/task_queue.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""SQLite 기반 영속 작업 큐.
|
||||
|
||||
동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
result TEXT
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class PersistentTaskQueue:
|
||||
"""SQLite 기반 영속 작업 큐."""
|
||||
|
||||
def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1):
|
||||
self._db_path = db_path
|
||||
self._max_concurrent = max_concurrent
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""DB 연결 및 테이블 생성."""
|
||||
self._db = await aiosqlite.connect(self._db_path)
|
||||
self._db.row_factory = aiosqlite.Row
|
||||
await self._db.execute(_CREATE_TABLE)
|
||||
await self._db.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""DB 연결 종료."""
|
||||
if self._db:
|
||||
await self._db.close()
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
thread_id: str,
|
||||
source: str,
|
||||
payload: dict,
|
||||
) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"INSERT INTO tasks (id, thread_id, source, payload, status, created_at) "
|
||||
"VALUES (?, ?, ?, ?, 'pending', ?)",
|
||||
(task_id, thread_id, source, json.dumps(payload), now),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source)
|
||||
return task_id
|
||||
|
||||
async def dequeue(self) -> dict | None:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row["cnt"] >= self._max_concurrent:
|
||||
return None
|
||||
|
||||
cursor = await self._db.execute(
|
||||
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?",
|
||||
(now, row["id"]),
|
||||
)
|
||||
await self._db.commit()
|
||||
|
||||
task = dict(row)
|
||||
task["payload"] = json.loads(task["payload"])
|
||||
task["status"] = "running"
|
||||
logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"])
|
||||
return task
|
||||
|
||||
async def mark_completed(self, task_id: str, result: dict | None = None) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?",
|
||||
(now, json.dumps(result or {}), task_id),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Task %s completed", task_id)
|
||||
|
||||
async def mark_failed(self, task_id: str, error: str = "") -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?",
|
||||
(now, json.dumps({"error": error}), task_id),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Task %s failed: %s", task_id, error)
|
||||
|
||||
async def get_pending(self) -> list[dict]:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
task = dict(row)
|
||||
task["payload"] = json.loads(task["payload"])
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
async def has_running_task(self, thread_id: str) -> bool:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'",
|
||||
(thread_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"] > 0
|
||||
|
||||
|
||||
# 지연 초기화 싱글턴
|
||||
_queue: PersistentTaskQueue | None = None
|
||||
|
||||
|
||||
async def get_task_queue() -> PersistentTaskQueue:
|
||||
"""PersistentTaskQueue 싱글턴을 반환한다."""
|
||||
global _queue
|
||||
if _queue is None:
|
||||
db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db")
|
||||
_queue = PersistentTaskQueue(db_path=db_path)
|
||||
await _queue.initialize()
|
||||
return _queue
|
||||
189
agent/webapp.py
189
agent/webapp.py
@ -1,13 +1,59 @@
|
||||
"""galaxis-agent webhook server."""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="galaxis-agent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""애플리케이션 시작/종료 시 리소스를 관리한다."""
|
||||
from agent.task_queue import get_task_queue
|
||||
from agent.message_store import get_message_store
|
||||
from agent.dispatcher import Dispatcher
|
||||
from agent.integrations.discord_handler import DiscordHandler
|
||||
|
||||
task_queue = await get_task_queue()
|
||||
message_store = await get_message_store()
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
await dispatcher.start()
|
||||
app.state.dispatcher = dispatcher
|
||||
|
||||
discord_token = os.environ.get("DISCORD_TOKEN", "")
|
||||
discord_handler = None
|
||||
if discord_token:
|
||||
discord_handler = DiscordHandler()
|
||||
discord_task = asyncio.create_task(discord_handler.start(discord_token))
|
||||
app.state.discord_handler = discord_handler
|
||||
logger.info("Discord bot starting...")
|
||||
|
||||
yield
|
||||
|
||||
await dispatcher.stop()
|
||||
if discord_handler:
|
||||
await discord_handler.close()
|
||||
await task_queue.close()
|
||||
await message_store.close()
|
||||
logger.info("Application shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="galaxis-agent", lifespan=lifespan)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
|
||||
def verify_gitea_signature(payload: bytes, signature: str, secret: str) -> bool:
|
||||
@ -22,21 +68,150 @@ def generate_thread_id(repo: str, issue_id: int) -> str:
|
||||
return f"{raw[:8]}-{raw[8:12]}-{raw[12:16]}-{raw[16:20]}-{raw[20:32]}"
|
||||
|
||||
|
||||
def parse_gitea_event(event_type: str, payload: dict) -> dict:
|
||||
"""Gitea webhook 페이로드를 파싱하여 처리 대상인지 판단한다."""
|
||||
repo = payload.get("repository", {})
|
||||
repo_name = repo.get("name", "")
|
||||
full_name = repo.get("full_name", "")
|
||||
repo_owner = full_name.split("/")[0] if "/" in full_name else ""
|
||||
|
||||
base = {
|
||||
"should_process": False,
|
||||
"issue_number": 0,
|
||||
"repo_name": repo_name,
|
||||
"repo_owner": repo_owner,
|
||||
"message": "",
|
||||
"event_type": event_type,
|
||||
"title": "",
|
||||
}
|
||||
|
||||
if event_type == "issue_comment":
|
||||
action = payload.get("action", "")
|
||||
if action != "created":
|
||||
return base
|
||||
comment_body = payload.get("comment", {}).get("body", "")
|
||||
issue = payload.get("issue", {})
|
||||
if "@agent" not in comment_body.lower():
|
||||
return base
|
||||
message = re.sub(r"@agent\b", "", comment_body, flags=re.IGNORECASE).strip()
|
||||
base.update({
|
||||
"should_process": True,
|
||||
"issue_number": issue.get("number", 0),
|
||||
"message": message,
|
||||
"title": issue.get("title", ""),
|
||||
})
|
||||
return base
|
||||
|
||||
if event_type == "issues":
|
||||
label = payload.get("label", {})
|
||||
if label.get("name") == "agent-fix":
|
||||
issue = payload.get("issue", {})
|
||||
base.update({
|
||||
"should_process": True,
|
||||
"issue_number": issue.get("number", 0),
|
||||
"message": issue.get("body", ""),
|
||||
"title": issue.get("title", ""),
|
||||
})
|
||||
return base
|
||||
|
||||
if event_type == "pull_request":
|
||||
action = payload.get("action", "")
|
||||
if action == "review_requested":
|
||||
pr = payload.get("pull_request", {})
|
||||
base.update({
|
||||
"should_process": True,
|
||||
"issue_number": pr.get("number", 0),
|
||||
"message": pr.get("body", ""),
|
||||
"title": pr.get("title", ""),
|
||||
})
|
||||
return base
|
||||
|
||||
return base
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/health/gitea")
|
||||
async def health_gitea():
|
||||
"""Gitea 연결 상태를 확인한다."""
|
||||
try:
|
||||
from agent.utils.gitea_client import get_gitea_client
|
||||
client = get_gitea_client()
|
||||
resp = await client._client.get("/settings/api")
|
||||
return {"status": "ok", "gitea_status_code": resp.status_code}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/health/discord")
|
||||
async def health_discord(request: Request):
|
||||
"""Discord 봇 연결 상태를 확인한다."""
|
||||
handler = getattr(request.app.state, "discord_handler", None)
|
||||
if not handler:
|
||||
return {"status": "not_configured"}
|
||||
bot = handler.bot
|
||||
if bot.is_ready():
|
||||
return {"status": "ok", "user": str(bot.user)}
|
||||
return {"status": "connecting"}
|
||||
|
||||
|
||||
@app.get("/health/queue")
|
||||
async def health_queue():
|
||||
"""작업 큐 상태를 반환한다."""
|
||||
from agent.task_queue import get_task_queue
|
||||
task_queue = await get_task_queue()
|
||||
pending = await task_queue.get_pending()
|
||||
return {
|
||||
"status": "ok",
|
||||
"pending_tasks": len(pending),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/webhooks/gitea")
|
||||
@limiter.limit("10/minute")
|
||||
async def gitea_webhook(request: Request):
|
||||
"""Gitea webhook endpoint. Full implementation in Phase 3."""
|
||||
import os
|
||||
body = await request.body()
|
||||
"""Gitea webhook endpoint with event parsing and task dispatch."""
|
||||
payload_bytes = await request.body()
|
||||
signature = request.headers.get("X-Gitea-Signature", "")
|
||||
secret = os.environ.get("GITEA_WEBHOOK_SECRET", "")
|
||||
|
||||
if not verify_gitea_signature(body, signature, secret):
|
||||
if not verify_gitea_signature(payload_bytes, signature, secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
logger.info("Gitea webhook received (not yet implemented)")
|
||||
return {"status": "received"}
|
||||
payload = json.loads(payload_bytes)
|
||||
event_type = request.headers.get("X-Gitea-Event", "")
|
||||
|
||||
event = parse_gitea_event(event_type, payload)
|
||||
if not event["should_process"]:
|
||||
return {"status": "ignored"}
|
||||
|
||||
thread_id = generate_thread_id(event["repo_name"], event["issue_number"])
|
||||
|
||||
from agent.message_store import get_message_store
|
||||
from agent.task_queue import get_task_queue
|
||||
task_queue = await get_task_queue()
|
||||
|
||||
if await task_queue.has_running_task(thread_id):
|
||||
store = await get_message_store()
|
||||
await store.push_message(thread_id, {
|
||||
"role": "human",
|
||||
"content": event["message"],
|
||||
})
|
||||
return {"status": "queued_message", "thread_id": thread_id}
|
||||
|
||||
task_id = await task_queue.enqueue(
|
||||
thread_id=thread_id,
|
||||
source="gitea",
|
||||
payload={
|
||||
"issue_number": event["issue_number"],
|
||||
"repo_owner": event["repo_owner"],
|
||||
"repo_name": event["repo_name"],
|
||||
"message": event["message"],
|
||||
"title": event["title"],
|
||||
"event_type": event["event_type"],
|
||||
},
|
||||
)
|
||||
return {"status": "enqueued", "task_id": task_id, "thread_id": thread_id}
|
||||
|
||||
@ -21,6 +21,7 @@ dependencies = [
|
||||
"pydantic-settings>=2.0.0",
|
||||
"slowapi>=0.1.9",
|
||||
"discord.py>=2.3.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
43
tests/test_discord_handler.py
Normal file
43
tests/test_discord_handler.py
Normal file
@ -0,0 +1,43 @@
|
||||
# tests/test_discord_handler.py
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
|
||||
def test_parse_discord_mention_with_issue():
|
||||
"""이슈 번호가 포함된 Discord 멘션을 파싱한다."""
|
||||
from agent.integrations.discord_handler import parse_discord_message
|
||||
|
||||
result = parse_discord_message("이슈 #42 해결해줘", bot_user_id=123)
|
||||
assert result["issue_number"] == 42
|
||||
assert result["repo_name"] == "galaxis-po"
|
||||
assert "해결해줘" in result["message"]
|
||||
|
||||
|
||||
def test_parse_discord_mention_with_repo():
|
||||
"""리포가 명시된 Discord 멘션을 파싱한다."""
|
||||
from agent.integrations.discord_handler import parse_discord_message
|
||||
|
||||
result = parse_discord_message("galaxis-po 이슈 #10 수정해줘", bot_user_id=123)
|
||||
assert result["issue_number"] == 10
|
||||
assert result["repo_name"] == "galaxis-po"
|
||||
|
||||
|
||||
def test_parse_discord_mention_freeform():
|
||||
"""이슈 번호 없는 자유형 요청을 파싱한다."""
|
||||
from agent.integrations.discord_handler import parse_discord_message
|
||||
|
||||
result = parse_discord_message("factor_calculator에 듀얼 모멘텀 추가해줘", bot_user_id=123)
|
||||
assert result["issue_number"] == 0
|
||||
assert result["message"] == "factor_calculator에 듀얼 모멘텀 추가해줘"
|
||||
|
||||
|
||||
def test_generate_discord_thread_id():
|
||||
"""Discord 메시지에서 결정론적 thread_id를 생성한다."""
|
||||
from agent.integrations.discord_handler import generate_discord_thread_id
|
||||
|
||||
tid1 = generate_discord_thread_id(channel_id=111, message_id=222)
|
||||
tid2 = generate_discord_thread_id(channel_id=111, message_id=222)
|
||||
tid3 = generate_discord_thread_id(channel_id=111, message_id=333)
|
||||
assert tid1 == tid2
|
||||
assert tid1 != tid3
|
||||
assert len(tid1) == 36
|
||||
87
tests/test_dispatcher.py
Normal file
87
tests/test_dispatcher.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Tests for the agent dispatcher."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent.dispatcher import Dispatcher
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_queue():
|
||||
"""Create a temporary task queue for testing."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
queue = PersistentTaskQueue(db_path=db_path)
|
||||
await queue.initialize()
|
||||
yield queue
|
||||
await queue.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_processes_task(task_queue):
|
||||
"""Dispatcher가 큐에서 작업을 꺼내 처리한다."""
|
||||
await task_queue.enqueue(
|
||||
"thread-1",
|
||||
"gitea",
|
||||
{
|
||||
"issue_number": 42,
|
||||
"repo_owner": "quant",
|
||||
"repo_name": "galaxis-po",
|
||||
"message": "Fix the bug",
|
||||
},
|
||||
)
|
||||
|
||||
mock_run_agent = AsyncMock(return_value={"pr_url": "http://..."})
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
|
||||
mock_run_agent.assert_called_once()
|
||||
pending = await task_queue.get_pending()
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_skips_when_empty(task_queue):
|
||||
"""큐가 비어있으면 아무 작업도 하지 않는다."""
|
||||
mock_run_agent = AsyncMock()
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
mock_run_agent.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_handles_failure(task_queue):
|
||||
"""에이전트 실행 실패 시 작업을 failed로 표시한다."""
|
||||
await task_queue.enqueue(
|
||||
"thread-1",
|
||||
"gitea",
|
||||
{
|
||||
"issue_number": 42,
|
||||
"repo_owner": "quant",
|
||||
"repo_name": "galaxis-po",
|
||||
"message": "Fix",
|
||||
},
|
||||
)
|
||||
|
||||
mock_run_agent = AsyncMock(side_effect=Exception("Agent crashed"))
|
||||
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
|
||||
# 실패 후 다음 작업 dequeue 가능해야 함
|
||||
await task_queue.enqueue("thread-2", "gitea", {"message": "Next"})
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
94
tests/test_gitea_webhook.py
Normal file
94
tests/test_gitea_webhook.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Tests for Gitea webhook event parsing and signature verification."""
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def make_signature(payload: bytes, secret: str) -> str:
|
||||
"""Gitea HMAC-SHA256 서명을 생성한다."""
|
||||
return hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def test_verify_signature():
|
||||
"""Gitea webhook 서명을 검증한다."""
|
||||
from agent.webapp import verify_gitea_signature
|
||||
|
||||
payload = b'{"action": "created"}'
|
||||
secret = "test-secret"
|
||||
sig = make_signature(payload, secret)
|
||||
assert verify_gitea_signature(payload, sig, secret) is True
|
||||
assert verify_gitea_signature(payload, "wrong", secret) is False
|
||||
|
||||
|
||||
def test_generate_thread_id():
|
||||
"""결정론적 스레드 ID를 생성한다."""
|
||||
from agent.webapp import generate_thread_id
|
||||
|
||||
tid1 = generate_thread_id("galaxis-po", 42)
|
||||
tid2 = generate_thread_id("galaxis-po", 42)
|
||||
tid3 = generate_thread_id("galaxis-po", 43)
|
||||
assert tid1 == tid2
|
||||
assert tid1 != tid3
|
||||
assert len(tid1) == 36
|
||||
assert tid1.count("-") == 4
|
||||
|
||||
|
||||
def test_parse_issue_comment_with_mention():
|
||||
"""이슈 코멘트에서 @agent 멘션을 감지한다."""
|
||||
from agent.webapp import parse_gitea_event
|
||||
|
||||
payload = {
|
||||
"action": "created",
|
||||
"comment": {"body": "@agent factor_calculator에 듀얼 모멘텀 추가해줘"},
|
||||
"issue": {"number": 42, "title": "Feature request", "body": "description"},
|
||||
"repository": {"full_name": "quant/galaxis-po", "name": "galaxis-po"},
|
||||
}
|
||||
result = parse_gitea_event("issue_comment", payload)
|
||||
assert result is not None
|
||||
assert result["should_process"] is True
|
||||
assert result["issue_number"] == 42
|
||||
assert result["repo_name"] == "galaxis-po"
|
||||
assert "@agent" not in result["message"]
|
||||
|
||||
|
||||
def test_parse_issue_comment_without_mention():
|
||||
"""@agent 멘션이 없는 코멘트는 무시한다."""
|
||||
from agent.webapp import parse_gitea_event
|
||||
|
||||
payload = {
|
||||
"action": "created",
|
||||
"comment": {"body": "일반 코멘트입니다"},
|
||||
"issue": {"number": 42, "title": "Bug", "body": "desc"},
|
||||
"repository": {"full_name": "quant/galaxis-po", "name": "galaxis-po"},
|
||||
}
|
||||
result = parse_gitea_event("issue_comment", payload)
|
||||
assert result["should_process"] is False
|
||||
|
||||
|
||||
def test_parse_issue_label_agent_fix():
|
||||
"""agent-fix 라벨 부착 시 작업을 트리거한다."""
|
||||
from agent.webapp import parse_gitea_event
|
||||
|
||||
payload = {
|
||||
"action": "label_updated",
|
||||
"issue": {"number": 10, "title": "Fix login", "body": "Login fails"},
|
||||
"label": {"name": "agent-fix"},
|
||||
"repository": {"full_name": "quant/galaxis-po", "name": "galaxis-po"},
|
||||
}
|
||||
result = parse_gitea_event("issues", payload)
|
||||
assert result is not None
|
||||
assert result["should_process"] is True
|
||||
|
||||
|
||||
def test_parse_pr_review_requested():
|
||||
"""PR 리뷰 요청을 감지한다."""
|
||||
from agent.webapp import parse_gitea_event
|
||||
|
||||
payload = {
|
||||
"action": "review_requested",
|
||||
"pull_request": {"number": 5, "title": "feat: add feature", "body": "desc"},
|
||||
"repository": {"full_name": "quant/galaxis-po", "name": "galaxis-po"},
|
||||
}
|
||||
result = parse_gitea_event("pull_request", payload)
|
||||
assert result is not None
|
||||
assert result["should_process"] is True
|
||||
assert result["issue_number"] == 5
|
||||
155
tests/test_health.py
Normal file
155
tests/test_health.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""Health check 엔드포인트 테스트."""
|
||||
import pytest
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_lifespan(app):
|
||||
"""테스트용 mock lifespan."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_basic():
|
||||
"""기본 health check가 200을 반환한다."""
|
||||
from agent.webapp import app
|
||||
# Override lifespan for testing
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_queue():
|
||||
"""큐 health check가 pending 카운트를 반환한다."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
mock_queue = MagicMock()
|
||||
mock_queue.get_pending = AsyncMock(return_value=[{"id": "1"}, {"id": "2"}])
|
||||
|
||||
with patch("agent.task_queue.get_task_queue", new_callable=AsyncMock, return_value=mock_queue):
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/queue")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["pending_tasks"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_queue_empty():
|
||||
"""큐가 비어있을 때 0을 반환한다."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
mock_queue = MagicMock()
|
||||
mock_queue.get_pending = AsyncMock(return_value=[])
|
||||
|
||||
with patch("agent.task_queue.get_task_queue", new_callable=AsyncMock, return_value=mock_queue):
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/queue")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["pending_tasks"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_discord_not_configured():
|
||||
"""Discord가 설정되지 않았을 때."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
# Ensure no discord_handler on app.state
|
||||
if hasattr(app.state, "discord_handler"):
|
||||
delattr(app.state, "discord_handler")
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/discord")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "not_configured"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_discord_ready():
|
||||
"""Discord 봇이 준비되었을 때."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
mock_bot = MagicMock()
|
||||
mock_bot.is_ready.return_value = True
|
||||
mock_bot.user = "TestBot#1234"
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.bot = mock_bot
|
||||
app.state.discord_handler = mock_handler
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/discord")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["user"] == "TestBot#1234"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_discord_connecting():
|
||||
"""Discord 봇이 연결 중일 때."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
mock_bot = MagicMock()
|
||||
mock_bot.is_ready.return_value = False
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.bot = mock_bot
|
||||
app.state.discord_handler = mock_handler
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/discord")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "connecting"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_gitea_ok():
|
||||
"""Gitea API 연결이 성공할 때."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("agent.utils.gitea_client.get_gitea_client", return_value=mock_client):
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/gitea")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["gitea_status_code"] == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_gitea_error():
|
||||
"""Gitea API 연결이 실패할 때."""
|
||||
from agent.webapp import app
|
||||
app.router.lifespan_context = mock_lifespan
|
||||
|
||||
with patch("agent.utils.gitea_client.get_gitea_client", side_effect=Exception("Connection failed")):
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://test") as client:
|
||||
resp = await client.get("/health/gitea")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "error"
|
||||
assert "Connection failed" in data["error"]
|
||||
67
tests/test_integration_webhook_flow.py
Normal file
67
tests/test_integration_webhook_flow.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""Integration tests for webhook-to-dispatcher flow."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
from agent.dispatcher import Dispatcher
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_queue():
|
||||
"""Create a temporary task queue for testing."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
queue = PersistentTaskQueue(db_path=db_path)
|
||||
await queue.initialize()
|
||||
yield queue
|
||||
await queue.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_to_dispatcher_flow(task_queue):
|
||||
"""Gitea webhook → TaskQueue → Dispatcher 전체 흐름."""
|
||||
from agent.webapp import parse_gitea_event, generate_thread_id
|
||||
|
||||
# 1. Webhook 이벤트 파싱
|
||||
payload = {
|
||||
"action": "created",
|
||||
"comment": {"body": "@agent factor_calculator에 듀얼 모멘텀 추가해줘"},
|
||||
"issue": {"number": 42, "title": "Feature request", "body": "desc"},
|
||||
"repository": {"full_name": "quant/galaxis-po", "name": "galaxis-po"},
|
||||
}
|
||||
event = parse_gitea_event("issue_comment", payload)
|
||||
assert event["should_process"] is True
|
||||
|
||||
# 2. TaskQueue에 enqueue
|
||||
thread_id = generate_thread_id("galaxis-po", 42)
|
||||
task_id = await task_queue.enqueue(
|
||||
thread_id=thread_id,
|
||||
source="gitea",
|
||||
payload={
|
||||
"issue_number": event["issue_number"],
|
||||
"repo_owner": event["repo_owner"],
|
||||
"repo_name": event["repo_name"],
|
||||
"message": event["message"],
|
||||
},
|
||||
)
|
||||
|
||||
# 3. Dispatcher가 처리
|
||||
mock_run_agent = AsyncMock(return_value={"status": "completed"})
|
||||
dispatcher = Dispatcher(task_queue=task_queue)
|
||||
dispatcher._run_agent_for_task = mock_run_agent
|
||||
|
||||
await dispatcher._poll_once()
|
||||
|
||||
# 4. 에이전트가 호출되었는지 확인
|
||||
mock_run_agent.assert_called_once()
|
||||
call_task = mock_run_agent.call_args[0][0]
|
||||
assert call_task["thread_id"] == thread_id
|
||||
assert call_task["payload"]["issue_number"] == 42
|
||||
|
||||
# 5. 작업이 완료 처리되었는지 확인
|
||||
pending = await task_queue.get_pending()
|
||||
assert len(pending) == 0
|
||||
62
tests/test_message_store.py
Normal file
62
tests/test_message_store.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Tests for MessageStore."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.message_store import MessageStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def store():
|
||||
"""Create a temporary MessageStore for testing."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
s = MessageStore(db_path=db_path)
|
||||
await s.initialize()
|
||||
yield s
|
||||
await s.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_get_messages(store):
|
||||
"""메시지를 push하고 get한다."""
|
||||
await store.push_message("thread-1", {"role": "human", "content": "추가 요청"})
|
||||
messages = await store.get_messages("thread-1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "추가 요청"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consume_messages(store):
|
||||
"""메시지를 소비하면 삭제된다."""
|
||||
await store.push_message("thread-1", {"role": "human", "content": "msg1"})
|
||||
await store.push_message("thread-1", {"role": "human", "content": "msg2"})
|
||||
|
||||
messages = await store.consume_messages("thread-1")
|
||||
assert len(messages) == 2
|
||||
|
||||
remaining = await store.get_messages("thread-1")
|
||||
assert len(remaining) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_messages_returns_empty(store):
|
||||
"""메시지가 없으면 빈 리스트를 반환한다."""
|
||||
messages = await store.get_messages("thread-999")
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_isolated_by_thread(store):
|
||||
"""스레드별로 메시지가 격리된다."""
|
||||
await store.push_message("thread-1", {"content": "for thread 1"})
|
||||
await store.push_message("thread-2", {"content": "for thread 2"})
|
||||
|
||||
msgs1 = await store.get_messages("thread-1")
|
||||
msgs2 = await store.get_messages("thread-2")
|
||||
assert len(msgs1) == 1
|
||||
assert len(msgs2) == 1
|
||||
assert msgs1[0]["content"] == "for thread 1"
|
||||
117
tests/test_task_queue.py
Normal file
117
tests/test_task_queue.py
Normal file
@ -0,0 +1,117 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_queue():
|
||||
"""임시 SQLite DB로 TaskQueue를 생성한다."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
queue = PersistentTaskQueue(db_path=db_path)
|
||||
await queue.initialize()
|
||||
yield queue
|
||||
await queue.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_and_dequeue(task_queue):
|
||||
"""작업을 enqueue하고 dequeue한다."""
|
||||
task_id = await task_queue.enqueue(
|
||||
thread_id="thread-1",
|
||||
source="gitea",
|
||||
payload={"issue_number": 42, "body": "Fix bug"},
|
||||
)
|
||||
assert task_id is not None
|
||||
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
assert task["thread_id"] == "thread-1"
|
||||
assert task["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_empty_returns_none(task_queue):
|
||||
"""큐가 비어있으면 None을 반환한다."""
|
||||
task = await task_queue.dequeue()
|
||||
assert task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fifo_order(task_queue):
|
||||
"""FIFO 순서로 dequeue한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {"order": 1})
|
||||
await task_queue.enqueue("thread-2", "discord", {"order": 2})
|
||||
|
||||
task1 = await task_queue.dequeue()
|
||||
assert task1["payload"]["order"] == 1
|
||||
|
||||
# Complete first task before dequeuing second
|
||||
await task_queue.mark_completed(task1["id"])
|
||||
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2["payload"]["order"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_limit(task_queue):
|
||||
"""running 작업이 있으면 dequeue하지 않는다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {"msg": "first"})
|
||||
await task_queue.enqueue("thread-2", "gitea", {"msg": "second"})
|
||||
|
||||
task1 = await task_queue.dequeue()
|
||||
assert task1 is not None
|
||||
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_completed(task_queue):
|
||||
"""작업을 completed로 표시한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
|
||||
await task_queue.mark_completed(task["id"], result={"pr_url": "http://..."})
|
||||
|
||||
await task_queue.enqueue("thread-2", "gitea", {})
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_failed(task_queue):
|
||||
"""작업을 failed로 표시한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue()
|
||||
|
||||
await task_queue.mark_failed(task["id"], error="Something broke")
|
||||
|
||||
await task_queue.enqueue("thread-2", "gitea", {})
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending(task_queue):
|
||||
"""미처리 작업 목록을 반환한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
await task_queue.enqueue("thread-2", "discord", {})
|
||||
|
||||
pending = await task_queue.get_pending()
|
||||
assert len(pending) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_running_task_for_thread(task_queue):
|
||||
"""특정 스레드에 실행 중인 작업이 있는지 확인한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue() # → running
|
||||
|
||||
assert await task_queue.has_running_task("thread-1") is True
|
||||
assert await task_queue.has_running_task("thread-2") is False
|
||||
11
uv.lock
generated
11
uv.lock
generated
@ -113,6 +113,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosqlite"
|
||||
version = "0.22.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.4"
|
||||
@ -694,6 +703,7 @@ name = "galaxis-agent"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "deepagents" },
|
||||
{ name = "discord-py" },
|
||||
@ -720,6 +730,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiosqlite", specifier = ">=0.20.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "deepagents", specifier = ">=0.4.3" },
|
||||
{ name = "discord-py", specifier = ">=2.3.0" },
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user