2026-03-20 14:38:07 +09:00
|
|
|
"""Before-model middleware that injects queued messages into state.
|
|
|
|
|
|
2026-03-20 18:11:15 +09:00
|
|
|
Checks the MessageStore for pending messages (e.g. follow-up Linear
|
2026-03-20 14:38:07 +09:00
|
|
|
comments that arrived while the agent was busy) and injects them as new
|
|
|
|
|
human messages before the next model call.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
from langchain.agents.middleware import AgentState, before_model
|
2026-03-20 18:11:15 +09:00
|
|
|
from langgraph.config import get_config
|
2026-03-20 14:38:07 +09:00
|
|
|
from langgraph.runtime import Runtime
|
|
|
|
|
|
2026-03-20 18:11:15 +09:00
|
|
|
from ..message_store import get_message_store
|
2026-03-20 14:38:07 +09:00
|
|
|
from ..utils.multimodal import fetch_image_block
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearNotifyState(AgentState):
|
|
|
|
|
"""Extended agent state for tracking Linear notifications."""
|
|
|
|
|
|
|
|
|
|
linear_messages_sent_count: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _build_blocks_from_payload(
|
|
|
|
|
payload: dict[str, Any],
|
|
|
|
|
) -> list[dict[str, Any]]:
|
|
|
|
|
text = payload.get("text", "")
|
|
|
|
|
image_urls = payload.get("image_urls", []) or []
|
|
|
|
|
blocks: list[dict[str, Any]] = []
|
|
|
|
|
if text:
|
|
|
|
|
blocks.append({"type": "text", "text": text})
|
|
|
|
|
|
|
|
|
|
if not image_urls:
|
|
|
|
|
return blocks
|
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
|
for image_url in image_urls:
|
|
|
|
|
image_block = await fetch_image_block(image_url, client)
|
|
|
|
|
if image_block:
|
|
|
|
|
blocks.append(image_block)
|
|
|
|
|
return blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@before_model(state_schema=LinearNotifyState)
|
|
|
|
|
async def check_message_queue_before_model( # noqa: PLR0911
|
|
|
|
|
state: LinearNotifyState, # noqa: ARG001
|
|
|
|
|
runtime: Runtime, # noqa: ARG001
|
|
|
|
|
) -> dict[str, Any] | None:
|
|
|
|
|
"""Middleware that checks for queued messages before each model call.
|
|
|
|
|
|
|
|
|
|
If messages are found in the queue for this thread, it extracts all messages,
|
|
|
|
|
adds them to the conversation state as new human messages, and clears the queue.
|
|
|
|
|
Messages are processed in FIFO order (oldest first).
|
|
|
|
|
|
|
|
|
|
This enables handling of follow-up comments that arrive while the agent is busy.
|
|
|
|
|
The agent will see the new messages and can incorporate them into its response.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
config = get_config()
|
|
|
|
|
configurable = config.get("configurable", {})
|
|
|
|
|
thread_id = configurable.get("thread_id")
|
|
|
|
|
|
|
|
|
|
if not thread_id:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-20 18:11:15 +09:00
|
|
|
store = await get_message_store()
|
2026-03-20 14:38:07 +09:00
|
|
|
except Exception as e: # noqa: BLE001
|
2026-03-20 18:11:15 +09:00
|
|
|
logger.warning("Could not get message store: %s", e)
|
2026-03-20 14:38:07 +09:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-20 18:11:15 +09:00
|
|
|
queued_messages = await store.consume_messages(thread_id)
|
2026-03-20 14:38:07 +09:00
|
|
|
except Exception as e: # noqa: BLE001
|
2026-03-20 18:11:15 +09:00
|
|
|
logger.warning("Failed to consume messages: %s", e)
|
2026-03-20 14:38:07 +09:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if not queued_messages:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Found %d queued message(s) for thread %s, injecting into state",
|
|
|
|
|
len(queued_messages),
|
|
|
|
|
thread_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
content_blocks: list[dict[str, Any]] = []
|
|
|
|
|
for msg in queued_messages:
|
|
|
|
|
content = msg.get("content")
|
|
|
|
|
if isinstance(content, dict) and ("text" in content or "image_urls" in content):
|
|
|
|
|
logger.debug("Queued message contains text + image URLs")
|
|
|
|
|
blocks = await _build_blocks_from_payload(content)
|
|
|
|
|
content_blocks.extend(blocks)
|
|
|
|
|
continue
|
|
|
|
|
if isinstance(content, list):
|
|
|
|
|
logger.debug("Queued message contains %d content block(s)", len(content))
|
|
|
|
|
content_blocks.extend(content)
|
|
|
|
|
continue
|
|
|
|
|
if isinstance(content, str) and content:
|
|
|
|
|
logger.debug("Queued message contains text content")
|
|
|
|
|
content_blocks.append({"type": "text", "text": content})
|
|
|
|
|
|
|
|
|
|
if not content_blocks:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
new_message = {
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": content_blocks,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Injected %d queued message(s) into state for thread %s",
|
|
|
|
|
len(content_blocks),
|
|
|
|
|
thread_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"messages": [new_message]} # noqa: TRY300
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Error in check_message_queue_before_model")
|
|
|
|
|
return None
|