""" Agent API endpoints — 자연어 쿼리 기반 투자 분석 에이전트. """ import json import logging from typing import Any from fastapi import APIRouter, HTTPException, status from fastapi.responses import StreamingResponse from pydantic import BaseModel from app.api.deps import CurrentUser logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/agent", tags=["agent"]) # --------------------------------------------------------------------------- # Schemas # --------------------------------------------------------------------------- class AgentQueryRequest(BaseModel): query: str model_tier: str = "strong" class ToolCallLog(BaseModel): tool_name: str params: dict[str, Any] result: str error: str | None = None class AgentQueryResponse(BaseModel): response: str tool_calls: list[ToolCallLog] iterations: int # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _parse_model_tier(raw: str) -> "ModelTier": """model_tier 문자열을 ModelTier enum으로 변환.""" from app.agents.core.agent import ModelTier mapping = { "strong": ModelTier.STRONG, "fast": ModelTier.FAST, } tier = mapping.get(raw.lower()) if tier is None: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Invalid model_tier '{raw}'. Must be one of: {list(mapping.keys())}", ) return tier # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post("/query", response_model=AgentQueryResponse) async def agent_query(body: AgentQueryRequest, user: CurrentUser): """동기 응답 — 에이전트 실행 후 최종 결과 반환.""" from app.agents.core.agent import Agent, AgentConfig model_tier = _parse_model_tier(body.model_tier) config = AgentConfig(model_tier=model_tier) agent = Agent(config=config) tool_calls: list[ToolCallLog] = [] response_text = "" iterations = 0 try: async for event in agent.run(body.query): if event.type == "tool_end": tool_calls.append( ToolCallLog( tool_name=event.data.get("tool_name", ""), params=event.data.get("params", {}), result=event.data.get("result", ""), error=event.data.get("error"), ) ) elif event.type == "response": response_text = event.data.get("text", "") elif event.type == "done": iterations = event.data.get("iterations", 0) except Exception: logger.exception("Agent query failed") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="에이전트 실행 중 오류가 발생했습니다.", ) return AgentQueryResponse( response=response_text, tool_calls=tool_calls, iterations=iterations, ) @router.post("/stream") async def agent_stream(body: AgentQueryRequest, user: CurrentUser): """SSE 스트리밍 — 도구 실행 과정을 실시간으로 전달.""" from app.agents.core.agent import Agent, AgentConfig model_tier = _parse_model_tier(body.model_tier) config = AgentConfig(model_tier=model_tier) agent = Agent(config=config) async def _event_generator(): try: async for event in agent.run(body.query): payload = json.dumps( {"type": event.type, "data": event.data}, ensure_ascii=False, ) yield f"data: {payload}\n\n" except Exception: logger.exception("Agent stream failed") error_payload = json.dumps( {"type": "error", "data": {"message": "에이전트 실행 중 오류가 발생했습니다."}}, ensure_ascii=False, ) yield f"data: {error_payload}\n\n" return StreamingResponse( _event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, )