81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from app.agents.tools.types import ToolResult
|
|
from app.services.llm import ModelTier, get_chat_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SubTool:
|
|
name: str
|
|
description: str
|
|
handler: Callable[[dict[str, Any]], Awaitable[ToolResult]]
|
|
|
|
|
|
class MetaTool(ABC):
|
|
"""NL 쿼리를 LLM으로 라우팅해 적절한 sub-tool을 실행하는 추상 베이스 클래스."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def sub_tools(self) -> list[SubTool]: ...
|
|
|
|
async def route(self, query: str) -> SubTool:
|
|
"""LLM을 사용해 쿼리에 가장 적합한 sub-tool을 선택한다."""
|
|
tool_list = "\n".join(
|
|
f"- {st.name}: {st.description}" for st in self.sub_tools
|
|
)
|
|
|
|
prompt = (
|
|
"You are a tool router. Given a user query and available tools, "
|
|
"select the best tool.\n\n"
|
|
f"Available tools:\n{tool_list}\n\n"
|
|
f"User query: {query}\n\n"
|
|
'Respond with ONLY a JSON object: {{"tool": "tool_name"}}'
|
|
)
|
|
|
|
llm = get_chat_model(tier=ModelTier.FAST, temperature=0.0)
|
|
response = llm.invoke(prompt)
|
|
|
|
content = response.content
|
|
if isinstance(content, list):
|
|
content = content[0] if content else ""
|
|
content = str(content).strip()
|
|
|
|
if content.startswith("```"):
|
|
content = content.split("\n", 1)[-1]
|
|
content = content.rsplit("```", 1)[0]
|
|
|
|
parsed = json.loads(content)
|
|
tool_name = parsed["tool"]
|
|
|
|
tool_map = {st.name: st for st in self.sub_tools}
|
|
sub_tool = tool_map.get(tool_name)
|
|
if sub_tool is None:
|
|
raise ValueError(
|
|
f"LLM selected unknown sub-tool '{tool_name}'. "
|
|
f"Available: {list(tool_map.keys())}"
|
|
)
|
|
|
|
return sub_tool
|
|
|
|
async def execute(self, params: dict[str, Any]) -> ToolResult:
|
|
"""NL 쿼리를 받아 라우팅 후 sub-tool을 실행한다."""
|
|
query = params.get("query", "")
|
|
if not query:
|
|
return ToolResult(data="Error: 'query' parameter is required.")
|
|
|
|
try:
|
|
sub_tool = await self.route(query)
|
|
logger.info("MetaTool routed to %s", sub_tool.name)
|
|
return await sub_tool.handler(params)
|
|
except Exception as e:
|
|
logger.exception("MetaTool execution failed")
|
|
return ToolResult(data=f"Error: {e}")
|