- Implement SQLite-based MessageStore for follow-up messages - Replace LangGraph store with MessageStore in check_message_queue middleware - Preserve all multimodal content parsing logic - Add comprehensive tests for MessageStore (4 tests, all passing) - All 85 tests pass
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
"""SQLite 기반 메시지 스토어.
|
|
|
|
에이전트 작업 중 도착하는 follow-up 메시지를 저장한다.
|
|
check_message_queue 미들웨어가 다음 모델 호출 시 주입한다.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
|
|
import aiosqlite
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CREATE_TABLE = """
|
|
CREATE TABLE IF NOT EXISTS pending_messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
thread_id TEXT NOT NULL,
|
|
message TEXT NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
)
|
|
"""
|
|
|
|
_store: MessageStore | None = None
|
|
|
|
|
|
class MessageStore:
|
|
"""SQLite 기반 follow-up 메시지 스토어."""
|
|
|
|
def __init__(self, db_path: str = "/data/message_store.db"):
|
|
self._db_path = db_path
|
|
self._db: aiosqlite.Connection | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the database connection and create tables."""
|
|
self._db = await aiosqlite.connect(self._db_path)
|
|
await self._db.execute(_CREATE_TABLE)
|
|
await self._db.commit()
|
|
|
|
async def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self._db:
|
|
await self._db.close()
|
|
|
|
async def push_message(self, thread_id: str, message: dict) -> None:
|
|
"""스레드에 pending 메시지를 추가한다."""
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
await self._db.execute(
|
|
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
|
|
(thread_id, json.dumps(message), now),
|
|
)
|
|
await self._db.commit()
|
|
logger.debug("Pushed message for thread %s", thread_id)
|
|
|
|
async def get_messages(self, thread_id: str) -> list[dict]:
|
|
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
|
|
cursor = await self._db.execute(
|
|
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
|
|
(thread_id,),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [json.loads(row[0]) for row in rows]
|
|
|
|
async def consume_messages(self, thread_id: str) -> list[dict]:
|
|
"""스레드의 pending 메시지를 조회하고 삭제한다."""
|
|
messages = await self.get_messages(thread_id)
|
|
await self._db.execute(
|
|
"DELETE FROM pending_messages WHERE thread_id = ?",
|
|
(thread_id,),
|
|
)
|
|
await self._db.commit()
|
|
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
|
|
return messages
|
|
|
|
|
|
async def get_message_store() -> MessageStore:
|
|
"""MessageStore 싱글턴을 반환한다."""
|
|
global _store
|
|
if _store is None:
|
|
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
|
|
_store = MessageStore(db_path=db_path)
|
|
await _store.initialize()
|
|
return _store
|