200 lines
6.3 KiB
Python
200 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncGenerator
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_core.tools import StructuredTool
|
|
|
|
from app.agents.core.compact import ContextCompactor
|
|
from app.agents.core.prompts import SystemPromptBuilder
|
|
from app.agents.core.rules import RulesLoader
|
|
from app.agents.core.scratchpad import Scratchpad
|
|
from app.agents.core.tool_executor import ToolCall, ToolExecutor
|
|
from app.agents.skills.registry import SkillRegistry
|
|
from app.agents.tools.registry import ToolRegistry
|
|
from app.agents.tools.types import RegisteredTool
|
|
from app.services.llm import ModelTier, get_chat_model
|
|
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
model_tier: ModelTier = ModelTier.STRONG
|
|
max_iterations: int = 20
|
|
temperature: float = 0.0
|
|
compact_threshold: int = 50000
|
|
|
|
|
|
@dataclass
|
|
class AgentEvent:
|
|
type: str
|
|
data: Any = field(default_factory=dict)
|
|
|
|
|
|
def _make_langchain_tool(rt: RegisteredTool) -> StructuredTool:
|
|
async def _fn(**kwargs: Any) -> str:
|
|
result = await rt.execute(kwargs)
|
|
return result.to_str()
|
|
|
|
return StructuredTool.from_function(
|
|
coroutine=_fn,
|
|
name=rt.name,
|
|
description=rt.description,
|
|
)
|
|
|
|
|
|
class Agent:
|
|
def __init__(
|
|
self,
|
|
config: AgentConfig | None = None,
|
|
tool_registry: ToolRegistry | None = None,
|
|
max_iterations: int | None = None,
|
|
) -> None:
|
|
self._config = config or AgentConfig()
|
|
self._tool_registry = tool_registry
|
|
if max_iterations is not None:
|
|
self._config.max_iterations = max_iterations
|
|
self._scratchpad = Scratchpad()
|
|
self._compactor = ContextCompactor()
|
|
|
|
async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
|
|
if self._tool_registry is not None:
|
|
tool_registry = self._tool_registry
|
|
else:
|
|
tool_registry = ToolRegistry.auto_register()
|
|
|
|
SkillRegistry.discover()
|
|
|
|
tools_section = tool_registry.build_compact_descriptions()
|
|
skills_section = SkillRegistry.build_skills_section()
|
|
rules = RulesLoader.load_rules()
|
|
|
|
system_prompt = SystemPromptBuilder.build(
|
|
tools_section=tools_section,
|
|
skills_section=skills_section,
|
|
rules=rules,
|
|
)
|
|
|
|
registered_tools = tool_registry.list_tools()
|
|
lc_tools = [_make_langchain_tool(rt) for rt in registered_tools]
|
|
|
|
llm = get_chat_model(
|
|
tier=self._config.model_tier,
|
|
temperature=self._config.temperature,
|
|
)
|
|
model = llm.bind_tools(lc_tools) if lc_tools else llm
|
|
|
|
messages: list = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=query),
|
|
]
|
|
|
|
executor = ToolExecutor(tool_registry)
|
|
|
|
for _iteration in range(self._config.max_iterations):
|
|
yield AgentEvent(type="thinking")
|
|
|
|
response: AIMessage = await model.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
content = (
|
|
response.content
|
|
if isinstance(response.content, str)
|
|
else str(response.content)
|
|
)
|
|
yield AgentEvent(type="response", data={"content": content})
|
|
yield AgentEvent(
|
|
type="done", data={"final_response": content}
|
|
)
|
|
return
|
|
|
|
tool_calls = [
|
|
ToolCall(
|
|
id=tc["id"],
|
|
name=tc["name"],
|
|
params=tc.get("args", {}),
|
|
)
|
|
for tc in response.tool_calls
|
|
]
|
|
|
|
for tc in tool_calls:
|
|
yield AgentEvent(
|
|
type="tool_start",
|
|
data={"tool_name": tc.name, "params": tc.params},
|
|
)
|
|
|
|
async for event in executor.execute_tool_calls(tool_calls):
|
|
if event.error:
|
|
yield AgentEvent(
|
|
type="tool_error",
|
|
data={
|
|
"tool_name": event.tool_name,
|
|
"error": event.error,
|
|
},
|
|
)
|
|
result_str = f"오류: {event.error}"
|
|
else:
|
|
result_str = (
|
|
event.result.to_str() if event.result else ""
|
|
)
|
|
yield AgentEvent(
|
|
type="tool_end",
|
|
data={
|
|
"tool_name": event.tool_name,
|
|
"result": result_str,
|
|
},
|
|
)
|
|
|
|
self._scratchpad.add(
|
|
tool_name=event.tool_name,
|
|
params={},
|
|
result=result_str,
|
|
)
|
|
|
|
messages.append(
|
|
ToolMessage(
|
|
content=result_str,
|
|
tool_call_id=event.call_id,
|
|
)
|
|
)
|
|
|
|
msg_dicts = [
|
|
{
|
|
"role": getattr(m, "type", "unknown"),
|
|
"content": str(getattr(m, "content", "")),
|
|
}
|
|
for m in messages
|
|
]
|
|
if self._compactor.should_compact(
|
|
msg_dicts, self._config.compact_threshold
|
|
):
|
|
compacted = await self._compactor.compact(
|
|
msg_dicts, system_prompt
|
|
)
|
|
messages = [
|
|
SystemMessage(content=compacted[0]["content"]),
|
|
HumanMessage(content=compacted[1]["content"]),
|
|
]
|
|
yield AgentEvent(type="compaction")
|
|
|
|
yield AgentEvent(
|
|
type="response",
|
|
data={
|
|
"content": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다."
|
|
},
|
|
)
|
|
yield AgentEvent(
|
|
type="done",
|
|
data={
|
|
"final_response": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다."
|
|
},
|
|
)
|