149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
|
|
"""SQLite 기반 영속 작업 큐.
|
||
|
|
|
||
|
|
동시 작업 수를 제한하고, 서버 재시작 시에도 작업이 유실되지 않도록 한다.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
|
||
|
|
import aiosqlite
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
_CREATE_TABLE = """
|
||
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
||
|
|
id TEXT PRIMARY KEY,
|
||
|
|
thread_id TEXT NOT NULL,
|
||
|
|
source TEXT NOT NULL,
|
||
|
|
payload TEXT NOT NULL,
|
||
|
|
status TEXT NOT NULL DEFAULT 'pending',
|
||
|
|
created_at TEXT NOT NULL,
|
||
|
|
started_at TEXT,
|
||
|
|
completed_at TEXT,
|
||
|
|
result TEXT
|
||
|
|
)
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
class PersistentTaskQueue:
|
||
|
|
"""SQLite 기반 영속 작업 큐."""
|
||
|
|
|
||
|
|
def __init__(self, db_path: str = "/data/task_queue.db", max_concurrent: int = 1):
|
||
|
|
self._db_path = db_path
|
||
|
|
self._max_concurrent = max_concurrent
|
||
|
|
self._db: aiosqlite.Connection | None = None
|
||
|
|
|
||
|
|
async def initialize(self) -> None:
|
||
|
|
"""DB 연결 및 테이블 생성."""
|
||
|
|
self._db = await aiosqlite.connect(self._db_path)
|
||
|
|
self._db.row_factory = aiosqlite.Row
|
||
|
|
await self._db.execute(_CREATE_TABLE)
|
||
|
|
await self._db.commit()
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
"""DB 연결 종료."""
|
||
|
|
if self._db:
|
||
|
|
await self._db.close()
|
||
|
|
|
||
|
|
async def enqueue(
|
||
|
|
self,
|
||
|
|
thread_id: str,
|
||
|
|
source: str,
|
||
|
|
payload: dict,
|
||
|
|
) -> str:
|
||
|
|
task_id = str(uuid.uuid4())
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
await self._db.execute(
|
||
|
|
"INSERT INTO tasks (id, thread_id, source, payload, status, created_at) "
|
||
|
|
"VALUES (?, ?, ?, ?, 'pending', ?)",
|
||
|
|
(task_id, thread_id, source, json.dumps(payload), now),
|
||
|
|
)
|
||
|
|
await self._db.commit()
|
||
|
|
logger.info("Enqueued task %s for thread %s from %s", task_id, thread_id, source)
|
||
|
|
return task_id
|
||
|
|
|
||
|
|
async def dequeue(self) -> dict | None:
|
||
|
|
cursor = await self._db.execute(
|
||
|
|
"SELECT COUNT(*) as cnt FROM tasks WHERE status = 'running'"
|
||
|
|
)
|
||
|
|
row = await cursor.fetchone()
|
||
|
|
if row["cnt"] >= self._max_concurrent:
|
||
|
|
return None
|
||
|
|
|
||
|
|
cursor = await self._db.execute(
|
||
|
|
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1"
|
||
|
|
)
|
||
|
|
row = await cursor.fetchone()
|
||
|
|
if not row:
|
||
|
|
return None
|
||
|
|
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
await self._db.execute(
|
||
|
|
"UPDATE tasks SET status = 'running', started_at = ? WHERE id = ?",
|
||
|
|
(now, row["id"]),
|
||
|
|
)
|
||
|
|
await self._db.commit()
|
||
|
|
|
||
|
|
task = dict(row)
|
||
|
|
task["payload"] = json.loads(task["payload"])
|
||
|
|
task["status"] = "running"
|
||
|
|
logger.info("Dequeued task %s (thread %s)", task["id"], task["thread_id"])
|
||
|
|
return task
|
||
|
|
|
||
|
|
async def mark_completed(self, task_id: str, result: dict | None = None) -> None:
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
await self._db.execute(
|
||
|
|
"UPDATE tasks SET status = 'completed', completed_at = ?, result = ? WHERE id = ?",
|
||
|
|
(now, json.dumps(result or {}), task_id),
|
||
|
|
)
|
||
|
|
await self._db.commit()
|
||
|
|
logger.info("Task %s completed", task_id)
|
||
|
|
|
||
|
|
async def mark_failed(self, task_id: str, error: str = "") -> None:
|
||
|
|
now = datetime.now(timezone.utc).isoformat()
|
||
|
|
await self._db.execute(
|
||
|
|
"UPDATE tasks SET status = 'failed', completed_at = ?, result = ? WHERE id = ?",
|
||
|
|
(now, json.dumps({"error": error}), task_id),
|
||
|
|
)
|
||
|
|
await self._db.commit()
|
||
|
|
logger.info("Task %s failed: %s", task_id, error)
|
||
|
|
|
||
|
|
async def get_pending(self) -> list[dict]:
|
||
|
|
cursor = await self._db.execute(
|
||
|
|
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC"
|
||
|
|
)
|
||
|
|
rows = await cursor.fetchall()
|
||
|
|
result = []
|
||
|
|
for row in rows:
|
||
|
|
task = dict(row)
|
||
|
|
task["payload"] = json.loads(task["payload"])
|
||
|
|
result.append(task)
|
||
|
|
return result
|
||
|
|
|
||
|
|
async def has_running_task(self, thread_id: str) -> bool:
|
||
|
|
cursor = await self._db.execute(
|
||
|
|
"SELECT COUNT(*) as cnt FROM tasks WHERE thread_id = ? AND status = 'running'",
|
||
|
|
(thread_id,),
|
||
|
|
)
|
||
|
|
row = await cursor.fetchone()
|
||
|
|
return row["cnt"] > 0
|
||
|
|
|
||
|
|
|
||
|
|
# 지연 초기화 싱글턴
|
||
|
|
_queue: PersistentTaskQueue | None = None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_task_queue() -> PersistentTaskQueue:
|
||
|
|
"""PersistentTaskQueue 싱글턴을 반환한다."""
|
||
|
|
global _queue
|
||
|
|
if _queue is None:
|
||
|
|
db_path = os.environ.get("TASK_QUEUE_DB", "/data/task_queue.db")
|
||
|
|
_queue = PersistentTaskQueue(db_path=db_path)
|
||
|
|
await _queue.initialize()
|
||
|
|
return _queue
|