380 lines
13 KiB
Python
380 lines
13 KiB
Python
|
|
"""Tests for agent core: scratchpad, context compactor, tool executor, and agent."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||
|
|
|
||
|
|
from app.agents.tools.types import ToolResult, RegisteredTool
|
||
|
|
from app.agents.tools.registry import ToolRegistry
|
||
|
|
from app.agents.core.scratchpad import Scratchpad
|
||
|
|
from app.agents.core.compact import ContextCompactor
|
||
|
|
from app.agents.core.tool_executor import ToolExecutor, ToolCall, ToolEvent
|
||
|
|
from app.agents.core.agent import Agent, AgentEvent
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Scratchpad
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestScratchpad:
|
||
|
|
def test_add_and_total_tokens(self):
|
||
|
|
sp = Scratchpad()
|
||
|
|
sp.add(tool_name="search", params={"q": "test"}, result="found 3 results")
|
||
|
|
sp.add(tool_name="calc", params={"expr": "2+2"}, result="4")
|
||
|
|
assert sp.total_tokens() > 0
|
||
|
|
|
||
|
|
def test_format_history(self):
|
||
|
|
sp = Scratchpad()
|
||
|
|
sp.add(tool_name="search", params={"q": "python"}, result="python docs")
|
||
|
|
history = sp.format_history()
|
||
|
|
assert "search" in history
|
||
|
|
assert "python" in history
|
||
|
|
|
||
|
|
def test_format_history_empty(self):
|
||
|
|
sp = Scratchpad()
|
||
|
|
history = sp.format_history()
|
||
|
|
assert isinstance(history, str)
|
||
|
|
|
||
|
|
def test_clear(self):
|
||
|
|
sp = Scratchpad()
|
||
|
|
sp.add(tool_name="tool1", params={}, result="result_one")
|
||
|
|
sp.add(tool_name="tool2", params={}, result="result_two")
|
||
|
|
assert sp.total_tokens() > 0
|
||
|
|
sp.clear()
|
||
|
|
assert sp.total_tokens() == 0
|
||
|
|
|
||
|
|
def test_multiple_entries_tracked(self):
|
||
|
|
sp = Scratchpad()
|
||
|
|
for i in range(5):
|
||
|
|
sp.add(tool_name=f"tool_{i}", params={"i": i}, result=f"result_{i}")
|
||
|
|
history = sp.format_history()
|
||
|
|
for i in range(5):
|
||
|
|
assert f"tool_{i}" in history
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# ContextCompactor
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestContextCompactor:
|
||
|
|
def test_should_compact_false_below_threshold(self):
|
||
|
|
compactor = ContextCompactor()
|
||
|
|
messages = [{"role": "user", "content": "short message"}]
|
||
|
|
assert compactor.should_compact(messages, threshold=10000) is False
|
||
|
|
|
||
|
|
def test_should_compact_true_above_threshold(self):
|
||
|
|
compactor = ContextCompactor()
|
||
|
|
long_content = "x" * 5000
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": long_content},
|
||
|
|
{"role": "assistant", "content": long_content},
|
||
|
|
{"role": "user", "content": long_content},
|
||
|
|
]
|
||
|
|
assert compactor.should_compact(messages, threshold=100) is True
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.compact.get_chat_model")
|
||
|
|
async def test_compact_calls_llm(self, mock_get_model):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = "Summary: user asked about Python, assistant explained basics."
|
||
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
compactor = ContextCompactor()
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "Tell me about Python"},
|
||
|
|
{"role": "assistant", "content": "Python is a programming language..."},
|
||
|
|
{"role": "user", "content": "What about async?"},
|
||
|
|
{"role": "assistant", "content": "Async in Python uses asyncio..."},
|
||
|
|
]
|
||
|
|
result = await compactor.compact(messages, system_prompt="You are helpful.")
|
||
|
|
assert isinstance(result, list)
|
||
|
|
assert len(result) > 0
|
||
|
|
mock_llm.ainvoke.assert_called_once()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.compact.get_chat_model")
|
||
|
|
async def test_compact_preserves_system_context(self, mock_get_model):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = "Summarized conversation."
|
||
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
compactor = ContextCompactor()
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "Hello"},
|
||
|
|
{"role": "assistant", "content": "Hi there!"},
|
||
|
|
]
|
||
|
|
result = await compactor.compact(messages, system_prompt="System prompt here")
|
||
|
|
assert isinstance(result, list)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# ToolExecutor
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _make_registered_tool(
|
||
|
|
name: str, concurrency_safe: bool = True, result_data: str = "ok"
|
||
|
|
) -> RegisteredTool:
|
||
|
|
return RegisteredTool(
|
||
|
|
name=name,
|
||
|
|
description=f"Desc {name}",
|
||
|
|
compact_description=f"compact {name}",
|
||
|
|
concurrency_safe=concurrency_safe,
|
||
|
|
execute=AsyncMock(return_value=ToolResult(data=result_data)),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestToolExecutor:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_concurrent_safe_tools(self):
|
||
|
|
tool_a = _make_registered_tool("tool_a", concurrency_safe=True, result_data="A")
|
||
|
|
tool_b = _make_registered_tool("tool_b", concurrency_safe=True, result_data="B")
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(tool_a)
|
||
|
|
registry.register(tool_b)
|
||
|
|
|
||
|
|
executor = ToolExecutor(registry=registry)
|
||
|
|
calls = [
|
||
|
|
ToolCall(id="c1", name="tool_a", params={"q": "hello"}),
|
||
|
|
ToolCall(id="c2", name="tool_b", params={"q": "world"}),
|
||
|
|
]
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in executor.execute_tool_calls(calls):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) == 2
|
||
|
|
names = {e.tool_name for e in events}
|
||
|
|
assert "tool_a" in names
|
||
|
|
assert "tool_b" in names
|
||
|
|
|
||
|
|
tool_a.execute.assert_called_once()
|
||
|
|
tool_b.execute.assert_called_once()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_serial_tools_run_sequentially(self):
|
||
|
|
tool_s = _make_registered_tool("serial_tool", concurrency_safe=False, result_data="S")
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(tool_s)
|
||
|
|
|
||
|
|
executor = ToolExecutor(registry=registry)
|
||
|
|
calls = [
|
||
|
|
ToolCall(id="c1", name="serial_tool", params={"x": 1}),
|
||
|
|
]
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in executor.execute_tool_calls(calls):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) == 1
|
||
|
|
assert events[0].tool_name == "serial_tool"
|
||
|
|
assert events[0].error is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_error_handling_for_failed_tool(self):
|
||
|
|
failing_tool = RegisteredTool(
|
||
|
|
name="fail_tool",
|
||
|
|
description="A tool that fails",
|
||
|
|
compact_description="fail",
|
||
|
|
concurrency_safe=True,
|
||
|
|
execute=AsyncMock(side_effect=RuntimeError("Tool exploded")),
|
||
|
|
)
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(failing_tool)
|
||
|
|
|
||
|
|
executor = ToolExecutor(registry=registry)
|
||
|
|
calls = [ToolCall(id="c1", name="fail_tool", params={})]
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in executor.execute_tool_calls(calls):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) == 1
|
||
|
|
assert events[0].error is not None
|
||
|
|
assert "exploded" in events[0].error.lower() or events[0].error
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_mixed_concurrent_and_serial(self):
|
||
|
|
safe = _make_registered_tool("safe", concurrency_safe=True, result_data="safe_r")
|
||
|
|
unsafe = _make_registered_tool("unsafe", concurrency_safe=False, result_data="unsafe_r")
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(safe)
|
||
|
|
registry.register(unsafe)
|
||
|
|
|
||
|
|
executor = ToolExecutor(registry=registry)
|
||
|
|
calls = [
|
||
|
|
ToolCall(id="c1", name="safe", params={}),
|
||
|
|
ToolCall(id="c2", name="unsafe", params={}),
|
||
|
|
]
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in executor.execute_tool_calls(calls):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) == 2
|
||
|
|
results = {e.tool_name: e.result for e in events}
|
||
|
|
assert results["safe"].data == "safe_r"
|
||
|
|
assert results["unsafe"].data == "unsafe_r"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_unknown_tool_returns_error(self):
|
||
|
|
registry = ToolRegistry()
|
||
|
|
executor = ToolExecutor(registry=registry)
|
||
|
|
calls = [ToolCall(id="c1", name="nonexistent", params={})]
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in executor.execute_tool_calls(calls):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) == 1
|
||
|
|
assert events[0].error is not None
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# ToolCall & ToolEvent dataclass tests
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestToolCallAndEvent:
|
||
|
|
def test_tool_call_creation(self):
|
||
|
|
tc = ToolCall(id="abc", name="my_tool", params={"key": "val"})
|
||
|
|
assert tc.id == "abc"
|
||
|
|
assert tc.name == "my_tool"
|
||
|
|
assert tc.params == {"key": "val"}
|
||
|
|
|
||
|
|
def test_tool_event_success(self):
|
||
|
|
result = ToolResult(data="success")
|
||
|
|
ev = ToolEvent(call_id="c1", tool_name="t1", result=result, error=None)
|
||
|
|
assert ev.result.data == "success"
|
||
|
|
assert ev.error is None
|
||
|
|
|
||
|
|
def test_tool_event_error(self):
|
||
|
|
ev = ToolEvent(call_id="c1", tool_name="t1", result=None, error="something went wrong")
|
||
|
|
assert ev.result is None
|
||
|
|
assert ev.error == "something went wrong"
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Agent
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgent:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.agent.get_chat_model")
|
||
|
|
async def test_run_yields_events(self, mock_get_model):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = "Here is the answer to your question."
|
||
|
|
mock_response.tool_calls = []
|
||
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||
|
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
agent = Agent(tool_registry=registry)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in agent.run("What is 2+2?"):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) >= 1
|
||
|
|
event_types = [e.type for e in events]
|
||
|
|
assert any(t in event_types for t in ["message", "answer", "response", "text"])
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.agent.get_chat_model")
|
||
|
|
async def test_run_with_tool_call(self, mock_get_model):
|
||
|
|
tool_call_response = MagicMock()
|
||
|
|
tool_call_response.content = ""
|
||
|
|
tool_call_response.tool_calls = [
|
||
|
|
{"id": "call_1", "name": "search", "args": {"q": "python"}}
|
||
|
|
]
|
||
|
|
|
||
|
|
final_response = MagicMock()
|
||
|
|
final_response.content = "Python is a programming language."
|
||
|
|
final_response.tool_calls = []
|
||
|
|
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_response, final_response])
|
||
|
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
search_tool = _make_registered_tool("search", result_data="Python docs found")
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(search_tool)
|
||
|
|
|
||
|
|
agent = Agent(tool_registry=registry)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in agent.run("Tell me about Python"):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) >= 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.agent.get_chat_model")
|
||
|
|
async def test_max_iterations_limit(self, mock_get_model):
|
||
|
|
tool_call_response = MagicMock()
|
||
|
|
tool_call_response.content = ""
|
||
|
|
tool_call_response.tool_calls = [
|
||
|
|
{"id": "call_n", "name": "loop_tool", "args": {}}
|
||
|
|
]
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_llm.ainvoke = AsyncMock(return_value=tool_call_response)
|
||
|
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
loop_tool = _make_registered_tool("loop_tool", result_data="looping")
|
||
|
|
registry = ToolRegistry()
|
||
|
|
registry.register(loop_tool)
|
||
|
|
|
||
|
|
agent = Agent(tool_registry=registry, max_iterations=3)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in agent.run("infinite loop query"):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert loop_tool.execute.call_count <= 3
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
@patch("app.agents.core.agent.get_chat_model")
|
||
|
|
async def test_run_empty_query(self, mock_get_model):
|
||
|
|
mock_llm = MagicMock()
|
||
|
|
mock_response = MagicMock()
|
||
|
|
mock_response.content = "I need more information."
|
||
|
|
mock_response.tool_calls = []
|
||
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||
|
|
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||
|
|
mock_get_model.return_value = mock_llm
|
||
|
|
|
||
|
|
registry = ToolRegistry()
|
||
|
|
agent = Agent(tool_registry=registry)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in agent.run(""):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) >= 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentEvent:
|
||
|
|
def test_creation(self):
|
||
|
|
event = AgentEvent(type="message", data={"text": "hello"})
|
||
|
|
assert event.type == "message"
|
||
|
|
assert event.data == {"text": "hello"}
|
||
|
|
|
||
|
|
def test_creation_with_string_data(self):
|
||
|
|
event = AgentEvent(type="error", data="something failed")
|
||
|
|
assert event.type == "error"
|
||
|
|
assert event.data == "something failed"
|