galaxis-po/backend/tests/unit/test_drawdown.py
머니페니 12d235a1f1 feat: add 9 new modules - notification alerts, trading journal, position sizing, pension allocation, drawdown monitoring, benchmark dashboard, tax simulation, correlation analysis, parameter optimizer
Phase 1:
- Real-time signal alerts (Discord/Telegram webhook)
- Trading journal with entry/exit tracking
- Position sizing calculator (Fixed/Kelly/ATR)

Phase 2:
- Pension asset allocation (DC/IRP 70% risk limit)
- Drawdown monitoring with SVG gauge
- Benchmark dashboard (portfolio vs KOSPI vs deposit)

Phase 3:
- Tax benefit simulation (Korean pension tax rules)
- Correlation matrix heatmap
- Parameter optimizer with grid search + overfit detection
2026-03-29 10:03:08 +09:00

279 lines
9.6 KiB
Python

"""
Tests for drawdown service and API endpoints.
"""
import pytest
from datetime import date
from decimal import Decimal
from app.models.portfolio import Portfolio, PortfolioSnapshot
from app.services.drawdown import (
calculate_drawdown,
calculate_rolling_drawdown,
check_drawdown_alert,
get_alert_threshold,
set_alert_threshold,
DEFAULT_ALERT_THRESHOLD,
_drawdown_settings,
)
# --- Helper ---
def _create_portfolio_with_snapshots(db, user_id, values_and_dates):
"""Create a portfolio with snapshot time series."""
portfolio = Portfolio(
user_id=user_id,
name="테스트 포트폴리오",
portfolio_type="general",
)
db.add(portfolio)
db.flush()
for snap_date, total_value in values_and_dates:
snapshot = PortfolioSnapshot(
portfolio_id=portfolio.id,
total_value=Decimal(str(total_value)),
snapshot_date=snap_date,
)
db.add(snapshot)
db.commit()
db.refresh(portfolio)
return portfolio
# --- calculate_drawdown tests ---
class TestCalculateDrawdown:
def test_no_snapshots(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("0")
assert result["peak_value"] is None
def test_no_drawdown_monotonic_increase(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_100_000),
(date(2025, 3, 1), 1_200_000),
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("0")
assert result["peak_value"] == Decimal("1200000")
def test_simple_drawdown(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000), # peak
(date(2025, 3, 1), 1_080_000), # -10%
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("10.00")
assert result["max_drawdown_pct"] == Decimal("10.00")
assert result["peak_value"] == Decimal("1200000")
assert result["peak_date"] == date(2025, 2, 1)
assert result["trough_value"] == Decimal("1080000")
def test_recovery_after_drawdown(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000), # peak
(date(2025, 3, 1), 960_000), # -20% (max dd)
(date(2025, 4, 1), 1_300_000), # new peak, recovery
])
result = calculate_drawdown(db, portfolio.id)
assert result["current_drawdown_pct"] == Decimal("0")
assert result["max_drawdown_pct"] == Decimal("20.00")
assert result["peak_value"] == Decimal("1300000")
assert result["max_drawdown_date"] == date(2025, 3, 1)
def test_multiple_drawdowns_picks_worst(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 900_000), # -10%
(date(2025, 3, 1), 1_100_000), # new peak
(date(2025, 4, 1), 880_000), # -20% from 1.1M
])
result = calculate_drawdown(db, portfolio.id)
assert result["max_drawdown_pct"] == Decimal("20.00")
assert result["current_drawdown_pct"] == Decimal("20.00")
# --- calculate_rolling_drawdown tests ---
class TestCalculateRollingDrawdown:
def test_empty_snapshots(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = calculate_rolling_drawdown(db, portfolio.id)
assert result == []
def test_rolling_drawdown_series(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
result = calculate_rolling_drawdown(db, portfolio.id)
assert len(result) == 3
# First point: no drawdown
assert result[0]["drawdown_pct"] == Decimal("0")
assert result[0]["peak"] == Decimal("1000000")
# Second point: new peak, no drawdown
assert result[1]["drawdown_pct"] == Decimal("0")
assert result[1]["peak"] == Decimal("1200000")
# Third point: drawdown from peak
assert result[2]["drawdown_pct"] == Decimal("10.00")
assert result[2]["peak"] == Decimal("1200000")
# --- check_drawdown_alert tests ---
class TestCheckDrawdownAlert:
def setup_method(self):
_drawdown_settings.clear()
def test_no_alert_under_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 950_000), # -5%
])
result = check_drawdown_alert(db, portfolio.id)
assert result is None
def test_alert_above_default_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 780_000), # -22%
])
result = check_drawdown_alert(db, portfolio.id)
assert result is not None
assert "Drawdown 경고" in result
assert "테스트 포트폴리오" in result
def test_alert_with_custom_threshold(self, db, test_user):
portfolio = _create_portfolio_with_snapshots(db, test_user.id, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 900_000), # -10%
])
set_alert_threshold(portfolio.id, Decimal("5"))
result = check_drawdown_alert(db, portfolio.id)
assert result is not None
assert "경고" in result
def test_no_alert_empty_portfolio(self, db, test_user):
portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오")
db.add(portfolio)
db.commit()
result = check_drawdown_alert(db, portfolio.id)
assert result is None
# --- settings tests ---
class TestDrawdownSettings:
def setup_method(self):
_drawdown_settings.clear()
def test_default_threshold(self):
assert get_alert_threshold(999) == DEFAULT_ALERT_THRESHOLD
def test_set_and_get_threshold(self):
set_alert_threshold(1, Decimal("15"))
assert get_alert_threshold(1) == Decimal("15")
def test_different_portfolios_independent(self):
set_alert_threshold(1, Decimal("10"))
set_alert_threshold(2, Decimal("25"))
assert get_alert_threshold(1) == Decimal("10")
assert get_alert_threshold(2) == Decimal("25")
# --- API endpoint tests ---
class TestDrawdownAPI:
def _create_portfolio_via_api(self, client, auth_headers):
resp = client.post(
"/api/portfolios",
headers=auth_headers,
json={"name": "테스트", "portfolio_type": "general"},
)
return resp.json()["id"]
def _add_snapshots(self, db, portfolio_id, values_and_dates):
for snap_date, total_value in values_and_dates:
snapshot = PortfolioSnapshot(
portfolio_id=portfolio_id,
total_value=Decimal(str(total_value)),
snapshot_date=snap_date,
)
db.add(snapshot)
db.commit()
def test_get_drawdown(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
self._add_snapshots(db, pid, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
resp = client.get(f"/api/drawdown/{pid}", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["portfolio_id"] == pid
assert data["current_drawdown_pct"] == 10.0
assert data["max_drawdown_pct"] == 10.0
def test_get_drawdown_history(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
self._add_snapshots(db, pid, [
(date(2025, 1, 1), 1_000_000),
(date(2025, 2, 1), 1_200_000),
(date(2025, 3, 1), 1_080_000),
])
resp = client.get(f"/api/drawdown/{pid}/history", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert len(data["data"]) == 3
assert data["max_drawdown_pct"] == 10.0
def test_update_settings(self, client, auth_headers, db):
pid = self._create_portfolio_via_api(client, auth_headers)
resp = client.put(
f"/api/drawdown/settings/{pid}",
headers=auth_headers,
json={"alert_threshold_pct": 15.0},
)
assert resp.status_code == 200
assert resp.json()["alert_threshold_pct"] == 15.0
def test_drawdown_nonexistent_portfolio(self, client, auth_headers):
resp = client.get("/api/drawdown/9999", headers=auth_headers)
assert resp.status_code == 404
def test_unauthenticated_access(self, client):
resp = client.get("/api/drawdown/1")
assert resp.status_code == 401