feat: add MessageStore and update check_message_queue middleware

- Implement SQLite-based MessageStore for follow-up messages
- Replace LangGraph store with MessageStore in check_message_queue middleware
- Preserve all multimodal content parsing logic
- Add comprehensive tests for MessageStore (4 tests, all passing)
- All 85 tests pass
This commit is contained in:
머니페니 2026-03-20 18:11:15 +09:00
parent 0136823462
commit 9242badeff
3 changed files with 155 additions and 20 deletions

86
agent/message_store.py Normal file
View File

@ -0,0 +1,86 @@
"""SQLite 기반 메시지 스토어.
에이전트 작업 도착하는 follow-up 메시지를 저장한다.
check_message_queue 미들웨어가 다음 모델 호출 주입한다.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
import aiosqlite
logger = logging.getLogger(__name__)
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS pending_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
thread_id TEXT NOT NULL,
message TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
_store: MessageStore | None = None
class MessageStore:
"""SQLite 기반 follow-up 메시지 스토어."""
def __init__(self, db_path: str = "/data/message_store.db"):
self._db_path = db_path
self._db: aiosqlite.Connection | None = None
async def initialize(self) -> None:
"""Initialize the database connection and create tables."""
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute(_CREATE_TABLE)
await self._db.commit()
async def close(self) -> None:
"""Close the database connection."""
if self._db:
await self._db.close()
async def push_message(self, thread_id: str, message: dict) -> None:
"""스레드에 pending 메시지를 추가한다."""
now = datetime.now(timezone.utc).isoformat()
await self._db.execute(
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
(thread_id, json.dumps(message), now),
)
await self._db.commit()
logger.debug("Pushed message for thread %s", thread_id)
async def get_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
cursor = await self._db.execute(
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
(thread_id,),
)
rows = await cursor.fetchall()
return [json.loads(row[0]) for row in rows]
async def consume_messages(self, thread_id: str) -> list[dict]:
"""스레드의 pending 메시지를 조회하고 삭제한다."""
messages = await self.get_messages(thread_id)
await self._db.execute(
"DELETE FROM pending_messages WHERE thread_id = ?",
(thread_id,),
)
await self._db.commit()
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
return messages
async def get_message_store() -> MessageStore:
"""MessageStore 싱글턴을 반환한다."""
global _store
if _store is None:
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
_store = MessageStore(db_path=db_path)
await _store.initialize()
return _store

View File

@ -1,6 +1,6 @@
"""Before-model middleware that injects queued messages into state. """Before-model middleware that injects queued messages into state.
Checks the LangGraph store for pending messages (e.g. follow-up Linear Checks the MessageStore for pending messages (e.g. follow-up Linear
comments that arrived while the agent was busy) and injects them as new comments that arrived while the agent was busy) and injects them as new
human messages before the next model call. human messages before the next model call.
""" """
@ -12,9 +12,10 @@ from typing import Any
import httpx import httpx
from langchain.agents.middleware import AgentState, before_model from langchain.agents.middleware import AgentState, before_model
from langgraph.config import get_config, get_store from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from ..message_store import get_message_store
from ..utils.multimodal import fetch_image_block from ..utils.multimodal import fetch_image_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911
return None return None
try: try:
store = get_store() store = await get_message_store()
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
logger.debug("Could not get store from context: %s", e) logger.warning("Could not get message store: %s", e)
return None return None
if store is None:
return None
namespace = ("queue", thread_id)
try: try:
queued_item = await store.aget(namespace, "pending_messages") queued_messages = await store.consume_messages(thread_id)
except Exception as e: # noqa: BLE001 except Exception as e: # noqa: BLE001
logger.warning("Failed to get queued item: %s", e) logger.warning("Failed to consume messages: %s", e)
return None return None
if queued_item is None:
return None
queued_value = queued_item.value
queued_messages = queued_value.get("messages", [])
# Delete early to prevent duplicate processing if middleware runs again
await store.adelete(namespace, "pending_messages")
if not queued_messages: if not queued_messages:
return None return None

View File

@ -0,0 +1,62 @@
"""Tests for MessageStore."""
import os
import tempfile
import pytest
from agent.message_store import MessageStore
@pytest.fixture
async def store():
"""Create a temporary MessageStore for testing."""
fd, db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
s = MessageStore(db_path=db_path)
await s.initialize()
yield s
await s.close()
os.unlink(db_path)
@pytest.mark.asyncio
async def test_push_and_get_messages(store):
"""메시지를 push하고 get한다."""
await store.push_message("thread-1", {"role": "human", "content": "추가 요청"})
messages = await store.get_messages("thread-1")
assert len(messages) == 1
assert messages[0]["content"] == "추가 요청"
@pytest.mark.asyncio
async def test_consume_messages(store):
"""메시지를 소비하면 삭제된다."""
await store.push_message("thread-1", {"role": "human", "content": "msg1"})
await store.push_message("thread-1", {"role": "human", "content": "msg2"})
messages = await store.consume_messages("thread-1")
assert len(messages) == 2
remaining = await store.get_messages("thread-1")
assert len(remaining) == 0
@pytest.mark.asyncio
async def test_no_messages_returns_empty(store):
"""메시지가 없으면 빈 리스트를 반환한다."""
messages = await store.get_messages("thread-999")
assert messages == []
@pytest.mark.asyncio
async def test_messages_isolated_by_thread(store):
"""스레드별로 메시지가 격리된다."""
await store.push_message("thread-1", {"content": "for thread 1"})
await store.push_message("thread-2", {"content": "for thread 2"})
msgs1 = await store.get_messages("thread-1")
msgs2 = await store.get_messages("thread-2")
assert len(msgs1) == 1
assert len(msgs2) == 1
assert msgs1[0]["content"] == "for thread 1"