""" Correlation analysis API endpoints. """ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser from app.schemas.correlation import ( CorrelationMatrixRequest, CorrelationMatrixResponse, DiversificationResponse, HighCorrelationPair, ) from app.services.correlation import CorrelationService router = APIRouter(prefix="/api/correlation", tags=["correlation"]) @router.post("/matrix", response_model=CorrelationMatrixResponse) async def calculate_correlation_matrix( request: CorrelationMatrixRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """종목 간 수익률 상관 행렬 계산.""" service = CorrelationService(db) result = service.get_correlation_data(request.stock_codes, request.period_days) return CorrelationMatrixResponse( stock_codes=result["stock_codes"], matrix=result["matrix"], high_correlation_pairs=[ HighCorrelationPair(**p) for p in result["high_correlation_pairs"] ], ) @router.get("/portfolio/{portfolio_id}", response_model=DiversificationResponse) async def get_portfolio_diversification( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """포트폴리오 분산 효과 점수 조회.""" service = CorrelationService(db) try: score = service.calculate_portfolio_diversification(portfolio_id) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) # Get holdings for correlation data from app.models.portfolio import Portfolio, PortfolioSnapshot portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") snapshot = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date.desc()) .first() ) high_pairs = [] stock_count = 0 if snapshot and snapshot.holdings: tickers = [h.ticker for h in snapshot.holdings] stock_count = len(tickers) if len(tickers) >= 2: corr_data = service.get_correlation_data(tickers, period_days=60) high_pairs = [ HighCorrelationPair(**p) for p in corr_data["high_correlation_pairs"] ] return DiversificationResponse( portfolio_id=portfolio_id, diversification_score=score, stock_count=stock_count, high_correlation_pairs=high_pairs, )