galaxis-po/backend/app/services/correlation.py

189 lines
5.6 KiB
Python
Raw Normal View History

"""
Correlation analysis service.
Calculates inter-stock correlation matrices and portfolio diversification scores.
"""
import logging
from datetime import date, timedelta
from typing import List, Optional
import numpy as np
import pandas as pd
from sqlalchemy.orm import Session
from app.models.stock import Price
from app.models.portfolio import Portfolio, PortfolioSnapshot
logger = logging.getLogger(__name__)
class CorrelationService:
def __init__(self, db: Session):
self.db = db
def calculate_correlation_matrix(
self, stock_codes: List[str], period_days: int = 60
) -> dict:
if not stock_codes:
return {"stock_codes": [], "matrix": []}
end_date = date.today()
start_date = end_date - timedelta(days=period_days)
prices = (
self.db.query(Price)
.filter(
Price.ticker.in_(stock_codes),
Price.date >= start_date,
Price.date <= end_date,
)
.order_by(Price.date)
.all()
)
returns_df = self._prices_to_returns_df(prices, stock_codes)
if returns_df.empty or len(returns_df) < 2:
n = len(stock_codes)
matrix = [[None if i != j else 1.0 for j in range(n)] for i in range(n)]
return {"stock_codes": stock_codes, "matrix": matrix}
corr_matrix = returns_df.corr()
matrix = []
for code in stock_codes:
row = []
for other in stock_codes:
if code in corr_matrix.columns and other in corr_matrix.columns:
val = corr_matrix.loc[code, other]
row.append(round(float(val), 4) if not np.isnan(val) else None)
else:
row.append(None if code != other else 1.0)
matrix.append(row)
return {"stock_codes": stock_codes, "matrix": matrix}
def calculate_portfolio_diversification(self, portfolio_id: int) -> float:
portfolio = (
self.db.query(Portfolio)
.filter(Portfolio.id == portfolio_id)
.first()
)
if not portfolio:
raise ValueError("Portfolio not found")
snapshot = (
self.db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.first()
)
if not snapshot or not snapshot.holdings:
return 1.0
holdings = snapshot.holdings
if len(holdings) == 1:
return 0.0
tickers = [h.ticker for h in holdings]
total_value = sum(float(h.value) for h in holdings)
if total_value == 0:
return 1.0
weights = np.array([float(h.value) / total_value for h in holdings])
end_date = date.today()
start_date = end_date - timedelta(days=60)
prices = (
self.db.query(Price)
.filter(
Price.ticker.in_(tickers),
Price.date >= start_date,
Price.date <= end_date,
)
.order_by(Price.date)
.all()
)
returns_df = self._prices_to_returns_df(prices, tickers)
if returns_df.empty or len(returns_df) < 2:
return 0.5
cov_matrix = returns_df.cov().values
stds = returns_df.std().values
# Portfolio variance
portfolio_variance = weights @ cov_matrix @ weights
# Weighted average variance (no diversification case)
weighted_avg_variance = np.sum((weights ** 2) * (stds ** 2)) + \
2 * np.sum([
weights[i] * weights[j] * stds[i] * stds[j]
for i in range(len(weights))
for j in range(i + 1, len(weights))
])
if weighted_avg_variance < 1e-10:
return 1.0
# Diversification ratio: 1 - (portfolio_vol / weighted_avg_vol)
portfolio_vol = np.sqrt(portfolio_variance)
weighted_avg_vol = np.sum(weights * stds)
if weighted_avg_vol < 1e-10:
return 1.0
diversification_ratio = 1.0 - (portfolio_vol / weighted_avg_vol)
return round(float(np.clip(diversification_ratio, 0, 1)), 4)
def get_correlation_data(
self, stock_codes: List[str], period_days: int = 60
) -> dict:
result = self.calculate_correlation_matrix(stock_codes, period_days)
high_pairs = []
codes = result["stock_codes"]
matrix = result["matrix"]
for i in range(len(codes)):
for j in range(i + 1, len(codes)):
val = matrix[i][j]
if val is not None and abs(val) > 0.7:
high_pairs.append({
"stock_a": codes[i],
"stock_b": codes[j],
"correlation": val,
})
result["high_correlation_pairs"] = high_pairs
return result
def _prices_to_returns_df(
self, prices: list, stock_codes: List[str]
) -> pd.DataFrame:
if not prices:
return pd.DataFrame()
data = {}
for p in prices:
if p.ticker not in data:
data[p.ticker] = {}
data[p.ticker][p.date] = float(p.close)
if not data:
return pd.DataFrame()
df = pd.DataFrame(data)
df.index = pd.to_datetime(df.index)
df = df.sort_index()
# Reorder columns to match requested order
existing = [c for c in stock_codes if c in df.columns]
df = df[existing]
returns_df = df.pct_change().dropna()
return returns_df