feat: 에이전트 기능 추가 (LLM 서비스, 에이전트 API, 테스트)
All checks were successful
Deploy to Production / deploy (push) Successful in 3m10s
All checks were successful
Deploy to Production / deploy (push) Successful in 3m10s
This commit is contained in:
parent
34d09d9d34
commit
76e3220e77
0
backend/app/agents/__init__.py
Normal file
0
backend/app/agents/__init__.py
Normal file
15
backend/app/agents/core/__init__.py
Normal file
15
backend/app/agents/core/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.core.agent import Agent
|
||||
from app.agents.core.compact import ContextCompactor
|
||||
from app.agents.core.prompts import SystemPromptBuilder
|
||||
from app.agents.core.scratchpad import Scratchpad
|
||||
from app.agents.core.tool_executor import ToolExecutor
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"ContextCompactor",
|
||||
"Scratchpad",
|
||||
"SystemPromptBuilder",
|
||||
"ToolExecutor",
|
||||
]
|
||||
199
backend/app/agents/core/agent.py
Normal file
199
backend/app/agents/core/agent.py
Normal file
@ -0,0 +1,199 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from app.agents.core.compact import ContextCompactor
|
||||
from app.agents.core.prompts import SystemPromptBuilder
|
||||
from app.agents.core.rules import RulesLoader
|
||||
from app.agents.core.scratchpad import Scratchpad
|
||||
from app.agents.core.tool_executor import ToolCall, ToolExecutor
|
||||
from app.agents.skills.registry import SkillRegistry
|
||||
from app.agents.tools.registry import ToolRegistry
|
||||
from app.agents.tools.types import RegisteredTool
|
||||
from app.services.llm import ModelTier, get_chat_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
model_tier: ModelTier = ModelTier.STRONG
|
||||
max_iterations: int = 20
|
||||
temperature: float = 0.0
|
||||
compact_threshold: int = 50000
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentEvent:
|
||||
type: str
|
||||
data: Any = field(default_factory=dict)
|
||||
|
||||
|
||||
def _make_langchain_tool(rt: RegisteredTool) -> StructuredTool:
|
||||
async def _fn(**kwargs: Any) -> str:
|
||||
result = await rt.execute(kwargs)
|
||||
return result.to_str()
|
||||
|
||||
return StructuredTool.from_function(
|
||||
coroutine=_fn,
|
||||
name=rt.name,
|
||||
description=rt.description,
|
||||
)
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentConfig | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
self._config = config or AgentConfig()
|
||||
self._tool_registry = tool_registry
|
||||
if max_iterations is not None:
|
||||
self._config.max_iterations = max_iterations
|
||||
self._scratchpad = Scratchpad()
|
||||
self._compactor = ContextCompactor()
|
||||
|
||||
async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
|
||||
if self._tool_registry is not None:
|
||||
tool_registry = self._tool_registry
|
||||
else:
|
||||
tool_registry = ToolRegistry.auto_register()
|
||||
|
||||
SkillRegistry.discover()
|
||||
|
||||
tools_section = tool_registry.build_compact_descriptions()
|
||||
skills_section = SkillRegistry.build_skills_section()
|
||||
rules = RulesLoader.load_rules()
|
||||
|
||||
system_prompt = SystemPromptBuilder.build(
|
||||
tools_section=tools_section,
|
||||
skills_section=skills_section,
|
||||
rules=rules,
|
||||
)
|
||||
|
||||
registered_tools = tool_registry.list_tools()
|
||||
lc_tools = [_make_langchain_tool(rt) for rt in registered_tools]
|
||||
|
||||
llm = get_chat_model(
|
||||
tier=self._config.model_tier,
|
||||
temperature=self._config.temperature,
|
||||
)
|
||||
model = llm.bind_tools(lc_tools) if lc_tools else llm
|
||||
|
||||
messages: list = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=query),
|
||||
]
|
||||
|
||||
executor = ToolExecutor(tool_registry)
|
||||
|
||||
for _iteration in range(self._config.max_iterations):
|
||||
yield AgentEvent(type="thinking")
|
||||
|
||||
response: AIMessage = await model.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
content = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
yield AgentEvent(type="response", data={"content": content})
|
||||
yield AgentEvent(
|
||||
type="done", data={"final_response": content}
|
||||
)
|
||||
return
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["name"],
|
||||
params=tc.get("args", {}),
|
||||
)
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
|
||||
for tc in tool_calls:
|
||||
yield AgentEvent(
|
||||
type="tool_start",
|
||||
data={"tool_name": tc.name, "params": tc.params},
|
||||
)
|
||||
|
||||
async for event in executor.execute_tool_calls(tool_calls):
|
||||
if event.error:
|
||||
yield AgentEvent(
|
||||
type="tool_error",
|
||||
data={
|
||||
"tool_name": event.tool_name,
|
||||
"error": event.error,
|
||||
},
|
||||
)
|
||||
result_str = f"오류: {event.error}"
|
||||
else:
|
||||
result_str = (
|
||||
event.result.to_str() if event.result else ""
|
||||
)
|
||||
yield AgentEvent(
|
||||
type="tool_end",
|
||||
data={
|
||||
"tool_name": event.tool_name,
|
||||
"result": result_str,
|
||||
},
|
||||
)
|
||||
|
||||
self._scratchpad.add(
|
||||
tool_name=event.tool_name,
|
||||
params={},
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
ToolMessage(
|
||||
content=result_str,
|
||||
tool_call_id=event.call_id,
|
||||
)
|
||||
)
|
||||
|
||||
msg_dicts = [
|
||||
{
|
||||
"role": getattr(m, "type", "unknown"),
|
||||
"content": str(getattr(m, "content", "")),
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
if self._compactor.should_compact(
|
||||
msg_dicts, self._config.compact_threshold
|
||||
):
|
||||
compacted = await self._compactor.compact(
|
||||
msg_dicts, system_prompt
|
||||
)
|
||||
messages = [
|
||||
SystemMessage(content=compacted[0]["content"]),
|
||||
HumanMessage(content=compacted[1]["content"]),
|
||||
]
|
||||
yield AgentEvent(type="compaction")
|
||||
|
||||
yield AgentEvent(
|
||||
type="response",
|
||||
data={
|
||||
"content": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다."
|
||||
},
|
||||
)
|
||||
yield AgentEvent(
|
||||
type="done",
|
||||
data={
|
||||
"final_response": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다."
|
||||
},
|
||||
)
|
||||
84
backend/app/agents/core/compact.py
Normal file
84
backend/app/agents/core/compact.py
Normal file
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.services.llm import ModelTier, get_chat_model
|
||||
|
||||
COMPACTION_PROMPT = """당신은 대화 요약 전문가입니다. 아래 대화를 구조화된 요약으로 압축하세요.
|
||||
|
||||
반드시 다음 섹션을 포함하세요:
|
||||
## 원래 질문
|
||||
## 핵심 데이터
|
||||
## 분석 진행 상황
|
||||
## 수치 데이터 (정확한 값 유지)
|
||||
## 오류/이슈
|
||||
## 다음 단계"""
|
||||
|
||||
|
||||
class ContextCompactor:
|
||||
"""대화 컨텍스트가 너무 커지면 LLM으로 요약하여 압축합니다."""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""텍스트의 토큰 수를 추정합니다 (~4자당 1토큰)."""
|
||||
return len(text) // 4
|
||||
|
||||
def should_compact(
|
||||
self, messages: list[dict], threshold: int = 50000
|
||||
) -> bool:
|
||||
"""메시지의 토큰 추정치가 임계값을 초과하는지 확인합니다.
|
||||
|
||||
Args:
|
||||
messages: 대화 메시지 목록.
|
||||
threshold: 토큰 임계값 (기본 50,000).
|
||||
|
||||
Returns:
|
||||
압축이 필요하면 True.
|
||||
"""
|
||||
total_text = "".join(
|
||||
str(m.get("content", "")) for m in messages
|
||||
)
|
||||
return self._estimate_tokens(total_text) > threshold
|
||||
|
||||
async def compact(
|
||||
self, messages: list[dict], system_prompt: str
|
||||
) -> list[dict]:
|
||||
"""대화를 요약하여 압축된 메시지 목록을 반환합니다.
|
||||
|
||||
Args:
|
||||
messages: 기존 대화 메시지 목록.
|
||||
system_prompt: 원래 시스템 프롬프트.
|
||||
|
||||
Returns:
|
||||
압축된 메시지 목록 (시스템 메시지 + 요약 메시지).
|
||||
"""
|
||||
# 대화 내용을 텍스트로 변환
|
||||
conversation_lines: list[str] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = str(msg.get("content", ""))
|
||||
conversation_lines.append(f"[{role}]: {content}")
|
||||
|
||||
conversation_text = "\n".join(conversation_lines)
|
||||
|
||||
llm = get_chat_model(tier=ModelTier.FAST)
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
SystemMessage(content=COMPACTION_PROMPT),
|
||||
HumanMessage(content=conversation_text),
|
||||
]
|
||||
)
|
||||
|
||||
summary_content = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"[이전 대화 요약]\n{summary_content}",
|
||||
},
|
||||
]
|
||||
59
backend/app/agents/core/prompts.py
Normal file
59
backend/app/agents/core/prompts.py
Normal file
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
_SYSTEM_PROMPT_TEMPLATE = """당신은 Galaxis-Po 투자 분석 에이전트입니다.
|
||||
한국 퇴직연금(DC) 포트폴리오 관리를 위한 퀀트 분석을 수행합니다.
|
||||
|
||||
## 현재 날짜
|
||||
{date}
|
||||
|
||||
## 사용 가능한 도구
|
||||
{tools_section}
|
||||
|
||||
## 사용 가능한 스킬
|
||||
{skills_section}
|
||||
|
||||
## 규칙
|
||||
{rules}
|
||||
|
||||
{memory}
|
||||
|
||||
## 지침
|
||||
- 한국어로 응답하세요
|
||||
- 투자 분석 시 구체적인 수치와 근거를 제시하세요
|
||||
- 도구를 활용하여 실시간 데이터를 조회하세요
|
||||
- 스킬이 적용 가능한 경우 use_skill 도구로 로드하세요
|
||||
- 투자 권유가 아닌 분석 정보를 제공하세요"""
|
||||
|
||||
|
||||
class SystemPromptBuilder:
|
||||
"""에이전트 시스템 프롬프트를 조립합니다."""
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
tools_section: str,
|
||||
skills_section: str,
|
||||
rules: str,
|
||||
memory: str = "",
|
||||
) -> str:
|
||||
"""구성 요소들을 조합하여 전체 시스템 프롬프트를 생성합니다.
|
||||
|
||||
Args:
|
||||
tools_section: 도구 설명 텍스트.
|
||||
skills_section: 스킬 설명 텍스트.
|
||||
rules: 프로젝트 규칙.
|
||||
memory: 추가 메모리/컨텍스트 (선택).
|
||||
|
||||
Returns:
|
||||
완성된 시스템 프롬프트 문자열.
|
||||
"""
|
||||
memory_block = f"## 메모리\n{memory}" if memory else ""
|
||||
|
||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||
date=date.today().isoformat(),
|
||||
tools_section=tools_section or "등록된 도구 없음",
|
||||
skills_section=skills_section or "등록된 스킬 없음",
|
||||
rules=rules or "추가 규칙 없음",
|
||||
memory=memory_block,
|
||||
)
|
||||
35
backend/app/agents/core/rules.py
Normal file
35
backend/app/agents/core/rules.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class RulesLoader:
|
||||
"""프로젝트 규칙 파일(RULES.md)을 로드합니다."""
|
||||
|
||||
@staticmethod
|
||||
def load_rules(project_root: Path | None = None) -> str:
|
||||
"""프로젝트 루트에서 RULES.md를 찾아 내용을 반환합니다.
|
||||
|
||||
탐색 순서:
|
||||
1. {project_root}/RULES.md
|
||||
2. {project_root}/.claude/RULES.md
|
||||
|
||||
Args:
|
||||
project_root: 프로젝트 루트 디렉토리. None이면 빈 문자열 반환.
|
||||
|
||||
Returns:
|
||||
RULES.md 내용 또는 빈 문자열.
|
||||
"""
|
||||
if project_root is None:
|
||||
return ""
|
||||
|
||||
candidates = [
|
||||
project_root / "RULES.md",
|
||||
project_root / ".claude" / "RULES.md",
|
||||
]
|
||||
|
||||
for path in candidates:
|
||||
if path.is_file():
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
return ""
|
||||
68
backend/app/agents/core/scratchpad.py
Normal file
68
backend/app/agents/core/scratchpad.py
Normal file
@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScratchpadEntry:
|
||||
"""도구 실행 기록 항목."""
|
||||
|
||||
tool_name: str
|
||||
params: dict
|
||||
result: str
|
||||
timestamp: datetime
|
||||
token_estimate: int
|
||||
|
||||
|
||||
class Scratchpad:
|
||||
"""도구 실행 이력과 토큰 추정치를 추적합니다."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.entries: list[ScratchpadEntry] = []
|
||||
|
||||
def add(self, tool_name: str, params: dict, result: str) -> None:
|
||||
"""도구 실행 결과를 기록합니다.
|
||||
|
||||
Args:
|
||||
tool_name: 실행된 도구 이름.
|
||||
params: 도구에 전달된 매개변수.
|
||||
result: 도구 실행 결과 문자열.
|
||||
"""
|
||||
token_estimate = max(len(result) // 4, 1)
|
||||
entry = ScratchpadEntry(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
result=result,
|
||||
timestamp=datetime.now(),
|
||||
token_estimate=token_estimate,
|
||||
)
|
||||
self.entries.append(entry)
|
||||
|
||||
def total_tokens(self) -> int:
|
||||
"""전체 기록의 추정 토큰 수를 반환합니다."""
|
||||
return sum(e.token_estimate for e in self.entries)
|
||||
|
||||
def format_history(self) -> str:
|
||||
"""도구 실행 이력을 포맷된 문자열로 반환합니다."""
|
||||
if not self.entries:
|
||||
return "도구 실행 이력 없음"
|
||||
|
||||
lines: list[str] = []
|
||||
for i, entry in enumerate(self.entries, 1):
|
||||
lines.append(
|
||||
f"[{i}] {entry.tool_name} "
|
||||
f"({entry.timestamp.strftime('%H:%M:%S')}) "
|
||||
f"토큰≈{entry.token_estimate}"
|
||||
)
|
||||
lines.append(f" params: {entry.params}")
|
||||
# 결과가 너무 길면 잘라냄
|
||||
result_preview = entry.result[:200]
|
||||
if len(entry.result) > 200:
|
||||
result_preview += "..."
|
||||
lines.append(f" result: {result_preview}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""모든 기록을 삭제합니다."""
|
||||
self.entries.clear()
|
||||
102
backend/app/agents/core/tool_executor.py
Normal file
102
backend/app/agents/core/tool_executor.py
Normal file
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import ToolResult
|
||||
from app.agents.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""도구 호출 요청."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
params: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolEvent:
|
||||
"""도구 실행 결과 이벤트."""
|
||||
|
||||
call_id: str
|
||||
tool_name: str
|
||||
result: ToolResult | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""도구 호출을 동시성 안전 여부에 따라 분류하여 실행합니다."""
|
||||
|
||||
def __init__(
|
||||
self, registry: ToolRegistry, max_concurrency: int = 5
|
||||
) -> None:
|
||||
self._registry = registry
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def _execute_single(self, call: ToolCall) -> ToolEvent:
|
||||
"""단일 도구를 실행하고 ToolEvent를 반환합니다."""
|
||||
tool = self._registry.get(call.name)
|
||||
if tool is None:
|
||||
return ToolEvent(
|
||||
call_id=call.id,
|
||||
tool_name=call.name,
|
||||
error=f"도구를 찾을 수 없습니다: {call.name}",
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._semaphore:
|
||||
result = await tool.execute(call.params)
|
||||
return ToolEvent(
|
||||
call_id=call.id,
|
||||
tool_name=call.name,
|
||||
result=result,
|
||||
)
|
||||
except Exception as exc:
|
||||
return ToolEvent(
|
||||
call_id=call.id,
|
||||
tool_name=call.name,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
|
||||
async def execute_tool_calls(
|
||||
self, tool_calls: list[ToolCall]
|
||||
) -> AsyncGenerator[ToolEvent, None]:
|
||||
"""도구 호출 목록을 실행합니다.
|
||||
|
||||
동시성 안전한 도구는 asyncio.gather로 병렬 실행하고,
|
||||
그렇지 않은 도구는 순차적으로 실행합니다.
|
||||
|
||||
Args:
|
||||
tool_calls: 실행할 도구 호출 목록.
|
||||
|
||||
Yields:
|
||||
각 도구 실행 결과에 대한 ToolEvent.
|
||||
"""
|
||||
concurrency_map = self._registry.get_concurrency_map()
|
||||
|
||||
concurrent_calls: list[ToolCall] = []
|
||||
serial_calls: list[ToolCall] = []
|
||||
|
||||
for call in tool_calls:
|
||||
if concurrency_map.get(call.name, False):
|
||||
concurrent_calls.append(call)
|
||||
else:
|
||||
serial_calls.append(call)
|
||||
|
||||
# 병렬 실행 가능한 도구들
|
||||
if concurrent_calls:
|
||||
tasks = [
|
||||
self._execute_single(call) for call in concurrent_calls
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
for event in results:
|
||||
yield event
|
||||
|
||||
# 순차 실행해야 하는 도구들
|
||||
for call in serial_calls:
|
||||
event = await self._execute_single(call)
|
||||
yield event
|
||||
12
backend/app/agents/skills/__init__.py
Normal file
12
backend/app/agents/skills/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from app.agents.skills.loader import SkillLoader
|
||||
from app.agents.skills.registry import SkillRegistry
|
||||
from app.agents.skills.tool import create_skill_tool
|
||||
from app.agents.skills.types import Skill, SkillMetadata
|
||||
|
||||
__all__ = [
|
||||
"Skill",
|
||||
"SkillLoader",
|
||||
"SkillMetadata",
|
||||
"SkillRegistry",
|
||||
"create_skill_tool",
|
||||
]
|
||||
0
backend/app/agents/skills/builtin/__init__.py
Normal file
0
backend/app/agents/skills/builtin/__init__.py
Normal file
89
backend/app/agents/skills/builtin/dcf/SKILL.md
Normal file
89
backend/app/agents/skills/builtin/dcf/SKILL.md
Normal file
@ -0,0 +1,89 @@
|
||||
---
|
||||
name: dcf-kr
|
||||
description: 한국 주식시장용 DCF(할인현금흐름) 밸류에이션 분석
|
||||
---
|
||||
|
||||
# DCF 밸류에이션 분석 (한국 시장)
|
||||
|
||||
## 실행 체크리스트
|
||||
|
||||
- [ ] 1단계: 재무 데이터 수집
|
||||
- [ ] 2단계: FCF 성장률 산출
|
||||
- [ ] 3단계: WACC 추정
|
||||
- [ ] 4단계: 미래 현금흐름 추정
|
||||
- [ ] 5단계: 현재가치 할인 및 주당 내재가치 산출
|
||||
- [ ] 6단계: 민감도 분석
|
||||
- [ ] 7단계: 검증
|
||||
- [ ] 8단계: 최종 보고서 작성
|
||||
|
||||
## 1단계: 재무 데이터 수집
|
||||
|
||||
get_financials 도구로 다음 데이터를 수집:
|
||||
- 최근 5년 현금흐름표 (영업활동CF, CAPEX, FCF)
|
||||
- 재무비율 (ROE, ROA, 부채비율)
|
||||
- 대차대조표 (총부채, 현금성자산, 발행주식수)
|
||||
- 시가총액, 현재 주가
|
||||
|
||||
## 2단계: FCF 성장률 산출
|
||||
|
||||
- 5년 FCF CAGR 계산
|
||||
- 매출 성장률, 영업이익률 추이와 교차 검증
|
||||
- 한국 시장 특성: 성장률 상한 15%
|
||||
|
||||
## 3단계: WACC 추정
|
||||
|
||||
한국 시장 기본 가정:
|
||||
- 무위험이자율: 3.0-3.5% (국고채 10년)
|
||||
- 시장위험프리미엄: 6-8%
|
||||
- 베타: 업종별 차등 적용
|
||||
|
||||
### 업종별 WACC 참고 범위
|
||||
| 업종 | WACC 범위 |
|
||||
|------|-----------|
|
||||
| 반도체 | 9-11% |
|
||||
| 자동차 | 8-10% |
|
||||
| 바이오 | 11-14% |
|
||||
| 인터넷/플랫폼 | 9-12% |
|
||||
| 유통/소매 | 8-10% |
|
||||
| 금융 | 7-9% |
|
||||
| 건설 | 9-11% |
|
||||
| 에너지 | 8-10% |
|
||||
| 통신 | 7-9% |
|
||||
| 유틸리티 | 6-8% |
|
||||
|
||||
## 4단계: 미래 현금흐름 추정 (5년)
|
||||
|
||||
- Year 1-5 FCF = 전년 FCF × (1 + 성장률) × 감쇠계수
|
||||
- 감쇠계수: [1.0, 0.95, 0.90, 0.85, 0.80]
|
||||
- 영구성장가치(Terminal Value) = Year5 FCF × (1 + 영구성장률) / (WACC - 영구성장률)
|
||||
- 영구성장률: 2.0-2.5% (한국 GDP 성장률 수준)
|
||||
|
||||
## 5단계: 현재가치 할인 및 주당 내재가치
|
||||
|
||||
- 각 연도 FCF를 WACC로 할인
|
||||
- Enterprise Value = PV(FCF) + PV(Terminal Value)
|
||||
- Equity Value = EV - 순부채 (총부채 - 현금)
|
||||
- 주당 내재가치 = Equity Value / 발행주식수
|
||||
|
||||
## 6단계: 민감도 분석
|
||||
|
||||
3×3 매트릭스:
|
||||
- WACC: 기본값 ± 1%
|
||||
- 영구성장률: 2.0%, 2.5%, 3.0%
|
||||
|
||||
## 7단계: 검증
|
||||
|
||||
- [ ] EV 산출값과 시가총액 비교 (±50% 이내)
|
||||
- [ ] Terminal Value 비중 < 75%
|
||||
- [ ] EV/FCF 배수 합리성 (10-30x)
|
||||
- [ ] 동종업계 PER과 비교
|
||||
|
||||
## 8단계: 최종 보고서
|
||||
|
||||
1. 기업 개요 및 투자 논거
|
||||
2. 핵심 가정 (성장률, WACC, 영구성장률)
|
||||
3. DCF 밸류에이션 결과 (주당 내재가치, 현재가 대비 괴리율)
|
||||
4. 민감도 분석 매트릭스
|
||||
5. 리스크 요인 및 제한사항
|
||||
|
||||
> 본 분석은 참고용이며 투자 권유가 아닙니다. 실제 투자 결정 전 전문가 상담을 권장합니다.
|
||||
@ -0,0 +1,72 @@
|
||||
---
|
||||
name: kim-jong-bong-strategy
|
||||
description: 김종봉 단기매매 전략 - 상대강도 기반 종목 선정 및 매매 시그널 생성
|
||||
---
|
||||
|
||||
# 김종봉 전략 분석
|
||||
|
||||
## 실행 체크리스트
|
||||
|
||||
- [ ] 1단계: 유니버스 구성 (시가총액 Top 30, 거래대금 2000억+)
|
||||
- [ ] 2단계: 상대강도 분석 (종목 수익률 vs 코스피)
|
||||
- [ ] 3단계: 차트 패턴 감지 (박스권 돌파, 장대양봉)
|
||||
- [ ] 4단계: 매수 시그널 생성
|
||||
- [ ] 5단계: 포트폴리오 구성 및 리스크 관리
|
||||
- [ ] 6단계: 청산 규칙 적용
|
||||
|
||||
## 전략 개요
|
||||
|
||||
- 목표 수익률: 월 10%
|
||||
- 종목 수: 5-10개
|
||||
- 현금 비중: 30%
|
||||
- 보유 기간: 수일~수주
|
||||
|
||||
## 1단계: 유니버스 구성
|
||||
|
||||
get_market_data 도구로 데이터 수집:
|
||||
- 시가총액 상위 30개 종목
|
||||
- 일 거래대금 >= 2,000억원 필터
|
||||
|
||||
## 2단계: 상대강도 분석
|
||||
|
||||
- RS = (종목 수익률 / 코스피 수익률) × 100
|
||||
- RS > 100이면 시장 대비 강세
|
||||
- 기본 lookback: 10일 (2주)
|
||||
|
||||
## 3단계: 차트 패턴 감지
|
||||
|
||||
### 박스권 돌파
|
||||
- 종가 > 직전 20일 최고가
|
||||
|
||||
### 장대양봉
|
||||
- 일 수익률 >= 5%
|
||||
- 거래량 >= 20일 평균 × 1.5
|
||||
|
||||
## 4단계: 매수 시그널
|
||||
|
||||
매수 조건 = RS > 100 AND (박스권 돌파 OR 장대양봉)
|
||||
|
||||
## 5단계: 포트폴리오 구성
|
||||
|
||||
| 항목 | 기준 |
|
||||
|------|------|
|
||||
| 종목당 투자금 | 총자산 / 최대종목수 |
|
||||
| 최대 종목 수 | 10개 |
|
||||
| 현금 비중 | 30% |
|
||||
| 리밸런싱 | 분기별 |
|
||||
|
||||
## 6단계: 청산 규칙
|
||||
|
||||
| 유형 | 기준 |
|
||||
|------|------|
|
||||
| 손절 | -3% |
|
||||
| 1차 익절 | +5% (50% 물량) |
|
||||
| 2차 익절 | +10% (나머지) |
|
||||
| 트레일링 스톱 | +5% 도달 시 손절선을 본전으로 상향 |
|
||||
|
||||
## 리스크 관리
|
||||
|
||||
- 종목당 최대 손실: -3%
|
||||
- 포트폴리오 최대 손실: -10%
|
||||
- 승률 목표: 60%+
|
||||
- 손익비: 1:1.5+
|
||||
61
backend/app/agents/skills/loader.py
Normal file
61
backend/app/agents/skills/loader.py
Normal file
@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from app.agents.skills.types import Skill, SkillMetadata, SkillSource
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
"""SKILL.md 파일 파싱 및 로딩."""
|
||||
|
||||
@staticmethod
|
||||
def parse_skill_file(content: str, path: str, source: SkillSource) -> Skill:
|
||||
"""YAML frontmatter + markdown body를 파싱하여 Skill 객체 반환."""
|
||||
frontmatter, instructions = SkillLoader._split_frontmatter(content)
|
||||
meta = yaml.safe_load(frontmatter) or {}
|
||||
return Skill(
|
||||
name=meta.get("name", ""),
|
||||
description=meta.get("description", ""),
|
||||
path=path,
|
||||
source=source,
|
||||
instructions=instructions.strip(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_path(path: str | Path, source: SkillSource) -> Skill:
|
||||
"""파일 경로에서 Skill을 로드."""
|
||||
p = Path(path)
|
||||
content = p.read_text(encoding="utf-8")
|
||||
return SkillLoader.parse_skill_file(content, str(p), source)
|
||||
|
||||
@staticmethod
|
||||
def extract_metadata(path: str | Path, source: SkillSource) -> SkillMetadata:
|
||||
"""frontmatter만 파싱하여 경량 메타데이터 반환."""
|
||||
p = Path(path)
|
||||
content = p.read_text(encoding="utf-8")
|
||||
frontmatter, _ = SkillLoader._split_frontmatter(content)
|
||||
meta = yaml.safe_load(frontmatter) or {}
|
||||
return SkillMetadata(
|
||||
name=meta.get("name", ""),
|
||||
description=meta.get("description", ""),
|
||||
path=str(p),
|
||||
source=source,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _split_frontmatter(content: str) -> tuple[str, str]:
|
||||
"""--- 마커 사이의 YAML frontmatter와 나머지 body를 분리."""
|
||||
stripped = content.strip()
|
||||
if not stripped.startswith("---"):
|
||||
return "", content
|
||||
|
||||
# 두 번째 --- 찾기
|
||||
end_idx = stripped.find("---", 3)
|
||||
if end_idx == -1:
|
||||
return "", content
|
||||
|
||||
frontmatter = stripped[3:end_idx].strip()
|
||||
body = stripped[end_idx + 3 :]
|
||||
return frontmatter, body
|
||||
52
backend/app/agents/skills/registry.py
Normal file
52
backend/app/agents/skills/registry.py
Normal file
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.agents.skills.loader import SkillLoader
|
||||
from app.agents.skills.types import Skill, SkillMetadata
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""builtin 스킬 탐색 및 캐싱."""
|
||||
|
||||
BUILTIN_DIR: Path = Path(__file__).parent / "builtin"
|
||||
_cache: dict[str, SkillMetadata] = {}
|
||||
|
||||
@classmethod
|
||||
def discover(cls) -> list[SkillMetadata]:
|
||||
"""BUILTIN_DIR 하위의 */SKILL.md 패턴을 스캔하여 메타데이터 캐싱."""
|
||||
cls._cache.clear()
|
||||
for skill_file in sorted(cls.BUILTIN_DIR.glob("*/SKILL.md")):
|
||||
meta = SkillLoader.extract_metadata(skill_file, "builtin")
|
||||
if meta.name:
|
||||
cls._cache[meta.name] = meta
|
||||
return list(cls._cache.values())
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Skill | None:
|
||||
"""캐시에서 이름으로 조회 후 전체 Skill(instructions 포함) 로드."""
|
||||
if not cls._cache:
|
||||
cls.discover()
|
||||
meta = cls._cache.get(name)
|
||||
if meta is None:
|
||||
return None
|
||||
return SkillLoader.load_from_path(meta.path, meta.source)
|
||||
|
||||
@classmethod
|
||||
def list_skills(cls) -> list[SkillMetadata]:
|
||||
return list(cls._cache.values())
|
||||
|
||||
@classmethod
|
||||
def build_skills_section(cls) -> str:
|
||||
"""시스템 프롬프트에 삽입할 스킬 목록 텍스트 생성."""
|
||||
skills = cls.list_skills()
|
||||
if not skills:
|
||||
return ""
|
||||
lines = ["## 사용 가능한 스킬\n"]
|
||||
for s in skills:
|
||||
lines.append(f"- **{s.name}**: {s.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
cls._cache.clear()
|
||||
35
backend/app/agents/skills/tool.py
Normal file
35
backend/app/agents/skills/tool.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.skills.registry import SkillRegistry
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
|
||||
def create_skill_tool() -> RegisteredTool:
|
||||
"""use_skill 도구를 생성하여 반환."""
|
||||
|
||||
async def execute(args: dict[str, Any]) -> ToolResult:
|
||||
skill_name: str = args["skill_name"]
|
||||
skill = SkillRegistry.get(skill_name)
|
||||
if skill is None:
|
||||
available = SkillRegistry.list_skills()
|
||||
names = [s.name for s in available]
|
||||
return ToolResult(
|
||||
data=(
|
||||
f"스킬 '{skill_name}'을(를) 찾을 수 없습니다. "
|
||||
f"사용 가능한 스킬: {', '.join(names) or '없음'}"
|
||||
),
|
||||
)
|
||||
return ToolResult(data=skill.instructions)
|
||||
|
||||
return RegisteredTool(
|
||||
name="use_skill",
|
||||
description=(
|
||||
"스킬을 로드하여 전문 분석 워크플로우를 실행합니다. "
|
||||
"사용 가능한 스킬 목록에서 선택하세요."
|
||||
),
|
||||
compact_description="스킬 실행",
|
||||
execute=execute,
|
||||
concurrency_safe=True,
|
||||
)
|
||||
23
backend/app/agents/skills/types.py
Normal file
23
backend/app/agents/skills/types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
SkillSource = Literal["builtin", "user", "project"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
source: SkillSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
source: SkillSource
|
||||
instructions: str
|
||||
6
backend/app/agents/tools/__init__.py
Normal file
6
backend/app/agents/tools/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.tools.types import ToolResult, RegisteredTool
|
||||
from app.agents.tools.registry import ToolRegistry
|
||||
|
||||
__all__ = ["ToolResult", "RegisteredTool", "ToolRegistry"]
|
||||
0
backend/app/agents/tools/filesystem/__init__.py
Normal file
0
backend/app/agents/tools/filesystem/__init__.py
Normal file
69
backend/app/agents/tools/filesystem/edit_file.py
Normal file
69
backend/app/agents/tools/filesystem/edit_file.py
Normal file
@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
def _validate_path(raw_path: str) -> Path:
|
||||
"""경로를 검증하고 절대 경로로 변환합니다."""
|
||||
resolved = (PROJECT_ROOT / raw_path).resolve()
|
||||
if not resolved.is_relative_to(PROJECT_ROOT):
|
||||
raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _edit_sync(path: Path, old_string: str, new_string: str) -> str:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
raise ValueError("old_string을 파일에서 찾을 수 없습니다")
|
||||
if count > 1:
|
||||
raise ValueError(f"old_string이 {count}회 발견됨 — 고유한 문자열을 제공해 주세요")
|
||||
updated = content.replace(old_string, new_string, 1)
|
||||
path.write_text(updated, encoding="utf-8")
|
||||
return "파일 수정 완료"
|
||||
|
||||
|
||||
def create_edit_file_tool() -> RegisteredTool:
|
||||
"""파일 수정 도구를 생성합니다."""
|
||||
|
||||
async def execute(params: dict[str, Any]) -> ToolResult:
|
||||
raw_path: str = params["path"]
|
||||
old_string: str = params["old_string"]
|
||||
new_string: str = params["new_string"]
|
||||
|
||||
try:
|
||||
path = _validate_path(raw_path)
|
||||
except ValueError as e:
|
||||
return ToolResult(data=f"오류: {e}")
|
||||
|
||||
if not path.exists():
|
||||
return ToolResult(data=f"파일 없음: {raw_path}")
|
||||
if not path.is_file():
|
||||
return ToolResult(data=f"파일이 아님: {raw_path}")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
result = await loop.run_in_executor(
|
||||
None, partial(_edit_sync, path, old_string, new_string)
|
||||
)
|
||||
except ValueError as e:
|
||||
return ToolResult(data=f"수정 실패: {e}")
|
||||
except Exception as e:
|
||||
return ToolResult(data=f"파일 수정 실패: {e}")
|
||||
|
||||
return ToolResult(data=result)
|
||||
|
||||
return RegisteredTool(
|
||||
name="edit_file",
|
||||
description="파일의 특정 부분을 수정합니다. 정확한 old_string을 제공해야 합니다.",
|
||||
compact_description="파일 수정",
|
||||
concurrency_safe=False,
|
||||
execute=execute,
|
||||
)
|
||||
65
backend/app/agents/tools/filesystem/read_file.py
Normal file
65
backend/app/agents/tools/filesystem/read_file.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
def _validate_path(raw_path: str) -> Path:
|
||||
"""경로를 검증하고 절대 경로로 변환합니다.
|
||||
|
||||
프로젝트 루트 외부 접근 및 경로 순회 공격을 차단합니다.
|
||||
"""
|
||||
resolved = (PROJECT_ROOT / raw_path).resolve()
|
||||
if not resolved.is_relative_to(PROJECT_ROOT):
|
||||
raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _read_sync(path: Path, offset: int, limit: int) -> str:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
lines = text.splitlines(keepends=True)
|
||||
selected = lines[offset : offset + limit]
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
def create_read_file_tool() -> RegisteredTool:
|
||||
"""파일 읽기 도구를 생성합니다."""
|
||||
|
||||
async def execute(params: dict[str, Any]) -> ToolResult:
|
||||
raw_path: str = params["path"]
|
||||
offset: int = params.get("offset", 0)
|
||||
limit: int = params.get("limit", 2000)
|
||||
|
||||
try:
|
||||
path = _validate_path(raw_path)
|
||||
except ValueError as e:
|
||||
return ToolResult(data=f"오류: {e}")
|
||||
|
||||
if not path.exists():
|
||||
return ToolResult(data=f"파일 없음: {raw_path}")
|
||||
if not path.is_file():
|
||||
return ToolResult(data=f"파일이 아님: {raw_path}")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
content = await loop.run_in_executor(
|
||||
None, partial(_read_sync, path, offset, limit)
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(data=f"파일 읽기 실패: {e}")
|
||||
|
||||
return ToolResult(data=content)
|
||||
|
||||
return RegisteredTool(
|
||||
name="read_file",
|
||||
description="파일 내용을 읽습니다. 프로젝트 내 파일만 접근 가능합니다.",
|
||||
compact_description="파일 읽기",
|
||||
concurrency_safe=True,
|
||||
execute=execute,
|
||||
)
|
||||
52
backend/app/agents/tools/filesystem/write_file.py
Normal file
52
backend/app/agents/tools/filesystem/write_file.py
Normal file
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
def _validate_path(raw_path: str) -> Path:
|
||||
"""경로를 검증하고 절대 경로로 변환합니다."""
|
||||
resolved = (PROJECT_ROOT / raw_path).resolve()
|
||||
if not resolved.is_relative_to(PROJECT_ROOT):
|
||||
raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _write_sync(path: Path, content: str) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def create_write_file_tool() -> RegisteredTool:
|
||||
"""파일 쓰기 도구를 생성합니다."""
|
||||
|
||||
async def execute(params: dict[str, Any]) -> ToolResult:
|
||||
raw_path: str = params["path"]
|
||||
content: str = params["content"]
|
||||
|
||||
try:
|
||||
path = _validate_path(raw_path)
|
||||
except ValueError as e:
|
||||
return ToolResult(data=f"오류: {e}")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(None, partial(_write_sync, path, content))
|
||||
except Exception as e:
|
||||
return ToolResult(data=f"파일 쓰기 실패: {e}")
|
||||
|
||||
return ToolResult(data=f"파일 작성 완료: {raw_path}")
|
||||
|
||||
return RegisteredTool(
|
||||
name="write_file",
|
||||
description="파일에 내용을 씁니다. 프로젝트 내 파일만 수정 가능합니다.",
|
||||
compact_description="파일 쓰기",
|
||||
concurrency_safe=False,
|
||||
execute=execute,
|
||||
)
|
||||
0
backend/app/agents/tools/finance/__init__.py
Normal file
0
backend/app/agents/tools/finance/__init__.py
Normal file
29
backend/app/agents/tools/finance/get_financials.py
Normal file
29
backend/app/agents/tools/finance/get_financials.py
Normal file
@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.tools.meta_tool import MetaTool, SubTool
|
||||
from app.agents.tools.finance.sub_tools import (
|
||||
get_financial_statements,
|
||||
get_market_metrics,
|
||||
)
|
||||
|
||||
|
||||
class FinancialsTool(MetaTool):
|
||||
"""재무제표, 재무비율, 밸류에이션 지표를 조회하는 메타 도구."""
|
||||
|
||||
NAME = "get_financials"
|
||||
DESCRIPTION = "재무제표, 재무비율, 밸류에이션 지표 조회"
|
||||
|
||||
@property
|
||||
def sub_tools(self) -> list[SubTool]:
|
||||
return [
|
||||
SubTool(
|
||||
name="get_financial_statements",
|
||||
description="종목의 BPS, PER, PBR, EPS 등 재무제표 데이터 조회 (ticker, year)",
|
||||
handler=get_financial_statements,
|
||||
),
|
||||
SubTool(
|
||||
name="get_market_metrics",
|
||||
description="종목의 시가총액, PER, PBR 등 시장 밸류에이션 지표 조회 (ticker, date)",
|
||||
handler=get_market_metrics,
|
||||
),
|
||||
]
|
||||
26
backend/app/agents/tools/finance/get_market_data.py
Normal file
26
backend/app/agents/tools/finance/get_market_data.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.tools.meta_tool import MetaTool, SubTool
|
||||
from app.agents.tools.finance.sub_tools import get_stock_price, get_market_index
|
||||
|
||||
|
||||
class MarketDataTool(MetaTool):
|
||||
"""주가, 거래량, 시장 지수 등 시장 데이터를 조회하는 메타 도구."""
|
||||
|
||||
NAME = "get_market_data"
|
||||
DESCRIPTION = "주가, 거래량, 시장 지수 등 시장 데이터 조회"
|
||||
|
||||
@property
|
||||
def sub_tools(self) -> list[SubTool]:
|
||||
return [
|
||||
SubTool(
|
||||
name="get_stock_price",
|
||||
description="종목 OHLCV(시가/고가/저가/종가/거래량) 데이터 조회 (ticker, start_date, end_date)",
|
||||
handler=get_stock_price,
|
||||
),
|
||||
SubTool(
|
||||
name="get_market_index",
|
||||
description="KOSPI/KOSDAQ 지수 OHLCV 데이터 조회 (index_code, start_date, end_date)",
|
||||
handler=get_market_index,
|
||||
),
|
||||
]
|
||||
124
backend/app/agents/tools/finance/sub_tools.py
Normal file
124
backend/app/agents/tools/finance/sub_tools.py
Normal file
@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_stock_price(params: dict[str, Any]) -> ToolResult:
|
||||
"""종목 OHLCV 데이터를 pykrx로 조회한다."""
|
||||
from pykrx import stock as pykrx_stock
|
||||
|
||||
ticker = params.get("ticker", "")
|
||||
start_date = params.get("start_date", "")
|
||||
end_date = params.get("end_date", "")
|
||||
|
||||
if not all([ticker, start_date, end_date]):
|
||||
return ToolResult(data="Error: ticker, start_date, end_date 모두 필수입니다.")
|
||||
|
||||
try:
|
||||
df = pykrx_stock.get_market_ohlcv_by_date(start_date, end_date, ticker)
|
||||
if df.empty:
|
||||
return ToolResult(data=f"데이터 없음: {ticker} ({start_date}~{end_date})")
|
||||
|
||||
records = df.reset_index().to_dict(orient="records")
|
||||
return ToolResult(
|
||||
data={"ticker": ticker, "period": f"{start_date}~{end_date}", "ohlcv": records}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("get_stock_price failed for %s", ticker)
|
||||
return ToolResult(data=f"Error: {e}")
|
||||
|
||||
|
||||
async def get_financial_statements(params: dict[str, Any]) -> ToolResult:
|
||||
"""종목의 재무제표 데이터를 pykrx로 조회한다."""
|
||||
from pykrx import stock as pykrx_stock
|
||||
|
||||
ticker = params.get("ticker", "")
|
||||
year = params.get("year", "")
|
||||
|
||||
if not all([ticker, year]):
|
||||
return ToolResult(data="Error: ticker, year 모두 필수입니다.")
|
||||
|
||||
try:
|
||||
# pykrx 기본 재무 데이터: BPS, PER, PBR, EPS, DIV, DPS
|
||||
start_date = f"{year}0101"
|
||||
end_date = f"{year}1231"
|
||||
df = pykrx_stock.get_market_fundamental_by_date(start_date, end_date, ticker)
|
||||
|
||||
if df.empty:
|
||||
return ToolResult(data=f"재무 데이터 없음: {ticker} ({year})")
|
||||
|
||||
# 연말 기준 최신 데이터
|
||||
latest = df.iloc[-1].to_dict()
|
||||
return ToolResult(
|
||||
data={"ticker": ticker, "year": year, "fundamentals": latest}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("get_financial_statements failed for %s", ticker)
|
||||
return ToolResult(data=f"Error: {e}")
|
||||
|
||||
|
||||
async def get_market_metrics(params: dict[str, Any]) -> ToolResult:
|
||||
"""종목의 시가총액, PER, PBR 등 시장 지표를 조회한다."""
|
||||
from pykrx import stock as pykrx_stock
|
||||
|
||||
ticker = params.get("ticker", "")
|
||||
date = params.get("date", "")
|
||||
|
||||
if not all([ticker, date]):
|
||||
return ToolResult(data="Error: ticker, date 모두 필수입니다.")
|
||||
|
||||
try:
|
||||
# 시가총액
|
||||
cap_df = pykrx_stock.get_market_cap_by_date(date, date, ticker)
|
||||
# 밸류에이션
|
||||
fund_df = pykrx_stock.get_market_fundamental_by_date(date, date, ticker)
|
||||
|
||||
result: dict[str, Any] = {"ticker": ticker, "date": date}
|
||||
|
||||
if not cap_df.empty:
|
||||
result["market_cap"] = cap_df.iloc[-1].to_dict()
|
||||
if not fund_df.empty:
|
||||
result["valuation"] = fund_df.iloc[-1].to_dict()
|
||||
|
||||
if len(result) == 2:
|
||||
return ToolResult(data=f"시장 지표 없음: {ticker} ({date})")
|
||||
|
||||
return ToolResult(data=result)
|
||||
except Exception as e:
|
||||
logger.exception("get_market_metrics failed for %s", ticker)
|
||||
return ToolResult(data=f"Error: {e}")
|
||||
|
||||
|
||||
async def get_market_index(params: dict[str, Any]) -> ToolResult:
|
||||
"""KOSPI/KOSDAQ 지수 데이터를 조회한다."""
|
||||
from pykrx import stock as pykrx_stock
|
||||
|
||||
index_code = params.get("index_code", "1001") # 기본: 코스피
|
||||
start_date = params.get("start_date", "")
|
||||
end_date = params.get("end_date", "")
|
||||
|
||||
if not all([start_date, end_date]):
|
||||
return ToolResult(data="Error: start_date, end_date 모두 필수입니다.")
|
||||
|
||||
try:
|
||||
df = pykrx_stock.get_index_ohlcv_by_date(start_date, end_date, index_code)
|
||||
|
||||
if df.empty:
|
||||
return ToolResult(data=f"지수 데이터 없음: {index_code} ({start_date}~{end_date})")
|
||||
|
||||
records = df.reset_index().to_dict(orient="records")
|
||||
return ToolResult(
|
||||
data={
|
||||
"index_code": index_code,
|
||||
"period": f"{start_date}~{end_date}",
|
||||
"ohlcv": records,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("get_market_index failed for %s", index_code)
|
||||
return ToolResult(data=f"Error: {e}")
|
||||
80
backend/app/agents/tools/meta_tool.py
Normal file
80
backend/app/agents/tools/meta_tool.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from app.agents.tools.types import ToolResult
|
||||
from app.services.llm import ModelTier, get_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTool:
|
||||
name: str
|
||||
description: str
|
||||
handler: Callable[[dict[str, Any]], Awaitable[ToolResult]]
|
||||
|
||||
|
||||
class MetaTool(ABC):
|
||||
"""NL 쿼리를 LLM으로 라우팅해 적절한 sub-tool을 실행하는 추상 베이스 클래스."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sub_tools(self) -> list[SubTool]: ...
|
||||
|
||||
async def route(self, query: str) -> SubTool:
|
||||
"""LLM을 사용해 쿼리에 가장 적합한 sub-tool을 선택한다."""
|
||||
tool_list = "\n".join(
|
||||
f"- {st.name}: {st.description}" for st in self.sub_tools
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"You are a tool router. Given a user query and available tools, "
|
||||
"select the best tool.\n\n"
|
||||
f"Available tools:\n{tool_list}\n\n"
|
||||
f"User query: {query}\n\n"
|
||||
'Respond with ONLY a JSON object: {{"tool": "tool_name"}}'
|
||||
)
|
||||
|
||||
llm = get_chat_model(tier=ModelTier.FAST, temperature=0.0)
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
content = response.content
|
||||
if isinstance(content, list):
|
||||
content = content[0] if content else ""
|
||||
content = str(content).strip()
|
||||
|
||||
if content.startswith("```"):
|
||||
content = content.split("\n", 1)[-1]
|
||||
content = content.rsplit("```", 1)[0]
|
||||
|
||||
parsed = json.loads(content)
|
||||
tool_name = parsed["tool"]
|
||||
|
||||
tool_map = {st.name: st for st in self.sub_tools}
|
||||
sub_tool = tool_map.get(tool_name)
|
||||
if sub_tool is None:
|
||||
raise ValueError(
|
||||
f"LLM selected unknown sub-tool '{tool_name}'. "
|
||||
f"Available: {list(tool_map.keys())}"
|
||||
)
|
||||
|
||||
return sub_tool
|
||||
|
||||
async def execute(self, params: dict[str, Any]) -> ToolResult:
|
||||
"""NL 쿼리를 받아 라우팅 후 sub-tool을 실행한다."""
|
||||
query = params.get("query", "")
|
||||
if not query:
|
||||
return ToolResult(data="Error: 'query' parameter is required.")
|
||||
|
||||
try:
|
||||
sub_tool = await self.route(query)
|
||||
logger.info("MetaTool routed to %s", sub_tool.name)
|
||||
return await sub_tool.handler(params)
|
||||
except Exception as e:
|
||||
logger.exception("MetaTool execution failed")
|
||||
return ToolResult(data=f"Error: {e}")
|
||||
56
backend/app/agents/tools/registry.py
Normal file
56
backend/app/agents/tools/registry.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.tools.types import RegisteredTool
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""등록된 도구를 관리하는 레지스트리."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
|
||||
def register(self, tool: RegisteredTool) -> None:
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> RegisteredTool | None:
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> list[RegisteredTool]:
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_concurrency_map(self) -> dict[str, bool]:
|
||||
return {t.name: t.concurrency_safe for t in self._tools.values()}
|
||||
|
||||
def build_compact_descriptions(self) -> str:
|
||||
lines = [
|
||||
f"- {t.name}: {t.compact_description}" for t in self._tools.values()
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
@classmethod
|
||||
def auto_register(cls) -> ToolRegistry:
|
||||
"""설정에 따라 사용 가능한 도구를 자동 등록한다."""
|
||||
from app.agents.tools.finance.get_market_data import MarketDataTool
|
||||
from app.agents.tools.finance.get_financials import FinancialsTool
|
||||
|
||||
registry = cls()
|
||||
|
||||
market = MarketDataTool()
|
||||
registry.register(RegisteredTool(
|
||||
name=MarketDataTool.NAME,
|
||||
description=MarketDataTool.DESCRIPTION,
|
||||
compact_description=MarketDataTool.DESCRIPTION,
|
||||
concurrency_safe=True,
|
||||
execute=market.execute,
|
||||
))
|
||||
|
||||
financials = FinancialsTool()
|
||||
registry.register(RegisteredTool(
|
||||
name=FinancialsTool.NAME,
|
||||
description=FinancialsTool.DESCRIPTION,
|
||||
compact_description=FinancialsTool.DESCRIPTION,
|
||||
concurrency_safe=True,
|
||||
execute=financials.execute,
|
||||
))
|
||||
|
||||
return registry
|
||||
0
backend/app/agents/tools/search/__init__.py
Normal file
0
backend/app/agents/tools/search/__init__.py
Normal file
40
backend/app/agents/tools/search/news_search.py
Normal file
40
backend/app/agents/tools/search/news_search.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
|
||||
def create_news_search_tool() -> RegisteredTool:
|
||||
"""금융 뉴스 검색 도구를 생성합니다."""
|
||||
|
||||
async def execute(params: dict[str, Any]) -> ToolResult:
|
||||
query: str = params["query"]
|
||||
num_results: int = params.get("num_results", 5)
|
||||
|
||||
# TODO: settings에서 API 키 로드 후 뉴스 검색 API 연동
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
api_key = getattr(settings, "news_search_api_key", None)
|
||||
except Exception:
|
||||
api_key = None
|
||||
|
||||
if not api_key:
|
||||
return ToolResult(
|
||||
data=f"검색 API 미설정: '{query}' (요청 결과 수: {num_results}). "
|
||||
"news_search_api_key 환경변수를 설정해 주세요.",
|
||||
)
|
||||
|
||||
# 향후 실제 뉴스 검색 API 호출 구현
|
||||
# async with httpx.AsyncClient() as client:
|
||||
# response = await client.get(...)
|
||||
return ToolResult(data=f"뉴스 검색 결과 없음: '{query}'")
|
||||
|
||||
return RegisteredTool(
|
||||
name="news_search",
|
||||
description="금융 뉴스 및 시장 동향을 검색합니다",
|
||||
compact_description="뉴스 검색",
|
||||
concurrency_safe=True,
|
||||
execute=execute,
|
||||
)
|
||||
40
backend/app/agents/tools/search/web_search.py
Normal file
40
backend/app/agents/tools/search/web_search.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.tools.types import RegisteredTool, ToolResult
|
||||
|
||||
|
||||
def create_web_search_tool() -> RegisteredTool:
|
||||
"""웹 검색 도구를 생성합니다."""
|
||||
|
||||
async def execute(params: dict[str, Any]) -> ToolResult:
|
||||
query: str = params["query"]
|
||||
num_results: int = params.get("num_results", 5)
|
||||
|
||||
# TODO: settings에서 API 키 로드 후 Tavily/Exa/SerpAPI 연동
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
api_key = getattr(settings, "web_search_api_key", None)
|
||||
except Exception:
|
||||
api_key = None
|
||||
|
||||
if not api_key:
|
||||
return ToolResult(
|
||||
data=f"검색 API 미설정: '{query}' (요청 결과 수: {num_results}). "
|
||||
"web_search_api_key 환경변수를 설정해 주세요.",
|
||||
)
|
||||
|
||||
# 향후 실제 검색 API 호출 구현
|
||||
# async with httpx.AsyncClient() as client:
|
||||
# response = await client.get(...)
|
||||
return ToolResult(data=f"검색 결과 없음: '{query}'")
|
||||
|
||||
return RegisteredTool(
|
||||
name="web_search",
|
||||
description="웹 검색을 수행하여 최신 정보를 조회합니다",
|
||||
compact_description="웹 검색",
|
||||
concurrency_safe=True,
|
||||
execute=execute,
|
||||
)
|
||||
25
backend/app/agents/tools/types.py
Normal file
25
backend/app/agents/tools/types.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
data: Any
|
||||
source_urls: list[str] = field(default_factory=list)
|
||||
|
||||
def to_str(self) -> str:
|
||||
if isinstance(self.data, str):
|
||||
return self.data
|
||||
return json.dumps(self.data, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredTool:
|
||||
name: str
|
||||
description: str
|
||||
compact_description: str
|
||||
concurrency_safe: bool
|
||||
execute: Callable[[dict[str, Any]], Awaitable[ToolResult]]
|
||||
141
backend/app/api/agents.py
Normal file
141
backend/app/api/agents.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
Agent API endpoints — 자연어 쿼리 기반 투자 분석 에이전트.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import CurrentUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AgentQueryRequest(BaseModel):
|
||||
query: str
|
||||
model_tier: str = "strong"
|
||||
|
||||
|
||||
class ToolCallLog(BaseModel):
|
||||
tool_name: str
|
||||
params: dict[str, Any]
|
||||
result: str
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AgentQueryResponse(BaseModel):
|
||||
response: str
|
||||
tool_calls: list[ToolCallLog]
|
||||
iterations: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_model_tier(raw: str) -> "ModelTier":
|
||||
"""model_tier 문자열을 ModelTier enum으로 변환."""
|
||||
from app.agents.core.agent import ModelTier
|
||||
|
||||
mapping = {
|
||||
"strong": ModelTier.STRONG,
|
||||
"fast": ModelTier.FAST,
|
||||
}
|
||||
tier = mapping.get(raw.lower())
|
||||
if tier is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Invalid model_tier '{raw}'. Must be one of: {list(mapping.keys())}",
|
||||
)
|
||||
return tier
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/query", response_model=AgentQueryResponse)
|
||||
async def agent_query(body: AgentQueryRequest, user: CurrentUser):
|
||||
"""동기 응답 — 에이전트 실행 후 최종 결과 반환."""
|
||||
from app.agents.core.agent import Agent, AgentConfig
|
||||
|
||||
model_tier = _parse_model_tier(body.model_tier)
|
||||
config = AgentConfig(model_tier=model_tier)
|
||||
agent = Agent(config=config)
|
||||
|
||||
tool_calls: list[ToolCallLog] = []
|
||||
response_text = ""
|
||||
iterations = 0
|
||||
|
||||
try:
|
||||
async for event in agent.run(body.query):
|
||||
if event.type == "tool_end":
|
||||
tool_calls.append(
|
||||
ToolCallLog(
|
||||
tool_name=event.data.get("tool_name", ""),
|
||||
params=event.data.get("params", {}),
|
||||
result=event.data.get("result", ""),
|
||||
error=event.data.get("error"),
|
||||
)
|
||||
)
|
||||
elif event.type == "response":
|
||||
response_text = event.data.get("text", "")
|
||||
elif event.type == "done":
|
||||
iterations = event.data.get("iterations", 0)
|
||||
except Exception:
|
||||
logger.exception("Agent query failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="에이전트 실행 중 오류가 발생했습니다.",
|
||||
)
|
||||
|
||||
return AgentQueryResponse(
|
||||
response=response_text,
|
||||
tool_calls=tool_calls,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def agent_stream(body: AgentQueryRequest, user: CurrentUser):
|
||||
"""SSE 스트리밍 — 도구 실행 과정을 실시간으로 전달."""
|
||||
from app.agents.core.agent import Agent, AgentConfig
|
||||
|
||||
model_tier = _parse_model_tier(body.model_tier)
|
||||
config = AgentConfig(model_tier=model_tier)
|
||||
agent = Agent(config=config)
|
||||
|
||||
async def _event_generator():
|
||||
try:
|
||||
async for event in agent.run(body.query):
|
||||
payload = json.dumps(
|
||||
{"type": event.type, "data": event.data},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
yield f"data: {payload}\n\n"
|
||||
except Exception:
|
||||
logger.exception("Agent stream failed")
|
||||
error_payload = json.dumps(
|
||||
{"type": "error", "data": {"message": "에이전트 실행 중 오류가 발생했습니다."}},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
yield f"data: {error_payload}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
92
backend/app/services/llm.py
Normal file
92
backend/app/services/llm.py
Normal file
@ -0,0 +1,92 @@
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
class LLMProvider(str, Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class ModelTier(str, Enum):
|
||||
FAST = "fast"
|
||||
STRONG = "strong"
|
||||
|
||||
|
||||
_DEFAULTS = {
|
||||
LLMProvider.ANTHROPIC: {
|
||||
ModelTier.FAST: "claude-haiku-4-5-20251001",
|
||||
ModelTier.STRONG: "claude-sonnet-4-6",
|
||||
},
|
||||
LLMProvider.OPENAI: {
|
||||
ModelTier.FAST: "gpt-4.1-mini",
|
||||
ModelTier.STRONG: "gpt-4.1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def detect_provider() -> LLMProvider:
|
||||
settings = get_settings()
|
||||
if settings.anthropic_api_key:
|
||||
return LLMProvider.ANTHROPIC
|
||||
if settings.openai_api_key:
|
||||
return LLMProvider.OPENAI
|
||||
raise RuntimeError(
|
||||
"No LLM API key configured. Set ANTHROPIC_API_KEY or OPENAI_API_KEY."
|
||||
)
|
||||
|
||||
|
||||
def _resolve_model_name(provider: LLMProvider, tier: ModelTier) -> str:
|
||||
settings = get_settings()
|
||||
if tier == ModelTier.FAST and settings.llm_fast_model:
|
||||
return settings.llm_fast_model
|
||||
if tier == ModelTier.STRONG and settings.llm_strong_model:
|
||||
return settings.llm_strong_model
|
||||
return _DEFAULTS[provider][tier]
|
||||
|
||||
|
||||
def _build_chat_model(
|
||||
provider: LLMProvider, model_name: str, temperature: float
|
||||
) -> BaseChatModel:
|
||||
settings = get_settings()
|
||||
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(
|
||||
model=model_name,
|
||||
api_key=settings.anthropic_api_key,
|
||||
temperature=temperature,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model_name,
|
||||
api_key=settings.openai_api_key,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_model(
|
||||
tier: ModelTier = ModelTier.STRONG,
|
||||
temperature: float = 0.0,
|
||||
provider: LLMProvider | None = None,
|
||||
) -> BaseChatModel:
|
||||
resolved_provider = provider or detect_provider()
|
||||
model_name = _resolve_model_name(resolved_provider, tier)
|
||||
return _build_chat_model(resolved_provider, model_name, temperature)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_fast_model() -> BaseChatModel:
|
||||
return get_chat_model(tier=ModelTier.FAST)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_strong_model() -> BaseChatModel:
|
||||
return get_chat_model(tier=ModelTier.STRONG)
|
||||
379
backend/tests/unit/test_agent_core.py
Normal file
379
backend/tests/unit/test_agent_core.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""Tests for agent core: scratchpad, context compactor, tool executor, and agent."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
from app.agents.tools.types import ToolResult, RegisteredTool
|
||||
from app.agents.tools.registry import ToolRegistry
|
||||
from app.agents.core.scratchpad import Scratchpad
|
||||
from app.agents.core.compact import ContextCompactor
|
||||
from app.agents.core.tool_executor import ToolExecutor, ToolCall, ToolEvent
|
||||
from app.agents.core.agent import Agent, AgentEvent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scratchpad
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScratchpad:
|
||||
def test_add_and_total_tokens(self):
|
||||
sp = Scratchpad()
|
||||
sp.add(tool_name="search", params={"q": "test"}, result="found 3 results")
|
||||
sp.add(tool_name="calc", params={"expr": "2+2"}, result="4")
|
||||
assert sp.total_tokens() > 0
|
||||
|
||||
def test_format_history(self):
|
||||
sp = Scratchpad()
|
||||
sp.add(tool_name="search", params={"q": "python"}, result="python docs")
|
||||
history = sp.format_history()
|
||||
assert "search" in history
|
||||
assert "python" in history
|
||||
|
||||
def test_format_history_empty(self):
|
||||
sp = Scratchpad()
|
||||
history = sp.format_history()
|
||||
assert isinstance(history, str)
|
||||
|
||||
def test_clear(self):
|
||||
sp = Scratchpad()
|
||||
sp.add(tool_name="tool1", params={}, result="result_one")
|
||||
sp.add(tool_name="tool2", params={}, result="result_two")
|
||||
assert sp.total_tokens() > 0
|
||||
sp.clear()
|
||||
assert sp.total_tokens() == 0
|
||||
|
||||
def test_multiple_entries_tracked(self):
|
||||
sp = Scratchpad()
|
||||
for i in range(5):
|
||||
sp.add(tool_name=f"tool_{i}", params={"i": i}, result=f"result_{i}")
|
||||
history = sp.format_history()
|
||||
for i in range(5):
|
||||
assert f"tool_{i}" in history
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextCompactor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextCompactor:
|
||||
def test_should_compact_false_below_threshold(self):
|
||||
compactor = ContextCompactor()
|
||||
messages = [{"role": "user", "content": "short message"}]
|
||||
assert compactor.should_compact(messages, threshold=10000) is False
|
||||
|
||||
def test_should_compact_true_above_threshold(self):
|
||||
compactor = ContextCompactor()
|
||||
long_content = "x" * 5000
|
||||
messages = [
|
||||
{"role": "user", "content": long_content},
|
||||
{"role": "assistant", "content": long_content},
|
||||
{"role": "user", "content": long_content},
|
||||
]
|
||||
assert compactor.should_compact(messages, threshold=100) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.compact.get_chat_model")
|
||||
async def test_compact_calls_llm(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Summary: user asked about Python, assistant explained basics."
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
compactor = ContextCompactor()
|
||||
messages = [
|
||||
{"role": "user", "content": "Tell me about Python"},
|
||||
{"role": "assistant", "content": "Python is a programming language..."},
|
||||
{"role": "user", "content": "What about async?"},
|
||||
{"role": "assistant", "content": "Async in Python uses asyncio..."},
|
||||
]
|
||||
result = await compactor.compact(messages, system_prompt="You are helpful.")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
mock_llm.ainvoke.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.compact.get_chat_model")
|
||||
async def test_compact_preserves_system_context(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Summarized conversation."
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
compactor = ContextCompactor()
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
result = await compactor.compact(messages, system_prompt="System prompt here")
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolExecutor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_registered_tool(
|
||||
name: str, concurrency_safe: bool = True, result_data: str = "ok"
|
||||
) -> RegisteredTool:
|
||||
return RegisteredTool(
|
||||
name=name,
|
||||
description=f"Desc {name}",
|
||||
compact_description=f"compact {name}",
|
||||
concurrency_safe=concurrency_safe,
|
||||
execute=AsyncMock(return_value=ToolResult(data=result_data)),
|
||||
)
|
||||
|
||||
|
||||
class TestToolExecutor:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_safe_tools(self):
|
||||
tool_a = _make_registered_tool("tool_a", concurrency_safe=True, result_data="A")
|
||||
tool_b = _make_registered_tool("tool_b", concurrency_safe=True, result_data="B")
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register(tool_a)
|
||||
registry.register(tool_b)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
calls = [
|
||||
ToolCall(id="c1", name="tool_a", params={"q": "hello"}),
|
||||
ToolCall(id="c2", name="tool_b", params={"q": "world"}),
|
||||
]
|
||||
|
||||
events = []
|
||||
async for event in executor.execute_tool_calls(calls):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 2
|
||||
names = {e.tool_name for e in events}
|
||||
assert "tool_a" in names
|
||||
assert "tool_b" in names
|
||||
|
||||
tool_a.execute.assert_called_once()
|
||||
tool_b.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_tools_run_sequentially(self):
|
||||
tool_s = _make_registered_tool("serial_tool", concurrency_safe=False, result_data="S")
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register(tool_s)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
calls = [
|
||||
ToolCall(id="c1", name="serial_tool", params={"x": 1}),
|
||||
]
|
||||
|
||||
events = []
|
||||
async for event in executor.execute_tool_calls(calls):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].tool_name == "serial_tool"
|
||||
assert events[0].error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_for_failed_tool(self):
|
||||
failing_tool = RegisteredTool(
|
||||
name="fail_tool",
|
||||
description="A tool that fails",
|
||||
compact_description="fail",
|
||||
concurrency_safe=True,
|
||||
execute=AsyncMock(side_effect=RuntimeError("Tool exploded")),
|
||||
)
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register(failing_tool)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
calls = [ToolCall(id="c1", name="fail_tool", params={})]
|
||||
|
||||
events = []
|
||||
async for event in executor.execute_tool_calls(calls):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].error is not None
|
||||
assert "exploded" in events[0].error.lower() or events[0].error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_concurrent_and_serial(self):
|
||||
safe = _make_registered_tool("safe", concurrency_safe=True, result_data="safe_r")
|
||||
unsafe = _make_registered_tool("unsafe", concurrency_safe=False, result_data="unsafe_r")
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register(safe)
|
||||
registry.register(unsafe)
|
||||
|
||||
executor = ToolExecutor(registry=registry)
|
||||
calls = [
|
||||
ToolCall(id="c1", name="safe", params={}),
|
||||
ToolCall(id="c2", name="unsafe", params={}),
|
||||
]
|
||||
|
||||
events = []
|
||||
async for event in executor.execute_tool_calls(calls):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 2
|
||||
results = {e.tool_name: e.result for e in events}
|
||||
assert results["safe"].data == "safe_r"
|
||||
assert results["unsafe"].data == "unsafe_r"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_returns_error(self):
|
||||
registry = ToolRegistry()
|
||||
executor = ToolExecutor(registry=registry)
|
||||
calls = [ToolCall(id="c1", name="nonexistent", params={})]
|
||||
|
||||
events = []
|
||||
async for event in executor.execute_tool_calls(calls):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolCall & ToolEvent dataclass tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolCallAndEvent:
|
||||
def test_tool_call_creation(self):
|
||||
tc = ToolCall(id="abc", name="my_tool", params={"key": "val"})
|
||||
assert tc.id == "abc"
|
||||
assert tc.name == "my_tool"
|
||||
assert tc.params == {"key": "val"}
|
||||
|
||||
def test_tool_event_success(self):
|
||||
result = ToolResult(data="success")
|
||||
ev = ToolEvent(call_id="c1", tool_name="t1", result=result, error=None)
|
||||
assert ev.result.data == "success"
|
||||
assert ev.error is None
|
||||
|
||||
def test_tool_event_error(self):
|
||||
ev = ToolEvent(call_id="c1", tool_name="t1", result=None, error="something went wrong")
|
||||
assert ev.result is None
|
||||
assert ev.error == "something went wrong"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAgent:
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.agent.get_chat_model")
|
||||
async def test_run_yields_events(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Here is the answer to your question."
|
||||
mock_response.tool_calls = []
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
registry = ToolRegistry()
|
||||
agent = Agent(tool_registry=registry)
|
||||
|
||||
events = []
|
||||
async for event in agent.run("What is 2+2?"):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) >= 1
|
||||
event_types = [e.type for e in events]
|
||||
assert any(t in event_types for t in ["message", "answer", "response", "text"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.agent.get_chat_model")
|
||||
async def test_run_with_tool_call(self, mock_get_model):
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.content = ""
|
||||
tool_call_response.tool_calls = [
|
||||
{"id": "call_1", "name": "search", "args": {"q": "python"}}
|
||||
]
|
||||
|
||||
final_response = MagicMock()
|
||||
final_response.content = "Python is a programming language."
|
||||
final_response.tool_calls = []
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_response, final_response])
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
search_tool = _make_registered_tool("search", result_data="Python docs found")
|
||||
registry = ToolRegistry()
|
||||
registry.register(search_tool)
|
||||
|
||||
agent = Agent(tool_registry=registry)
|
||||
|
||||
events = []
|
||||
async for event in agent.run("Tell me about Python"):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.agent.get_chat_model")
|
||||
async def test_max_iterations_limit(self, mock_get_model):
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.content = ""
|
||||
tool_call_response.tool_calls = [
|
||||
{"id": "call_n", "name": "loop_tool", "args": {}}
|
||||
]
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(return_value=tool_call_response)
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
loop_tool = _make_registered_tool("loop_tool", result_data="looping")
|
||||
registry = ToolRegistry()
|
||||
registry.register(loop_tool)
|
||||
|
||||
agent = Agent(tool_registry=registry, max_iterations=3)
|
||||
|
||||
events = []
|
||||
async for event in agent.run("infinite loop query"):
|
||||
events.append(event)
|
||||
|
||||
assert loop_tool.execute.call_count <= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.core.agent.get_chat_model")
|
||||
async def test_run_empty_query(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "I need more information."
|
||||
mock_response.tool_calls = []
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
registry = ToolRegistry()
|
||||
agent = Agent(tool_registry=registry)
|
||||
|
||||
events = []
|
||||
async for event in agent.run(""):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) >= 1
|
||||
|
||||
|
||||
class TestAgentEvent:
|
||||
def test_creation(self):
|
||||
event = AgentEvent(type="message", data={"text": "hello"})
|
||||
assert event.type == "message"
|
||||
assert event.data == {"text": "hello"}
|
||||
|
||||
def test_creation_with_string_data(self):
|
||||
event = AgentEvent(type="error", data="something failed")
|
||||
assert event.type == "error"
|
||||
assert event.data == "something failed"
|
||||
108
backend/tests/unit/test_llm.py
Normal file
108
backend/tests/unit/test_llm.py
Normal file
@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.services.llm import (
|
||||
LLMProvider,
|
||||
ModelTier,
|
||||
detect_provider,
|
||||
get_chat_model,
|
||||
_resolve_model_name,
|
||||
_DEFAULTS,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectProvider:
|
||||
def test_anthropic_preferred_when_both_keys_set(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(
|
||||
anthropic_api_key="sk-ant-xxx", openai_api_key="sk-xxx"
|
||||
)
|
||||
assert detect_provider() == LLMProvider.ANTHROPIC
|
||||
|
||||
def test_anthropic_when_only_anthropic_key(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(anthropic_api_key="sk-ant-xxx", openai_api_key="")
|
||||
assert detect_provider() == LLMProvider.ANTHROPIC
|
||||
|
||||
def test_openai_when_only_openai_key(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="sk-xxx")
|
||||
assert detect_provider() == LLMProvider.OPENAI
|
||||
|
||||
def test_raises_when_no_keys(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="")
|
||||
with pytest.raises(RuntimeError, match="No LLM API key configured"):
|
||||
detect_provider()
|
||||
|
||||
|
||||
class TestResolveModelName:
|
||||
def test_defaults_anthropic_fast(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="")
|
||||
name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST)
|
||||
assert name == _DEFAULTS[LLMProvider.ANTHROPIC][ModelTier.FAST]
|
||||
|
||||
def test_defaults_openai_strong(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="")
|
||||
name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG)
|
||||
assert name == _DEFAULTS[LLMProvider.OPENAI][ModelTier.STRONG]
|
||||
|
||||
def test_override_fast_model(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(llm_fast_model="custom-fast", llm_strong_model="")
|
||||
name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST)
|
||||
assert name == "custom-fast"
|
||||
|
||||
def test_override_strong_model(self):
|
||||
with patch("app.services.llm.get_settings") as mock:
|
||||
mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="custom-strong")
|
||||
name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG)
|
||||
assert name == "custom-strong"
|
||||
|
||||
|
||||
class TestGetChatModel:
|
||||
@patch("app.services.llm.get_settings")
|
||||
def test_returns_anthropic_model(self, mock_settings):
|
||||
mock_settings.return_value = MagicMock(
|
||||
anthropic_api_key="sk-ant-xxx",
|
||||
openai_api_key="",
|
||||
llm_fast_model="",
|
||||
llm_strong_model="",
|
||||
)
|
||||
model = get_chat_model(tier=ModelTier.STRONG)
|
||||
assert type(model).__name__ == "ChatAnthropic"
|
||||
|
||||
@patch("app.services.llm.get_settings")
|
||||
def test_returns_openai_model(self, mock_settings):
|
||||
mock_settings.return_value = MagicMock(
|
||||
anthropic_api_key="",
|
||||
openai_api_key="sk-xxx",
|
||||
llm_fast_model="",
|
||||
llm_strong_model="",
|
||||
)
|
||||
model = get_chat_model(tier=ModelTier.FAST)
|
||||
assert type(model).__name__ == "ChatOpenAI"
|
||||
|
||||
@patch("app.services.llm.get_settings")
|
||||
def test_explicit_provider_override(self, mock_settings):
|
||||
mock_settings.return_value = MagicMock(
|
||||
anthropic_api_key="sk-ant-xxx",
|
||||
openai_api_key="sk-xxx",
|
||||
llm_fast_model="",
|
||||
llm_strong_model="",
|
||||
)
|
||||
model = get_chat_model(tier=ModelTier.FAST, provider=LLMProvider.OPENAI)
|
||||
assert type(model).__name__ == "ChatOpenAI"
|
||||
|
||||
@patch("app.services.llm.get_settings")
|
||||
def test_custom_temperature(self, mock_settings):
|
||||
mock_settings.return_value = MagicMock(
|
||||
anthropic_api_key="sk-ant-xxx",
|
||||
openai_api_key="",
|
||||
llm_fast_model="",
|
||||
llm_strong_model="",
|
||||
)
|
||||
model = get_chat_model(tier=ModelTier.STRONG, temperature=0.7)
|
||||
assert model.temperature == 0.7
|
||||
227
backend/tests/unit/test_meta_tool.py
Normal file
227
backend/tests/unit/test_meta_tool.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""Tests for agent tool types, registry, and meta-tool."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.agents.tools.types import ToolResult, RegisteredTool
|
||||
from app.agents.tools.registry import ToolRegistry
|
||||
from app.agents.tools.meta_tool import MetaTool, SubTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
def test_to_str_with_string_data(self):
|
||||
result = ToolResult(data="hello world")
|
||||
assert result.to_str() == "hello world"
|
||||
|
||||
def test_to_str_with_dict_data(self):
|
||||
result = ToolResult(data={"key": "value", "count": 42})
|
||||
text = result.to_str()
|
||||
assert "key" in text
|
||||
assert "value" in text
|
||||
|
||||
def test_to_str_with_list_data(self):
|
||||
result = ToolResult(data=["a", "b", "c"])
|
||||
text = result.to_str()
|
||||
assert "a" in text
|
||||
assert "b" in text
|
||||
|
||||
def test_source_urls_default_empty(self):
|
||||
result = ToolResult(data="test")
|
||||
assert result.source_urls == []
|
||||
|
||||
def test_source_urls_preserved(self):
|
||||
urls = ["https://example.com/a", "https://example.com/b"]
|
||||
result = ToolResult(data="test", source_urls=urls)
|
||||
assert result.source_urls == urls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RegisteredTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisteredTool:
|
||||
def test_creation_and_attributes(self):
|
||||
handler = AsyncMock(return_value=ToolResult(data="ok"))
|
||||
tool = RegisteredTool(
|
||||
name="test_tool",
|
||||
description="A test tool for unit tests",
|
||||
compact_description="test tool",
|
||||
concurrency_safe=True,
|
||||
execute=handler,
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "A test tool for unit tests"
|
||||
assert tool.compact_description == "test tool"
|
||||
assert tool.concurrency_safe is True
|
||||
assert tool.execute is handler
|
||||
|
||||
def test_concurrency_unsafe(self):
|
||||
handler = AsyncMock(return_value=ToolResult(data="ok"))
|
||||
tool = RegisteredTool(
|
||||
name="serial_tool",
|
||||
description="desc",
|
||||
compact_description="serial",
|
||||
concurrency_safe=False,
|
||||
execute=handler,
|
||||
)
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolRegistry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool(name: str, concurrency_safe: bool = True) -> RegisteredTool:
|
||||
return RegisteredTool(
|
||||
name=name,
|
||||
description=f"Description of {name}",
|
||||
compact_description=f"compact {name}",
|
||||
concurrency_safe=concurrency_safe,
|
||||
execute=AsyncMock(return_value=ToolResult(data=f"{name} result")),
|
||||
)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
def test_register_and_get(self):
|
||||
registry = ToolRegistry()
|
||||
tool = _make_tool("alpha")
|
||||
registry.register(tool)
|
||||
assert registry.get("alpha") is tool
|
||||
|
||||
def test_get_unknown_returns_none(self):
|
||||
registry = ToolRegistry()
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_list_tools(self):
|
||||
registry = ToolRegistry()
|
||||
registry.register(_make_tool("a"))
|
||||
registry.register(_make_tool("b"))
|
||||
names = [t.name for t in registry.list_tools()]
|
||||
assert "a" in names
|
||||
assert "b" in names
|
||||
|
||||
def test_get_concurrency_map(self):
|
||||
registry = ToolRegistry()
|
||||
registry.register(_make_tool("safe_tool", concurrency_safe=True))
|
||||
registry.register(_make_tool("unsafe_tool", concurrency_safe=False))
|
||||
cmap = registry.get_concurrency_map()
|
||||
assert cmap["safe_tool"] is True
|
||||
assert cmap["unsafe_tool"] is False
|
||||
|
||||
def test_build_compact_descriptions(self):
|
||||
registry = ToolRegistry()
|
||||
registry.register(_make_tool("tool_x"))
|
||||
registry.register(_make_tool("tool_y"))
|
||||
text = registry.build_compact_descriptions()
|
||||
assert "tool_x" in text
|
||||
assert "tool_y" in text
|
||||
assert "compact tool_x" in text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MetaTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteMetaTool(MetaTool):
|
||||
"""Concrete subclass for testing abstract MetaTool."""
|
||||
|
||||
def __init__(self):
|
||||
self._sub_tools = [
|
||||
SubTool(
|
||||
name="search",
|
||||
description="Search the web for information",
|
||||
handler=AsyncMock(
|
||||
return_value=ToolResult(data="search result for query")
|
||||
),
|
||||
),
|
||||
SubTool(
|
||||
name="calculate",
|
||||
description="Perform mathematical calculations",
|
||||
handler=AsyncMock(
|
||||
return_value=ToolResult(data="42")
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def sub_tools(self) -> list[SubTool]:
|
||||
return self._sub_tools
|
||||
|
||||
|
||||
class TestSubTool:
|
||||
def test_creation(self):
|
||||
handler = AsyncMock()
|
||||
st = SubTool(name="test", description="desc", handler=handler)
|
||||
assert st.name == "test"
|
||||
assert st.description == "desc"
|
||||
assert st.handler is handler
|
||||
|
||||
|
||||
class TestMetaTool:
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.tools.meta_tool.get_chat_model")
|
||||
async def test_route_selects_correct_subtool(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"tool": "search"}'
|
||||
mock_llm.invoke = MagicMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
meta = ConcreteMetaTool()
|
||||
selected = await meta.route("find information about Python")
|
||||
assert selected.name == "search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.tools.meta_tool.get_chat_model")
|
||||
async def test_route_selects_calculate(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"tool": "calculate"}'
|
||||
mock_llm.invoke = MagicMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
meta = ConcreteMetaTool()
|
||||
selected = await meta.route("what is 6 times 7")
|
||||
assert selected.name == "calculate"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.tools.meta_tool.get_chat_model")
|
||||
async def test_execute_calls_routed_subtool(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"tool": "search"}'
|
||||
mock_llm.invoke = MagicMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
meta = ConcreteMetaTool()
|
||||
result = await meta.execute({"query": "find Python docs"})
|
||||
|
||||
search_handler = meta.sub_tools[0].handler
|
||||
search_handler.assert_called_once()
|
||||
assert result.data == "search result for query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.agents.tools.meta_tool.get_chat_model")
|
||||
async def test_execute_passes_params_to_subtool(self, mock_get_model):
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"tool": "calculate"}'
|
||||
mock_llm.invoke = MagicMock(return_value=mock_response)
|
||||
mock_get_model.return_value = mock_llm
|
||||
|
||||
meta = ConcreteMetaTool()
|
||||
params = {"query": "compute 2+2", "extra": "data"}
|
||||
result = await meta.execute(params)
|
||||
|
||||
calc_handler = meta.sub_tools[1].handler
|
||||
calc_handler.assert_called_once_with(params)
|
||||
assert result.data == "42"
|
||||
209
backend/tests/unit/test_skills.py
Normal file
209
backend/tests/unit/test_skills.py
Normal file
@ -0,0 +1,209 @@
|
||||
"""Tests for agent skill types, loader, registry, and tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.skills.types import SkillMetadata, Skill, SkillSource
|
||||
from app.agents.skills.loader import SkillLoader
|
||||
from app.agents.skills.registry import SkillRegistry
|
||||
from app.agents.skills.tool import create_skill_tool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillMetadata & Skill
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillMetadata:
|
||||
def test_creation(self) -> None:
|
||||
meta = SkillMetadata(
|
||||
name="test-skill",
|
||||
description="A test skill",
|
||||
path="/skills/test.md",
|
||||
source="builtin",
|
||||
)
|
||||
assert meta.name == "test-skill"
|
||||
assert meta.description == "A test skill"
|
||||
assert meta.path == "/skills/test.md"
|
||||
assert meta.source == "builtin"
|
||||
|
||||
|
||||
class TestSkill:
|
||||
def test_creation_with_instructions(self) -> None:
|
||||
skill = Skill(
|
||||
name="s1",
|
||||
description="desc",
|
||||
path="/skills/s1.md",
|
||||
source="builtin",
|
||||
instructions="Do this and that.",
|
||||
)
|
||||
assert skill.name == "s1"
|
||||
assert skill.instructions == "Do this and that."
|
||||
assert skill.source == "builtin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillLoader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_SKILL_CONTENT = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: A skill for testing
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
Follow these steps to perform the test skill.
|
||||
|
||||
1. Step one
|
||||
2. Step two
|
||||
"""
|
||||
|
||||
_NO_FRONTMATTER_CONTENT = """\
|
||||
## Instructions
|
||||
|
||||
This file has no YAML frontmatter.
|
||||
"""
|
||||
|
||||
_EMPTY_FRONTMATTER_CONTENT = """\
|
||||
---
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
Empty frontmatter.
|
||||
"""
|
||||
|
||||
|
||||
class TestSkillLoader:
|
||||
def test_parse_valid(self) -> None:
|
||||
skill = SkillLoader.parse_skill_file(
|
||||
content=_VALID_SKILL_CONTENT,
|
||||
path="/skills/test-skill.md",
|
||||
source="builtin",
|
||||
)
|
||||
assert skill.name == "test-skill"
|
||||
assert skill.description == "A skill for testing"
|
||||
assert "Step one" in skill.instructions
|
||||
assert skill.source == "builtin"
|
||||
|
||||
def test_parse_no_frontmatter(self) -> None:
|
||||
"""frontmatter가 없으면 name/description이 빈 문자열."""
|
||||
skill = SkillLoader.parse_skill_file(
|
||||
content=_NO_FRONTMATTER_CONTENT,
|
||||
path="/skills/no-front.md",
|
||||
source="builtin",
|
||||
)
|
||||
assert skill.name == ""
|
||||
assert skill.description == ""
|
||||
assert "no YAML frontmatter" in skill.instructions
|
||||
|
||||
def test_parse_empty_frontmatter(self) -> None:
|
||||
"""빈 frontmatter도 정상 처리."""
|
||||
skill = SkillLoader.parse_skill_file(
|
||||
content=_EMPTY_FRONTMATTER_CONTENT,
|
||||
path="/skills/empty-front.md",
|
||||
source="builtin",
|
||||
)
|
||||
assert skill.name == ""
|
||||
assert skill.description == ""
|
||||
|
||||
def test_load_from_path(self, tmp_path: Path) -> None:
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(_VALID_SKILL_CONTENT, encoding="utf-8")
|
||||
|
||||
skill = SkillLoader.load_from_path(skill_file, "builtin")
|
||||
assert skill.name == "test-skill"
|
||||
assert "Step one" in skill.instructions
|
||||
|
||||
def test_extract_metadata(self, tmp_path: Path) -> None:
|
||||
skill_file = tmp_path / "SKILL.md"
|
||||
skill_file.write_text(_VALID_SKILL_CONTENT, encoding="utf-8")
|
||||
|
||||
meta = SkillLoader.extract_metadata(skill_file, "builtin")
|
||||
assert isinstance(meta, SkillMetadata)
|
||||
assert meta.name == "test-skill"
|
||||
assert meta.description == "A skill for testing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillRegistry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillRegistry:
|
||||
def setup_method(self) -> None:
|
||||
SkillRegistry.clear_cache()
|
||||
|
||||
def test_discover_finds_builtin_skills(self) -> None:
|
||||
"""실제 builtin 디렉토리에서 dcf-kr, kim-jong-bong-strategy 발견."""
|
||||
skills = SkillRegistry.discover()
|
||||
names = [s.name for s in skills]
|
||||
assert "dcf-kr" in names
|
||||
assert "kim-jong-bong-strategy" in names
|
||||
|
||||
def test_get_returns_skill(self) -> None:
|
||||
SkillRegistry.discover()
|
||||
skill = SkillRegistry.get("dcf-kr")
|
||||
assert skill is not None
|
||||
assert skill.name == "dcf-kr"
|
||||
assert "WACC" in skill.instructions
|
||||
|
||||
def test_get_returns_none_for_unknown(self) -> None:
|
||||
SkillRegistry.discover()
|
||||
assert SkillRegistry.get("nonexistent-skill") is None
|
||||
|
||||
def test_list_skills(self) -> None:
|
||||
SkillRegistry.discover()
|
||||
skills = SkillRegistry.list_skills()
|
||||
assert len(skills) >= 2
|
||||
|
||||
def test_build_skills_section(self) -> None:
|
||||
SkillRegistry.discover()
|
||||
section = SkillRegistry.build_skills_section()
|
||||
assert "dcf-kr" in section
|
||||
assert "kim-jong-bong-strategy" in section
|
||||
|
||||
def test_build_skills_section_empty(self) -> None:
|
||||
"""캐시 비어있으면 빈 문자열 반환."""
|
||||
section = SkillRegistry.build_skills_section()
|
||||
assert section == ""
|
||||
|
||||
def test_clear_cache(self) -> None:
|
||||
SkillRegistry.discover()
|
||||
assert len(SkillRegistry.list_skills()) >= 2
|
||||
SkillRegistry.clear_cache()
|
||||
assert len(SkillRegistry._cache) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_skill_tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillTool:
|
||||
def setup_method(self) -> None:
|
||||
SkillRegistry.clear_cache()
|
||||
SkillRegistry.discover()
|
||||
|
||||
def test_tool_metadata(self) -> None:
|
||||
tool = create_skill_tool()
|
||||
assert tool.name == "use_skill"
|
||||
assert tool.concurrency_safe is True
|
||||
|
||||
def test_execute_valid_skill(self) -> None:
|
||||
tool = create_skill_tool()
|
||||
result = asyncio.run(tool.execute({"skill_name": "dcf-kr"}))
|
||||
assert "WACC" in result.data
|
||||
|
||||
def test_execute_unknown_skill(self) -> None:
|
||||
tool = create_skill_tool()
|
||||
result = asyncio.run(tool.execute({"skill_name": "no-such-skill"}))
|
||||
assert "찾을 수 없습니다" in result.data
|
||||
assert "dcf-kr" in result.data
|
||||
68
docs/plans/agents-architecture.md
Normal file
68
docs/plans/agents-architecture.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Agent Architecture
|
||||
|
||||
## 개요
|
||||
Galaxis-Po 투자 분석 에이전트 시스템. 자연어 쿼리를 받아 도구를 활용하여 투자 분석을 수행합니다.
|
||||
|
||||
## 구조
|
||||
|
||||
```
|
||||
backend/app/agents/
|
||||
├── __init__.py
|
||||
├── tools/ # Meta-tool 패턴
|
||||
│ ├── types.py # ToolResult, RegisteredTool
|
||||
│ ├── registry.py # 도구 레지스트리
|
||||
│ ├── meta_tool.py # MetaTool 베이스 클래스
|
||||
│ ├── finance/ # 금융 데이터 도구
|
||||
│ │ ├── get_financials.py
|
||||
│ │ ├── get_market_data.py
|
||||
│ │ └── sub_tools.py
|
||||
│ ├── search/ # 검색 도구
|
||||
│ │ ├── web_search.py
|
||||
│ │ └── news_search.py
|
||||
│ └── filesystem/ # 파일 도구
|
||||
│ ├── read_file.py
|
||||
│ ├── write_file.py
|
||||
│ └── edit_file.py
|
||||
├── skills/ # SKILL.md 시스템
|
||||
│ ├── types.py # SkillMetadata, Skill
|
||||
│ ├── loader.py # SKILL.md 파서
|
||||
│ ├── registry.py # 스킬 디스커버리
|
||||
│ ├── tool.py # use_skill 도구
|
||||
│ └── builtin/ # 내장 스킬
|
||||
│ ├── dcf/SKILL.md
|
||||
│ └── kim-jong-bong-strategy/SKILL.md
|
||||
└── core/ # 에이전트 코어
|
||||
├── agent.py # 에이전트 루프
|
||||
├── compact.py # 컨텍스트 압축
|
||||
├── prompts.py # 시스템 프롬프트
|
||||
├── scratchpad.py # 실행 로그
|
||||
├── tool_executor.py # 동시 도구 실행
|
||||
└── rules.py # RULES.md 로더
|
||||
```
|
||||
|
||||
## 핵심 패턴
|
||||
|
||||
### Meta-tool
|
||||
NL 쿼리를 받아 LLM(fast tier)으로 적절한 sub-tool을 선택하여 실행합니다.
|
||||
|
||||
### SKILL.md
|
||||
YAML frontmatter + 마크다운 본문으로 구성된 전문 분석 워크플로우입니다.
|
||||
에이전트가 use_skill 도구로 스킬을 로드하면 해당 워크플로우를 따릅니다.
|
||||
|
||||
### 도구 동시성
|
||||
concurrency_safe 플래그로 읽기 전용 도구는 병렬 실행, 쓰기 도구는 직렬 실행합니다.
|
||||
|
||||
### 컨텍스트 압축
|
||||
토큰 수가 임계값을 초과하면 LLM으로 대화를 요약하여 컨텍스트를 압축합니다.
|
||||
|
||||
## API
|
||||
|
||||
### POST /api/agent/query
|
||||
동기 응답. 에이전트 실행 후 최종 결과 반환.
|
||||
|
||||
### POST /api/agent/stream
|
||||
SSE 스트리밍. 도구 실행 과정을 실시간으로 전달.
|
||||
|
||||
## LLM 연동
|
||||
app.services.llm.get_chat_model()을 통해 LangChain BaseChatModel 사용.
|
||||
ModelTier.FAST(라우팅/압축), ModelTier.STRONG(분석/추론).
|
||||
Loading…
x
Reference in New Issue
Block a user