galaxis-agent/agent/task_queue.py

167 lines
5.3 KiB
Python
Raw Normal View History

"""SQLite 기반 영속 작업 큐.
동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다.
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from datetime import datetime, timezone
import aiosqlite
logger = logging.getLogger(__name__)
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL,
source TEXT NOT NULL,
payload TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
result TEXT
)
"""
class PersistentTaskQueue:
"""SQLite 기반 영속 작업 큐."""
def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1):
self._db_path = db_path
self._max_concurrent = max_concurrent
self._db: aiosqlite.Connection | None = None
async def initialize(self) -> None:
"""DB 연결 및 테이블 생성."""
self._db = await aiosqlite.connect(self._db_path)
self._db.row_factory = aiosqlite.Row
await self._db.execute(_CREATE_TABLE)
await self._db.commit()
async def close(self) -> None:
"""DB 연결 종료."""
if self._db:
await self._db.close()
async def enqueue(
self,
thread_id: str,
source: str,
payload: dict,
) -> str:
task_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"INSERT INTO tasks (id, thread_id, source, payload, status, created_at) "
"VALUES (?, ?, ?, ?, 'pending', ?)",
(task_id, thread_id, source, json.dumps(payload), now),
)
await self._db.commit()
logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source)
return task_id
async def dequeue(self) -> dict | None:
cursor = await self._db.execute(
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
)
row = await cursor.fetchone()
if row["cnt"] >= self._max_concurrent:
return None
cursor = await self._db.execute(
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1"
)
row = await cursor.fetchone()
if not row:
return None
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?",
(now, row["id"]),
)
await self._db.commit()
task = dict(row)
task["payload"] = json.loads(task["payload"])
task["status"] = "running"
logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"])
return task
async def mark_completed(self, task_id: str, result: dict | None = None) -> None:
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?",
(now, json.dumps(result or {}), task_id),
)
await self._db.commit()
logger.info("Task %s completed", task_id)
async def mark_failed(self, task_id: str, error: str = "") -> None:
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?",
(now, json.dumps({"error": error}), task_id),
)
await self._db.commit()
logger.info("Task %s failed: %s", task_id, error)
async def get_pending(self) -> list[dict]:
cursor = await self._db.execute(
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC"
)
rows = await cursor.fetchall()
result = []
for row in rows:
task = dict(row)
task["payload"] = json.loads(task["payload"])
result.append(task)
return result
async def has_running_task(self, thread_id: str) -> bool:
cursor = await self._db.execute(
"SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'",
(thread_id,),
)
row = await cursor.fetchone()
return row["cnt"] > 0
async def reset_running_to_pending(self) -> int:
"""running 상태 작업을 pending으로 리셋한다 (복구용).
Returns:
리셋된 작업 .
"""
cursor = await self._db.execute(
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
)
row = await cursor.fetchone()
count = row["cnt"]
if count:
await self._db.execute(
"UPDATE tasks SET status = 'pending', started_at = NULL WHERE status = 'running'"
)
await self._db.commit()
return count
# 지연 초기화 싱글턴
_queue: PersistentTaskQueue | None = None
async def get_task_queue() -> PersistentTaskQueue:
"""PersistentTaskQueue 싱글턴을 반환한다."""
global _queue
if _queue is None:
db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db")
_queue = PersistentTaskQueue(db_path=db_path)
await _queue.initialize()
return _queue