"""Tests for agent core: scratchpad, context compactor, tool executor, and agent.""" import pytest from unittest.mock import patch, MagicMock, AsyncMock from app.agents.tools.types import ToolResult, RegisteredTool from app.agents.tools.registry import ToolRegistry from app.agents.core.scratchpad import Scratchpad from app.agents.core.compact import ContextCompactor from app.agents.core.tool_executor import ToolExecutor, ToolCall, ToolEvent from app.agents.core.agent import Agent, AgentEvent # --------------------------------------------------------------------------- # Scratchpad # --------------------------------------------------------------------------- class TestScratchpad: def test_add_and_total_tokens(self): sp = Scratchpad() sp.add(tool_name="search", params={"q": "test"}, result="found 3 results") sp.add(tool_name="calc", params={"expr": "2+2"}, result="4") assert sp.total_tokens() > 0 def test_format_history(self): sp = Scratchpad() sp.add(tool_name="search", params={"q": "python"}, result="python docs") history = sp.format_history() assert "search" in history assert "python" in history def test_format_history_empty(self): sp = Scratchpad() history = sp.format_history() assert isinstance(history, str) def test_clear(self): sp = Scratchpad() sp.add(tool_name="tool1", params={}, result="result_one") sp.add(tool_name="tool2", params={}, result="result_two") assert sp.total_tokens() > 0 sp.clear() assert sp.total_tokens() == 0 def test_multiple_entries_tracked(self): sp = Scratchpad() for i in range(5): sp.add(tool_name=f"tool_{i}", params={"i": i}, result=f"result_{i}") history = sp.format_history() for i in range(5): assert f"tool_{i}" in history # --------------------------------------------------------------------------- # ContextCompactor # --------------------------------------------------------------------------- class TestContextCompactor: def test_should_compact_false_below_threshold(self): compactor = ContextCompactor() messages = [{"role": "user", "content": "short message"}] assert compactor.should_compact(messages, threshold=10000) is False def test_should_compact_true_above_threshold(self): compactor = ContextCompactor() long_content = "x" * 5000 messages = [ {"role": "user", "content": long_content}, {"role": "assistant", "content": long_content}, {"role": "user", "content": long_content}, ] assert compactor.should_compact(messages, threshold=100) is True @pytest.mark.asyncio @patch("app.agents.core.compact.get_chat_model") async def test_compact_calls_llm(self, mock_get_model): mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = "Summary: user asked about Python, assistant explained basics." mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_get_model.return_value = mock_llm compactor = ContextCompactor() messages = [ {"role": "user", "content": "Tell me about Python"}, {"role": "assistant", "content": "Python is a programming language..."}, {"role": "user", "content": "What about async?"}, {"role": "assistant", "content": "Async in Python uses asyncio..."}, ] result = await compactor.compact(messages, system_prompt="You are helpful.") assert isinstance(result, list) assert len(result) > 0 mock_llm.ainvoke.assert_called_once() @pytest.mark.asyncio @patch("app.agents.core.compact.get_chat_model") async def test_compact_preserves_system_context(self, mock_get_model): mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = "Summarized conversation." mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_get_model.return_value = mock_llm compactor = ContextCompactor() messages = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] result = await compactor.compact(messages, system_prompt="System prompt here") assert isinstance(result, list) # --------------------------------------------------------------------------- # ToolExecutor # --------------------------------------------------------------------------- def _make_registered_tool( name: str, concurrency_safe: bool = True, result_data: str = "ok" ) -> RegisteredTool: return RegisteredTool( name=name, description=f"Desc {name}", compact_description=f"compact {name}", concurrency_safe=concurrency_safe, execute=AsyncMock(return_value=ToolResult(data=result_data)), ) class TestToolExecutor: @pytest.mark.asyncio async def test_concurrent_safe_tools(self): tool_a = _make_registered_tool("tool_a", concurrency_safe=True, result_data="A") tool_b = _make_registered_tool("tool_b", concurrency_safe=True, result_data="B") registry = ToolRegistry() registry.register(tool_a) registry.register(tool_b) executor = ToolExecutor(registry=registry) calls = [ ToolCall(id="c1", name="tool_a", params={"q": "hello"}), ToolCall(id="c2", name="tool_b", params={"q": "world"}), ] events = [] async for event in executor.execute_tool_calls(calls): events.append(event) assert len(events) == 2 names = {e.tool_name for e in events} assert "tool_a" in names assert "tool_b" in names tool_a.execute.assert_called_once() tool_b.execute.assert_called_once() @pytest.mark.asyncio async def test_serial_tools_run_sequentially(self): tool_s = _make_registered_tool("serial_tool", concurrency_safe=False, result_data="S") registry = ToolRegistry() registry.register(tool_s) executor = ToolExecutor(registry=registry) calls = [ ToolCall(id="c1", name="serial_tool", params={"x": 1}), ] events = [] async for event in executor.execute_tool_calls(calls): events.append(event) assert len(events) == 1 assert events[0].tool_name == "serial_tool" assert events[0].error is None @pytest.mark.asyncio async def test_error_handling_for_failed_tool(self): failing_tool = RegisteredTool( name="fail_tool", description="A tool that fails", compact_description="fail", concurrency_safe=True, execute=AsyncMock(side_effect=RuntimeError("Tool exploded")), ) registry = ToolRegistry() registry.register(failing_tool) executor = ToolExecutor(registry=registry) calls = [ToolCall(id="c1", name="fail_tool", params={})] events = [] async for event in executor.execute_tool_calls(calls): events.append(event) assert len(events) == 1 assert events[0].error is not None assert "exploded" in events[0].error.lower() or events[0].error @pytest.mark.asyncio async def test_mixed_concurrent_and_serial(self): safe = _make_registered_tool("safe", concurrency_safe=True, result_data="safe_r") unsafe = _make_registered_tool("unsafe", concurrency_safe=False, result_data="unsafe_r") registry = ToolRegistry() registry.register(safe) registry.register(unsafe) executor = ToolExecutor(registry=registry) calls = [ ToolCall(id="c1", name="safe", params={}), ToolCall(id="c2", name="unsafe", params={}), ] events = [] async for event in executor.execute_tool_calls(calls): events.append(event) assert len(events) == 2 results = {e.tool_name: e.result for e in events} assert results["safe"].data == "safe_r" assert results["unsafe"].data == "unsafe_r" @pytest.mark.asyncio async def test_unknown_tool_returns_error(self): registry = ToolRegistry() executor = ToolExecutor(registry=registry) calls = [ToolCall(id="c1", name="nonexistent", params={})] events = [] async for event in executor.execute_tool_calls(calls): events.append(event) assert len(events) == 1 assert events[0].error is not None # --------------------------------------------------------------------------- # ToolCall & ToolEvent dataclass tests # --------------------------------------------------------------------------- class TestToolCallAndEvent: def test_tool_call_creation(self): tc = ToolCall(id="abc", name="my_tool", params={"key": "val"}) assert tc.id == "abc" assert tc.name == "my_tool" assert tc.params == {"key": "val"} def test_tool_event_success(self): result = ToolResult(data="success") ev = ToolEvent(call_id="c1", tool_name="t1", result=result, error=None) assert ev.result.data == "success" assert ev.error is None def test_tool_event_error(self): ev = ToolEvent(call_id="c1", tool_name="t1", result=None, error="something went wrong") assert ev.result is None assert ev.error == "something went wrong" # --------------------------------------------------------------------------- # Agent # --------------------------------------------------------------------------- class TestAgent: @pytest.mark.asyncio @patch("app.agents.core.agent.get_chat_model") async def test_run_yields_events(self, mock_get_model): mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = "Here is the answer to your question." mock_response.tool_calls = [] mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_get_model.return_value = mock_llm registry = ToolRegistry() agent = Agent(tool_registry=registry) events = [] async for event in agent.run("What is 2+2?"): events.append(event) assert len(events) >= 1 event_types = [e.type for e in events] assert any(t in event_types for t in ["message", "answer", "response", "text"]) @pytest.mark.asyncio @patch("app.agents.core.agent.get_chat_model") async def test_run_with_tool_call(self, mock_get_model): tool_call_response = MagicMock() tool_call_response.content = "" tool_call_response.tool_calls = [ {"id": "call_1", "name": "search", "args": {"q": "python"}} ] final_response = MagicMock() final_response.content = "Python is a programming language." final_response.tool_calls = [] mock_llm = MagicMock() mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_response, final_response]) mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_get_model.return_value = mock_llm search_tool = _make_registered_tool("search", result_data="Python docs found") registry = ToolRegistry() registry.register(search_tool) agent = Agent(tool_registry=registry) events = [] async for event in agent.run("Tell me about Python"): events.append(event) assert len(events) >= 1 @pytest.mark.asyncio @patch("app.agents.core.agent.get_chat_model") async def test_max_iterations_limit(self, mock_get_model): tool_call_response = MagicMock() tool_call_response.content = "" tool_call_response.tool_calls = [ {"id": "call_n", "name": "loop_tool", "args": {}} ] mock_llm = MagicMock() mock_llm.ainvoke = AsyncMock(return_value=tool_call_response) mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_get_model.return_value = mock_llm loop_tool = _make_registered_tool("loop_tool", result_data="looping") registry = ToolRegistry() registry.register(loop_tool) agent = Agent(tool_registry=registry, max_iterations=3) events = [] async for event in agent.run("infinite loop query"): events.append(event) assert loop_tool.execute.call_count <= 3 @pytest.mark.asyncio @patch("app.agents.core.agent.get_chat_model") async def test_run_empty_query(self, mock_get_model): mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = "I need more information." mock_response.tool_calls = [] mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_llm.bind_tools = MagicMock(return_value=mock_llm) mock_get_model.return_value = mock_llm registry = ToolRegistry() agent = Agent(tool_registry=registry) events = [] async for event in agent.run(""): events.append(event) assert len(events) >= 1 class TestAgentEvent: def test_creation(self): event = AgentEvent(type="message", data={"text": "hello"}) assert event.type == "message" assert event.data == {"text": "hello"} def test_creation_with_string_data(self): event = AgentEvent(type="error", data="something failed") assert event.type == "error" assert event.data == "something failed"