103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from langchain.agents.middleware import AgentState, after_model
|
|
from langchain_core.messages import AnyMessage, ToolMessage
|
|
from langgraph.runtime import Runtime
|
|
|
|
|
|
def get_every_message_since_last_human(state: AgentState) -> list[AnyMessage]:
|
|
messages = state["messages"]
|
|
last_human_idx = -1
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
if messages[i].type == "human":
|
|
last_human_idx = i
|
|
break
|
|
return messages[last_human_idx + 1 :]
|
|
|
|
|
|
def check_if_model_already_called_commit_and_open_pr(messages: list[AnyMessage]) -> bool:
|
|
for msg in messages:
|
|
if msg.type == "tool" and msg.name == "commit_and_open_pr":
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_if_model_messaged_user(messages: list[AnyMessage]) -> bool:
|
|
for msg in messages:
|
|
if msg.type == "tool" and msg.name in [
|
|
"slack_thread_reply",
|
|
"linear_comment",
|
|
"github_comment",
|
|
]:
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_if_confirming_completion(messages: list[AnyMessage]) -> bool:
|
|
for msg in messages:
|
|
if msg.type == "tool" and msg.name == "confirming_completion":
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_if_no_op(messages: list[AnyMessage]) -> bool:
|
|
for msg in messages:
|
|
if msg.type == "tool" and msg.name == "no_op":
|
|
return True
|
|
return False
|
|
|
|
|
|
@after_model
|
|
def ensure_no_empty_msg(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
last_msg = state["messages"][-1]
|
|
has_contents = bool(last_msg.text())
|
|
has_tool_calls = bool(last_msg.tool_calls)
|
|
if not has_tool_calls and not has_contents:
|
|
messages_since_last_human = get_every_message_since_last_human(state)
|
|
if check_if_no_op(messages_since_last_human):
|
|
return None
|
|
|
|
if check_if_model_already_called_commit_and_open_pr(
|
|
messages_since_last_human
|
|
) and check_if_model_messaged_user(messages_since_last_human):
|
|
return None
|
|
|
|
tc_id = str(uuid4())
|
|
last_msg.tool_calls = [{"name": "no_op", "args": {}, "id": tc_id}]
|
|
no_op_tool_msg = ToolMessage(
|
|
content="No operation performed."
|
|
+ "Please continue with the task, ensuring you ALWAYS call at least one tool in"
|
|
+ " every message unless you are absolutely sure the task has been fully completed.",
|
|
tool_call_id=tc_id,
|
|
)
|
|
|
|
return {"messages": [last_msg, no_op_tool_msg]}
|
|
|
|
if has_contents and not has_tool_calls:
|
|
# See if the model already called open_pr or it sent a slack/linear message
|
|
# First, get every message since the last human message
|
|
messages_since_last_human = get_every_message_since_last_human(state)
|
|
|
|
# If it opened a PR, we don't need to do anything
|
|
if (
|
|
check_if_model_already_called_commit_and_open_pr(messages_since_last_human)
|
|
or check_if_model_messaged_user(messages_since_last_human)
|
|
or check_if_confirming_completion(messages_since_last_human)
|
|
):
|
|
return None
|
|
|
|
tc_id = str(uuid4())
|
|
last_msg.tool_calls = [{"name": "confirming_completion", "args": {}, "id": tc_id}]
|
|
no_op_tool_msg = ToolMessage(
|
|
content="Confirming task completion. I see you did not call a tool, which would end the task, however you haven't called a tool to message the user or open a pull request."
|
|
+ "This may indicate premature termination - please ensure you fully complete the task before ending it. "
|
|
+ "If you do not call any tools it will end the task.",
|
|
name="confirming_completion",
|
|
tool_call_id=tc_id,
|
|
)
|
|
|
|
return {"messages": [last_msg, no_op_tool_msg]}
|
|
|
|
return None
|