diff --git a/agent/message_store.py b/agent/message_store.py new file mode 100644 index 0000000..1035a32 --- /dev/null +++ b/agent/message_store.py @@ -0,0 +1,86 @@ +"""SQLite 기반 메시지 스토어. + +에이전트 작업 중 도착하는 follow-up 메시지를 저장한다. +check_message_queue 미들웨어가 다음 모델 호출 시 주입한다. +""" + +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime, timezone + +import aiosqlite + +logger = logging.getLogger(__name__) + +_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS pending_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + thread_id TEXT NOT NULL, + message TEXT NOT NULL, + created_at TEXT NOT NULL +) +""" + +_store: MessageStore | None = None + + +class MessageStore: + """SQLite 기반 follow-up 메시지 스토어.""" + + def __init__(self, db_path: str = "/data/message_store.db"): + self._db_path = db_path + self._db: aiosqlite.Connection | None = None + + async def initialize(self) -> None: + """Initialize the database connection and create tables.""" + self._db = await aiosqlite.connect(self._db_path) + await self._db.execute(_CREATE_TABLE) + await self._db.commit() + + async def close(self) -> None: + """Close the database connection.""" + if self._db: + await self._db.close() + + async def push_message(self, thread_id: str, message: dict) -> None: + """스레드에 pending 메시지를 추가한다.""" + now = datetime.now(timezone.utc).isoformat() + await self._db.execute( + "INSERT INTO pending_messages (thread_id, message, created_at) VALUES (?, ?, ?)", + (thread_id, json.dumps(message), now), + ) + await self._db.commit() + logger.debug("Pushed message for thread %s", thread_id) + + async def get_messages(self, thread_id: str) -> list[dict]: + """스레드의 pending 메시지를 조회한다 (삭제하지 않음).""" + cursor = await self._db.execute( + "SELECT message FROM pending_messages WHERE thread_id = ? ORDER BY id ASC", + (thread_id,), + ) + rows = await cursor.fetchall() + return [json.loads(row[0]) for row in rows] + + async def consume_messages(self, thread_id: str) -> list[dict]: + """스레드의 pending 메시지를 조회하고 삭제한다.""" + messages = await self.get_messages(thread_id) + await self._db.execute( + "DELETE FROM pending_messages WHERE thread_id = ?", + (thread_id,), + ) + await self._db.commit() + logger.debug("Consumed %d messages for thread %s", len(messages), thread_id) + return messages + + +async def get_message_store() -> MessageStore: + """MessageStore 싱글턴을 반환한다.""" + global _store + if _store is None: + db_path = os.environ.get("MESSAGE_STORE_DB", "/data/message_store.db") + _store = MessageStore(db_path=db_path) + await _store.initialize() + return _store diff --git a/agent/middleware/check_message_queue.py b/agent/middleware/check_message_queue.py index a882576..f64a7b4 100644 --- a/agent/middleware/check_message_queue.py +++ b/agent/middleware/check_message_queue.py @@ -1,6 +1,6 @@ """Before-model middleware that injects queued messages into state. -Checks the LangGraph store for pending messages (e.g. follow-up Linear +Checks the MessageStore for pending messages (e.g. follow-up Linear comments that arrived while the agent was busy) and injects them as new human messages before the next model call. """ @@ -12,9 +12,10 @@ from typing import Any import httpx from langchain.agents.middleware import AgentState, before_model -from langgraph.config import get_config, get_store +from langgraph.config import get_config from langgraph.runtime import Runtime +from ..message_store import get_message_store from ..utils.multimodal import fetch_image_block logger = logging.getLogger(__name__) @@ -68,31 +69,17 @@ async def check_message_queue_before_model( # noqa: PLR0911 return None try: - store = get_store() + store = await get_message_store() except Exception as e: # noqa: BLE001 - logger.debug("Could not get store from context: %s", e) + logger.warning("Could not get message store: %s", e) return None - if store is None: - return None - - namespace = ("queue", thread_id) - try: - queued_item = await store.aget(namespace, "pending_messages") + queued_messages = await store.consume_messages(thread_id) except Exception as e: # noqa: BLE001 - logger.warning("Failed to get queued item: %s", e) + logger.warning("Failed to consume messages: %s", e) return None - if queued_item is None: - return None - - queued_value = queued_item.value - queued_messages = queued_value.get("messages", []) - - # Delete early to prevent duplicate processing if middleware runs again - await store.adelete(namespace, "pending_messages") - if not queued_messages: return None diff --git a/tests/test_message_store.py b/tests/test_message_store.py new file mode 100644 index 0000000..faecb11 --- /dev/null +++ b/tests/test_message_store.py @@ -0,0 +1,62 @@ +"""Tests for MessageStore.""" + +import os +import tempfile + +import pytest + +from agent.message_store import MessageStore + + +@pytest.fixture +async def store(): + """Create a temporary MessageStore for testing.""" + fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + s = MessageStore(db_path=db_path) + await s.initialize() + yield s + await s.close() + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_push_and_get_messages(store): + """메시지를 push하고 get한다.""" + await store.push_message("thread-1", {"role": "human", "content": "추가 요청"}) + messages = await store.get_messages("thread-1") + assert len(messages) == 1 + assert messages[0]["content"] == "추가 요청" + + +@pytest.mark.asyncio +async def test_consume_messages(store): + """메시지를 소비하면 삭제된다.""" + await store.push_message("thread-1", {"role": "human", "content": "msg1"}) + await store.push_message("thread-1", {"role": "human", "content": "msg2"}) + + messages = await store.consume_messages("thread-1") + assert len(messages) == 2 + + remaining = await store.get_messages("thread-1") + assert len(remaining) == 0 + + +@pytest.mark.asyncio +async def test_no_messages_returns_empty(store): + """메시지가 없으면 빈 리스트를 반환한다.""" + messages = await store.get_messages("thread-999") + assert messages == [] + + +@pytest.mark.asyncio +async def test_messages_isolated_by_thread(store): + """스레드별로 메시지가 격리된다.""" + await store.push_message("thread-1", {"content": "for thread 1"}) + await store.push_message("thread-2", {"content": "for thread 2"}) + + msgs1 = await store.get_messages("thread-1") + msgs2 = await store.get_messages("thread-2") + assert len(msgs1) == 1 + assert len(msgs2) == 1 + assert msgs1[0]["content"] == "for thread 1"