galaxis-agent/agent/message_store.py

87 lines
2.8 KiB
Python
Raw Normal View History

"""SQLite 기반 메시지 스토어.
에이전트 작업 도착하는 follow-up 메시지를 저장한다.
check_message_queue 미들웨어가 다음 모델 호출 주입한다.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
import aiosqlite
logger = logging.getLogger(__name__)
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS pending_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
message TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
_store: MessageStore | None = None
class MessageStore:
"""SQLite 기반 follow-up 메시지 스토어."""
def __init__(self, db_path: str = "/data/message_store.db"):
self._db_path = db_path
self._db: aiosqlite.Connection | None = None
async def initialize(self) -> None:
"""Initialize the database connection and create tables."""
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute(_CREATE_TABLE)
await self._db.commit()
async def close(self) -> None:
"""Close the database connection."""
if self._db:
await self._db.close()
async def push_message(self, thread_id: str, message: dict) -> None:
"""스레드에 pending 메시지를 추가한다."""
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
(thread_id, json.dumps(message), now),
)
await self._db.commit()
logger.debug("Pushed message for thread %s", thread_id)
async def get_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
cursor = await self._db.execute(
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
(thread_id,),
)
rows = await cursor.fetchall()
return [json.loads(row[0]) for row in rows]
async def consume_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회하고 삭제한다."""
messages = await self.get_messages(thread_id)
await self._db.execute(
"DELETE FROM pending_messages WHERE thread_id = ?",
(thread_id,),
)
await self._db.commit()
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
return messages
async def get_message_store() -> MessageStore:
"""MessageStore 싱글턴을 반환한다."""
global _store
if _store is None:
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
_store = MessageStore(db_path=db_path)
await _store.initialize()
return _store