""" Tests for drawdown service and API endpoints. """ import pytest from datetime import date from decimal import Decimal from app.models.portfolio import Portfolio, PortfolioSnapshot from app.services.drawdown import ( calculate_drawdown, calculate_rolling_drawdown, check_drawdown_alert, get_alert_threshold, set_alert_threshold, DEFAULT_ALERT_THRESHOLD, _drawdown_settings, ) # --- Helper --- def _create_portfolio_with_snapshots(db, user_id, values_and_dates): """Create a portfolio with snapshot time series.""" portfolio = Portfolio( user_id=user_id, name="테스트 포트폴리오", portfolio_type="general", ) db.add(portfolio) db.flush() for snap_date, total_value in values_and_dates: snapshot = PortfolioSnapshot( portfolio_id=portfolio.id, total_value=Decimal(str(total_value)), snapshot_date=snap_date, ) db.add(snapshot) db.commit() db.refresh(portfolio) return portfolio # --- calculate_drawdown tests --- class TestCalculateDrawdown: def test_no_snapshots(self, db, test_user): portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") db.add(portfolio) db.commit() result = calculate_drawdown(db, portfolio.id) assert result["current_drawdown_pct"] == Decimal("0") assert result["max_drawdown_pct"] == Decimal("0") assert result["peak_value"] is None def test_no_drawdown_monotonic_increase(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_100_000), (date(2025, 3, 1), 1_200_000), ]) result = calculate_drawdown(db, portfolio.id) assert result["current_drawdown_pct"] == Decimal("0") assert result["max_drawdown_pct"] == Decimal("0") assert result["peak_value"] == Decimal("1200000") def test_simple_drawdown(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_200_000), # peak (date(2025, 3, 1), 1_080_000), # -10% ]) result = calculate_drawdown(db, portfolio.id) assert result["current_drawdown_pct"] == Decimal("10.00") assert result["max_drawdown_pct"] == Decimal("10.00") assert result["peak_value"] == Decimal("1200000") assert result["peak_date"] == date(2025, 2, 1) assert result["trough_value"] == Decimal("1080000") def test_recovery_after_drawdown(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_200_000), # peak (date(2025, 3, 1), 960_000), # -20% (max dd) (date(2025, 4, 1), 1_300_000), # new peak, recovery ]) result = calculate_drawdown(db, portfolio.id) assert result["current_drawdown_pct"] == Decimal("0") assert result["max_drawdown_pct"] == Decimal("20.00") assert result["peak_value"] == Decimal("1300000") assert result["max_drawdown_date"] == date(2025, 3, 1) def test_multiple_drawdowns_picks_worst(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 900_000), # -10% (date(2025, 3, 1), 1_100_000), # new peak (date(2025, 4, 1), 880_000), # -20% from 1.1M ]) result = calculate_drawdown(db, portfolio.id) assert result["max_drawdown_pct"] == Decimal("20.00") assert result["current_drawdown_pct"] == Decimal("20.00") # --- calculate_rolling_drawdown tests --- class TestCalculateRollingDrawdown: def test_empty_snapshots(self, db, test_user): portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") db.add(portfolio) db.commit() result = calculate_rolling_drawdown(db, portfolio.id) assert result == [] def test_rolling_drawdown_series(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_200_000), (date(2025, 3, 1), 1_080_000), ]) result = calculate_rolling_drawdown(db, portfolio.id) assert len(result) == 3 # First point: no drawdown assert result[0]["drawdown_pct"] == Decimal("0") assert result[0]["peak"] == Decimal("1000000") # Second point: new peak, no drawdown assert result[1]["drawdown_pct"] == Decimal("0") assert result[1]["peak"] == Decimal("1200000") # Third point: drawdown from peak assert result[2]["drawdown_pct"] == Decimal("10.00") assert result[2]["peak"] == Decimal("1200000") # --- check_drawdown_alert tests --- class TestCheckDrawdownAlert: def setup_method(self): _drawdown_settings.clear() def test_no_alert_under_threshold(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 950_000), # -5% ]) result = check_drawdown_alert(db, portfolio.id) assert result is None def test_alert_above_default_threshold(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 780_000), # -22% ]) result = check_drawdown_alert(db, portfolio.id) assert result is not None assert "Drawdown 경고" in result assert "테스트 포트폴리오" in result def test_alert_with_custom_threshold(self, db, test_user): portfolio = _create_portfolio_with_snapshots(db, test_user.id, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 900_000), # -10% ]) set_alert_threshold(portfolio.id, Decimal("5")) result = check_drawdown_alert(db, portfolio.id) assert result is not None assert "경고" in result def test_no_alert_empty_portfolio(self, db, test_user): portfolio = Portfolio(user_id=test_user.id, name="빈 포트폴리오") db.add(portfolio) db.commit() result = check_drawdown_alert(db, portfolio.id) assert result is None # --- settings tests --- class TestDrawdownSettings: def setup_method(self): _drawdown_settings.clear() def test_default_threshold(self): assert get_alert_threshold(999) == DEFAULT_ALERT_THRESHOLD def test_set_and_get_threshold(self): set_alert_threshold(1, Decimal("15")) assert get_alert_threshold(1) == Decimal("15") def test_different_portfolios_independent(self): set_alert_threshold(1, Decimal("10")) set_alert_threshold(2, Decimal("25")) assert get_alert_threshold(1) == Decimal("10") assert get_alert_threshold(2) == Decimal("25") # --- API endpoint tests --- class TestDrawdownAPI: def _create_portfolio_via_api(self, client, auth_headers): resp = client.post( "/api/portfolios", headers=auth_headers, json={"name": "테스트", "portfolio_type": "general"}, ) return resp.json()["id"] def _add_snapshots(self, db, portfolio_id, values_and_dates): for snap_date, total_value in values_and_dates: snapshot = PortfolioSnapshot( portfolio_id=portfolio_id, total_value=Decimal(str(total_value)), snapshot_date=snap_date, ) db.add(snapshot) db.commit() def test_get_drawdown(self, client, auth_headers, db): pid = self._create_portfolio_via_api(client, auth_headers) self._add_snapshots(db, pid, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_200_000), (date(2025, 3, 1), 1_080_000), ]) resp = client.get(f"/api/drawdown/{pid}", headers=auth_headers) assert resp.status_code == 200 data = resp.json() assert data["portfolio_id"] == pid assert data["current_drawdown_pct"] == 10.0 assert data["max_drawdown_pct"] == 10.0 def test_get_drawdown_history(self, client, auth_headers, db): pid = self._create_portfolio_via_api(client, auth_headers) self._add_snapshots(db, pid, [ (date(2025, 1, 1), 1_000_000), (date(2025, 2, 1), 1_200_000), (date(2025, 3, 1), 1_080_000), ]) resp = client.get(f"/api/drawdown/{pid}/history", headers=auth_headers) assert resp.status_code == 200 data = resp.json() assert len(data["data"]) == 3 assert data["max_drawdown_pct"] == 10.0 def test_update_settings(self, client, auth_headers, db): pid = self._create_portfolio_via_api(client, auth_headers) resp = client.put( f"/api/drawdown/settings/{pid}", headers=auth_headers, json={"alert_threshold_pct": 15.0}, ) assert resp.status_code == 200 assert resp.json()["alert_threshold_pct"] == 15.0 def test_drawdown_nonexistent_portfolio(self, client, auth_headers): resp = client.get("/api/drawdown/9999", headers=auth_headers) assert resp.status_code == 404 def test_unauthenticated_access(self, client): resp = client.get("/api/drawdown/1") assert resp.status_code == 401