From 0136823462ca112e1e4f631b33cfa3fd234b8626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=A8=B8=EB=8B=88=ED=8E=98=EB=8B=88?= Date: Fri, 20 Mar 2026 18:04:57 +0900 Subject: [PATCH] feat: add PersistentTaskQueue with SQLite backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Task 1 of Phase 3: SQLite-based persistent task queue with: - FIFO ordering (created_at ASC) - Concurrency limit (default 1 running task) - State machine: pending → running → completed|failed|timeout - Methods: enqueue, dequeue, mark_completed, mark_failed, get_pending, has_running_task - Thread-aware task tracking - Singleton pattern with lazy initialization All 8 tests passing. --- agent/task_queue.py | 148 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/test_task_queue.py | 117 +++++++++++++++++++++++++++++++ uv.lock | 11 +++ 4 files changed, 277 insertions(+) create mode 100644 agent/task_queue.py create mode 100644 tests/test_task_queue.py diff --git a/agent/task_queue.py b/agent/task_queue.py new file mode 100644 index 0000000..59c8446 --- /dev/null +++ b/agent/task_queue.py @@ -0,0 +1,148 @@ +"""SQLite 기반 영속 작업 큐. + +동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다. +""" + +from __future__ import annotations + +import json +import logging +import os +import uuid +from datetime import datetime, timezone + +import aiosqlite + +logger = logging.getLogger(__name__) + +_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + thread_id TEXT NOT NULL, + source TEXT NOT NULL, + payload TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT, + result TEXT +) +""" + + +class PersistentTaskQueue: + """SQLite 기반 영속 작업 큐.""" + + def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1): + self._db_path = db_path + self._max_concurrent = max_concurrent + self._db: aiosqlite.Connection | None = None + + async def initialize(self) -> None: + """DB 연결 및 테이블 생성.""" + self._db = await aiosqlite.connect(self._db_path) + self._db.row_factory = aiosqlite.Row + await self._db.execute(_CREATE_TABLE) + await self._db.commit() + + async def close(self) -> None: + """DB 연결 종료.""" + if self._db: + await self._db.close() + + async def enqueue( + self, + thread_id: str, + source: str, + payload: dict, + ) -> str: + task_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "INSERT INTO tasks (id, thread_id, source, payload, status, created_at) " + "VALUES (?, ?, ?, ?, 'pending', ?)", + (task_id, thread_id, source, json.dumps(payload), now), + ) + await self._db.commit() + logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source) + return task_id + + async def dequeue(self) -> dict | None: + cursor = await self._db.execute( + "SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'" + ) + row = await cursor.fetchone() + if row["cnt"] >= self._max_concurrent: + return None + + cursor = await self._db.execute( + "SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1" + ) + row = await cursor.fetchone() + if not row: + return None + + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?", + (now, row["id"]), + ) + await self._db.commit() + + task = dict(row) + task["payload"] = json.loads(task["payload"]) + task["status"] = "running" + logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"]) + return task + + async def mark_completed(self, task_id: str, result: dict | None = None) -> None: + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?", + (now, json.dumps(result or {}), task_id), + ) + await self._db.commit() + logger.info("Task %s completed", task_id) + + async def mark_failed(self, task_id: str, error: str = "") -> None: + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?", + (now, json.dumps({"error": error}), task_id), + ) + await self._db.commit() + logger.info("Task %s failed: %s", task_id, error) + + async def get_pending(self) -> list[dict]: + cursor = await self._db.execute( + "SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC" + ) + rows = await cursor.fetchall() + result = [] + for row in rows: + task = dict(row) + task["payload"] = json.loads(task["payload"]) + result.append(task) + return result + + async def has_running_task(self, thread_id: str) -> bool: + cursor = await self._db.execute( + "SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'", + (thread_id,), + ) + row = await cursor.fetchone() + return row["cnt"] > 0 + + +# 지연 초기화 싱글턴 +_queue: PersistentTaskQueue | None = None + + +async def get_task_queue() -> PersistentTaskQueue: + """PersistentTaskQueue 싱글턴을 반환한다.""" + global _queue + if _queue is None: + db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db") + _queue = PersistentTaskQueue(db_path=db_path) + await _queue.initialize() + return _queue diff --git a/pyproject.toml b/pyproject.toml index b2f287e..3143690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "pydantic-settings>=2.0.0", "slowapi>=0.1.9", "discord.py>=2.3.0", + "aiosqlite>=0.20.0", ] [project.optional-dependencies] diff --git a/tests/test_task_queue.py b/tests/test_task_queue.py new file mode 100644 index 0000000..b956b14 --- /dev/null +++ b/tests/test_task_queue.py @@ -0,0 +1,117 @@ +import os +import tempfile + +import pytest + +from agent.task_queue import PersistentTaskQueue + + +@pytest.fixture +async def task_queue(): + """임시 SQLite DB로 TaskQueue를 생성한다.""" + fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + queue = PersistentTaskQueue(db_path=db_path) + await queue.initialize() + yield queue + await queue.close() + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_enqueue_and_dequeue(task_queue): + """작업을 enqueue하고 dequeue한다.""" + task_id = await task_queue.enqueue( + thread_id="thread-1", + source="gitea", + payload={"issue_number": 42, "body": "Fix bug"}, + ) + assert task_id is not None + + task = await task_queue.dequeue() + assert task is not None + assert task["thread_id"] == "thread-1" + assert task["status"] == "running" + + +@pytest.mark.asyncio +async def test_dequeue_empty_returns_none(task_queue): + """큐가 비어있으면 None을 반환한다.""" + task = await task_queue.dequeue() + assert task is None + + +@pytest.mark.asyncio +async def test_fifo_order(task_queue): + """FIFO 순서로 dequeue한다.""" + await task_queue.enqueue("thread-1", "gitea", {"order": 1}) + await task_queue.enqueue("thread-2", "discord", {"order": 2}) + + task1 = await task_queue.dequeue() + assert task1["payload"]["order"] == 1 + + # Complete first task before dequeuing second + await task_queue.mark_completed(task1["id"]) + + task2 = await task_queue.dequeue() + assert task2["payload"]["order"] == 2 + + +@pytest.mark.asyncio +async def test_concurrency_limit(task_queue): + """running 작업이 있으면 dequeue하지 않는다.""" + await task_queue.enqueue("thread-1", "gitea", {"msg": "first"}) + await task_queue.enqueue("thread-2", "gitea", {"msg": "second"}) + + task1 = await task_queue.dequeue() + assert task1 is not None + + task2 = await task_queue.dequeue() + assert task2 is None + + +@pytest.mark.asyncio +async def test_mark_completed(task_queue): + """작업을 completed로 표시한다.""" + await task_queue.enqueue("thread-1", "gitea", {}) + task = await task_queue.dequeue() + assert task is not None + + await task_queue.mark_completed(task["id"], result={"pr_url": "http://..."}) + + await task_queue.enqueue("thread-2", "gitea", {}) + task2 = await task_queue.dequeue() + assert task2 is not None + + +@pytest.mark.asyncio +async def test_mark_failed(task_queue): + """작업을 failed로 표시한다.""" + await task_queue.enqueue("thread-1", "gitea", {}) + task = await task_queue.dequeue() + + await task_queue.mark_failed(task["id"], error="Something broke") + + await task_queue.enqueue("thread-2", "gitea", {}) + task2 = await task_queue.dequeue() + assert task2 is not None + + +@pytest.mark.asyncio +async def test_get_pending(task_queue): + """미처리 작업 목록을 반환한다.""" + await task_queue.enqueue("thread-1", "gitea", {}) + await task_queue.enqueue("thread-2", "discord", {}) + + pending = await task_queue.get_pending() + assert len(pending) == 2 + + +@pytest.mark.asyncio +async def test_has_running_task_for_thread(task_queue): + """특정 스레드에 실행 중인 작업이 있는지 확인한다.""" + await task_queue.enqueue("thread-1", "gitea", {}) + task = await task_queue.dequeue() # → running + + assert await task_queue.has_running_task("thread-1") is True + assert await task_queue.has_running_task("thread-2") is False diff --git a/uv.lock b/uv.lock index 0ccd14a..dd3b0c6 100644 --- a/uv.lock +++ b/uv.lock @@ -113,6 +113,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -694,6 +703,7 @@ name = "galaxis-agent" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "aiosqlite" }, { name = "cryptography" }, { name = "deepagents" }, { name = "discord-py" }, @@ -720,6 +730,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "cryptography", specifier = ">=41.0.0" }, { name = "deepagents", specifier = ">=0.4.3" }, { name = "discord-py", specifier = ">=2.3.0" },