"""SQLite 기반 영속 작업 큐. 동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다. """ from __future__ import annotations import json import logging import os import uuid from datetime import datetime, timezone import aiosqlite logger = logging.getLogger(__name__) _CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, thread_id TEXT NOT NULL, source TEXT NOT NULL, payload TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'pending', created_at TEXT NOT NULL, started_at TEXT, completed_at TEXT, result TEXT ) """ class PersistentTaskQueue: """SQLite 기반 영속 작업 큐.""" def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1): self._db_path = db_path self._max_concurrent = max_concurrent self._db: aiosqlite.Connection | None = None async def initialize(self) -> None: """DB 연결 및 테이블 생성.""" self._db = await aiosqlite.connect(self._db_path) self._db.row_factory = aiosqlite.Row await self._db.execute(_CREATE_TABLE) await self._db.commit() async def close(self) -> None: """DB 연결 종료.""" if self._db: await self._db.close() async def enqueue( self, thread_id: str, source: str, payload: dict, ) -> str: task_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() await self._db.execute( "INSERT INTO tasks (id, thread_id, source, payload, status, created_at) " "VALUES (?, ?, ?, ?, 'pending', ?)", (task_id, thread_id, source, json.dumps(payload), now), ) await self._db.commit() logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source) return task_id async def dequeue(self) -> dict | None: cursor = await self._db.execute( "SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'" ) row = await cursor.fetchone() if row["cnt"] >= self._max_concurrent: return None cursor = await self._db.execute( "SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1" ) row = await cursor.fetchone() if not row: return None now = datetime.now(timezone.utc).isoformat() await self._db.execute( "UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?", (now, row["id"]), ) await self._db.commit() task = dict(row) task["payload"] = json.loads(task["payload"]) task["status"] = "running" logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"]) return task async def mark_completed(self, task_id: str, result: dict | None = None) -> None: now = datetime.now(timezone.utc).isoformat() await self._db.execute( "UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?", (now, json.dumps(result or {}), task_id), ) await self._db.commit() logger.info("Task %s completed", task_id) async def mark_failed(self, task_id: str, error: str = "") -> None: now = datetime.now(timezone.utc).isoformat() await self._db.execute( "UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?", (now, json.dumps({"error": error}), task_id), ) await self._db.commit() logger.info("Task %s failed: %s", task_id, error) async def get_pending(self) -> list[dict]: cursor = await self._db.execute( "SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC" ) rows = await cursor.fetchall() result = [] for row in rows: task = dict(row) task["payload"] = json.loads(task["payload"]) result.append(task) return result async def has_running_task(self, thread_id: str) -> bool: cursor = await self._db.execute( "SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'", (thread_id,), ) row = await cursor.fetchone() return row["cnt"] > 0 async def reset_running_to_pending(self) -> int: """running 상태 작업을 pending으로 리셋한다 (복구용). Returns: 리셋된 작업 수. """ cursor = await self._db.execute( "SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'" ) row = await cursor.fetchone() count = row["cnt"] if count: await self._db.execute( "UPDATE tasks SET status = 'pending', started_at = NULL WHERE status = 'running'" ) await self._db.commit() return count # 지연 초기화 싱글턴 _queue: PersistentTaskQueue | None = None async def get_task_queue() -> PersistentTaskQueue: """PersistentTaskQueue 싱글턴을 반환한다.""" global _queue if _queue is None: db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db") _queue = PersistentTaskQueue(db_path=db_path) await _queue.initialize() return _queue