feat: add MessageStore and update check_message_queue middleware

- Implement SQLite-based MessageStore for follow-up messages
- Replace LangGraph store with MessageStore in check_message_queue middleware
- Preserve all multimodal content parsing logic
- Add comprehensive tests for MessageStore (4 tests, all passing)
- All 85 tests pass
This commit is contained in:
머니페니 2026-03-20 18:11:15 +09:00
parent 0136823462
commit 9242badeff
3 changed files with 155 additions and 20 deletions

86
agent/message_store.py Normal file
View File

@ -0,0 +1,86 @@
"""SQLite 기반 메시지 스토어.
에이전트 작업 도착하는 follow-up 메시지를 저장한다.
check_message_queue 미들웨어가 다음 모델 호출 주입한다.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
import aiosqlite
logger = logging.getLogger(__name__)
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS pending_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
message TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
_store: MessageStore | None = None
class MessageStore:
"""SQLite 기반 follow-up 메시지 스토어."""
def __init__(self, db_path: str = "/data/message_store.db"):
self._db_path = db_path
self._db: aiosqlite.Connection | None = None
async def initialize(self) -> None:
"""Initialize the database connection and create tables."""
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute(_CREATE_TABLE)
await self._db.commit()
async def close(self) -> None:
"""Close the database connection."""
if self._db:
await self._db.close()
async def push_message(self, thread_id: str, message: dict) -> None:
"""스레드에 pending 메시지를 추가한다."""
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
(thread_id, json.dumps(message), now),
)
await self._db.commit()
logger.debug("Pushed message for thread %s", thread_id)
async def get_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
cursor = await self._db.execute(
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
(thread_id,),
)
rows = await cursor.fetchall()
return [json.loads(row[0]) for row in rows]
async def consume_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회하고 삭제한다."""
messages = await self.get_messages(thread_id)
await self._db.execute(
"DELETE FROM pending_messages WHERE thread_id = ?",
(thread_id,),
)
await self._db.commit()
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
return messages
async def get_message_store() -> MessageStore:
"""MessageStore 싱글턴을 반환한다."""
global _store
if _store is None:
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
_store = MessageStore(db_path=db_path)
await _store.initialize()
return _store

View File

@ -1,6 +1,6 @@
"""Before-model middleware that injects queued messages into state.
Checks the LangGraph store for pending messages (e.g. follow-up Linear
Checks the MessageStore for pending messages (e.g. follow-up Linear
comments that arrived while the agent was busy) and injects them as new
human messages before the next model call.
"""
@ -12,9 +12,10 @@ from typing import Any
import httpx
from langchain.agents.middleware import AgentState, before_model
from langgraph.config import get_config, get_store
from langgraph.config import get_config
from langgraph.runtime import Runtime
from ..message_store import get_message_store
from ..utils.multimodal import fetch_image_block
logger = logging.getLogger(__name__)
@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911
return None
try:
store = get_store()
store = await get_message_store()
except Exception as e: # noqa: BLE001
logger.debug("Could not get store from context: %s", e)
logger.warning("Could not get message store: %s", e)
return None
if store is None:
return None
namespace = ("queue", thread_id)
try:
queued_item = await store.aget(namespace, "pending_messages")
queued_messages = await store.consume_messages(thread_id)
except Exception as e: # noqa: BLE001
logger.warning("Failed to get queued item: %s", e)
logger.warning("Failed to consume messages: %s", e)
return None
if queued_item is None:
return None
queued_value = queued_item.value
queued_messages = queued_value.get("messages", [])
# Delete early to prevent duplicate processing if middleware runs again
await store.adelete(namespace, "pending_messages")
if not queued_messages:
return None

View File

@ -0,0 +1,62 @@
"""Tests for MessageStore."""
import os
import tempfile
import pytest
from agent.message_store import MessageStore
@pytest.fixture
async def store():
"""Create a temporary MessageStore for testing."""
fd, db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
s = MessageStore(db_path=db_path)
await s.initialize()
yield s
await s.close()
os.unlink(db_path)
@pytest.mark.asyncio
async def test_push_and_get_messages(store):
"""메시지를 push하고 get한다."""
await store.push_message("thread-1", {"role": "human", "content": "추가 요청"})
messages = await store.get_messages("thread-1")
assert len(messages) == 1
assert messages[0]["content"] == "추가 요청"
@pytest.mark.asyncio
async def test_consume_messages(store):
"""메시지를 소비하면 삭제된다."""
await store.push_message("thread-1", {"role": "human", "content": "msg1"})
await store.push_message("thread-1", {"role": "human", "content": "msg2"})
messages = await store.consume_messages("thread-1")
assert len(messages) == 2
remaining = await store.get_messages("thread-1")
assert len(remaining) == 0
@pytest.mark.asyncio
async def test_no_messages_returns_empty(store):
"""메시지가 없으면 빈 리스트를 반환한다."""
messages = await store.get_messages("thread-999")
assert messages == []
@pytest.mark.asyncio
async def test_messages_isolated_by_thread(store):
"""스레드별로 메시지가 격리된다."""
await store.push_message("thread-1", {"content": "for thread 1"})
await store.push_message("thread-2", {"content": "for thread 2"})
msgs1 = await store.get_messages("thread-1")
msgs2 = await store.get_messages("thread-2")
assert len(msgs1) == 1
assert len(msgs2) == 1
assert msgs1[0]["content"] == "for thread 1"