diff --git a/backend/app/agents/__init__.py b/backend/app/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/core/__init__.py b/backend/app/agents/core/__init__.py new file mode 100644 index 0000000..c42a13b --- /dev/null +++ b/backend/app/agents/core/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from app.agents.core.agent import Agent +from app.agents.core.compact import ContextCompactor +from app.agents.core.prompts import SystemPromptBuilder +from app.agents.core.scratchpad import Scratchpad +from app.agents.core.tool_executor import ToolExecutor + +__all__ = [ + "Agent", + "ContextCompactor", + "Scratchpad", + "SystemPromptBuilder", + "ToolExecutor", +] diff --git a/backend/app/agents/core/agent.py b/backend/app/agents/core/agent.py new file mode 100644 index 0000000..46676ee --- /dev/null +++ b/backend/app/agents/core/agent.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.tools import StructuredTool + +from app.agents.core.compact import ContextCompactor +from app.agents.core.prompts import SystemPromptBuilder +from app.agents.core.rules import RulesLoader +from app.agents.core.scratchpad import Scratchpad +from app.agents.core.tool_executor import ToolCall, ToolExecutor +from app.agents.skills.registry import SkillRegistry +from app.agents.tools.registry import ToolRegistry +from app.agents.tools.types import RegisteredTool +from app.services.llm import ModelTier, get_chat_model + + +@dataclass +class AgentConfig: + model_tier: ModelTier = ModelTier.STRONG + max_iterations: int = 20 + temperature: float = 0.0 + compact_threshold: int = 50000 + + +@dataclass +class AgentEvent: + type: str + data: Any = field(default_factory=dict) + + +def _make_langchain_tool(rt: RegisteredTool) -> StructuredTool: + async def _fn(**kwargs: Any) -> str: + result = await rt.execute(kwargs) + return result.to_str() + + return StructuredTool.from_function( + coroutine=_fn, + name=rt.name, + description=rt.description, + ) + + +class Agent: + def __init__( + self, + config: AgentConfig | None = None, + tool_registry: ToolRegistry | None = None, + max_iterations: int | None = None, + ) -> None: + self._config = config or AgentConfig() + self._tool_registry = tool_registry + if max_iterations is not None: + self._config.max_iterations = max_iterations + self._scratchpad = Scratchpad() + self._compactor = ContextCompactor() + + async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: + if self._tool_registry is not None: + tool_registry = self._tool_registry + else: + tool_registry = ToolRegistry.auto_register() + + SkillRegistry.discover() + + tools_section = tool_registry.build_compact_descriptions() + skills_section = SkillRegistry.build_skills_section() + rules = RulesLoader.load_rules() + + system_prompt = SystemPromptBuilder.build( + tools_section=tools_section, + skills_section=skills_section, + rules=rules, + ) + + registered_tools = tool_registry.list_tools() + lc_tools = [_make_langchain_tool(rt) for rt in registered_tools] + + llm = get_chat_model( + tier=self._config.model_tier, + temperature=self._config.temperature, + ) + model = llm.bind_tools(lc_tools) if lc_tools else llm + + messages: list = [ + SystemMessage(content=system_prompt), + HumanMessage(content=query), + ] + + executor = ToolExecutor(tool_registry) + + for _iteration in range(self._config.max_iterations): + yield AgentEvent(type="thinking") + + response: AIMessage = await model.ainvoke(messages) + messages.append(response) + + if not response.tool_calls: + content = ( + response.content + if isinstance(response.content, str) + else str(response.content) + ) + yield AgentEvent(type="response", data={"content": content}) + yield AgentEvent( + type="done", data={"final_response": content} + ) + return + + tool_calls = [ + ToolCall( + id=tc["id"], + name=tc["name"], + params=tc.get("args", {}), + ) + for tc in response.tool_calls + ] + + for tc in tool_calls: + yield AgentEvent( + type="tool_start", + data={"tool_name": tc.name, "params": tc.params}, + ) + + async for event in executor.execute_tool_calls(tool_calls): + if event.error: + yield AgentEvent( + type="tool_error", + data={ + "tool_name": event.tool_name, + "error": event.error, + }, + ) + result_str = f"오류: {event.error}" + else: + result_str = ( + event.result.to_str() if event.result else "" + ) + yield AgentEvent( + type="tool_end", + data={ + "tool_name": event.tool_name, + "result": result_str, + }, + ) + + self._scratchpad.add( + tool_name=event.tool_name, + params={}, + result=result_str, + ) + + messages.append( + ToolMessage( + content=result_str, + tool_call_id=event.call_id, + ) + ) + + msg_dicts = [ + { + "role": getattr(m, "type", "unknown"), + "content": str(getattr(m, "content", "")), + } + for m in messages + ] + if self._compactor.should_compact( + msg_dicts, self._config.compact_threshold + ): + compacted = await self._compactor.compact( + msg_dicts, system_prompt + ) + messages = [ + SystemMessage(content=compacted[0]["content"]), + HumanMessage(content=compacted[1]["content"]), + ] + yield AgentEvent(type="compaction") + + yield AgentEvent( + type="response", + data={ + "content": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다." + }, + ) + yield AgentEvent( + type="done", + data={ + "final_response": "최대 반복 횟수에 도달했습니다. 분석을 완료하지 못했습니다." + }, + ) diff --git a/backend/app/agents/core/compact.py b/backend/app/agents/core/compact.py new file mode 100644 index 0000000..86cf11b --- /dev/null +++ b/backend/app/agents/core/compact.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from langchain_core.messages import HumanMessage, SystemMessage + +from app.services.llm import ModelTier, get_chat_model + +COMPACTION_PROMPT = """당신은 대화 요약 전문가입니다. 아래 대화를 구조화된 요약으로 압축하세요. + +반드시 다음 섹션을 포함하세요: +## 원래 질문 +## 핵심 데이터 +## 분석 진행 상황 +## 수치 데이터 (정확한 값 유지) +## 오류/이슈 +## 다음 단계""" + + +class ContextCompactor: + """대화 컨텍스트가 너무 커지면 LLM으로 요약하여 압축합니다.""" + + @staticmethod + def _estimate_tokens(text: str) -> int: + """텍스트의 토큰 수를 추정합니다 (~4자당 1토큰).""" + return len(text) // 4 + + def should_compact( + self, messages: list[dict], threshold: int = 50000 + ) -> bool: + """메시지의 토큰 추정치가 임계값을 초과하는지 확인합니다. + + Args: + messages: 대화 메시지 목록. + threshold: 토큰 임계값 (기본 50,000). + + Returns: + 압축이 필요하면 True. + """ + total_text = "".join( + str(m.get("content", "")) for m in messages + ) + return self._estimate_tokens(total_text) > threshold + + async def compact( + self, messages: list[dict], system_prompt: str + ) -> list[dict]: + """대화를 요약하여 압축된 메시지 목록을 반환합니다. + + Args: + messages: 기존 대화 메시지 목록. + system_prompt: 원래 시스템 프롬프트. + + Returns: + 압축된 메시지 목록 (시스템 메시지 + 요약 메시지). + """ + # 대화 내용을 텍스트로 변환 + conversation_lines: list[str] = [] + for msg in messages: + role = msg.get("role", "unknown") + content = str(msg.get("content", "")) + conversation_lines.append(f"[{role}]: {content}") + + conversation_text = "\n".join(conversation_lines) + + llm = get_chat_model(tier=ModelTier.FAST) + response = await llm.ainvoke( + [ + SystemMessage(content=COMPACTION_PROMPT), + HumanMessage(content=conversation_text), + ] + ) + + summary_content = ( + response.content + if isinstance(response.content, str) + else str(response.content) + ) + + return [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": f"[이전 대화 요약]\n{summary_content}", + }, + ] diff --git a/backend/app/agents/core/prompts.py b/backend/app/agents/core/prompts.py new file mode 100644 index 0000000..9ee1995 --- /dev/null +++ b/backend/app/agents/core/prompts.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from datetime import date + +_SYSTEM_PROMPT_TEMPLATE = """당신은 Galaxis-Po 투자 분석 에이전트입니다. +한국 퇴직연금(DC) 포트폴리오 관리를 위한 퀀트 분석을 수행합니다. + +## 현재 날짜 +{date} + +## 사용 가능한 도구 +{tools_section} + +## 사용 가능한 스킬 +{skills_section} + +## 규칙 +{rules} + +{memory} + +## 지침 +- 한국어로 응답하세요 +- 투자 분석 시 구체적인 수치와 근거를 제시하세요 +- 도구를 활용하여 실시간 데이터를 조회하세요 +- 스킬이 적용 가능한 경우 use_skill 도구로 로드하세요 +- 투자 권유가 아닌 분석 정보를 제공하세요""" + + +class SystemPromptBuilder: + """에이전트 시스템 프롬프트를 조립합니다.""" + + @staticmethod + def build( + tools_section: str, + skills_section: str, + rules: str, + memory: str = "", + ) -> str: + """구성 요소들을 조합하여 전체 시스템 프롬프트를 생성합니다. + + Args: + tools_section: 도구 설명 텍스트. + skills_section: 스킬 설명 텍스트. + rules: 프로젝트 규칙. + memory: 추가 메모리/컨텍스트 (선택). + + Returns: + 완성된 시스템 프롬프트 문자열. + """ + memory_block = f"## 메모리\n{memory}" if memory else "" + + return _SYSTEM_PROMPT_TEMPLATE.format( + date=date.today().isoformat(), + tools_section=tools_section or "등록된 도구 없음", + skills_section=skills_section or "등록된 스킬 없음", + rules=rules or "추가 규칙 없음", + memory=memory_block, + ) diff --git a/backend/app/agents/core/rules.py b/backend/app/agents/core/rules.py new file mode 100644 index 0000000..764a210 --- /dev/null +++ b/backend/app/agents/core/rules.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from pathlib import Path + + +class RulesLoader: + """프로젝트 규칙 파일(RULES.md)을 로드합니다.""" + + @staticmethod + def load_rules(project_root: Path | None = None) -> str: + """프로젝트 루트에서 RULES.md를 찾아 내용을 반환합니다. + + 탐색 순서: + 1. {project_root}/RULES.md + 2. {project_root}/.claude/RULES.md + + Args: + project_root: 프로젝트 루트 디렉토리. None이면 빈 문자열 반환. + + Returns: + RULES.md 내용 또는 빈 문자열. + """ + if project_root is None: + return "" + + candidates = [ + project_root / "RULES.md", + project_root / ".claude" / "RULES.md", + ] + + for path in candidates: + if path.is_file(): + return path.read_text(encoding="utf-8") + + return "" diff --git a/backend/app/agents/core/scratchpad.py b/backend/app/agents/core/scratchpad.py new file mode 100644 index 0000000..deb6b3c --- /dev/null +++ b/backend/app/agents/core/scratchpad.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime + + +@dataclass +class ScratchpadEntry: + """도구 실행 기록 항목.""" + + tool_name: str + params: dict + result: str + timestamp: datetime + token_estimate: int + + +class Scratchpad: + """도구 실행 이력과 토큰 추정치를 추적합니다.""" + + def __init__(self) -> None: + self.entries: list[ScratchpadEntry] = [] + + def add(self, tool_name: str, params: dict, result: str) -> None: + """도구 실행 결과를 기록합니다. + + Args: + tool_name: 실행된 도구 이름. + params: 도구에 전달된 매개변수. + result: 도구 실행 결과 문자열. + """ + token_estimate = max(len(result) // 4, 1) + entry = ScratchpadEntry( + tool_name=tool_name, + params=params, + result=result, + timestamp=datetime.now(), + token_estimate=token_estimate, + ) + self.entries.append(entry) + + def total_tokens(self) -> int: + """전체 기록의 추정 토큰 수를 반환합니다.""" + return sum(e.token_estimate for e in self.entries) + + def format_history(self) -> str: + """도구 실행 이력을 포맷된 문자열로 반환합니다.""" + if not self.entries: + return "도구 실행 이력 없음" + + lines: list[str] = [] + for i, entry in enumerate(self.entries, 1): + lines.append( + f"[{i}] {entry.tool_name} " + f"({entry.timestamp.strftime('%H:%M:%S')}) " + f"토큰≈{entry.token_estimate}" + ) + lines.append(f" params: {entry.params}") + # 결과가 너무 길면 잘라냄 + result_preview = entry.result[:200] + if len(entry.result) > 200: + result_preview += "..." + lines.append(f" result: {result_preview}") + return "\n".join(lines) + + def clear(self) -> None: + """모든 기록을 삭제합니다.""" + self.entries.clear() diff --git a/backend/app/agents/core/tool_executor.py b/backend/app/agents/core/tool_executor.py new file mode 100644 index 0000000..5db60e7 --- /dev/null +++ b/backend/app/agents/core/tool_executor.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +from app.agents.tools.types import ToolResult +from app.agents.tools.registry import ToolRegistry + + +@dataclass +class ToolCall: + """도구 호출 요청.""" + + id: str + name: str + params: dict + + +@dataclass +class ToolEvent: + """도구 실행 결과 이벤트.""" + + call_id: str + tool_name: str + result: ToolResult | None = None + error: str | None = None + + +class ToolExecutor: + """도구 호출을 동시성 안전 여부에 따라 분류하여 실행합니다.""" + + def __init__( + self, registry: ToolRegistry, max_concurrency: int = 5 + ) -> None: + self._registry = registry + self._semaphore = asyncio.Semaphore(max_concurrency) + + async def _execute_single(self, call: ToolCall) -> ToolEvent: + """단일 도구를 실행하고 ToolEvent를 반환합니다.""" + tool = self._registry.get(call.name) + if tool is None: + return ToolEvent( + call_id=call.id, + tool_name=call.name, + error=f"도구를 찾을 수 없습니다: {call.name}", + ) + + try: + async with self._semaphore: + result = await tool.execute(call.params) + return ToolEvent( + call_id=call.id, + tool_name=call.name, + result=result, + ) + except Exception as exc: + return ToolEvent( + call_id=call.id, + tool_name=call.name, + error=f"{type(exc).__name__}: {exc}", + ) + + async def execute_tool_calls( + self, tool_calls: list[ToolCall] + ) -> AsyncGenerator[ToolEvent, None]: + """도구 호출 목록을 실행합니다. + + 동시성 안전한 도구는 asyncio.gather로 병렬 실행하고, + 그렇지 않은 도구는 순차적으로 실행합니다. + + Args: + tool_calls: 실행할 도구 호출 목록. + + Yields: + 각 도구 실행 결과에 대한 ToolEvent. + """ + concurrency_map = self._registry.get_concurrency_map() + + concurrent_calls: list[ToolCall] = [] + serial_calls: list[ToolCall] = [] + + for call in tool_calls: + if concurrency_map.get(call.name, False): + concurrent_calls.append(call) + else: + serial_calls.append(call) + + # 병렬 실행 가능한 도구들 + if concurrent_calls: + tasks = [ + self._execute_single(call) for call in concurrent_calls + ] + results = await asyncio.gather(*tasks) + for event in results: + yield event + + # 순차 실행해야 하는 도구들 + for call in serial_calls: + event = await self._execute_single(call) + yield event diff --git a/backend/app/agents/skills/__init__.py b/backend/app/agents/skills/__init__.py new file mode 100644 index 0000000..e0cfb6c --- /dev/null +++ b/backend/app/agents/skills/__init__.py @@ -0,0 +1,12 @@ +from app.agents.skills.loader import SkillLoader +from app.agents.skills.registry import SkillRegistry +from app.agents.skills.tool import create_skill_tool +from app.agents.skills.types import Skill, SkillMetadata + +__all__ = [ + "Skill", + "SkillLoader", + "SkillMetadata", + "SkillRegistry", + "create_skill_tool", +] diff --git a/backend/app/agents/skills/builtin/__init__.py b/backend/app/agents/skills/builtin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/skills/builtin/dcf/SKILL.md b/backend/app/agents/skills/builtin/dcf/SKILL.md new file mode 100644 index 0000000..ec224ff --- /dev/null +++ b/backend/app/agents/skills/builtin/dcf/SKILL.md @@ -0,0 +1,89 @@ +--- +name: dcf-kr +description: 한국 주식시장용 DCF(할인현금흐름) 밸류에이션 분석 +--- + +# DCF 밸류에이션 분석 (한국 시장) + +## 실행 체크리스트 + +- [ ] 1단계: 재무 데이터 수집 +- [ ] 2단계: FCF 성장률 산출 +- [ ] 3단계: WACC 추정 +- [ ] 4단계: 미래 현금흐름 추정 +- [ ] 5단계: 현재가치 할인 및 주당 내재가치 산출 +- [ ] 6단계: 민감도 분석 +- [ ] 7단계: 검증 +- [ ] 8단계: 최종 보고서 작성 + +## 1단계: 재무 데이터 수집 + +get_financials 도구로 다음 데이터를 수집: +- 최근 5년 현금흐름표 (영업활동CF, CAPEX, FCF) +- 재무비율 (ROE, ROA, 부채비율) +- 대차대조표 (총부채, 현금성자산, 발행주식수) +- 시가총액, 현재 주가 + +## 2단계: FCF 성장률 산출 + +- 5년 FCF CAGR 계산 +- 매출 성장률, 영업이익률 추이와 교차 검증 +- 한국 시장 특성: 성장률 상한 15% + +## 3단계: WACC 추정 + +한국 시장 기본 가정: +- 무위험이자율: 3.0-3.5% (국고채 10년) +- 시장위험프리미엄: 6-8% +- 베타: 업종별 차등 적용 + +### 업종별 WACC 참고 범위 +| 업종 | WACC 범위 | +|------|-----------| +| 반도체 | 9-11% | +| 자동차 | 8-10% | +| 바이오 | 11-14% | +| 인터넷/플랫폼 | 9-12% | +| 유통/소매 | 8-10% | +| 금융 | 7-9% | +| 건설 | 9-11% | +| 에너지 | 8-10% | +| 통신 | 7-9% | +| 유틸리티 | 6-8% | + +## 4단계: 미래 현금흐름 추정 (5년) + +- Year 1-5 FCF = 전년 FCF × (1 + 성장률) × 감쇠계수 +- 감쇠계수: [1.0, 0.95, 0.90, 0.85, 0.80] +- 영구성장가치(Terminal Value) = Year5 FCF × (1 + 영구성장률) / (WACC - 영구성장률) +- 영구성장률: 2.0-2.5% (한국 GDP 성장률 수준) + +## 5단계: 현재가치 할인 및 주당 내재가치 + +- 각 연도 FCF를 WACC로 할인 +- Enterprise Value = PV(FCF) + PV(Terminal Value) +- Equity Value = EV - 순부채 (총부채 - 현금) +- 주당 내재가치 = Equity Value / 발행주식수 + +## 6단계: 민감도 분석 + +3×3 매트릭스: +- WACC: 기본값 ± 1% +- 영구성장률: 2.0%, 2.5%, 3.0% + +## 7단계: 검증 + +- [ ] EV 산출값과 시가총액 비교 (±50% 이내) +- [ ] Terminal Value 비중 < 75% +- [ ] EV/FCF 배수 합리성 (10-30x) +- [ ] 동종업계 PER과 비교 + +## 8단계: 최종 보고서 + +1. 기업 개요 및 투자 논거 +2. 핵심 가정 (성장률, WACC, 영구성장률) +3. DCF 밸류에이션 결과 (주당 내재가치, 현재가 대비 괴리율) +4. 민감도 분석 매트릭스 +5. 리스크 요인 및 제한사항 + +> 본 분석은 참고용이며 투자 권유가 아닙니다. 실제 투자 결정 전 전문가 상담을 권장합니다. diff --git a/backend/app/agents/skills/builtin/kim-jong-bong-strategy/SKILL.md b/backend/app/agents/skills/builtin/kim-jong-bong-strategy/SKILL.md new file mode 100644 index 0000000..9063454 --- /dev/null +++ b/backend/app/agents/skills/builtin/kim-jong-bong-strategy/SKILL.md @@ -0,0 +1,72 @@ +--- +name: kim-jong-bong-strategy +description: 김종봉 단기매매 전략 - 상대강도 기반 종목 선정 및 매매 시그널 생성 +--- + +# 김종봉 전략 분석 + +## 실행 체크리스트 + +- [ ] 1단계: 유니버스 구성 (시가총액 Top 30, 거래대금 2000억+) +- [ ] 2단계: 상대강도 분석 (종목 수익률 vs 코스피) +- [ ] 3단계: 차트 패턴 감지 (박스권 돌파, 장대양봉) +- [ ] 4단계: 매수 시그널 생성 +- [ ] 5단계: 포트폴리오 구성 및 리스크 관리 +- [ ] 6단계: 청산 규칙 적용 + +## 전략 개요 + +- 목표 수익률: 월 10% +- 종목 수: 5-10개 +- 현금 비중: 30% +- 보유 기간: 수일~수주 + +## 1단계: 유니버스 구성 + +get_market_data 도구로 데이터 수집: +- 시가총액 상위 30개 종목 +- 일 거래대금 >= 2,000억원 필터 + +## 2단계: 상대강도 분석 + +- RS = (종목 수익률 / 코스피 수익률) × 100 +- RS > 100이면 시장 대비 강세 +- 기본 lookback: 10일 (2주) + +## 3단계: 차트 패턴 감지 + +### 박스권 돌파 +- 종가 > 직전 20일 최고가 + +### 장대양봉 +- 일 수익률 >= 5% +- 거래량 >= 20일 평균 × 1.5 + +## 4단계: 매수 시그널 + +매수 조건 = RS > 100 AND (박스권 돌파 OR 장대양봉) + +## 5단계: 포트폴리오 구성 + +| 항목 | 기준 | +|------|------| +| 종목당 투자금 | 총자산 / 최대종목수 | +| 최대 종목 수 | 10개 | +| 현금 비중 | 30% | +| 리밸런싱 | 분기별 | + +## 6단계: 청산 규칙 + +| 유형 | 기준 | +|------|------| +| 손절 | -3% | +| 1차 익절 | +5% (50% 물량) | +| 2차 익절 | +10% (나머지) | +| 트레일링 스톱 | +5% 도달 시 손절선을 본전으로 상향 | + +## 리스크 관리 + +- 종목당 최대 손실: -3% +- 포트폴리오 최대 손실: -10% +- 승률 목표: 60%+ +- 손익비: 1:1.5+ diff --git a/backend/app/agents/skills/loader.py b/backend/app/agents/skills/loader.py new file mode 100644 index 0000000..cd7a58c --- /dev/null +++ b/backend/app/agents/skills/loader.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml + +from app.agents.skills.types import Skill, SkillMetadata, SkillSource + + +class SkillLoader: + """SKILL.md 파일 파싱 및 로딩.""" + + @staticmethod + def parse_skill_file(content: str, path: str, source: SkillSource) -> Skill: + """YAML frontmatter + markdown body를 파싱하여 Skill 객체 반환.""" + frontmatter, instructions = SkillLoader._split_frontmatter(content) + meta = yaml.safe_load(frontmatter) or {} + return Skill( + name=meta.get("name", ""), + description=meta.get("description", ""), + path=path, + source=source, + instructions=instructions.strip(), + ) + + @staticmethod + def load_from_path(path: str | Path, source: SkillSource) -> Skill: + """파일 경로에서 Skill을 로드.""" + p = Path(path) + content = p.read_text(encoding="utf-8") + return SkillLoader.parse_skill_file(content, str(p), source) + + @staticmethod + def extract_metadata(path: str | Path, source: SkillSource) -> SkillMetadata: + """frontmatter만 파싱하여 경량 메타데이터 반환.""" + p = Path(path) + content = p.read_text(encoding="utf-8") + frontmatter, _ = SkillLoader._split_frontmatter(content) + meta = yaml.safe_load(frontmatter) or {} + return SkillMetadata( + name=meta.get("name", ""), + description=meta.get("description", ""), + path=str(p), + source=source, + ) + + @staticmethod + def _split_frontmatter(content: str) -> tuple[str, str]: + """--- 마커 사이의 YAML frontmatter와 나머지 body를 분리.""" + stripped = content.strip() + if not stripped.startswith("---"): + return "", content + + # 두 번째 --- 찾기 + end_idx = stripped.find("---", 3) + if end_idx == -1: + return "", content + + frontmatter = stripped[3:end_idx].strip() + body = stripped[end_idx + 3 :] + return frontmatter, body diff --git a/backend/app/agents/skills/registry.py b/backend/app/agents/skills/registry.py new file mode 100644 index 0000000..085dafa --- /dev/null +++ b/backend/app/agents/skills/registry.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from pathlib import Path + +from app.agents.skills.loader import SkillLoader +from app.agents.skills.types import Skill, SkillMetadata + + +class SkillRegistry: + """builtin 스킬 탐색 및 캐싱.""" + + BUILTIN_DIR: Path = Path(__file__).parent / "builtin" + _cache: dict[str, SkillMetadata] = {} + + @classmethod + def discover(cls) -> list[SkillMetadata]: + """BUILTIN_DIR 하위의 */SKILL.md 패턴을 스캔하여 메타데이터 캐싱.""" + cls._cache.clear() + for skill_file in sorted(cls.BUILTIN_DIR.glob("*/SKILL.md")): + meta = SkillLoader.extract_metadata(skill_file, "builtin") + if meta.name: + cls._cache[meta.name] = meta + return list(cls._cache.values()) + + @classmethod + def get(cls, name: str) -> Skill | None: + """캐시에서 이름으로 조회 후 전체 Skill(instructions 포함) 로드.""" + if not cls._cache: + cls.discover() + meta = cls._cache.get(name) + if meta is None: + return None + return SkillLoader.load_from_path(meta.path, meta.source) + + @classmethod + def list_skills(cls) -> list[SkillMetadata]: + return list(cls._cache.values()) + + @classmethod + def build_skills_section(cls) -> str: + """시스템 프롬프트에 삽입할 스킬 목록 텍스트 생성.""" + skills = cls.list_skills() + if not skills: + return "" + lines = ["## 사용 가능한 스킬\n"] + for s in skills: + lines.append(f"- **{s.name}**: {s.description}") + return "\n".join(lines) + + @classmethod + def clear_cache(cls) -> None: + cls._cache.clear() diff --git a/backend/app/agents/skills/tool.py b/backend/app/agents/skills/tool.py new file mode 100644 index 0000000..29f4dc4 --- /dev/null +++ b/backend/app/agents/skills/tool.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Any + +from app.agents.skills.registry import SkillRegistry +from app.agents.tools.types import RegisteredTool, ToolResult + + +def create_skill_tool() -> RegisteredTool: + """use_skill 도구를 생성하여 반환.""" + + async def execute(args: dict[str, Any]) -> ToolResult: + skill_name: str = args["skill_name"] + skill = SkillRegistry.get(skill_name) + if skill is None: + available = SkillRegistry.list_skills() + names = [s.name for s in available] + return ToolResult( + data=( + f"스킬 '{skill_name}'을(를) 찾을 수 없습니다. " + f"사용 가능한 스킬: {', '.join(names) or '없음'}" + ), + ) + return ToolResult(data=skill.instructions) + + return RegisteredTool( + name="use_skill", + description=( + "스킬을 로드하여 전문 분석 워크플로우를 실행합니다. " + "사용 가능한 스킬 목록에서 선택하세요." + ), + compact_description="스킬 실행", + execute=execute, + concurrency_safe=True, + ) diff --git a/backend/app/agents/skills/types.py b/backend/app/agents/skills/types.py new file mode 100644 index 0000000..ae9723f --- /dev/null +++ b/backend/app/agents/skills/types.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +SkillSource = Literal["builtin", "user", "project"] + + +@dataclass +class SkillMetadata: + name: str + description: str + path: str + source: SkillSource + + +@dataclass +class Skill: + name: str + description: str + path: str + source: SkillSource + instructions: str diff --git a/backend/app/agents/tools/__init__.py b/backend/app/agents/tools/__init__.py new file mode 100644 index 0000000..b7c7613 --- /dev/null +++ b/backend/app/agents/tools/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from app.agents.tools.types import ToolResult, RegisteredTool +from app.agents.tools.registry import ToolRegistry + +__all__ = ["ToolResult", "RegisteredTool", "ToolRegistry"] diff --git a/backend/app/agents/tools/filesystem/__init__.py b/backend/app/agents/tools/filesystem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/tools/filesystem/edit_file.py b/backend/app/agents/tools/filesystem/edit_file.py new file mode 100644 index 0000000..ab566ac --- /dev/null +++ b/backend/app/agents/tools/filesystem/edit_file.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import asyncio +from functools import partial +from pathlib import Path +from typing import Any + +from app.agents.tools.types import RegisteredTool, ToolResult + +PROJECT_ROOT = Path(__file__).resolve().parents[5] + + +def _validate_path(raw_path: str) -> Path: + """경로를 검증하고 절대 경로로 변환합니다.""" + resolved = (PROJECT_ROOT / raw_path).resolve() + if not resolved.is_relative_to(PROJECT_ROOT): + raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}") + return resolved + + +def _edit_sync(path: Path, old_string: str, new_string: str) -> str: + content = path.read_text(encoding="utf-8") + count = content.count(old_string) + if count == 0: + raise ValueError("old_string을 파일에서 찾을 수 없습니다") + if count > 1: + raise ValueError(f"old_string이 {count}회 발견됨 — 고유한 문자열을 제공해 주세요") + updated = content.replace(old_string, new_string, 1) + path.write_text(updated, encoding="utf-8") + return "파일 수정 완료" + + +def create_edit_file_tool() -> RegisteredTool: + """파일 수정 도구를 생성합니다.""" + + async def execute(params: dict[str, Any]) -> ToolResult: + raw_path: str = params["path"] + old_string: str = params["old_string"] + new_string: str = params["new_string"] + + try: + path = _validate_path(raw_path) + except ValueError as e: + return ToolResult(data=f"오류: {e}") + + if not path.exists(): + return ToolResult(data=f"파일 없음: {raw_path}") + if not path.is_file(): + return ToolResult(data=f"파일이 아님: {raw_path}") + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor( + None, partial(_edit_sync, path, old_string, new_string) + ) + except ValueError as e: + return ToolResult(data=f"수정 실패: {e}") + except Exception as e: + return ToolResult(data=f"파일 수정 실패: {e}") + + return ToolResult(data=result) + + return RegisteredTool( + name="edit_file", + description="파일의 특정 부분을 수정합니다. 정확한 old_string을 제공해야 합니다.", + compact_description="파일 수정", + concurrency_safe=False, + execute=execute, + ) diff --git a/backend/app/agents/tools/filesystem/read_file.py b/backend/app/agents/tools/filesystem/read_file.py new file mode 100644 index 0000000..2e0a621 --- /dev/null +++ b/backend/app/agents/tools/filesystem/read_file.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import asyncio +from functools import partial +from pathlib import Path +from typing import Any + +from app.agents.tools.types import RegisteredTool, ToolResult + +PROJECT_ROOT = Path(__file__).resolve().parents[5] + + +def _validate_path(raw_path: str) -> Path: + """경로를 검증하고 절대 경로로 변환합니다. + + 프로젝트 루트 외부 접근 및 경로 순회 공격을 차단합니다. + """ + resolved = (PROJECT_ROOT / raw_path).resolve() + if not resolved.is_relative_to(PROJECT_ROOT): + raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}") + return resolved + + +def _read_sync(path: Path, offset: int, limit: int) -> str: + text = path.read_text(encoding="utf-8") + lines = text.splitlines(keepends=True) + selected = lines[offset : offset + limit] + return "".join(selected) + + +def create_read_file_tool() -> RegisteredTool: + """파일 읽기 도구를 생성합니다.""" + + async def execute(params: dict[str, Any]) -> ToolResult: + raw_path: str = params["path"] + offset: int = params.get("offset", 0) + limit: int = params.get("limit", 2000) + + try: + path = _validate_path(raw_path) + except ValueError as e: + return ToolResult(data=f"오류: {e}") + + if not path.exists(): + return ToolResult(data=f"파일 없음: {raw_path}") + if not path.is_file(): + return ToolResult(data=f"파일이 아님: {raw_path}") + + loop = asyncio.get_running_loop() + try: + content = await loop.run_in_executor( + None, partial(_read_sync, path, offset, limit) + ) + except Exception as e: + return ToolResult(data=f"파일 읽기 실패: {e}") + + return ToolResult(data=content) + + return RegisteredTool( + name="read_file", + description="파일 내용을 읽습니다. 프로젝트 내 파일만 접근 가능합니다.", + compact_description="파일 읽기", + concurrency_safe=True, + execute=execute, + ) diff --git a/backend/app/agents/tools/filesystem/write_file.py b/backend/app/agents/tools/filesystem/write_file.py new file mode 100644 index 0000000..9d82b99 --- /dev/null +++ b/backend/app/agents/tools/filesystem/write_file.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import asyncio +from functools import partial +from pathlib import Path +from typing import Any + +from app.agents.tools.types import RegisteredTool, ToolResult + +PROJECT_ROOT = Path(__file__).resolve().parents[5] + + +def _validate_path(raw_path: str) -> Path: + """경로를 검증하고 절대 경로로 변환합니다.""" + resolved = (PROJECT_ROOT / raw_path).resolve() + if not resolved.is_relative_to(PROJECT_ROOT): + raise ValueError(f"프로젝트 외부 경로 접근 불가: {raw_path}") + return resolved + + +def _write_sync(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + + +def create_write_file_tool() -> RegisteredTool: + """파일 쓰기 도구를 생성합니다.""" + + async def execute(params: dict[str, Any]) -> ToolResult: + raw_path: str = params["path"] + content: str = params["content"] + + try: + path = _validate_path(raw_path) + except ValueError as e: + return ToolResult(data=f"오류: {e}") + + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(None, partial(_write_sync, path, content)) + except Exception as e: + return ToolResult(data=f"파일 쓰기 실패: {e}") + + return ToolResult(data=f"파일 작성 완료: {raw_path}") + + return RegisteredTool( + name="write_file", + description="파일에 내용을 씁니다. 프로젝트 내 파일만 수정 가능합니다.", + compact_description="파일 쓰기", + concurrency_safe=False, + execute=execute, + ) diff --git a/backend/app/agents/tools/finance/__init__.py b/backend/app/agents/tools/finance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/tools/finance/get_financials.py b/backend/app/agents/tools/finance/get_financials.py new file mode 100644 index 0000000..71558b4 --- /dev/null +++ b/backend/app/agents/tools/finance/get_financials.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from app.agents.tools.meta_tool import MetaTool, SubTool +from app.agents.tools.finance.sub_tools import ( + get_financial_statements, + get_market_metrics, +) + + +class FinancialsTool(MetaTool): + """재무제표, 재무비율, 밸류에이션 지표를 조회하는 메타 도구.""" + + NAME = "get_financials" + DESCRIPTION = "재무제표, 재무비율, 밸류에이션 지표 조회" + + @property + def sub_tools(self) -> list[SubTool]: + return [ + SubTool( + name="get_financial_statements", + description="종목의 BPS, PER, PBR, EPS 등 재무제표 데이터 조회 (ticker, year)", + handler=get_financial_statements, + ), + SubTool( + name="get_market_metrics", + description="종목의 시가총액, PER, PBR 등 시장 밸류에이션 지표 조회 (ticker, date)", + handler=get_market_metrics, + ), + ] diff --git a/backend/app/agents/tools/finance/get_market_data.py b/backend/app/agents/tools/finance/get_market_data.py new file mode 100644 index 0000000..1a5363b --- /dev/null +++ b/backend/app/agents/tools/finance/get_market_data.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from app.agents.tools.meta_tool import MetaTool, SubTool +from app.agents.tools.finance.sub_tools import get_stock_price, get_market_index + + +class MarketDataTool(MetaTool): + """주가, 거래량, 시장 지수 등 시장 데이터를 조회하는 메타 도구.""" + + NAME = "get_market_data" + DESCRIPTION = "주가, 거래량, 시장 지수 등 시장 데이터 조회" + + @property + def sub_tools(self) -> list[SubTool]: + return [ + SubTool( + name="get_stock_price", + description="종목 OHLCV(시가/고가/저가/종가/거래량) 데이터 조회 (ticker, start_date, end_date)", + handler=get_stock_price, + ), + SubTool( + name="get_market_index", + description="KOSPI/KOSDAQ 지수 OHLCV 데이터 조회 (index_code, start_date, end_date)", + handler=get_market_index, + ), + ] diff --git a/backend/app/agents/tools/finance/sub_tools.py b/backend/app/agents/tools/finance/sub_tools.py new file mode 100644 index 0000000..f78184c --- /dev/null +++ b/backend/app/agents/tools/finance/sub_tools.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +from typing import Any + +from app.agents.tools.types import ToolResult + +logger = logging.getLogger(__name__) + + +async def get_stock_price(params: dict[str, Any]) -> ToolResult: + """종목 OHLCV 데이터를 pykrx로 조회한다.""" + from pykrx import stock as pykrx_stock + + ticker = params.get("ticker", "") + start_date = params.get("start_date", "") + end_date = params.get("end_date", "") + + if not all([ticker, start_date, end_date]): + return ToolResult(data="Error: ticker, start_date, end_date 모두 필수입니다.") + + try: + df = pykrx_stock.get_market_ohlcv_by_date(start_date, end_date, ticker) + if df.empty: + return ToolResult(data=f"데이터 없음: {ticker} ({start_date}~{end_date})") + + records = df.reset_index().to_dict(orient="records") + return ToolResult( + data={"ticker": ticker, "period": f"{start_date}~{end_date}", "ohlcv": records} + ) + except Exception as e: + logger.exception("get_stock_price failed for %s", ticker) + return ToolResult(data=f"Error: {e}") + + +async def get_financial_statements(params: dict[str, Any]) -> ToolResult: + """종목의 재무제표 데이터를 pykrx로 조회한다.""" + from pykrx import stock as pykrx_stock + + ticker = params.get("ticker", "") + year = params.get("year", "") + + if not all([ticker, year]): + return ToolResult(data="Error: ticker, year 모두 필수입니다.") + + try: + # pykrx 기본 재무 데이터: BPS, PER, PBR, EPS, DIV, DPS + start_date = f"{year}0101" + end_date = f"{year}1231" + df = pykrx_stock.get_market_fundamental_by_date(start_date, end_date, ticker) + + if df.empty: + return ToolResult(data=f"재무 데이터 없음: {ticker} ({year})") + + # 연말 기준 최신 데이터 + latest = df.iloc[-1].to_dict() + return ToolResult( + data={"ticker": ticker, "year": year, "fundamentals": latest} + ) + except Exception as e: + logger.exception("get_financial_statements failed for %s", ticker) + return ToolResult(data=f"Error: {e}") + + +async def get_market_metrics(params: dict[str, Any]) -> ToolResult: + """종목의 시가총액, PER, PBR 등 시장 지표를 조회한다.""" + from pykrx import stock as pykrx_stock + + ticker = params.get("ticker", "") + date = params.get("date", "") + + if not all([ticker, date]): + return ToolResult(data="Error: ticker, date 모두 필수입니다.") + + try: + # 시가총액 + cap_df = pykrx_stock.get_market_cap_by_date(date, date, ticker) + # 밸류에이션 + fund_df = pykrx_stock.get_market_fundamental_by_date(date, date, ticker) + + result: dict[str, Any] = {"ticker": ticker, "date": date} + + if not cap_df.empty: + result["market_cap"] = cap_df.iloc[-1].to_dict() + if not fund_df.empty: + result["valuation"] = fund_df.iloc[-1].to_dict() + + if len(result) == 2: + return ToolResult(data=f"시장 지표 없음: {ticker} ({date})") + + return ToolResult(data=result) + except Exception as e: + logger.exception("get_market_metrics failed for %s", ticker) + return ToolResult(data=f"Error: {e}") + + +async def get_market_index(params: dict[str, Any]) -> ToolResult: + """KOSPI/KOSDAQ 지수 데이터를 조회한다.""" + from pykrx import stock as pykrx_stock + + index_code = params.get("index_code", "1001") # 기본: 코스피 + start_date = params.get("start_date", "") + end_date = params.get("end_date", "") + + if not all([start_date, end_date]): + return ToolResult(data="Error: start_date, end_date 모두 필수입니다.") + + try: + df = pykrx_stock.get_index_ohlcv_by_date(start_date, end_date, index_code) + + if df.empty: + return ToolResult(data=f"지수 데이터 없음: {index_code} ({start_date}~{end_date})") + + records = df.reset_index().to_dict(orient="records") + return ToolResult( + data={ + "index_code": index_code, + "period": f"{start_date}~{end_date}", + "ohlcv": records, + } + ) + except Exception as e: + logger.exception("get_market_index failed for %s", index_code) + return ToolResult(data=f"Error: {e}") diff --git a/backend/app/agents/tools/meta_tool.py b/backend/app/agents/tools/meta_tool.py new file mode 100644 index 0000000..e4de951 --- /dev/null +++ b/backend/app/agents/tools/meta_tool.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + +from app.agents.tools.types import ToolResult +from app.services.llm import ModelTier, get_chat_model + +logger = logging.getLogger(__name__) + + +@dataclass +class SubTool: + name: str + description: str + handler: Callable[[dict[str, Any]], Awaitable[ToolResult]] + + +class MetaTool(ABC): + """NL 쿼리를 LLM으로 라우팅해 적절한 sub-tool을 실행하는 추상 베이스 클래스.""" + + @property + @abstractmethod + def sub_tools(self) -> list[SubTool]: ... + + async def route(self, query: str) -> SubTool: + """LLM을 사용해 쿼리에 가장 적합한 sub-tool을 선택한다.""" + tool_list = "\n".join( + f"- {st.name}: {st.description}" for st in self.sub_tools + ) + + prompt = ( + "You are a tool router. Given a user query and available tools, " + "select the best tool.\n\n" + f"Available tools:\n{tool_list}\n\n" + f"User query: {query}\n\n" + 'Respond with ONLY a JSON object: {{"tool": "tool_name"}}' + ) + + llm = get_chat_model(tier=ModelTier.FAST, temperature=0.0) + response = llm.invoke(prompt) + + content = response.content + if isinstance(content, list): + content = content[0] if content else "" + content = str(content).strip() + + if content.startswith("```"): + content = content.split("\n", 1)[-1] + content = content.rsplit("```", 1)[0] + + parsed = json.loads(content) + tool_name = parsed["tool"] + + tool_map = {st.name: st for st in self.sub_tools} + sub_tool = tool_map.get(tool_name) + if sub_tool is None: + raise ValueError( + f"LLM selected unknown sub-tool '{tool_name}'. " + f"Available: {list(tool_map.keys())}" + ) + + return sub_tool + + async def execute(self, params: dict[str, Any]) -> ToolResult: + """NL 쿼리를 받아 라우팅 후 sub-tool을 실행한다.""" + query = params.get("query", "") + if not query: + return ToolResult(data="Error: 'query' parameter is required.") + + try: + sub_tool = await self.route(query) + logger.info("MetaTool routed to %s", sub_tool.name) + return await sub_tool.handler(params) + except Exception as e: + logger.exception("MetaTool execution failed") + return ToolResult(data=f"Error: {e}") diff --git a/backend/app/agents/tools/registry.py b/backend/app/agents/tools/registry.py new file mode 100644 index 0000000..074540d --- /dev/null +++ b/backend/app/agents/tools/registry.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from app.agents.tools.types import RegisteredTool + + +class ToolRegistry: + """등록된 도구를 관리하는 레지스트리.""" + + def __init__(self) -> None: + self._tools: dict[str, RegisteredTool] = {} + + def register(self, tool: RegisteredTool) -> None: + self._tools[tool.name] = tool + + def get(self, name: str) -> RegisteredTool | None: + return self._tools.get(name) + + def list_tools(self) -> list[RegisteredTool]: + return list(self._tools.values()) + + def get_concurrency_map(self) -> dict[str, bool]: + return {t.name: t.concurrency_safe for t in self._tools.values()} + + def build_compact_descriptions(self) -> str: + lines = [ + f"- {t.name}: {t.compact_description}" for t in self._tools.values() + ] + return "\n".join(lines) + + @classmethod + def auto_register(cls) -> ToolRegistry: + """설정에 따라 사용 가능한 도구를 자동 등록한다.""" + from app.agents.tools.finance.get_market_data import MarketDataTool + from app.agents.tools.finance.get_financials import FinancialsTool + + registry = cls() + + market = MarketDataTool() + registry.register(RegisteredTool( + name=MarketDataTool.NAME, + description=MarketDataTool.DESCRIPTION, + compact_description=MarketDataTool.DESCRIPTION, + concurrency_safe=True, + execute=market.execute, + )) + + financials = FinancialsTool() + registry.register(RegisteredTool( + name=FinancialsTool.NAME, + description=FinancialsTool.DESCRIPTION, + compact_description=FinancialsTool.DESCRIPTION, + concurrency_safe=True, + execute=financials.execute, + )) + + return registry diff --git a/backend/app/agents/tools/search/__init__.py b/backend/app/agents/tools/search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agents/tools/search/news_search.py b/backend/app/agents/tools/search/news_search.py new file mode 100644 index 0000000..84cfeee --- /dev/null +++ b/backend/app/agents/tools/search/news_search.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any + +from app.agents.tools.types import RegisteredTool, ToolResult + + +def create_news_search_tool() -> RegisteredTool: + """금융 뉴스 검색 도구를 생성합니다.""" + + async def execute(params: dict[str, Any]) -> ToolResult: + query: str = params["query"] + num_results: int = params.get("num_results", 5) + + # TODO: settings에서 API 키 로드 후 뉴스 검색 API 연동 + try: + from app.core.config import settings + + api_key = getattr(settings, "news_search_api_key", None) + except Exception: + api_key = None + + if not api_key: + return ToolResult( + data=f"검색 API 미설정: '{query}' (요청 결과 수: {num_results}). " + "news_search_api_key 환경변수를 설정해 주세요.", + ) + + # 향후 실제 뉴스 검색 API 호출 구현 + # async with httpx.AsyncClient() as client: + # response = await client.get(...) + return ToolResult(data=f"뉴스 검색 결과 없음: '{query}'") + + return RegisteredTool( + name="news_search", + description="금융 뉴스 및 시장 동향을 검색합니다", + compact_description="뉴스 검색", + concurrency_safe=True, + execute=execute, + ) diff --git a/backend/app/agents/tools/search/web_search.py b/backend/app/agents/tools/search/web_search.py new file mode 100644 index 0000000..c2336d0 --- /dev/null +++ b/backend/app/agents/tools/search/web_search.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any + +from app.agents.tools.types import RegisteredTool, ToolResult + + +def create_web_search_tool() -> RegisteredTool: + """웹 검색 도구를 생성합니다.""" + + async def execute(params: dict[str, Any]) -> ToolResult: + query: str = params["query"] + num_results: int = params.get("num_results", 5) + + # TODO: settings에서 API 키 로드 후 Tavily/Exa/SerpAPI 연동 + try: + from app.core.config import settings + + api_key = getattr(settings, "web_search_api_key", None) + except Exception: + api_key = None + + if not api_key: + return ToolResult( + data=f"검색 API 미설정: '{query}' (요청 결과 수: {num_results}). " + "web_search_api_key 환경변수를 설정해 주세요.", + ) + + # 향후 실제 검색 API 호출 구현 + # async with httpx.AsyncClient() as client: + # response = await client.get(...) + return ToolResult(data=f"검색 결과 없음: '{query}'") + + return RegisteredTool( + name="web_search", + description="웹 검색을 수행하여 최신 정보를 조회합니다", + compact_description="웹 검색", + concurrency_safe=True, + execute=execute, + ) diff --git a/backend/app/agents/tools/types.py b/backend/app/agents/tools/types.py new file mode 100644 index 0000000..9b9ff32 --- /dev/null +++ b/backend/app/agents/tools/types.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable + + +@dataclass +class ToolResult: + data: Any + source_urls: list[str] = field(default_factory=list) + + def to_str(self) -> str: + if isinstance(self.data, str): + return self.data + return json.dumps(self.data, ensure_ascii=False, default=str) + + +@dataclass +class RegisteredTool: + name: str + description: str + compact_description: str + concurrency_safe: bool + execute: Callable[[dict[str, Any]], Awaitable[ToolResult]] diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py new file mode 100644 index 0000000..6b74c0b --- /dev/null +++ b/backend/app/api/agents.py @@ -0,0 +1,141 @@ +""" +Agent API endpoints — 자연어 쿼리 기반 투자 분석 에이전트. +""" +import json +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException, status +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.api.deps import CurrentUser + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/agent", tags=["agent"]) + + +# --------------------------------------------------------------------------- +# Schemas +# --------------------------------------------------------------------------- + +class AgentQueryRequest(BaseModel): + query: str + model_tier: str = "strong" + + +class ToolCallLog(BaseModel): + tool_name: str + params: dict[str, Any] + result: str + error: str | None = None + + +class AgentQueryResponse(BaseModel): + response: str + tool_calls: list[ToolCallLog] + iterations: int + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _parse_model_tier(raw: str) -> "ModelTier": + """model_tier 문자열을 ModelTier enum으로 변환.""" + from app.agents.core.agent import ModelTier + + mapping = { + "strong": ModelTier.STRONG, + "fast": ModelTier.FAST, + } + tier = mapping.get(raw.lower()) + if tier is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid model_tier '{raw}'. Must be one of: {list(mapping.keys())}", + ) + return tier + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@router.post("/query", response_model=AgentQueryResponse) +async def agent_query(body: AgentQueryRequest, user: CurrentUser): + """동기 응답 — 에이전트 실행 후 최종 결과 반환.""" + from app.agents.core.agent import Agent, AgentConfig + + model_tier = _parse_model_tier(body.model_tier) + config = AgentConfig(model_tier=model_tier) + agent = Agent(config=config) + + tool_calls: list[ToolCallLog] = [] + response_text = "" + iterations = 0 + + try: + async for event in agent.run(body.query): + if event.type == "tool_end": + tool_calls.append( + ToolCallLog( + tool_name=event.data.get("tool_name", ""), + params=event.data.get("params", {}), + result=event.data.get("result", ""), + error=event.data.get("error"), + ) + ) + elif event.type == "response": + response_text = event.data.get("text", "") + elif event.type == "done": + iterations = event.data.get("iterations", 0) + except Exception: + logger.exception("Agent query failed") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="에이전트 실행 중 오류가 발생했습니다.", + ) + + return AgentQueryResponse( + response=response_text, + tool_calls=tool_calls, + iterations=iterations, + ) + + +@router.post("/stream") +async def agent_stream(body: AgentQueryRequest, user: CurrentUser): + """SSE 스트리밍 — 도구 실행 과정을 실시간으로 전달.""" + from app.agents.core.agent import Agent, AgentConfig + + model_tier = _parse_model_tier(body.model_tier) + config = AgentConfig(model_tier=model_tier) + agent = Agent(config=config) + + async def _event_generator(): + try: + async for event in agent.run(body.query): + payload = json.dumps( + {"type": event.type, "data": event.data}, + ensure_ascii=False, + ) + yield f"data: {payload}\n\n" + except Exception: + logger.exception("Agent stream failed") + error_payload = json.dumps( + {"type": "error", "data": {"message": "에이전트 실행 중 오류가 발생했습니다."}}, + ensure_ascii=False, + ) + yield f"data: {error_payload}\n\n" + + return StreamingResponse( + _event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py new file mode 100644 index 0000000..52ed03c --- /dev/null +++ b/backend/app/services/llm.py @@ -0,0 +1,92 @@ +from enum import Enum +from functools import lru_cache + +from langchain_core.language_models.chat_models import BaseChatModel + +from app.core.config import get_settings + + +class LLMProvider(str, Enum): + ANTHROPIC = "anthropic" + OPENAI = "openai" + + +class ModelTier(str, Enum): + FAST = "fast" + STRONG = "strong" + + +_DEFAULTS = { + LLMProvider.ANTHROPIC: { + ModelTier.FAST: "claude-haiku-4-5-20251001", + ModelTier.STRONG: "claude-sonnet-4-6", + }, + LLMProvider.OPENAI: { + ModelTier.FAST: "gpt-4.1-mini", + ModelTier.STRONG: "gpt-4.1", + }, +} + + +def detect_provider() -> LLMProvider: + settings = get_settings() + if settings.anthropic_api_key: + return LLMProvider.ANTHROPIC + if settings.openai_api_key: + return LLMProvider.OPENAI + raise RuntimeError( + "No LLM API key configured. Set ANTHROPIC_API_KEY or OPENAI_API_KEY." + ) + + +def _resolve_model_name(provider: LLMProvider, tier: ModelTier) -> str: + settings = get_settings() + if tier == ModelTier.FAST and settings.llm_fast_model: + return settings.llm_fast_model + if tier == ModelTier.STRONG and settings.llm_strong_model: + return settings.llm_strong_model + return _DEFAULTS[provider][tier] + + +def _build_chat_model( + provider: LLMProvider, model_name: str, temperature: float +) -> BaseChatModel: + settings = get_settings() + + if provider == LLMProvider.ANTHROPIC: + from langchain_anthropic import ChatAnthropic + + return ChatAnthropic( + model=model_name, + api_key=settings.anthropic_api_key, + temperature=temperature, + max_tokens=4096, + ) + + from langchain_openai import ChatOpenAI + + return ChatOpenAI( + model=model_name, + api_key=settings.openai_api_key, + temperature=temperature, + ) + + +def get_chat_model( + tier: ModelTier = ModelTier.STRONG, + temperature: float = 0.0, + provider: LLMProvider | None = None, +) -> BaseChatModel: + resolved_provider = provider or detect_provider() + model_name = _resolve_model_name(resolved_provider, tier) + return _build_chat_model(resolved_provider, model_name, temperature) + + +@lru_cache +def get_fast_model() -> BaseChatModel: + return get_chat_model(tier=ModelTier.FAST) + + +@lru_cache +def get_strong_model() -> BaseChatModel: + return get_chat_model(tier=ModelTier.STRONG) diff --git a/backend/tests/unit/test_agent_core.py b/backend/tests/unit/test_agent_core.py new file mode 100644 index 0000000..02cd76c --- /dev/null +++ b/backend/tests/unit/test_agent_core.py @@ -0,0 +1,379 @@ +"""Tests for agent core: scratchpad, context compactor, tool executor, and agent.""" + +import pytest +from unittest.mock import patch, MagicMock, AsyncMock + +from app.agents.tools.types import ToolResult, RegisteredTool +from app.agents.tools.registry import ToolRegistry +from app.agents.core.scratchpad import Scratchpad +from app.agents.core.compact import ContextCompactor +from app.agents.core.tool_executor import ToolExecutor, ToolCall, ToolEvent +from app.agents.core.agent import Agent, AgentEvent + + +# --------------------------------------------------------------------------- +# Scratchpad +# --------------------------------------------------------------------------- + + +class TestScratchpad: + def test_add_and_total_tokens(self): + sp = Scratchpad() + sp.add(tool_name="search", params={"q": "test"}, result="found 3 results") + sp.add(tool_name="calc", params={"expr": "2+2"}, result="4") + assert sp.total_tokens() > 0 + + def test_format_history(self): + sp = Scratchpad() + sp.add(tool_name="search", params={"q": "python"}, result="python docs") + history = sp.format_history() + assert "search" in history + assert "python" in history + + def test_format_history_empty(self): + sp = Scratchpad() + history = sp.format_history() + assert isinstance(history, str) + + def test_clear(self): + sp = Scratchpad() + sp.add(tool_name="tool1", params={}, result="result_one") + sp.add(tool_name="tool2", params={}, result="result_two") + assert sp.total_tokens() > 0 + sp.clear() + assert sp.total_tokens() == 0 + + def test_multiple_entries_tracked(self): + sp = Scratchpad() + for i in range(5): + sp.add(tool_name=f"tool_{i}", params={"i": i}, result=f"result_{i}") + history = sp.format_history() + for i in range(5): + assert f"tool_{i}" in history + + +# --------------------------------------------------------------------------- +# ContextCompactor +# --------------------------------------------------------------------------- + + +class TestContextCompactor: + def test_should_compact_false_below_threshold(self): + compactor = ContextCompactor() + messages = [{"role": "user", "content": "short message"}] + assert compactor.should_compact(messages, threshold=10000) is False + + def test_should_compact_true_above_threshold(self): + compactor = ContextCompactor() + long_content = "x" * 5000 + messages = [ + {"role": "user", "content": long_content}, + {"role": "assistant", "content": long_content}, + {"role": "user", "content": long_content}, + ] + assert compactor.should_compact(messages, threshold=100) is True + + @pytest.mark.asyncio + @patch("app.agents.core.compact.get_chat_model") + async def test_compact_calls_llm(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Summary: user asked about Python, assistant explained basics." + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + compactor = ContextCompactor() + messages = [ + {"role": "user", "content": "Tell me about Python"}, + {"role": "assistant", "content": "Python is a programming language..."}, + {"role": "user", "content": "What about async?"}, + {"role": "assistant", "content": "Async in Python uses asyncio..."}, + ] + result = await compactor.compact(messages, system_prompt="You are helpful.") + assert isinstance(result, list) + assert len(result) > 0 + mock_llm.ainvoke.assert_called_once() + + @pytest.mark.asyncio + @patch("app.agents.core.compact.get_chat_model") + async def test_compact_preserves_system_context(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Summarized conversation." + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + compactor = ContextCompactor() + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = await compactor.compact(messages, system_prompt="System prompt here") + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# ToolExecutor +# --------------------------------------------------------------------------- + + +def _make_registered_tool( + name: str, concurrency_safe: bool = True, result_data: str = "ok" +) -> RegisteredTool: + return RegisteredTool( + name=name, + description=f"Desc {name}", + compact_description=f"compact {name}", + concurrency_safe=concurrency_safe, + execute=AsyncMock(return_value=ToolResult(data=result_data)), + ) + + +class TestToolExecutor: + @pytest.mark.asyncio + async def test_concurrent_safe_tools(self): + tool_a = _make_registered_tool("tool_a", concurrency_safe=True, result_data="A") + tool_b = _make_registered_tool("tool_b", concurrency_safe=True, result_data="B") + + registry = ToolRegistry() + registry.register(tool_a) + registry.register(tool_b) + + executor = ToolExecutor(registry=registry) + calls = [ + ToolCall(id="c1", name="tool_a", params={"q": "hello"}), + ToolCall(id="c2", name="tool_b", params={"q": "world"}), + ] + + events = [] + async for event in executor.execute_tool_calls(calls): + events.append(event) + + assert len(events) == 2 + names = {e.tool_name for e in events} + assert "tool_a" in names + assert "tool_b" in names + + tool_a.execute.assert_called_once() + tool_b.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_serial_tools_run_sequentially(self): + tool_s = _make_registered_tool("serial_tool", concurrency_safe=False, result_data="S") + + registry = ToolRegistry() + registry.register(tool_s) + + executor = ToolExecutor(registry=registry) + calls = [ + ToolCall(id="c1", name="serial_tool", params={"x": 1}), + ] + + events = [] + async for event in executor.execute_tool_calls(calls): + events.append(event) + + assert len(events) == 1 + assert events[0].tool_name == "serial_tool" + assert events[0].error is None + + @pytest.mark.asyncio + async def test_error_handling_for_failed_tool(self): + failing_tool = RegisteredTool( + name="fail_tool", + description="A tool that fails", + compact_description="fail", + concurrency_safe=True, + execute=AsyncMock(side_effect=RuntimeError("Tool exploded")), + ) + + registry = ToolRegistry() + registry.register(failing_tool) + + executor = ToolExecutor(registry=registry) + calls = [ToolCall(id="c1", name="fail_tool", params={})] + + events = [] + async for event in executor.execute_tool_calls(calls): + events.append(event) + + assert len(events) == 1 + assert events[0].error is not None + assert "exploded" in events[0].error.lower() or events[0].error + + @pytest.mark.asyncio + async def test_mixed_concurrent_and_serial(self): + safe = _make_registered_tool("safe", concurrency_safe=True, result_data="safe_r") + unsafe = _make_registered_tool("unsafe", concurrency_safe=False, result_data="unsafe_r") + + registry = ToolRegistry() + registry.register(safe) + registry.register(unsafe) + + executor = ToolExecutor(registry=registry) + calls = [ + ToolCall(id="c1", name="safe", params={}), + ToolCall(id="c2", name="unsafe", params={}), + ] + + events = [] + async for event in executor.execute_tool_calls(calls): + events.append(event) + + assert len(events) == 2 + results = {e.tool_name: e.result for e in events} + assert results["safe"].data == "safe_r" + assert results["unsafe"].data == "unsafe_r" + + @pytest.mark.asyncio + async def test_unknown_tool_returns_error(self): + registry = ToolRegistry() + executor = ToolExecutor(registry=registry) + calls = [ToolCall(id="c1", name="nonexistent", params={})] + + events = [] + async for event in executor.execute_tool_calls(calls): + events.append(event) + + assert len(events) == 1 + assert events[0].error is not None + + +# --------------------------------------------------------------------------- +# ToolCall & ToolEvent dataclass tests +# --------------------------------------------------------------------------- + + +class TestToolCallAndEvent: + def test_tool_call_creation(self): + tc = ToolCall(id="abc", name="my_tool", params={"key": "val"}) + assert tc.id == "abc" + assert tc.name == "my_tool" + assert tc.params == {"key": "val"} + + def test_tool_event_success(self): + result = ToolResult(data="success") + ev = ToolEvent(call_id="c1", tool_name="t1", result=result, error=None) + assert ev.result.data == "success" + assert ev.error is None + + def test_tool_event_error(self): + ev = ToolEvent(call_id="c1", tool_name="t1", result=None, error="something went wrong") + assert ev.result is None + assert ev.error == "something went wrong" + + +# --------------------------------------------------------------------------- +# Agent +# --------------------------------------------------------------------------- + + +class TestAgent: + @pytest.mark.asyncio + @patch("app.agents.core.agent.get_chat_model") + async def test_run_yields_events(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Here is the answer to your question." + mock_response.tool_calls = [] + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_llm.bind_tools = MagicMock(return_value=mock_llm) + mock_get_model.return_value = mock_llm + + registry = ToolRegistry() + agent = Agent(tool_registry=registry) + + events = [] + async for event in agent.run("What is 2+2?"): + events.append(event) + + assert len(events) >= 1 + event_types = [e.type for e in events] + assert any(t in event_types for t in ["message", "answer", "response", "text"]) + + @pytest.mark.asyncio + @patch("app.agents.core.agent.get_chat_model") + async def test_run_with_tool_call(self, mock_get_model): + tool_call_response = MagicMock() + tool_call_response.content = "" + tool_call_response.tool_calls = [ + {"id": "call_1", "name": "search", "args": {"q": "python"}} + ] + + final_response = MagicMock() + final_response.content = "Python is a programming language." + final_response.tool_calls = [] + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_response, final_response]) + mock_llm.bind_tools = MagicMock(return_value=mock_llm) + mock_get_model.return_value = mock_llm + + search_tool = _make_registered_tool("search", result_data="Python docs found") + registry = ToolRegistry() + registry.register(search_tool) + + agent = Agent(tool_registry=registry) + + events = [] + async for event in agent.run("Tell me about Python"): + events.append(event) + + assert len(events) >= 1 + + @pytest.mark.asyncio + @patch("app.agents.core.agent.get_chat_model") + async def test_max_iterations_limit(self, mock_get_model): + tool_call_response = MagicMock() + tool_call_response.content = "" + tool_call_response.tool_calls = [ + {"id": "call_n", "name": "loop_tool", "args": {}} + ] + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(return_value=tool_call_response) + mock_llm.bind_tools = MagicMock(return_value=mock_llm) + mock_get_model.return_value = mock_llm + + loop_tool = _make_registered_tool("loop_tool", result_data="looping") + registry = ToolRegistry() + registry.register(loop_tool) + + agent = Agent(tool_registry=registry, max_iterations=3) + + events = [] + async for event in agent.run("infinite loop query"): + events.append(event) + + assert loop_tool.execute.call_count <= 3 + + @pytest.mark.asyncio + @patch("app.agents.core.agent.get_chat_model") + async def test_run_empty_query(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "I need more information." + mock_response.tool_calls = [] + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + mock_llm.bind_tools = MagicMock(return_value=mock_llm) + mock_get_model.return_value = mock_llm + + registry = ToolRegistry() + agent = Agent(tool_registry=registry) + + events = [] + async for event in agent.run(""): + events.append(event) + + assert len(events) >= 1 + + +class TestAgentEvent: + def test_creation(self): + event = AgentEvent(type="message", data={"text": "hello"}) + assert event.type == "message" + assert event.data == {"text": "hello"} + + def test_creation_with_string_data(self): + event = AgentEvent(type="error", data="something failed") + assert event.type == "error" + assert event.data == "something failed" diff --git a/backend/tests/unit/test_llm.py b/backend/tests/unit/test_llm.py new file mode 100644 index 0000000..bbf09e3 --- /dev/null +++ b/backend/tests/unit/test_llm.py @@ -0,0 +1,108 @@ +import pytest +from unittest.mock import patch, MagicMock + +from app.services.llm import ( + LLMProvider, + ModelTier, + detect_provider, + get_chat_model, + _resolve_model_name, + _DEFAULTS, +) + + +class TestDetectProvider: + def test_anthropic_preferred_when_both_keys_set(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock( + anthropic_api_key="sk-ant-xxx", openai_api_key="sk-xxx" + ) + assert detect_provider() == LLMProvider.ANTHROPIC + + def test_anthropic_when_only_anthropic_key(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(anthropic_api_key="sk-ant-xxx", openai_api_key="") + assert detect_provider() == LLMProvider.ANTHROPIC + + def test_openai_when_only_openai_key(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="sk-xxx") + assert detect_provider() == LLMProvider.OPENAI + + def test_raises_when_no_keys(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(anthropic_api_key="", openai_api_key="") + with pytest.raises(RuntimeError, match="No LLM API key configured"): + detect_provider() + + +class TestResolveModelName: + def test_defaults_anthropic_fast(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="") + name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST) + assert name == _DEFAULTS[LLMProvider.ANTHROPIC][ModelTier.FAST] + + def test_defaults_openai_strong(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="") + name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG) + assert name == _DEFAULTS[LLMProvider.OPENAI][ModelTier.STRONG] + + def test_override_fast_model(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(llm_fast_model="custom-fast", llm_strong_model="") + name = _resolve_model_name(LLMProvider.ANTHROPIC, ModelTier.FAST) + assert name == "custom-fast" + + def test_override_strong_model(self): + with patch("app.services.llm.get_settings") as mock: + mock.return_value = MagicMock(llm_fast_model="", llm_strong_model="custom-strong") + name = _resolve_model_name(LLMProvider.OPENAI, ModelTier.STRONG) + assert name == "custom-strong" + + +class TestGetChatModel: + @patch("app.services.llm.get_settings") + def test_returns_anthropic_model(self, mock_settings): + mock_settings.return_value = MagicMock( + anthropic_api_key="sk-ant-xxx", + openai_api_key="", + llm_fast_model="", + llm_strong_model="", + ) + model = get_chat_model(tier=ModelTier.STRONG) + assert type(model).__name__ == "ChatAnthropic" + + @patch("app.services.llm.get_settings") + def test_returns_openai_model(self, mock_settings): + mock_settings.return_value = MagicMock( + anthropic_api_key="", + openai_api_key="sk-xxx", + llm_fast_model="", + llm_strong_model="", + ) + model = get_chat_model(tier=ModelTier.FAST) + assert type(model).__name__ == "ChatOpenAI" + + @patch("app.services.llm.get_settings") + def test_explicit_provider_override(self, mock_settings): + mock_settings.return_value = MagicMock( + anthropic_api_key="sk-ant-xxx", + openai_api_key="sk-xxx", + llm_fast_model="", + llm_strong_model="", + ) + model = get_chat_model(tier=ModelTier.FAST, provider=LLMProvider.OPENAI) + assert type(model).__name__ == "ChatOpenAI" + + @patch("app.services.llm.get_settings") + def test_custom_temperature(self, mock_settings): + mock_settings.return_value = MagicMock( + anthropic_api_key="sk-ant-xxx", + openai_api_key="", + llm_fast_model="", + llm_strong_model="", + ) + model = get_chat_model(tier=ModelTier.STRONG, temperature=0.7) + assert model.temperature == 0.7 diff --git a/backend/tests/unit/test_meta_tool.py b/backend/tests/unit/test_meta_tool.py new file mode 100644 index 0000000..56401d6 --- /dev/null +++ b/backend/tests/unit/test_meta_tool.py @@ -0,0 +1,227 @@ +"""Tests for agent tool types, registry, and meta-tool.""" + +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from dataclasses import dataclass + +from app.agents.tools.types import ToolResult, RegisteredTool +from app.agents.tools.registry import ToolRegistry +from app.agents.tools.meta_tool import MetaTool, SubTool + + +# --------------------------------------------------------------------------- +# ToolResult +# --------------------------------------------------------------------------- + + +class TestToolResult: + def test_to_str_with_string_data(self): + result = ToolResult(data="hello world") + assert result.to_str() == "hello world" + + def test_to_str_with_dict_data(self): + result = ToolResult(data={"key": "value", "count": 42}) + text = result.to_str() + assert "key" in text + assert "value" in text + + def test_to_str_with_list_data(self): + result = ToolResult(data=["a", "b", "c"]) + text = result.to_str() + assert "a" in text + assert "b" in text + + def test_source_urls_default_empty(self): + result = ToolResult(data="test") + assert result.source_urls == [] + + def test_source_urls_preserved(self): + urls = ["https://example.com/a", "https://example.com/b"] + result = ToolResult(data="test", source_urls=urls) + assert result.source_urls == urls + + +# --------------------------------------------------------------------------- +# RegisteredTool +# --------------------------------------------------------------------------- + + +class TestRegisteredTool: + def test_creation_and_attributes(self): + handler = AsyncMock(return_value=ToolResult(data="ok")) + tool = RegisteredTool( + name="test_tool", + description="A test tool for unit tests", + compact_description="test tool", + concurrency_safe=True, + execute=handler, + ) + assert tool.name == "test_tool" + assert tool.description == "A test tool for unit tests" + assert tool.compact_description == "test tool" + assert tool.concurrency_safe is True + assert tool.execute is handler + + def test_concurrency_unsafe(self): + handler = AsyncMock(return_value=ToolResult(data="ok")) + tool = RegisteredTool( + name="serial_tool", + description="desc", + compact_description="serial", + concurrency_safe=False, + execute=handler, + ) + assert tool.concurrency_safe is False + + +# --------------------------------------------------------------------------- +# ToolRegistry +# --------------------------------------------------------------------------- + + +def _make_tool(name: str, concurrency_safe: bool = True) -> RegisteredTool: + return RegisteredTool( + name=name, + description=f"Description of {name}", + compact_description=f"compact {name}", + concurrency_safe=concurrency_safe, + execute=AsyncMock(return_value=ToolResult(data=f"{name} result")), + ) + + +class TestToolRegistry: + def test_register_and_get(self): + registry = ToolRegistry() + tool = _make_tool("alpha") + registry.register(tool) + assert registry.get("alpha") is tool + + def test_get_unknown_returns_none(self): + registry = ToolRegistry() + assert registry.get("nonexistent") is None + + def test_list_tools(self): + registry = ToolRegistry() + registry.register(_make_tool("a")) + registry.register(_make_tool("b")) + names = [t.name for t in registry.list_tools()] + assert "a" in names + assert "b" in names + + def test_get_concurrency_map(self): + registry = ToolRegistry() + registry.register(_make_tool("safe_tool", concurrency_safe=True)) + registry.register(_make_tool("unsafe_tool", concurrency_safe=False)) + cmap = registry.get_concurrency_map() + assert cmap["safe_tool"] is True + assert cmap["unsafe_tool"] is False + + def test_build_compact_descriptions(self): + registry = ToolRegistry() + registry.register(_make_tool("tool_x")) + registry.register(_make_tool("tool_y")) + text = registry.build_compact_descriptions() + assert "tool_x" in text + assert "tool_y" in text + assert "compact tool_x" in text + + +# --------------------------------------------------------------------------- +# MetaTool +# --------------------------------------------------------------------------- + + +class ConcreteMetaTool(MetaTool): + """Concrete subclass for testing abstract MetaTool.""" + + def __init__(self): + self._sub_tools = [ + SubTool( + name="search", + description="Search the web for information", + handler=AsyncMock( + return_value=ToolResult(data="search result for query") + ), + ), + SubTool( + name="calculate", + description="Perform mathematical calculations", + handler=AsyncMock( + return_value=ToolResult(data="42") + ), + ), + ] + + @property + def sub_tools(self) -> list[SubTool]: + return self._sub_tools + + +class TestSubTool: + def test_creation(self): + handler = AsyncMock() + st = SubTool(name="test", description="desc", handler=handler) + assert st.name == "test" + assert st.description == "desc" + assert st.handler is handler + + +class TestMetaTool: + @pytest.mark.asyncio + @patch("app.agents.tools.meta_tool.get_chat_model") + async def test_route_selects_correct_subtool(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"tool": "search"}' + mock_llm.invoke = MagicMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + meta = ConcreteMetaTool() + selected = await meta.route("find information about Python") + assert selected.name == "search" + + @pytest.mark.asyncio + @patch("app.agents.tools.meta_tool.get_chat_model") + async def test_route_selects_calculate(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"tool": "calculate"}' + mock_llm.invoke = MagicMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + meta = ConcreteMetaTool() + selected = await meta.route("what is 6 times 7") + assert selected.name == "calculate" + + @pytest.mark.asyncio + @patch("app.agents.tools.meta_tool.get_chat_model") + async def test_execute_calls_routed_subtool(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"tool": "search"}' + mock_llm.invoke = MagicMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + meta = ConcreteMetaTool() + result = await meta.execute({"query": "find Python docs"}) + + search_handler = meta.sub_tools[0].handler + search_handler.assert_called_once() + assert result.data == "search result for query" + + @pytest.mark.asyncio + @patch("app.agents.tools.meta_tool.get_chat_model") + async def test_execute_passes_params_to_subtool(self, mock_get_model): + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = '{"tool": "calculate"}' + mock_llm.invoke = MagicMock(return_value=mock_response) + mock_get_model.return_value = mock_llm + + meta = ConcreteMetaTool() + params = {"query": "compute 2+2", "extra": "data"} + result = await meta.execute(params) + + calc_handler = meta.sub_tools[1].handler + calc_handler.assert_called_once_with(params) + assert result.data == "42" diff --git a/backend/tests/unit/test_skills.py b/backend/tests/unit/test_skills.py new file mode 100644 index 0000000..d3cfacf --- /dev/null +++ b/backend/tests/unit/test_skills.py @@ -0,0 +1,209 @@ +"""Tests for agent skill types, loader, registry, and tool.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from unittest.mock import patch + +import pytest + +from app.agents.skills.types import SkillMetadata, Skill, SkillSource +from app.agents.skills.loader import SkillLoader +from app.agents.skills.registry import SkillRegistry +from app.agents.skills.tool import create_skill_tool + + +# --------------------------------------------------------------------------- +# SkillMetadata & Skill +# --------------------------------------------------------------------------- + + +class TestSkillMetadata: + def test_creation(self) -> None: + meta = SkillMetadata( + name="test-skill", + description="A test skill", + path="/skills/test.md", + source="builtin", + ) + assert meta.name == "test-skill" + assert meta.description == "A test skill" + assert meta.path == "/skills/test.md" + assert meta.source == "builtin" + + +class TestSkill: + def test_creation_with_instructions(self) -> None: + skill = Skill( + name="s1", + description="desc", + path="/skills/s1.md", + source="builtin", + instructions="Do this and that.", + ) + assert skill.name == "s1" + assert skill.instructions == "Do this and that." + assert skill.source == "builtin" + + +# --------------------------------------------------------------------------- +# SkillLoader +# --------------------------------------------------------------------------- + +_VALID_SKILL_CONTENT = """\ +--- +name: test-skill +description: A skill for testing +--- + +## Instructions + +Follow these steps to perform the test skill. + +1. Step one +2. Step two +""" + +_NO_FRONTMATTER_CONTENT = """\ +## Instructions + +This file has no YAML frontmatter. +""" + +_EMPTY_FRONTMATTER_CONTENT = """\ +--- +--- + +## Instructions + +Empty frontmatter. +""" + + +class TestSkillLoader: + def test_parse_valid(self) -> None: + skill = SkillLoader.parse_skill_file( + content=_VALID_SKILL_CONTENT, + path="/skills/test-skill.md", + source="builtin", + ) + assert skill.name == "test-skill" + assert skill.description == "A skill for testing" + assert "Step one" in skill.instructions + assert skill.source == "builtin" + + def test_parse_no_frontmatter(self) -> None: + """frontmatter가 없으면 name/description이 빈 문자열.""" + skill = SkillLoader.parse_skill_file( + content=_NO_FRONTMATTER_CONTENT, + path="/skills/no-front.md", + source="builtin", + ) + assert skill.name == "" + assert skill.description == "" + assert "no YAML frontmatter" in skill.instructions + + def test_parse_empty_frontmatter(self) -> None: + """빈 frontmatter도 정상 처리.""" + skill = SkillLoader.parse_skill_file( + content=_EMPTY_FRONTMATTER_CONTENT, + path="/skills/empty-front.md", + source="builtin", + ) + assert skill.name == "" + assert skill.description == "" + + def test_load_from_path(self, tmp_path: Path) -> None: + skill_file = tmp_path / "SKILL.md" + skill_file.write_text(_VALID_SKILL_CONTENT, encoding="utf-8") + + skill = SkillLoader.load_from_path(skill_file, "builtin") + assert skill.name == "test-skill" + assert "Step one" in skill.instructions + + def test_extract_metadata(self, tmp_path: Path) -> None: + skill_file = tmp_path / "SKILL.md" + skill_file.write_text(_VALID_SKILL_CONTENT, encoding="utf-8") + + meta = SkillLoader.extract_metadata(skill_file, "builtin") + assert isinstance(meta, SkillMetadata) + assert meta.name == "test-skill" + assert meta.description == "A skill for testing" + + +# --------------------------------------------------------------------------- +# SkillRegistry +# --------------------------------------------------------------------------- + + +class TestSkillRegistry: + def setup_method(self) -> None: + SkillRegistry.clear_cache() + + def test_discover_finds_builtin_skills(self) -> None: + """실제 builtin 디렉토리에서 dcf-kr, kim-jong-bong-strategy 발견.""" + skills = SkillRegistry.discover() + names = [s.name for s in skills] + assert "dcf-kr" in names + assert "kim-jong-bong-strategy" in names + + def test_get_returns_skill(self) -> None: + SkillRegistry.discover() + skill = SkillRegistry.get("dcf-kr") + assert skill is not None + assert skill.name == "dcf-kr" + assert "WACC" in skill.instructions + + def test_get_returns_none_for_unknown(self) -> None: + SkillRegistry.discover() + assert SkillRegistry.get("nonexistent-skill") is None + + def test_list_skills(self) -> None: + SkillRegistry.discover() + skills = SkillRegistry.list_skills() + assert len(skills) >= 2 + + def test_build_skills_section(self) -> None: + SkillRegistry.discover() + section = SkillRegistry.build_skills_section() + assert "dcf-kr" in section + assert "kim-jong-bong-strategy" in section + + def test_build_skills_section_empty(self) -> None: + """캐시 비어있으면 빈 문자열 반환.""" + section = SkillRegistry.build_skills_section() + assert section == "" + + def test_clear_cache(self) -> None: + SkillRegistry.discover() + assert len(SkillRegistry.list_skills()) >= 2 + SkillRegistry.clear_cache() + assert len(SkillRegistry._cache) == 0 + + +# --------------------------------------------------------------------------- +# create_skill_tool +# --------------------------------------------------------------------------- + + +class TestSkillTool: + def setup_method(self) -> None: + SkillRegistry.clear_cache() + SkillRegistry.discover() + + def test_tool_metadata(self) -> None: + tool = create_skill_tool() + assert tool.name == "use_skill" + assert tool.concurrency_safe is True + + def test_execute_valid_skill(self) -> None: + tool = create_skill_tool() + result = asyncio.run(tool.execute({"skill_name": "dcf-kr"})) + assert "WACC" in result.data + + def test_execute_unknown_skill(self) -> None: + tool = create_skill_tool() + result = asyncio.run(tool.execute({"skill_name": "no-such-skill"})) + assert "찾을 수 없습니다" in result.data + assert "dcf-kr" in result.data diff --git a/docs/plans/agents-architecture.md b/docs/plans/agents-architecture.md new file mode 100644 index 0000000..8e437ad --- /dev/null +++ b/docs/plans/agents-architecture.md @@ -0,0 +1,68 @@ +# Agent Architecture + +## 개요 +Galaxis-Po 투자 분석 에이전트 시스템. 자연어 쿼리를 받아 도구를 활용하여 투자 분석을 수행합니다. + +## 구조 + +``` +backend/app/agents/ +├── __init__.py +├── tools/ # Meta-tool 패턴 +│ ├── types.py # ToolResult, RegisteredTool +│ ├── registry.py # 도구 레지스트리 +│ ├── meta_tool.py # MetaTool 베이스 클래스 +│ ├── finance/ # 금융 데이터 도구 +│ │ ├── get_financials.py +│ │ ├── get_market_data.py +│ │ └── sub_tools.py +│ ├── search/ # 검색 도구 +│ │ ├── web_search.py +│ │ └── news_search.py +│ └── filesystem/ # 파일 도구 +│ ├── read_file.py +│ ├── write_file.py +│ └── edit_file.py +├── skills/ # SKILL.md 시스템 +│ ├── types.py # SkillMetadata, Skill +│ ├── loader.py # SKILL.md 파서 +│ ├── registry.py # 스킬 디스커버리 +│ ├── tool.py # use_skill 도구 +│ └── builtin/ # 내장 스킬 +│ ├── dcf/SKILL.md +│ └── kim-jong-bong-strategy/SKILL.md +└── core/ # 에이전트 코어 + ├── agent.py # 에이전트 루프 + ├── compact.py # 컨텍스트 압축 + ├── prompts.py # 시스템 프롬프트 + ├── scratchpad.py # 실행 로그 + ├── tool_executor.py # 동시 도구 실행 + └── rules.py # RULES.md 로더 +``` + +## 핵심 패턴 + +### Meta-tool +NL 쿼리를 받아 LLM(fast tier)으로 적절한 sub-tool을 선택하여 실행합니다. + +### SKILL.md +YAML frontmatter + 마크다운 본문으로 구성된 전문 분석 워크플로우입니다. +에이전트가 use_skill 도구로 스킬을 로드하면 해당 워크플로우를 따릅니다. + +### 도구 동시성 +concurrency_safe 플래그로 읽기 전용 도구는 병렬 실행, 쓰기 도구는 직렬 실행합니다. + +### 컨텍스트 압축 +토큰 수가 임계값을 초과하면 LLM으로 대화를 요약하여 컨텍스트를 압축합니다. + +## API + +### POST /api/agent/query +동기 응답. 에이전트 실행 후 최종 결과 반환. + +### POST /api/agent/stream +SSE 스트리밍. 도구 실행 과정을 실시간으로 전달. + +## LLM 연동 +app.services.llm.get_chat_model()을 통해 LangChain BaseChatModel 사용. +ModelTier.FAST(라우팅/압축), ModelTier.STRONG(분석/추론).