galaxis-agent/agent/middleware/check_message_queue.py
2026-03-20 14:38:07 +09:00

139 lines
4.4 KiB
Python

"""Before-model middleware that injects queued messages into state.
Checks the LangGraph store for pending messages (e.g. follow-up Linear
comments that arrived while the agent was busy) and injects them as new
human messages before the next model call.
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from langchain.agents.middleware import AgentState, before_model
from langgraph.config import get_config, get_store
from langgraph.runtime import Runtime
from ..utils.multimodal import fetch_image_block
logger = logging.getLogger(__name__)
class LinearNotifyState(AgentState):
"""Extended agent state for tracking Linear notifications."""
linear_messages_sent_count: int
async def _build_blocks_from_payload(
payload: dict[str, Any],
) -> list[dict[str, Any]]:
text = payload.get("text", "")
image_urls = payload.get("image_urls", []) or []
blocks: list[dict[str, Any]] = []
if text:
blocks.append({"type": "text", "text": text})
if not image_urls:
return blocks
async with httpx.AsyncClient() as client:
for image_url in image_urls:
image_block = await fetch_image_block(image_url, client)
if image_block:
blocks.append(image_block)
return blocks
@before_model(state_schema=LinearNotifyState)
async def check_message_queue_before_model( # noqa: PLR0911
state: LinearNotifyState, # noqa: ARG001
runtime: Runtime, # noqa: ARG001
) -> dict[str, Any] | None:
"""Middleware that checks for queued messages before each model call.
If messages are found in the queue for this thread, it extracts all messages,
adds them to the conversation state as new human messages, and clears the queue.
Messages are processed in FIFO order (oldest first).
This enables handling of follow-up comments that arrive while the agent is busy.
The agent will see the new messages and can incorporate them into its response.
"""
try:
config = get_config()
configurable = config.get("configurable", {})
thread_id = configurable.get("thread_id")
if not thread_id:
return None
try:
store = get_store()
except Exception as e: # noqa: BLE001
logger.debug("Could not get store from context: %s", e)
return None
if store is None:
return None
namespace = ("queue", thread_id)
try:
queued_item = await store.aget(namespace, "pending_messages")
except Exception as e: # noqa: BLE001
logger.warning("Failed to get queued item: %s", e)
return None
if queued_item is None:
return None
queued_value = queued_item.value
queued_messages = queued_value.get("messages", [])
# Delete early to prevent duplicate processing if middleware runs again
await store.adelete(namespace, "pending_messages")
if not queued_messages:
return None
logger.info(
"Found %d queued message(s) for thread %s, injecting into state",
len(queued_messages),
thread_id,
)
content_blocks: list[dict[str, Any]] = []
for msg in queued_messages:
content = msg.get("content")
if isinstance(content, dict) and ("text" in content or "image_urls" in content):
logger.debug("Queued message contains text + image URLs")
blocks = await _build_blocks_from_payload(content)
content_blocks.extend(blocks)
continue
if isinstance(content, list):
logger.debug("Queued message contains %d content block(s)", len(content))
content_blocks.extend(content)
continue
if isinstance(content, str) and content:
logger.debug("Queued message contains text content")
content_blocks.append({"type": "text", "text": content})
if not content_blocks:
return None
new_message = {
"role": "user",
"content": content_blocks,
}
logger.info(
"Injected %d queued message(s) into state for thread %s",
len(content_blocks),
thread_id,
)
return {"messages": [new_message]} # noqa: TRY300
except Exception:
logger.exception("Error in check_message_queue_before_model")
return None