feat: add MessageStore and update check_message_queue middleware
- Implement SQLite-based MessageStore for follow-up messages - Replace LangGraph store with MessageStore in check_message_queue middleware - Preserve all multimodal content parsing logic - Add comprehensive tests for MessageStore (4 tests, all passing) - All 85 tests pass
This commit is contained in:
parent
0136823462
commit
9242badeff
86
agent/message_store.py
Normal file
86
agent/message_store.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""SQLite 기반 메시지 스토어.
|
||||||
|
|
||||||
|
에이전트 작업 중 도착하는 follow-up 메시지를 저장한다.
|
||||||
|
check_message_queue 미들웨어가 다음 모델 호출 시 주입한다.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CREATE_TABLE = """
|
||||||
|
CREATE TABLE IF NOT EXISTS pending_messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_store: MessageStore | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStore:
|
||||||
|
"""SQLite 기반 follow-up 메시지 스토어."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "/data/message_store.db"):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the database connection and create tables."""
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute(_CREATE_TABLE)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the database connection."""
|
||||||
|
if self._db:
|
||||||
|
await self._db.close()
|
||||||
|
|
||||||
|
async def push_message(self, thread_id: str, message: dict) -> None:
|
||||||
|
"""스레드에 pending 메시지를 추가한다."""
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)",
|
||||||
|
(thread_id, json.dumps(message), now),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.debug("Pushed message for thread %s", thread_id)
|
||||||
|
|
||||||
|
async def get_messages(self, thread_id: str) -> list[dict]:
|
||||||
|
"""스레드의 pending 메시지를 조회한다 (삭제하지 않음)."""
|
||||||
|
cursor = await self._db.execute(
|
||||||
|
"SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC",
|
||||||
|
(thread_id,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [json.loads(row[0]) for row in rows]
|
||||||
|
|
||||||
|
async def consume_messages(self, thread_id: str) -> list[dict]:
|
||||||
|
"""스레드의 pending 메시지를 조회하고 삭제한다."""
|
||||||
|
messages = await self.get_messages(thread_id)
|
||||||
|
await self._db.execute(
|
||||||
|
"DELETE FROM pending_messages WHERE thread_id = ?",
|
||||||
|
(thread_id,),
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
logger.debug("Consumed %d messages for thread %s", len(messages), thread_id)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def get_message_store() -> MessageStore:
|
||||||
|
"""MessageStore 싱글턴을 반환한다."""
|
||||||
|
global _store
|
||||||
|
if _store is None:
|
||||||
|
db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db")
|
||||||
|
_store = MessageStore(db_path=db_path)
|
||||||
|
await _store.initialize()
|
||||||
|
return _store
|
||||||
@ -1,6 +1,6 @@
|
|||||||
"""Before-model middleware that injects queued messages into state.
|
"""Before-model middleware that injects queued messages into state.
|
||||||
|
|
||||||
Checks the LangGraph store for pending messages (e.g. follow-up Linear
|
Checks the MessageStore for pending messages (e.g. follow-up Linear
|
||||||
comments that arrived while the agent was busy) and injects them as new
|
comments that arrived while the agent was busy) and injects them as new
|
||||||
human messages before the next model call.
|
human messages before the next model call.
|
||||||
"""
|
"""
|
||||||
@ -12,9 +12,10 @@ from typing import Any
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from langchain.agents.middleware import AgentState, before_model
|
from langchain.agents.middleware import AgentState, before_model
|
||||||
from langgraph.config import get_config, get_store
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from ..message_store import get_message_store
|
||||||
from ..utils.multimodal import fetch_image_block
|
from ..utils.multimodal import fetch_image_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
store = get_store()
|
store = await get_message_store()
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
logger.debug("Could not get store from context: %s", e)
|
logger.warning("Could not get message store: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if store is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
namespace = ("queue", thread_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
queued_item = await store.aget(namespace, "pending_messages")
|
queued_messages = await store.consume_messages(thread_id)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
logger.warning("Failed to get queued item: %s", e)
|
logger.warning("Failed to consume messages: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if queued_item is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
queued_value = queued_item.value
|
|
||||||
queued_messages = queued_value.get("messages", [])
|
|
||||||
|
|
||||||
# Delete early to prevent duplicate processing if middleware runs again
|
|
||||||
await store.adelete(namespace, "pending_messages")
|
|
||||||
|
|
||||||
if not queued_messages:
|
if not queued_messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
62
tests/test_message_store.py
Normal file
62
tests/test_message_store.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""Tests for MessageStore."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agent.message_store import MessageStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def store():
|
||||||
|
"""Create a temporary MessageStore for testing."""
|
||||||
|
fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
s = MessageStore(db_path=db_path)
|
||||||
|
await s.initialize()
|
||||||
|
yield s
|
||||||
|
await s.close()
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_push_and_get_messages(store):
|
||||||
|
"""메시지를 push하고 get한다."""
|
||||||
|
await store.push_message("thread-1", {"role": "human", "content": "추가 요청"})
|
||||||
|
messages = await store.get_messages("thread-1")
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0]["content"] == "추가 요청"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consume_messages(store):
|
||||||
|
"""메시지를 소비하면 삭제된다."""
|
||||||
|
await store.push_message("thread-1", {"role": "human", "content": "msg1"})
|
||||||
|
await store.push_message("thread-1", {"role": "human", "content": "msg2"})
|
||||||
|
|
||||||
|
messages = await store.consume_messages("thread-1")
|
||||||
|
assert len(messages) == 2
|
||||||
|
|
||||||
|
remaining = await store.get_messages("thread-1")
|
||||||
|
assert len(remaining) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_messages_returns_empty(store):
|
||||||
|
"""메시지가 없으면 빈 리스트를 반환한다."""
|
||||||
|
messages = await store.get_messages("thread-999")
|
||||||
|
assert messages == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_messages_isolated_by_thread(store):
|
||||||
|
"""스레드별로 메시지가 격리된다."""
|
||||||
|
await store.push_message("thread-1", {"content": "for thread 1"})
|
||||||
|
await store.push_message("thread-2", {"content": "for thread 2"})
|
||||||
|
|
||||||
|
msgs1 = await store.get_messages("thread-1")
|
||||||
|
msgs2 = await store.get_messages("thread-2")
|
||||||
|
assert len(msgs1) == 1
|
||||||
|
assert len(msgs2) == 1
|
||||||
|
assert msgs1[0]["content"] == "for thread 1"
|
||||||
Loading…
x
Reference in New Issue
Block a user