galaxis-po/backend/app/api/correlation.py

87 lines
2.7 KiB
Python
Raw Normal View History

"""
Correlation analysis API endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.schemas.correlation import (
CorrelationMatrixRequest,
CorrelationMatrixResponse,
DiversificationResponse,
HighCorrelationPair,
)
from app.services.correlation import CorrelationService
router = APIRouter(prefix="/api/correlation", tags=["correlation"])
@router.post("/matrix", response_model=CorrelationMatrixResponse)
async def calculate_correlation_matrix(
request: CorrelationMatrixRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""종목 간 수익률 상관 행렬 계산."""
service = CorrelationService(db)
result = service.get_correlation_data(request.stock_codes, request.period_days)
return CorrelationMatrixResponse(
stock_codes=result["stock_codes"],
matrix=result["matrix"],
high_correlation_pairs=[
HighCorrelationPair(**p) for p in result["high_correlation_pairs"]
],
)
@router.get("/portfolio/{portfolio_id}", response_model=DiversificationResponse)
async def get_portfolio_diversification(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""포트폴리오 분산 효과 점수 조회."""
service = CorrelationService(db)
try:
score = service.calculate_portfolio_diversification(portfolio_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Get holdings for correlation data
from app.models.portfolio import Portfolio, PortfolioSnapshot
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == current_user.id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
snapshot = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.first()
)
high_pairs = []
stock_count = 0
if snapshot and snapshot.holdings:
tickers = [h.ticker for h in snapshot.holdings]
stock_count = len(tickers)
if len(tickers) >= 2:
corr_data = service.get_correlation_data(tickers, period_days=60)
high_pairs = [
HighCorrelationPair(**p) for p in corr_data["high_correlation_pairs"]
]
return DiversificationResponse(
portfolio_id=portfolio_id,
diversification_score=score,
stock_count=stock_count,
high_correlation_pairs=high_pairs,
)