Phase 1: - Real-time signal alerts (Discord/Telegram webhook) - Trading journal with entry/exit tracking - Position sizing calculator (Fixed/Kelly/ATR) Phase 2: - Pension asset allocation (DC/IRP 70% risk limit) - Drawdown monitoring with SVG gauge - Benchmark dashboard (portfolio vs KOSPI vs deposit) Phase 3: - Tax benefit simulation (Korean pension tax rules) - Correlation matrix heatmap - Parameter optimizer with grid search + overfit detection
279 lines
9.6 KiB
Python
279 lines
9.6 KiB
Python
"""
|
|
Tests for drawdown service and API endpoints.
|
|
"""
|
|
import pytest
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
|
|
from app.models.portfolio import Portfolio, PortfolioSnapshot
|
|
from app.services.drawdown import (
|
|
calculate_drawdown,
|
|
calculate_rolling_drawdown,
|
|
check_drawdown_alert,
|
|
get_alert_threshold,
|
|
set_alert_threshold,
|
|
DEFAULT_ALERT_THRESHOLD,
|
|
_drawdown_settings,
|
|
)
|
|
|
|
|
|
# --- Helper ---
|
|
|
|
def _create_portfolio_with_snapshots(db, user_id, values_and_dates):
|
|
"""Create a portfolio with snapshot time series."""
|
|
portfolio = Portfolio(
|
|
user_id=user_id,
|
|
name="테스트 포트폴리오",
|
|
portfolio_type="general",
|
|
)
|
|
db.add(portfolio)
|
|
db.flush()
|
|
|
|
for snap_date, total_value in values_and_dates:
|
|
snapshot = PortfolioSnapshot(
|
|
portfolio_id=portfolio.id,
|
|
total_value=Decimal(str(total_value)),
|
|
snapshot_date=snap_date,
|
|
)
|
|
db.add(snapshot)
|
|
|
|
db.commit()
|
|
db.refresh(portfolio)
|
|
return portfolio
|
|
|
|
|
|
# --- calculate_drawdown tests ---
|
|
|
|
class TestCalculateDrawdown:
|
|
def test_no_snapshots(self, db, test_user):
|
|
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
|
|
db.add(portfolio)
|
|
db.commit()
|
|
|
|
result = calculate_drawdown(db, portfolio.id)
|
|
assert result["current_drawdown_pct"] == Decimal("0")
|
|
assert result["max_drawdown_pct"] == Decimal("0")
|
|
assert result["peak_value"] is None
|
|
|
|
def test_no_drawdown_monotonic_increase(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_100_000),
|
|
(date(2025, 3, 1), 1_200_000),
|
|
])
|
|
|
|
result = calculate_drawdown(db, portfolio.id)
|
|
assert result["current_drawdown_pct"] == Decimal("0")
|
|
assert result["max_drawdown_pct"] == Decimal("0")
|
|
assert result["peak_value"] == Decimal("1200000")
|
|
|
|
def test_simple_drawdown(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_200_000), # peak
|
|
(date(2025, 3, 1), 1_080_000), # -10%
|
|
])
|
|
|
|
result = calculate_drawdown(db, portfolio.id)
|
|
assert result["current_drawdown_pct"] == Decimal("10.00")
|
|
assert result["max_drawdown_pct"] == Decimal("10.00")
|
|
assert result["peak_value"] == Decimal("1200000")
|
|
assert result["peak_date"] == date(2025, 2, 1)
|
|
assert result["trough_value"] == Decimal("1080000")
|
|
|
|
def test_recovery_after_drawdown(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_200_000), # peak
|
|
(date(2025, 3, 1), 960_000), # -20% (max dd)
|
|
(date(2025, 4, 1), 1_300_000), # new peak, recovery
|
|
])
|
|
|
|
result = calculate_drawdown(db, portfolio.id)
|
|
assert result["current_drawdown_pct"] == Decimal("0")
|
|
assert result["max_drawdown_pct"] == Decimal("20.00")
|
|
assert result["peak_value"] == Decimal("1300000")
|
|
assert result["max_drawdown_date"] == date(2025, 3, 1)
|
|
|
|
def test_multiple_drawdowns_picks_worst(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 900_000), # -10%
|
|
(date(2025, 3, 1), 1_100_000), # new peak
|
|
(date(2025, 4, 1), 880_000), # -20% from 1.1M
|
|
])
|
|
|
|
result = calculate_drawdown(db, portfolio.id)
|
|
assert result["max_drawdown_pct"] == Decimal("20.00")
|
|
assert result["current_drawdown_pct"] == Decimal("20.00")
|
|
|
|
|
|
# --- calculate_rolling_drawdown tests ---
|
|
|
|
class TestCalculateRollingDrawdown:
|
|
def test_empty_snapshots(self, db, test_user):
|
|
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
|
|
db.add(portfolio)
|
|
db.commit()
|
|
|
|
result = calculate_rolling_drawdown(db, portfolio.id)
|
|
assert result == []
|
|
|
|
def test_rolling_drawdown_series(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_200_000),
|
|
(date(2025, 3, 1), 1_080_000),
|
|
])
|
|
|
|
result = calculate_rolling_drawdown(db, portfolio.id)
|
|
assert len(result) == 3
|
|
|
|
# First point: no drawdown
|
|
assert result[0]["drawdown_pct"] == Decimal("0")
|
|
assert result[0]["peak"] == Decimal("1000000")
|
|
|
|
# Second point: new peak, no drawdown
|
|
assert result[1]["drawdown_pct"] == Decimal("0")
|
|
assert result[1]["peak"] == Decimal("1200000")
|
|
|
|
# Third point: drawdown from peak
|
|
assert result[2]["drawdown_pct"] == Decimal("10.00")
|
|
assert result[2]["peak"] == Decimal("1200000")
|
|
|
|
|
|
# --- check_drawdown_alert tests ---
|
|
|
|
class TestCheckDrawdownAlert:
|
|
def setup_method(self):
|
|
_drawdown_settings.clear()
|
|
|
|
def test_no_alert_under_threshold(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 950_000), # -5%
|
|
])
|
|
|
|
result = check_drawdown_alert(db, portfolio.id)
|
|
assert result is None
|
|
|
|
def test_alert_above_default_threshold(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 780_000), # -22%
|
|
])
|
|
|
|
result = check_drawdown_alert(db, portfolio.id)
|
|
assert result is not None
|
|
assert "Drawdown 경고" in result
|
|
assert "테스트 포트폴리오" in result
|
|
|
|
def test_alert_with_custom_threshold(self, db, test_user):
|
|
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 900_000), # -10%
|
|
])
|
|
|
|
set_alert_threshold(portfolio.id, Decimal("5"))
|
|
result = check_drawdown_alert(db, portfolio.id)
|
|
assert result is not None
|
|
assert "경고" in result
|
|
|
|
def test_no_alert_empty_portfolio(self, db, test_user):
|
|
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
|
|
db.add(portfolio)
|
|
db.commit()
|
|
|
|
result = check_drawdown_alert(db, portfolio.id)
|
|
assert result is None
|
|
|
|
|
|
# --- settings tests ---
|
|
|
|
class TestDrawdownSettings:
|
|
def setup_method(self):
|
|
_drawdown_settings.clear()
|
|
|
|
def test_default_threshold(self):
|
|
assert get_alert_threshold(999) == DEFAULT_ALERT_THRESHOLD
|
|
|
|
def test_set_and_get_threshold(self):
|
|
set_alert_threshold(1, Decimal("15"))
|
|
assert get_alert_threshold(1) == Decimal("15")
|
|
|
|
def test_different_portfolios_independent(self):
|
|
set_alert_threshold(1, Decimal("10"))
|
|
set_alert_threshold(2, Decimal("25"))
|
|
assert get_alert_threshold(1) == Decimal("10")
|
|
assert get_alert_threshold(2) == Decimal("25")
|
|
|
|
|
|
# --- API endpoint tests ---
|
|
|
|
class TestDrawdownAPI:
|
|
def _create_portfolio_via_api(self, client, auth_headers):
|
|
resp = client.post(
|
|
"/api/portfolios",
|
|
headers=auth_headers,
|
|
json={"name": "테스트", "portfolio_type": "general"},
|
|
)
|
|
return resp.json()["id"]
|
|
|
|
def _add_snapshots(self, db, portfolio_id, values_and_dates):
|
|
for snap_date, total_value in values_and_dates:
|
|
snapshot = PortfolioSnapshot(
|
|
portfolio_id=portfolio_id,
|
|
total_value=Decimal(str(total_value)),
|
|
snapshot_date=snap_date,
|
|
)
|
|
db.add(snapshot)
|
|
db.commit()
|
|
|
|
def test_get_drawdown(self, client, auth_headers, db):
|
|
pid = self._create_portfolio_via_api(client, auth_headers)
|
|
self._add_snapshots(db, pid, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_200_000),
|
|
(date(2025, 3, 1), 1_080_000),
|
|
])
|
|
|
|
resp = client.get(f"/api/drawdown/{pid}", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["portfolio_id"] == pid
|
|
assert data["current_drawdown_pct"] == 10.0
|
|
assert data["max_drawdown_pct"] == 10.0
|
|
|
|
def test_get_drawdown_history(self, client, auth_headers, db):
|
|
pid = self._create_portfolio_via_api(client, auth_headers)
|
|
self._add_snapshots(db, pid, [
|
|
(date(2025, 1, 1), 1_000_000),
|
|
(date(2025, 2, 1), 1_200_000),
|
|
(date(2025, 3, 1), 1_080_000),
|
|
])
|
|
|
|
resp = client.get(f"/api/drawdown/{pid}/history", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data["data"]) == 3
|
|
assert data["max_drawdown_pct"] == 10.0
|
|
|
|
def test_update_settings(self, client, auth_headers, db):
|
|
pid = self._create_portfolio_via_api(client, auth_headers)
|
|
|
|
resp = client.put(
|
|
f"/api/drawdown/settings/{pid}",
|
|
headers=auth_headers,
|
|
json={"alert_threshold_pct": 15.0},
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["alert_threshold_pct"] == 15.0
|
|
|
|
def test_drawdown_nonexistent_portfolio(self, client, auth_headers):
|
|
resp = client.get("/api/drawdown/9999", headers=auth_headers)
|
|
assert resp.status_code == 404
|
|
|
|
def test_unauthenticated_access(self, client):
|
|
resp = client.get("/api/drawdown/1")
|
|
assert resp.status_code == 401
|