feat: add PersistentTaskQueue with SQLite backend
Implements Task 1 of Phase 3: SQLite-based persistent task queue with: - FIFO ordering (created_at ASC) - Concurrency limit (default 1 running task) - State machine: pending → running → completed|failed|timeout - Methods: enqueue, dequeue, mark_completed, mark_failed, get_pending, has_running_task - Thread-aware task tracking - Singleton pattern with lazy initialization All 8 tests passing.
This commit is contained in:
parent
60aeebf7a7
commit
0136823462
148
agent/task_queue.py
Normal file
148
agent/task_queue.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""SQLite 기반 영속 작업 큐.
|
||||
|
||||
동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
payload TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
result TEXT
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class PersistentTaskQueue:
|
||||
"""SQLite 기반 영속 작업 큐."""
|
||||
|
||||
def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1):
|
||||
self._db_path = db_path
|
||||
self._max_concurrent = max_concurrent
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""DB 연결 및 테이블 생성."""
|
||||
self._db = await aiosqlite.connect(self._db_path)
|
||||
self._db.row_factory = aiosqlite.Row
|
||||
await self._db.execute(_CREATE_TABLE)
|
||||
await self._db.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""DB 연결 종료."""
|
||||
if self._db:
|
||||
await self._db.close()
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
thread_id: str,
|
||||
source: str,
|
||||
payload: dict,
|
||||
) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"INSERT INTO tasks (id, thread_id, source, payload, status, created_at) "
|
||||
"VALUES (?, ?, ?, ?, 'pending', ?)",
|
||||
(task_id, thread_id, source, json.dumps(payload), now),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source)
|
||||
return task_id
|
||||
|
||||
async def dequeue(self) -> dict | None:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row["cnt"] >= self._max_concurrent:
|
||||
return None
|
||||
|
||||
cursor = await self._db.execute(
|
||||
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?",
|
||||
(now, row["id"]),
|
||||
)
|
||||
await self._db.commit()
|
||||
|
||||
task = dict(row)
|
||||
task["payload"] = json.loads(task["payload"])
|
||||
task["status"] = "running"
|
||||
logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"])
|
||||
return task
|
||||
|
||||
async def mark_completed(self, task_id: str, result: dict | None = None) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?",
|
||||
(now, json.dumps(result or {}), task_id),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Task %s completed", task_id)
|
||||
|
||||
async def mark_failed(self, task_id: str, error: str = "") -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?",
|
||||
(now, json.dumps({"error": error}), task_id),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.info("Task %s failed: %s", task_id, error)
|
||||
|
||||
async def get_pending(self) -> list[dict]:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
task = dict(row)
|
||||
task["payload"] = json.loads(task["payload"])
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
async def has_running_task(self, thread_id: str) -> bool:
|
||||
cursor = await self._db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'",
|
||||
(thread_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"] > 0
|
||||
|
||||
|
||||
# 지연 초기화 싱글턴
|
||||
_queue: PersistentTaskQueue | None = None
|
||||
|
||||
|
||||
async def get_task_queue() -> PersistentTaskQueue:
|
||||
"""PersistentTaskQueue 싱글턴을 반환한다."""
|
||||
global _queue
|
||||
if _queue is None:
|
||||
db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db")
|
||||
_queue = PersistentTaskQueue(db_path=db_path)
|
||||
await _queue.initialize()
|
||||
return _queue
|
||||
@ -21,6 +21,7 @@ dependencies = [
|
||||
"pydantic-settings>=2.0.0",
|
||||
"slowapi>=0.1.9",
|
||||
"discord.py>=2.3.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
117
tests/test_task_queue.py
Normal file
117
tests/test_task_queue.py
Normal file
@ -0,0 +1,117 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.task_queue import PersistentTaskQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def task_queue():
|
||||
"""임시 SQLite DB로 TaskQueue를 생성한다."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
queue = PersistentTaskQueue(db_path=db_path)
|
||||
await queue.initialize()
|
||||
yield queue
|
||||
await queue.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_and_dequeue(task_queue):
|
||||
"""작업을 enqueue하고 dequeue한다."""
|
||||
task_id = await task_queue.enqueue(
|
||||
thread_id="thread-1",
|
||||
source="gitea",
|
||||
payload={"issue_number": 42, "body": "Fix bug"},
|
||||
)
|
||||
assert task_id is not None
|
||||
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
assert task["thread_id"] == "thread-1"
|
||||
assert task["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_empty_returns_none(task_queue):
|
||||
"""큐가 비어있으면 None을 반환한다."""
|
||||
task = await task_queue.dequeue()
|
||||
assert task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fifo_order(task_queue):
|
||||
"""FIFO 순서로 dequeue한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {"order": 1})
|
||||
await task_queue.enqueue("thread-2", "discord", {"order": 2})
|
||||
|
||||
task1 = await task_queue.dequeue()
|
||||
assert task1["payload"]["order"] == 1
|
||||
|
||||
# Complete first task before dequeuing second
|
||||
await task_queue.mark_completed(task1["id"])
|
||||
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2["payload"]["order"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_limit(task_queue):
|
||||
"""running 작업이 있으면 dequeue하지 않는다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {"msg": "first"})
|
||||
await task_queue.enqueue("thread-2", "gitea", {"msg": "second"})
|
||||
|
||||
task1 = await task_queue.dequeue()
|
||||
assert task1 is not None
|
||||
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_completed(task_queue):
|
||||
"""작업을 completed로 표시한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue()
|
||||
assert task is not None
|
||||
|
||||
await task_queue.mark_completed(task["id"], result={"pr_url": "http://..."})
|
||||
|
||||
await task_queue.enqueue("thread-2", "gitea", {})
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_failed(task_queue):
|
||||
"""작업을 failed로 표시한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue()
|
||||
|
||||
await task_queue.mark_failed(task["id"], error="Something broke")
|
||||
|
||||
await task_queue.enqueue("thread-2", "gitea", {})
|
||||
task2 = await task_queue.dequeue()
|
||||
assert task2 is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending(task_queue):
|
||||
"""미처리 작업 목록을 반환한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
await task_queue.enqueue("thread-2", "discord", {})
|
||||
|
||||
pending = await task_queue.get_pending()
|
||||
assert len(pending) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_running_task_for_thread(task_queue):
|
||||
"""특정 스레드에 실행 중인 작업이 있는지 확인한다."""
|
||||
await task_queue.enqueue("thread-1", "gitea", {})
|
||||
task = await task_queue.dequeue() # → running
|
||||
|
||||
assert await task_queue.has_running_task("thread-1") is True
|
||||
assert await task_queue.has_running_task("thread-2") is False
|
||||
11
uv.lock
generated
11
uv.lock
generated
@ -113,6 +113,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosqlite"
|
||||
version = "0.22.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.4"
|
||||
@ -694,6 +703,7 @@ name = "galaxis-agent"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "deepagents" },
|
||||
{ name = "discord-py" },
|
||||
@ -720,6 +730,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiosqlite", specifier = ">=0.20.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "deepagents", specifier = ">=0.4.3" },
|
||||
{ name = "discord-py", specifier = ">=2.3.0" },
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user