105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
|
|
"""Tool error handling middleware.
|
||
|
|
|
||
|
|
Wraps all tool calls in try/except so that unhandled exceptions are
|
||
|
|
returned as error ToolMessages instead of crashing the agent run.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
|
||
|
|
from langchain.agents.middleware.types import (
|
||
|
|
AgentMiddleware,
|
||
|
|
AgentState,
|
||
|
|
)
|
||
|
|
from langchain_core.messages import ToolMessage
|
||
|
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||
|
|
from langgraph.types import Command
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_name(candidate: object) -> str | None:
|
||
|
|
if not candidate:
|
||
|
|
return None
|
||
|
|
if isinstance(candidate, str):
|
||
|
|
return candidate
|
||
|
|
if isinstance(candidate, dict):
|
||
|
|
name = candidate.get("name")
|
||
|
|
else:
|
||
|
|
name = getattr(candidate, "name", None)
|
||
|
|
return name if isinstance(name, str) and name else None
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_tool_name(request: ToolCallRequest | None) -> str | None:
|
||
|
|
if request is None:
|
||
|
|
return None
|
||
|
|
for attr in ("tool_call", "tool_name", "name"):
|
||
|
|
name = _get_name(getattr(request, attr, None))
|
||
|
|
if name:
|
||
|
|
return name
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _to_error_payload(e: Exception, request: ToolCallRequest | None = None) -> dict[str, str]:
|
||
|
|
data: dict[str, str] = {
|
||
|
|
"error": str(e),
|
||
|
|
"error_type": e.__class__.__name__,
|
||
|
|
"status": "error",
|
||
|
|
}
|
||
|
|
tool_name = _extract_tool_name(request)
|
||
|
|
if tool_name:
|
||
|
|
data["name"] = tool_name
|
||
|
|
return data
|
||
|
|
|
||
|
|
|
||
|
|
def _get_tool_call_id(request: ToolCallRequest) -> str | None:
|
||
|
|
if isinstance(request.tool_call, dict):
|
||
|
|
return request.tool_call.get("id")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
class ToolErrorMiddleware(AgentMiddleware):
|
||
|
|
"""Normalize tool execution errors into predictable payloads.
|
||
|
|
|
||
|
|
Catches any exception thrown during a tool call and converts it into
|
||
|
|
a ToolMessage with status="error" so the LLM can see the failure and
|
||
|
|
self-correct, rather than crashing the entire agent run.
|
||
|
|
"""
|
||
|
|
|
||
|
|
state_schema = AgentState
|
||
|
|
|
||
|
|
def wrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
try:
|
||
|
|
return handler(request)
|
||
|
|
except Exception as e:
|
||
|
|
logger.exception("Error during tool call handling; request=%r", request)
|
||
|
|
data = _to_error_payload(e, request)
|
||
|
|
return ToolMessage(
|
||
|
|
content=json.dumps(data),
|
||
|
|
tool_call_id=_get_tool_call_id(request),
|
||
|
|
status="error",
|
||
|
|
)
|
||
|
|
|
||
|
|
async def awrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
try:
|
||
|
|
return await handler(request)
|
||
|
|
except Exception as e:
|
||
|
|
logger.exception("Error during tool call handling; request=%r", request)
|
||
|
|
data = _to_error_payload(e, request)
|
||
|
|
return ToolMessage(
|
||
|
|
content=json.dumps(data),
|
||
|
|
tool_call_id=_get_tool_call_id(request),
|
||
|
|
status="error",
|
||
|
|
)
|