228 lines
7.7 KiB
Python
228 lines
7.7 KiB
Python
"""Tests for agent tool types, registry, and meta-tool."""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
from dataclasses import dataclass
|
|
|
|
from app.agents.tools.types import ToolResult, RegisteredTool
|
|
from app.agents.tools.registry import ToolRegistry
|
|
from app.agents.tools.meta_tool import MetaTool, SubTool
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ToolResult
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestToolResult:
|
|
def test_to_str_with_string_data(self):
|
|
result = ToolResult(data="hello world")
|
|
assert result.to_str() == "hello world"
|
|
|
|
def test_to_str_with_dict_data(self):
|
|
result = ToolResult(data={"key": "value", "count": 42})
|
|
text = result.to_str()
|
|
assert "key" in text
|
|
assert "value" in text
|
|
|
|
def test_to_str_with_list_data(self):
|
|
result = ToolResult(data=["a", "b", "c"])
|
|
text = result.to_str()
|
|
assert "a" in text
|
|
assert "b" in text
|
|
|
|
def test_source_urls_default_empty(self):
|
|
result = ToolResult(data="test")
|
|
assert result.source_urls == []
|
|
|
|
def test_source_urls_preserved(self):
|
|
urls = ["https://example.com/a", "https://example.com/b"]
|
|
result = ToolResult(data="test", source_urls=urls)
|
|
assert result.source_urls == urls
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RegisteredTool
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRegisteredTool:
|
|
def test_creation_and_attributes(self):
|
|
handler = AsyncMock(return_value=ToolResult(data="ok"))
|
|
tool = RegisteredTool(
|
|
name="test_tool",
|
|
description="A test tool for unit tests",
|
|
compact_description="test tool",
|
|
concurrency_safe=True,
|
|
execute=handler,
|
|
)
|
|
assert tool.name == "test_tool"
|
|
assert tool.description == "A test tool for unit tests"
|
|
assert tool.compact_description == "test tool"
|
|
assert tool.concurrency_safe is True
|
|
assert tool.execute is handler
|
|
|
|
def test_concurrency_unsafe(self):
|
|
handler = AsyncMock(return_value=ToolResult(data="ok"))
|
|
tool = RegisteredTool(
|
|
name="serial_tool",
|
|
description="desc",
|
|
compact_description="serial",
|
|
concurrency_safe=False,
|
|
execute=handler,
|
|
)
|
|
assert tool.concurrency_safe is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ToolRegistry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_tool(name: str, concurrency_safe: bool = True) -> RegisteredTool:
|
|
return RegisteredTool(
|
|
name=name,
|
|
description=f"Description of {name}",
|
|
compact_description=f"compact {name}",
|
|
concurrency_safe=concurrency_safe,
|
|
execute=AsyncMock(return_value=ToolResult(data=f"{name} result")),
|
|
)
|
|
|
|
|
|
class TestToolRegistry:
|
|
def test_register_and_get(self):
|
|
registry = ToolRegistry()
|
|
tool = _make_tool("alpha")
|
|
registry.register(tool)
|
|
assert registry.get("alpha") is tool
|
|
|
|
def test_get_unknown_returns_none(self):
|
|
registry = ToolRegistry()
|
|
assert registry.get("nonexistent") is None
|
|
|
|
def test_list_tools(self):
|
|
registry = ToolRegistry()
|
|
registry.register(_make_tool("a"))
|
|
registry.register(_make_tool("b"))
|
|
names = [t.name for t in registry.list_tools()]
|
|
assert "a" in names
|
|
assert "b" in names
|
|
|
|
def test_get_concurrency_map(self):
|
|
registry = ToolRegistry()
|
|
registry.register(_make_tool("safe_tool", concurrency_safe=True))
|
|
registry.register(_make_tool("unsafe_tool", concurrency_safe=False))
|
|
cmap = registry.get_concurrency_map()
|
|
assert cmap["safe_tool"] is True
|
|
assert cmap["unsafe_tool"] is False
|
|
|
|
def test_build_compact_descriptions(self):
|
|
registry = ToolRegistry()
|
|
registry.register(_make_tool("tool_x"))
|
|
registry.register(_make_tool("tool_y"))
|
|
text = registry.build_compact_descriptions()
|
|
assert "tool_x" in text
|
|
assert "tool_y" in text
|
|
assert "compact tool_x" in text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MetaTool
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ConcreteMetaTool(MetaTool):
|
|
"""Concrete subclass for testing abstract MetaTool."""
|
|
|
|
def __init__(self):
|
|
self._sub_tools = [
|
|
SubTool(
|
|
name="search",
|
|
description="Search the web for information",
|
|
handler=AsyncMock(
|
|
return_value=ToolResult(data="search result for query")
|
|
),
|
|
),
|
|
SubTool(
|
|
name="calculate",
|
|
description="Perform mathematical calculations",
|
|
handler=AsyncMock(
|
|
return_value=ToolResult(data="42")
|
|
),
|
|
),
|
|
]
|
|
|
|
@property
|
|
def sub_tools(self) -> list[SubTool]:
|
|
return self._sub_tools
|
|
|
|
|
|
class TestSubTool:
|
|
def test_creation(self):
|
|
handler = AsyncMock()
|
|
st = SubTool(name="test", description="desc", handler=handler)
|
|
assert st.name == "test"
|
|
assert st.description == "desc"
|
|
assert st.handler is handler
|
|
|
|
|
|
class TestMetaTool:
|
|
@pytest.mark.asyncio
|
|
@patch("app.agents.tools.meta_tool.get_chat_model")
|
|
async def test_route_selects_correct_subtool(self, mock_get_model):
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"tool": "search"}'
|
|
mock_llm.invoke = MagicMock(return_value=mock_response)
|
|
mock_get_model.return_value = mock_llm
|
|
|
|
meta = ConcreteMetaTool()
|
|
selected = await meta.route("find information about Python")
|
|
assert selected.name == "search"
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.agents.tools.meta_tool.get_chat_model")
|
|
async def test_route_selects_calculate(self, mock_get_model):
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"tool": "calculate"}'
|
|
mock_llm.invoke = MagicMock(return_value=mock_response)
|
|
mock_get_model.return_value = mock_llm
|
|
|
|
meta = ConcreteMetaTool()
|
|
selected = await meta.route("what is 6 times 7")
|
|
assert selected.name == "calculate"
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.agents.tools.meta_tool.get_chat_model")
|
|
async def test_execute_calls_routed_subtool(self, mock_get_model):
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"tool": "search"}'
|
|
mock_llm.invoke = MagicMock(return_value=mock_response)
|
|
mock_get_model.return_value = mock_llm
|
|
|
|
meta = ConcreteMetaTool()
|
|
result = await meta.execute({"query": "find Python docs"})
|
|
|
|
search_handler = meta.sub_tools[0].handler
|
|
search_handler.assert_called_once()
|
|
assert result.data == "search result for query"
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.agents.tools.meta_tool.get_chat_model")
|
|
async def test_execute_passes_params_to_subtool(self, mock_get_model):
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"tool": "calculate"}'
|
|
mock_llm.invoke = MagicMock(return_value=mock_response)
|
|
mock_get_model.return_value = mock_llm
|
|
|
|
meta = ConcreteMetaTool()
|
|
params = {"query": "compute 2+2", "extra": "data"}
|
|
result = await meta.execute(params)
|
|
|
|
calc_handler = meta.sub_tools[1].handler
|
|
calc_handler.assert_called_once_with(params)
|
|
assert result.data == "42"
|