galaxis-po/backend/tests/unit/test_agent_core.py
머니페니 76e3220e77
All checks were successful
Deploy to Production / deploy (push) Successful in 3m10s
feat: 에이전트 기능 추가 (LLM 서비스, 에이전트 API, 테스트)
2026-05-06 20:56:45 +09:00

380 lines
13 KiB
Python

"""Tests for agent core: scratchpad, context compactor, tool executor, and agent."""
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from app.agents.tools.types import ToolResult, RegisteredTool
from app.agents.tools.registry import ToolRegistry
from app.agents.core.scratchpad import Scratchpad
from app.agents.core.compact import ContextCompactor
from app.agents.core.tool_executor import ToolExecutor, ToolCall, ToolEvent
from app.agents.core.agent import Agent, AgentEvent
# ---------------------------------------------------------------------------
# Scratchpad
# ---------------------------------------------------------------------------
class TestScratchpad:
def test_add_and_total_tokens(self):
sp = Scratchpad()
sp.add(tool_name="search", params={"q": "test"}, result="found 3 results")
sp.add(tool_name="calc", params={"expr": "2+2"}, result="4")
assert sp.total_tokens() > 0
def test_format_history(self):
sp = Scratchpad()
sp.add(tool_name="search", params={"q": "python"}, result="python docs")
history = sp.format_history()
assert "search" in history
assert "python" in history
def test_format_history_empty(self):
sp = Scratchpad()
history = sp.format_history()
assert isinstance(history, str)
def test_clear(self):
sp = Scratchpad()
sp.add(tool_name="tool1", params={}, result="result_one")
sp.add(tool_name="tool2", params={}, result="result_two")
assert sp.total_tokens() > 0
sp.clear()
assert sp.total_tokens() == 0
def test_multiple_entries_tracked(self):
sp = Scratchpad()
for i in range(5):
sp.add(tool_name=f"tool_{i}", params={"i": i}, result=f"result_{i}")
history = sp.format_history()
for i in range(5):
assert f"tool_{i}" in history
# ---------------------------------------------------------------------------
# ContextCompactor
# ---------------------------------------------------------------------------
class TestContextCompactor:
def test_should_compact_false_below_threshold(self):
compactor = ContextCompactor()
messages = [{"role": "user", "content": "short message"}]
assert compactor.should_compact(messages, threshold=10000) is False
def test_should_compact_true_above_threshold(self):
compactor = ContextCompactor()
long_content = "x" * 5000
messages = [
{"role": "user", "content": long_content},
{"role": "assistant", "content": long_content},
{"role": "user", "content": long_content},
]
assert compactor.should_compact(messages, threshold=100) is True
@pytest.mark.asyncio
@patch("app.agents.core.compact.get_chat_model")
async def test_compact_calls_llm(self, mock_get_model):
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Summary: user asked about Python, assistant explained basics."
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_get_model.return_value = mock_llm
compactor = ContextCompactor()
messages = [
{"role": "user", "content": "Tell me about Python"},
{"role": "assistant", "content": "Python is a programming language..."},
{"role": "user", "content": "What about async?"},
{"role": "assistant", "content": "Async in Python uses asyncio..."},
]
result = await compactor.compact(messages, system_prompt="You are helpful.")
assert isinstance(result, list)
assert len(result) > 0
mock_llm.ainvoke.assert_called_once()
@pytest.mark.asyncio
@patch("app.agents.core.compact.get_chat_model")
async def test_compact_preserves_system_context(self, mock_get_model):
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Summarized conversation."
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_get_model.return_value = mock_llm
compactor = ContextCompactor()
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
result = await compactor.compact(messages, system_prompt="System prompt here")
assert isinstance(result, list)
# ---------------------------------------------------------------------------
# ToolExecutor
# ---------------------------------------------------------------------------
def _make_registered_tool(
name: str, concurrency_safe: bool = True, result_data: str = "ok"
) -> RegisteredTool:
return RegisteredTool(
name=name,
description=f"Desc {name}",
compact_description=f"compact {name}",
concurrency_safe=concurrency_safe,
execute=AsyncMock(return_value=ToolResult(data=result_data)),
)
class TestToolExecutor:
@pytest.mark.asyncio
async def test_concurrent_safe_tools(self):
tool_a = _make_registered_tool("tool_a", concurrency_safe=True, result_data="A")
tool_b = _make_registered_tool("tool_b", concurrency_safe=True, result_data="B")
registry = ToolRegistry()
registry.register(tool_a)
registry.register(tool_b)
executor = ToolExecutor(registry=registry)
calls = [
ToolCall(id="c1", name="tool_a", params={"q": "hello"}),
ToolCall(id="c2", name="tool_b", params={"q": "world"}),
]
events = []
async for event in executor.execute_tool_calls(calls):
events.append(event)
assert len(events) == 2
names = {e.tool_name for e in events}
assert "tool_a" in names
assert "tool_b" in names
tool_a.execute.assert_called_once()
tool_b.execute.assert_called_once()
@pytest.mark.asyncio
async def test_serial_tools_run_sequentially(self):
tool_s = _make_registered_tool("serial_tool", concurrency_safe=False, result_data="S")
registry = ToolRegistry()
registry.register(tool_s)
executor = ToolExecutor(registry=registry)
calls = [
ToolCall(id="c1", name="serial_tool", params={"x": 1}),
]
events = []
async for event in executor.execute_tool_calls(calls):
events.append(event)
assert len(events) == 1
assert events[0].tool_name == "serial_tool"
assert events[0].error is None
@pytest.mark.asyncio
async def test_error_handling_for_failed_tool(self):
failing_tool = RegisteredTool(
name="fail_tool",
description="A tool that fails",
compact_description="fail",
concurrency_safe=True,
execute=AsyncMock(side_effect=RuntimeError("Tool exploded")),
)
registry = ToolRegistry()
registry.register(failing_tool)
executor = ToolExecutor(registry=registry)
calls = [ToolCall(id="c1", name="fail_tool", params={})]
events = []
async for event in executor.execute_tool_calls(calls):
events.append(event)
assert len(events) == 1
assert events[0].error is not None
assert "exploded" in events[0].error.lower() or events[0].error
@pytest.mark.asyncio
async def test_mixed_concurrent_and_serial(self):
safe = _make_registered_tool("safe", concurrency_safe=True, result_data="safe_r")
unsafe = _make_registered_tool("unsafe", concurrency_safe=False, result_data="unsafe_r")
registry = ToolRegistry()
registry.register(safe)
registry.register(unsafe)
executor = ToolExecutor(registry=registry)
calls = [
ToolCall(id="c1", name="safe", params={}),
ToolCall(id="c2", name="unsafe", params={}),
]
events = []
async for event in executor.execute_tool_calls(calls):
events.append(event)
assert len(events) == 2
results = {e.tool_name: e.result for e in events}
assert results["safe"].data == "safe_r"
assert results["unsafe"].data == "unsafe_r"
@pytest.mark.asyncio
async def test_unknown_tool_returns_error(self):
registry = ToolRegistry()
executor = ToolExecutor(registry=registry)
calls = [ToolCall(id="c1", name="nonexistent", params={})]
events = []
async for event in executor.execute_tool_calls(calls):
events.append(event)
assert len(events) == 1
assert events[0].error is not None
# ---------------------------------------------------------------------------
# ToolCall & ToolEvent dataclass tests
# ---------------------------------------------------------------------------
class TestToolCallAndEvent:
def test_tool_call_creation(self):
tc = ToolCall(id="abc", name="my_tool", params={"key": "val"})
assert tc.id == "abc"
assert tc.name == "my_tool"
assert tc.params == {"key": "val"}
def test_tool_event_success(self):
result = ToolResult(data="success")
ev = ToolEvent(call_id="c1", tool_name="t1", result=result, error=None)
assert ev.result.data == "success"
assert ev.error is None
def test_tool_event_error(self):
ev = ToolEvent(call_id="c1", tool_name="t1", result=None, error="something went wrong")
assert ev.result is None
assert ev.error == "something went wrong"
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class TestAgent:
@pytest.mark.asyncio
@patch("app.agents.core.agent.get_chat_model")
async def test_run_yields_events(self, mock_get_model):
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Here is the answer to your question."
mock_response.tool_calls = []
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_get_model.return_value = mock_llm
registry = ToolRegistry()
agent = Agent(tool_registry=registry)
events = []
async for event in agent.run("What is 2+2?"):
events.append(event)
assert len(events) >= 1
event_types = [e.type for e in events]
assert any(t in event_types for t in ["message", "answer", "response", "text"])
@pytest.mark.asyncio
@patch("app.agents.core.agent.get_chat_model")
async def test_run_with_tool_call(self, mock_get_model):
tool_call_response = MagicMock()
tool_call_response.content = ""
tool_call_response.tool_calls = [
{"id": "call_1", "name": "search", "args": {"q": "python"}}
]
final_response = MagicMock()
final_response.content = "Python is a programming language."
final_response.tool_calls = []
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_response, final_response])
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_get_model.return_value = mock_llm
search_tool = _make_registered_tool("search", result_data="Python docs found")
registry = ToolRegistry()
registry.register(search_tool)
agent = Agent(tool_registry=registry)
events = []
async for event in agent.run("Tell me about Python"):
events.append(event)
assert len(events) >= 1
@pytest.mark.asyncio
@patch("app.agents.core.agent.get_chat_model")
async def test_max_iterations_limit(self, mock_get_model):
tool_call_response = MagicMock()
tool_call_response.content = ""
tool_call_response.tool_calls = [
{"id": "call_n", "name": "loop_tool", "args": {}}
]
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(return_value=tool_call_response)
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_get_model.return_value = mock_llm
loop_tool = _make_registered_tool("loop_tool", result_data="looping")
registry = ToolRegistry()
registry.register(loop_tool)
agent = Agent(tool_registry=registry, max_iterations=3)
events = []
async for event in agent.run("infinite loop query"):
events.append(event)
assert loop_tool.execute.call_count <= 3
@pytest.mark.asyncio
@patch("app.agents.core.agent.get_chat_model")
async def test_run_empty_query(self, mock_get_model):
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "I need more information."
mock_response.tool_calls = []
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_get_model.return_value = mock_llm
registry = ToolRegistry()
agent = Agent(tool_registry=registry)
events = []
async for event in agent.run(""):
events.append(event)
assert len(events) >= 1
class TestAgentEvent:
def test_creation(self):
event = AgentEvent(type="message", data={"text": "hello"})
assert event.type == "message"
assert event.data == {"text": "hello"}
def test_creation_with_string_data(self):
event = AgentEvent(type="error", data="something failed")
assert event.type == "error"
assert event.data == "something failed"