galaxis-agent/agent/middleware/check_message_queue.py

126 lines
4.1 KiB
Python
Raw Permalink Normal View History

2026-03-20 14:38:07 +09:00
"""Before-model middleware that injects queued messages into state.
Checks the MessageStore for pending messages (e.g. follow-up Linear
2026-03-20 14:38:07 +09:00
comments that arrived while the agent was busy) and injects them as new
human messages before the next model call.
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from langchain.agents.middleware import AgentState, before_model
from langgraph.config import get_config
2026-03-20 14:38:07 +09:00
from langgraph.runtime import Runtime
from ..message_store import get_message_store
2026-03-20 14:38:07 +09:00
from ..utils.multimodal import fetch_image_block
logger = logging.getLogger(__name__)
class LinearNotifyState(AgentState):
"""Extended agent state for tracking Linear notifications."""
linear_messages_sent_count: int
async def _build_blocks_from_payload(
payload: dict[str, Any],
) -> list[dict[str, Any]]:
text = payload.get("text", "")
image_urls = payload.get("image_urls", []) or []
blocks: list[dict[str, Any]] = []
if text:
blocks.append({"type": "text", "text": text})
if not image_urls:
return blocks
async with httpx.AsyncClient() as client:
for image_url in image_urls:
image_block = await fetch_image_block(image_url, client)
if image_block:
blocks.append(image_block)
return blocks
@before_model(state_schema=LinearNotifyState)
async def check_message_queue_before_model( # noqa: PLR0911
state: LinearNotifyState, # noqa: ARG001
runtime: Runtime, # noqa: ARG001
) -> dict[str, Any] | None:
"""Middleware that checks for queued messages before each model call.
If messages are found in the queue for this thread, it extracts all messages,
adds them to the conversation state as new human messages, and clears the queue.
Messages are processed in FIFO order (oldest first).
This enables handling of follow-up comments that arrive while the agent is busy.
The agent will see the new messages and can incorporate them into its response.
"""
try:
config = get_config()
configurable = config.get("configurable", {})
thread_id = configurable.get("thread_id")
if not thread_id:
return None
try:
store = await get_message_store()
2026-03-20 14:38:07 +09:00
except Exception as e: # noqa: BLE001
logger.warning("Could not get message store: %s", e)
2026-03-20 14:38:07 +09:00
return None
try:
queued_messages = await store.consume_messages(thread_id)
2026-03-20 14:38:07 +09:00
except Exception as e: # noqa: BLE001
logger.warning("Failed to consume messages: %s", e)
2026-03-20 14:38:07 +09:00
return None
if not queued_messages:
return None
logger.info(
"Found %d queued message(s) for thread %s, injecting into state",
len(queued_messages),
thread_id,
)
content_blocks: list[dict[str, Any]] = []
for msg in queued_messages:
content = msg.get("content")
if isinstance(content, dict) and ("text" in content or "image_urls" in content):
logger.debug("Queued message contains text + image URLs")
blocks = await _build_blocks_from_payload(content)
content_blocks.extend(blocks)
continue
if isinstance(content, list):
logger.debug("Queued message contains %d content block(s)", len(content))
content_blocks.extend(content)
continue
if isinstance(content, str) and content:
logger.debug("Queued message contains text content")
content_blocks.append({"type": "text", "text": content})
if not content_blocks:
return None
new_message = {
"role": "user",
"content": content_blocks,
}
logger.info(
"Injected %d queued message(s) into state for thread %s",
len(content_blocks),
thread_id,
)
return {"messages": [new_message]} # noqa: TRY300
except Exception:
logger.exception("Error in check_message_queue_before_model")
return None