""" Drawdown calculation service using PortfolioSnapshot.total_value time series. """ import logging from datetime import date from decimal import Decimal from typing import Optional from sqlalchemy.orm import Session from app.models.portfolio import Portfolio, PortfolioSnapshot logger = logging.getLogger(__name__) # In-memory per-portfolio settings (no separate table needed) _drawdown_settings: dict[int, Decimal] = {} DEFAULT_ALERT_THRESHOLD = Decimal("20") def get_alert_threshold(portfolio_id: int) -> Decimal: return _drawdown_settings.get(portfolio_id, DEFAULT_ALERT_THRESHOLD) def set_alert_threshold(portfolio_id: int, threshold_pct: Decimal) -> None: _drawdown_settings[portfolio_id] = threshold_pct def calculate_drawdown( db: Session, portfolio_id: int, ) -> dict: """Calculate current and max drawdown from snapshot time series. Returns dict with: current_drawdown_pct, max_drawdown_pct, peak_value, peak_date, trough_value, trough_date, max_drawdown_date, alert_threshold_pct """ snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date) .all() ) if not snapshots: return { "current_drawdown_pct": Decimal("0"), "max_drawdown_pct": Decimal("0"), "peak_value": None, "peak_date": None, "trough_value": None, "trough_date": None, "max_drawdown_date": None, "alert_threshold_pct": get_alert_threshold(portfolio_id), } peak = Decimal(str(snapshots[0].total_value)) peak_date = snapshots[0].snapshot_date max_dd = Decimal("0") max_dd_date: Optional[date] = None trough_value = peak trough_date = peak_date for snap in snapshots: value = Decimal(str(snap.total_value)) if value > peak: peak = value peak_date = snap.snapshot_date if peak > 0: dd = ((peak - value) / peak * 100).quantize(Decimal("0.01")) else: dd = Decimal("0") if dd > max_dd: max_dd = dd max_dd_date = snap.snapshot_date trough_value = value trough_date = snap.snapshot_date # Current drawdown = drawdown of last snapshot from running peak last_value = Decimal(str(snapshots[-1].total_value)) if peak > 0: current_dd = ((peak - last_value) / peak * 100).quantize(Decimal("0.01")) else: current_dd = Decimal("0") return { "current_drawdown_pct": current_dd, "max_drawdown_pct": max_dd, "peak_value": peak, "peak_date": peak_date, "trough_value": trough_value, "trough_date": trough_date, "max_drawdown_date": max_dd_date, "alert_threshold_pct": get_alert_threshold(portfolio_id), } def calculate_rolling_drawdown( db: Session, portfolio_id: int, ) -> list[dict]: """Calculate rolling drawdown time series. Returns list of {date, total_value, peak, drawdown_pct}. """ snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date) .all() ) if not snapshots: return [] result = [] peak = Decimal("0") for snap in snapshots: value = Decimal(str(snap.total_value)) if value > peak: peak = value if peak > 0: dd = ((peak - value) / peak * 100).quantize(Decimal("0.01")) else: dd = Decimal("0") result.append({ "date": snap.snapshot_date, "total_value": value, "peak": peak, "drawdown_pct": dd, }) return result def check_drawdown_alert( db: Session, portfolio_id: int, ) -> Optional[str]: """Check if current drawdown exceeds alert threshold. Returns alert message string if threshold exceeded, None otherwise. """ data = calculate_drawdown(db, portfolio_id) threshold = data["alert_threshold_pct"] current_dd = data["current_drawdown_pct"] if current_dd >= threshold: portfolio = db.query(Portfolio).filter(Portfolio.id == portfolio_id).first() name = portfolio.name if portfolio else f"Portfolio #{portfolio_id}" return ( f"[Drawdown 경고] {name}: " f"현재 낙폭 {current_dd}%가 한도 {threshold}%를 초과했습니다. " f"(고점: {data['peak_value']:,.0f}원, 현재: {data['trough_value']:,.0f}원)" ) return None