feat: add MessageStore and update check_message_queue middleware
- Implement SQLite-based MessageStore for follow-up messages - Replace LangGraph store with MessageStore in check_message_queue middleware - Preserve all multimodal content parsing logic - Add comprehensive tests for MessageStore (4 tests, all passing) - All 85 tests pass
This commit is contained in:
parent
0136823462
commit
9242badeff
86
agent/message_store.py
Normal file
86
agent/message_store.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""SQLite 기반 메시지 스토어.
|
||||
|
||||
에이전트 작업 중 도착하는 follow-up 메시지를 저장한다.
|
||||
check_message_queue 미들웨어가 다음 모델 호출 시 주입한다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS pending_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
thread_id TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
|
||||
_store: MessageStore | None = None
|
||||
|
||||
|
||||
class MessageStore:
|
||||
"""SQLite 기반 follow-up 메시지 스토어."""
|
||||
|
||||
def __init__(self, db_path: str = "/data/message_store.db"):
|
||||
self._db_path = db_path
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the database connection and create tables."""
|
||||
self._db = await aiosqlite.connect(self._db_path)
|
||||
await self._db.execute(_CREATE_TABLE)
|
||||
await self._db.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._db:
|
||||
await self._db.close()
|
||||
|
||||
async def push_message(self, thread_id: str, message: dict) -> None:
|
||||
"""스레드에 pending 메시지를 추가한다."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self._db.execute(
|
||||
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
|
||||
(thread_id, json.dumps(message), now),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.debug("Pushed message for thread %s", thread_id)
|
||||
|
||||
async def get_messages(self, thread_id: str) -> list[dict]:
|
||||
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
|
||||
cursor = await self._db.execute(
|
||||
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
|
||||
(thread_id,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [json.loads(row[0]) for row in rows]
|
||||
|
||||
async def consume_messages(self, thread_id: str) -> list[dict]:
|
||||
"""스레드의 pending 메시지를 조회하고 삭제한다."""
|
||||
messages = await self.get_messages(thread_id)
|
||||
await self._db.execute(
|
||||
"DELETE FROM pending_messages WHERE thread_id = ?",
|
||||
(thread_id,),
|
||||
)
|
||||
await self._db.commit()
|
||||
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
|
||||
return messages
|
||||
|
||||
|
||||
async def get_message_store() -> MessageStore:
|
||||
"""MessageStore 싱글턴을 반환한다."""
|
||||
global _store
|
||||
if _store is None:
|
||||
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
|
||||
_store = MessageStore(db_path=db_path)
|
||||
await _store.initialize()
|
||||
return _store
|
||||
@ -1,6 +1,6 @@
|
||||
"""Before-model middleware that injects queued messages into state.
|
||||
|
||||
Checks the LangGraph store for pending messages (e.g. follow-up Linear
|
||||
Checks the MessageStore for pending messages (e.g. follow-up Linear
|
||||
comments that arrived while the agent was busy) and injects them as new
|
||||
human messages before the next model call.
|
||||
"""
|
||||
@ -12,9 +12,10 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain.agents.middleware import AgentState, before_model
|
||||
from langgraph.config import get_config, get_store
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ..message_store import get_message_store
|
||||
from ..utils.multimodal import fetch_image_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911
|
||||
return None
|
||||
|
||||
try:
|
||||
store = get_store()
|
||||
store = await get_message_store()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("Could not get store from context: %s", e)
|
||||
logger.warning("Could not get message store: %s", e)
|
||||
return None
|
||||
|
||||
if store is None:
|
||||
return None
|
||||
|
||||
namespace = ("queue", thread_id)
|
||||
|
||||
try:
|
||||
queued_item = await store.aget(namespace, "pending_messages")
|
||||
queued_messages = await store.consume_messages(thread_id)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to get queued item: %s", e)
|
||||
logger.warning("Failed to consume messages: %s", e)
|
||||
return None
|
||||
|
||||
if queued_item is None:
|
||||
return None
|
||||
|
||||
queued_value = queued_item.value
|
||||
queued_messages = queued_value.get("messages", [])
|
||||
|
||||
# Delete early to prevent duplicate processing if middleware runs again
|
||||
await store.adelete(namespace, "pending_messages")
|
||||
|
||||
if not queued_messages:
|
||||
return None
|
||||
|
||||
|
||||
62
tests/test_message_store.py
Normal file
62
tests/test_message_store.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Tests for MessageStore."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.message_store import MessageStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def store():
|
||||
"""Create a temporary MessageStore for testing."""
|
||||
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||
os.close(fd)
|
||||
s = MessageStore(db_path=db_path)
|
||||
await s.initialize()
|
||||
yield s
|
||||
await s.close()
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_get_messages(store):
|
||||
"""메시지를 push하고 get한다."""
|
||||
await store.push_message("thread-1", {"role": "human", "content": "추가 요청"})
|
||||
messages = await store.get_messages("thread-1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "추가 요청"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consume_messages(store):
|
||||
"""메시지를 소비하면 삭제된다."""
|
||||
await store.push_message("thread-1", {"role": "human", "content": "msg1"})
|
||||
await store.push_message("thread-1", {"role": "human", "content": "msg2"})
|
||||
|
||||
messages = await store.consume_messages("thread-1")
|
||||
assert len(messages) == 2
|
||||
|
||||
remaining = await store.get_messages("thread-1")
|
||||
assert len(remaining) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_messages_returns_empty(store):
|
||||
"""메시지가 없으면 빈 리스트를 반환한다."""
|
||||
messages = await store.get_messages("thread-999")
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_isolated_by_thread(store):
|
||||
"""스레드별로 메시지가 격리된다."""
|
||||
await store.push_message("thread-1", {"content": "for thread 1"})
|
||||
await store.push_message("thread-2", {"content": "for thread 2"})
|
||||
|
||||
msgs1 = await store.get_messages("thread-1")
|
||||
msgs2 = await store.get_messages("thread-2")
|
||||
assert len(msgs1) == 1
|
||||
assert len(msgs2) == 1
|
||||
assert msgs1[0]["content"] == "for thread 1"
|
||||
Loading…
x
Reference in New Issue
Block a user