142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
"""
|
|
Agent API endpoints — 자연어 쿼리 기반 투자 분석 에이전트.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.api.deps import CurrentUser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class AgentQueryRequest(BaseModel):
|
|
query: str
|
|
model_tier: str = "strong"
|
|
|
|
|
|
class ToolCallLog(BaseModel):
|
|
tool_name: str
|
|
params: dict[str, Any]
|
|
result: str
|
|
error: str | None = None
|
|
|
|
|
|
class AgentQueryResponse(BaseModel):
|
|
response: str
|
|
tool_calls: list[ToolCallLog]
|
|
iterations: int
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _parse_model_tier(raw: str) -> "ModelTier":
|
|
"""model_tier 문자열을 ModelTier enum으로 변환."""
|
|
from app.agents.core.agent import ModelTier
|
|
|
|
mapping = {
|
|
"strong": ModelTier.STRONG,
|
|
"fast": ModelTier.FAST,
|
|
}
|
|
tier = mapping.get(raw.lower())
|
|
if tier is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=f"Invalid model_tier '{raw}'. Must be one of: {list(mapping.keys())}",
|
|
)
|
|
return tier
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/query", response_model=AgentQueryResponse)
|
|
async def agent_query(body: AgentQueryRequest, user: CurrentUser):
|
|
"""동기 응답 — 에이전트 실행 후 최종 결과 반환."""
|
|
from app.agents.core.agent import Agent, AgentConfig
|
|
|
|
model_tier = _parse_model_tier(body.model_tier)
|
|
config = AgentConfig(model_tier=model_tier)
|
|
agent = Agent(config=config)
|
|
|
|
tool_calls: list[ToolCallLog] = []
|
|
response_text = ""
|
|
iterations = 0
|
|
|
|
try:
|
|
async for event in agent.run(body.query):
|
|
if event.type == "tool_end":
|
|
tool_calls.append(
|
|
ToolCallLog(
|
|
tool_name=event.data.get("tool_name", ""),
|
|
params=event.data.get("params", {}),
|
|
result=event.data.get("result", ""),
|
|
error=event.data.get("error"),
|
|
)
|
|
)
|
|
elif event.type == "response":
|
|
response_text = event.data.get("text", "")
|
|
elif event.type == "done":
|
|
iterations = event.data.get("iterations", 0)
|
|
except Exception:
|
|
logger.exception("Agent query failed")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="에이전트 실행 중 오류가 발생했습니다.",
|
|
)
|
|
|
|
return AgentQueryResponse(
|
|
response=response_text,
|
|
tool_calls=tool_calls,
|
|
iterations=iterations,
|
|
)
|
|
|
|
|
|
@router.post("/stream")
|
|
async def agent_stream(body: AgentQueryRequest, user: CurrentUser):
|
|
"""SSE 스트리밍 — 도구 실행 과정을 실시간으로 전달."""
|
|
from app.agents.core.agent import Agent, AgentConfig
|
|
|
|
model_tier = _parse_model_tier(body.model_tier)
|
|
config = AgentConfig(model_tier=model_tier)
|
|
agent = Agent(config=config)
|
|
|
|
async def _event_generator():
|
|
try:
|
|
async for event in agent.run(body.query):
|
|
payload = json.dumps(
|
|
{"type": event.type, "data": event.data},
|
|
ensure_ascii=False,
|
|
)
|
|
yield f"data: {payload}\n\n"
|
|
except Exception:
|
|
logger.exception("Agent stream failed")
|
|
error_payload = json.dumps(
|
|
{"type": "error", "data": {"message": "에이전트 실행 중 오류가 발생했습니다."}},
|
|
ensure_ascii=False,
|
|
)
|
|
yield f"data: {error_payload}\n\n"
|
|
|
|
return StreamingResponse(
|
|
_event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|