galaxis-agent/agent/middleware/ensure_no_empty_msg.py
2026-03-20 14:38:07 +09:00

103 lines
3.7 KiB
Python

from typing import Any
from uuid import uuid4
from langchain.agents.middleware import AgentState, after_model
from langchain_core.messages import AnyMessage, ToolMessage
from langgraph.runtime import Runtime
def get_every_message_since_last_human(state: AgentState) -> list[AnyMessage]:
messages = state["messages"]
last_human_idx = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].type == "human":
last_human_idx = i
break
return messages[last_human_idx + 1 :]
def check_if_model_already_called_commit_and_open_pr(messages: list[AnyMessage]) -> bool:
for msg in messages:
if msg.type == "tool" and msg.name == "commit_and_open_pr":
return True
return False
def check_if_model_messaged_user(messages: list[AnyMessage]) -> bool:
for msg in messages:
if msg.type == "tool" and msg.name in [
"slack_thread_reply",
"linear_comment",
"github_comment",
]:
return True
return False
def check_if_confirming_completion(messages: list[AnyMessage]) -> bool:
for msg in messages:
if msg.type == "tool" and msg.name == "confirming_completion":
return True
return False
def check_if_no_op(messages: list[AnyMessage]) -> bool:
for msg in messages:
if msg.type == "tool" and msg.name == "no_op":
return True
return False
@after_model
def ensure_no_empty_msg(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
last_msg = state["messages"][-1]
has_contents = bool(last_msg.text())
has_tool_calls = bool(last_msg.tool_calls)
if not has_tool_calls and not has_contents:
messages_since_last_human = get_every_message_since_last_human(state)
if check_if_no_op(messages_since_last_human):
return None
if check_if_model_already_called_commit_and_open_pr(
messages_since_last_human
) and check_if_model_messaged_user(messages_since_last_human):
return None
tc_id = str(uuid4())
last_msg.tool_calls = [{"name": "no_op", "args": {}, "id": tc_id}]
no_op_tool_msg = ToolMessage(
content="No operation performed."
+ "Please continue with the task, ensuring you ALWAYS call at least one tool in"
+ " every message unless you are absolutely sure the task has been fully completed.",
tool_call_id=tc_id,
)
return {"messages": [last_msg, no_op_tool_msg]}
if has_contents and not has_tool_calls:
# See if the model already called open_pr or it sent a slack/linear message
# First, get every message since the last human message
messages_since_last_human = get_every_message_since_last_human(state)
# If it opened a PR, we don't need to do anything
if (
check_if_model_already_called_commit_and_open_pr(messages_since_last_human)
or check_if_model_messaged_user(messages_since_last_human)
or check_if_confirming_completion(messages_since_last_human)
):
return None
tc_id = str(uuid4())
last_msg.tool_calls = [{"name": "confirming_completion", "args": {}, "id": tc_id}]
no_op_tool_msg = ToolMessage(
content="Confirming task completion. I see you did not call a tool, which would end the task, however you haven't called a tool to message the user or open a pull request."
+ "This may indicate premature termination - please ensure you fully complete the task before ending it. "
+ "If you do not call any tools it will end the task.",
name="confirming_completion",
tool_call_id=tc_id,
)
return {"messages": [last_msg, no_op_tool_msg]}
return None