Phase 1: - Real-time signal alerts (Discord/Telegram webhook) - Trading journal with entry/exit tracking - Position sizing calculator (Fixed/Kelly/ATR) Phase 2: - Pension asset allocation (DC/IRP 70% risk limit) - Drawdown monitoring with SVG gauge - Benchmark dashboard (portfolio vs KOSPI vs deposit) Phase 3: - Tax benefit simulation (Korean pension tax rules) - Correlation matrix heatmap - Parameter optimizer with grid search + overfit detection
221 lines
8.1 KiB
Python
221 lines
8.1 KiB
Python
"""
|
|
Unit tests for correlation analysis service.
|
|
"""
|
|
from datetime import date, timedelta
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from app.services.correlation import CorrelationService
|
|
|
|
|
|
@pytest.fixture
|
|
def db():
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def service(db):
|
|
return CorrelationService(db)
|
|
|
|
|
|
class TestCalculateCorrelationMatrix:
|
|
def _make_prices(self, ticker: str, dates: list, closes: list):
|
|
prices = []
|
|
for d, c in zip(dates, closes):
|
|
p = MagicMock()
|
|
p.ticker = ticker
|
|
p.date = d
|
|
p.close = Decimal(str(c))
|
|
prices.append(p)
|
|
return prices
|
|
|
|
def test_two_stocks_positive_correlation(self, service, db):
|
|
dates = [date(2025, 1, i) for i in range(1, 11)]
|
|
prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112])
|
|
prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56])
|
|
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
|
|
|
|
result = service.calculate_correlation_matrix(["A", "B"], period_days=60)
|
|
|
|
assert "A" in result["stock_codes"]
|
|
assert "B" in result["stock_codes"]
|
|
assert len(result["matrix"]) == 2
|
|
assert len(result["matrix"][0]) == 2
|
|
# Diagonal should be 1.0
|
|
assert result["matrix"][0][0] == pytest.approx(1.0, abs=0.01)
|
|
assert result["matrix"][1][1] == pytest.approx(1.0, abs=0.01)
|
|
# These stocks move together, correlation should be high
|
|
assert result["matrix"][0][1] > 0.5
|
|
|
|
def test_single_stock_returns_identity(self, service, db):
|
|
dates = [date(2025, 1, i) for i in range(1, 6)]
|
|
prices = self._make_prices("A", dates, [100, 102, 101, 103, 105])
|
|
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
|
|
|
|
result = service.calculate_correlation_matrix(["A"], period_days=60)
|
|
|
|
assert result["matrix"] == [[1.0]]
|
|
|
|
def test_empty_stock_codes(self, service, db):
|
|
result = service.calculate_correlation_matrix([], period_days=60)
|
|
|
|
assert result["stock_codes"] == []
|
|
assert result["matrix"] == []
|
|
|
|
def test_insufficient_data_returns_nan_as_none(self, service, db):
|
|
dates = [date(2025, 1, 1)]
|
|
prices_a = self._make_prices("A", dates, [100])
|
|
prices_b = self._make_prices("B", dates, [50])
|
|
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
|
|
|
|
result = service.calculate_correlation_matrix(["A", "B"], period_days=60)
|
|
|
|
# With only 1 data point, no returns can be calculated
|
|
assert result["matrix"][0][1] is None
|
|
|
|
|
|
class TestCalculatePortfolioDiversification:
|
|
def _make_holding(self, ticker: str, value: float, ratio: float):
|
|
h = MagicMock()
|
|
h.ticker = ticker
|
|
h.value = Decimal(str(value))
|
|
h.current_ratio = Decimal(str(ratio))
|
|
return h
|
|
|
|
def _make_prices(self, ticker: str, dates: list, closes: list):
|
|
prices = []
|
|
for d, c in zip(dates, closes):
|
|
p = MagicMock()
|
|
p.ticker = ticker
|
|
p.date = d
|
|
p.close = Decimal(str(c))
|
|
prices.append(p)
|
|
return prices
|
|
|
|
def test_diversified_portfolio(self, service, db):
|
|
"""Low correlation stocks -> high diversification score."""
|
|
dates = [date(2025, 1, i) for i in range(1, 21)]
|
|
np.random.seed(42)
|
|
closes_a = np.cumsum(np.random.randn(20)) + 100
|
|
closes_b = np.cumsum(np.random.randn(20)) + 200
|
|
|
|
prices = (
|
|
self._make_prices("A", dates, closes_a.tolist()) +
|
|
self._make_prices("B", dates, closes_b.tolist())
|
|
)
|
|
|
|
portfolio = MagicMock()
|
|
portfolio.id = 1
|
|
portfolio.user_id = 1
|
|
|
|
snapshot = MagicMock()
|
|
snapshot.holdings = [
|
|
self._make_holding("A", 5000, 50),
|
|
self._make_holding("B", 5000, 50),
|
|
]
|
|
|
|
# DB query chain
|
|
db.query.return_value.filter.return_value.first.return_value = portfolio
|
|
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
|
|
|
|
score = service.calculate_portfolio_diversification(portfolio_id=1)
|
|
|
|
assert 0 <= score <= 1
|
|
|
|
def test_portfolio_not_found(self, service, db):
|
|
db.query.return_value.filter.return_value.first.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="Portfolio not found"):
|
|
service.calculate_portfolio_diversification(portfolio_id=999)
|
|
|
|
def test_no_holdings(self, service, db):
|
|
portfolio = MagicMock()
|
|
portfolio.id = 1
|
|
|
|
snapshot = MagicMock()
|
|
snapshot.holdings = []
|
|
|
|
db.query.return_value.filter.return_value.first.return_value = portfolio
|
|
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
|
|
|
|
score = service.calculate_portfolio_diversification(portfolio_id=1)
|
|
assert score == 1.0
|
|
|
|
def test_single_holding(self, service, db):
|
|
dates = [date(2025, 1, i) for i in range(1, 11)]
|
|
prices = self._make_prices("A", dates, [100 + i for i in range(10)])
|
|
|
|
portfolio = MagicMock()
|
|
portfolio.id = 1
|
|
|
|
snapshot = MagicMock()
|
|
snapshot.holdings = [self._make_holding("A", 10000, 100)]
|
|
|
|
db.query.return_value.filter.return_value.first.return_value = portfolio
|
|
db.query.return_value.filter.return_value.order_by.return_value.first.return_value = snapshot
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
|
|
|
|
score = service.calculate_portfolio_diversification(portfolio_id=1)
|
|
# Single stock = no diversification benefit, score should be 0
|
|
assert score == 0.0
|
|
|
|
|
|
class TestGetCorrelationData:
|
|
def _make_prices(self, ticker: str, dates: list, closes: list):
|
|
prices = []
|
|
for d, c in zip(dates, closes):
|
|
p = MagicMock()
|
|
p.ticker = ticker
|
|
p.date = d
|
|
p.close = Decimal(str(c))
|
|
prices.append(p)
|
|
return prices
|
|
|
|
def test_heatmap_data_structure(self, service, db):
|
|
dates = [date(2025, 1, i) for i in range(1, 11)]
|
|
prices_a = self._make_prices("A", dates, [100, 102, 104, 103, 105, 107, 106, 108, 110, 112])
|
|
prices_b = self._make_prices("B", dates, [50, 51, 52, 51.5, 52.5, 53.5, 53, 54, 55, 56])
|
|
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices_a + prices_b
|
|
|
|
result = service.get_correlation_data(["A", "B"], period_days=60)
|
|
|
|
assert "stock_codes" in result
|
|
assert "matrix" in result
|
|
assert "high_correlation_pairs" in result
|
|
|
|
# high_correlation_pairs should have pairs with corr > 0.7
|
|
for pair in result["high_correlation_pairs"]:
|
|
assert "stock_a" in pair
|
|
assert "stock_b" in pair
|
|
assert "correlation" in pair
|
|
assert abs(pair["correlation"]) > 0.7
|
|
|
|
def test_no_high_correlation_pairs_when_uncorrelated(self, service, db):
|
|
dates = [date(2025, 1, i) for i in range(1, 21)]
|
|
np.random.seed(123)
|
|
closes_a = np.cumsum(np.random.randn(20)) + 100
|
|
closes_b = np.cumsum(np.random.randn(20)) + 200
|
|
|
|
prices = (
|
|
self._make_prices("A", dates, closes_a.tolist()) +
|
|
self._make_prices("B", dates, closes_b.tolist())
|
|
)
|
|
|
|
db.query.return_value.filter.return_value.order_by.return_value.all.return_value = prices
|
|
|
|
result = service.get_correlation_data(["A", "B"], period_days=60)
|
|
|
|
# random walks are unlikely to have > 0.7 correlation
|
|
high_pairs = [p for p in result["high_correlation_pairs"] if abs(p["correlation"]) > 0.7]
|
|
# This is probabilistic but with seed 123 they should be uncorrelated
|
|
assert len(result["high_correlation_pairs"]) >= 0 # may or may not have pairs
|