""" Main backtest engine. """ import logging from dataclasses import dataclass, field from datetime import date, timedelta from decimal import Decimal from typing import List, Dict, Optional from dateutil.relativedelta import relativedelta from sqlalchemy import func from sqlalchemy.orm import Session from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, BacktestHolding, BacktestTransaction, RebalancePeriod, ) from app.models.stock import Stock, Price from app.services.backtest.portfolio import VirtualPortfolio, Transaction from app.services.backtest.metrics import MetricsCalculator from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy from app.schemas.strategy import UniverseFilter, FactorWeights logger = logging.getLogger(__name__) @dataclass class DataValidationResult: """Result of pre-backtest data validation.""" is_valid: bool = True errors: List[str] = field(default_factory=list) warnings: List[str] = field(default_factory=list) class BacktestEngine: """ Main backtest engine that simulates strategy performance over time. """ def __init__(self, db: Session): self.db = db def run(self, backtest_id: int) -> None: """Execute backtest and save results.""" backtest = self.db.query(Backtest).get(backtest_id) if not backtest: raise ValueError(f"Backtest {backtest_id} not found") # Initialize portfolio portfolio = VirtualPortfolio(backtest.initial_capital) # Generate rebalance dates rebalance_dates = self._generate_rebalance_dates( backtest.start_date, backtest.end_date, backtest.rebalance_period, ) # Get trading days trading_days = self._get_trading_days( backtest.start_date, backtest.end_date, ) if not trading_days: raise ValueError("No trading days found in the specified period") # Load benchmark data benchmark_prices = self._load_benchmark_prices( backtest.benchmark, backtest.start_date, backtest.end_date, ) # Pre-backtest data validation validation = self._validate_data( trading_days=trading_days, benchmark_prices=benchmark_prices, benchmark=backtest.benchmark, start_date=backtest.start_date, end_date=backtest.end_date, ) for warning in validation.warnings: logger.warning(f"Backtest {backtest_id}: {warning}") if not validation.is_valid: raise ValueError( "데이터 검증 실패:\n" + "\n".join(validation.errors) ) # Create strategy instance strategy = self._create_strategy( backtest.strategy_type, backtest.strategy_params, backtest.top_n, ) # Simulation equity_curve_data: List[Dict] = [] all_transactions: List[Transaction] = [] holdings_history: List[Dict] = [] initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) if initial_benchmark == 0: initial_benchmark = Decimal("1") for trading_date in trading_days: # Get prices for this date prices = self._get_prices_for_date(trading_date) names = self._get_stock_names() # Warn about holdings with missing prices missing = [ t for t in portfolio.holdings if portfolio.holdings[t] > 0 and t not in prices ] if missing: logger.warning( f"{trading_date}: 보유 종목 가격 누락 {missing} " f"(0원으로 처리됨)" ) # Rebalance if needed if trading_date in rebalance_dates: # Run strategy to get target stocks target_stocks = strategy.run( universe_filter=UniverseFilter(), top_n=backtest.top_n, base_date=trading_date, ) target_tickers = [s.ticker for s in target_stocks.stocks] # Execute rebalance transactions = portfolio.rebalance( target_tickers=target_tickers, prices=prices, names=names, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, ) all_transactions.extend([ (trading_date, txn) for txn in transactions ]) # Record holdings holdings = portfolio.get_holdings_with_weights(prices, names) holdings_history.append({ 'date': trading_date, 'holdings': holdings, }) # Record daily value portfolio_value = portfolio.get_value(prices) benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) # Normalize benchmark to initial capital normalized_benchmark = ( benchmark_value / initial_benchmark * backtest.initial_capital ) equity_curve_data.append({ 'date': trading_date, 'portfolio_value': portfolio_value, 'benchmark_value': normalized_benchmark, }) # Calculate metrics portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data] benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data] metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values) # Save results self._save_results( backtest_id=backtest_id, metrics=metrics, equity_curve_data=equity_curve_data, drawdowns=drawdowns, holdings_history=holdings_history, transactions=all_transactions, ) def _validate_data( self, trading_days: List[date], benchmark_prices: Dict[date, Decimal], benchmark: str, start_date: date, end_date: date, ) -> DataValidationResult: """ Validate price data completeness before running backtest. Checks: 1. Minimum trading days requirement 2. Benchmark data coverage 3. Overall price data density (tickers per trading day) 4. Large date gaps in trading days """ result = DataValidationResult() total_days = trading_days num_trading_days = len(total_days) calendar_days = (end_date - start_date).days # 1. Minimum trading days check if calendar_days > 30: # Expect at least 60% of calendar days to be trading days # (weekends ~28%, holidays ~3% => ~69% expected) expected_min = int(calendar_days * 0.5) if num_trading_days < expected_min: result.errors.append( f"거래일 수 부족: {num_trading_days}일 " f"(기간 {calendar_days}일 중 최소 {expected_min}일 필요)" ) result.is_valid = False # 2. Benchmark data coverage benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500" benchmark_coverage = sum( 1 for d in total_days if d in benchmark_prices ) benchmark_pct = ( benchmark_coverage / num_trading_days * 100 if num_trading_days > 0 else 0 ) if benchmark_coverage == 0: result.errors.append( f"벤치마크({benchmark_ticker}) 가격 데이터 없음" ) result.is_valid = False elif benchmark_pct < 90: result.warnings.append( f"벤치마크({benchmark_ticker}) 데이터 커버리지 낮음: " f"{benchmark_coverage}/{num_trading_days}일 ({benchmark_pct:.1f}%)" ) # 3. Price data density per trading day (sample check) # Check first, middle, last trading days sample_dates = [ total_days[0], total_days[num_trading_days // 2], total_days[-1], ] for sample_date in sample_dates: ticker_count = ( self.db.query(func.count(Price.ticker)) .filter(Price.date == sample_date) .scalar() ) if ticker_count == 0: result.errors.append( f"{sample_date} 가격 데이터 없음 (종목 0개)" ) result.is_valid = False elif ticker_count < 100: result.warnings.append( f"{sample_date} 종목 수 적음: {ticker_count}개" ) # 4. Large gaps in trading days (> 7 calendar days excluding normal weekends) for i in range(1, num_trading_days): gap = (total_days[i] - total_days[i - 1]).days if gap > 7: result.warnings.append( f"거래일 갭 발견: {total_days[i-1]} ~ {total_days[i]} " f"({gap}일)" ) if result.is_valid and not result.warnings: logger.info( f"데이터 검증 통과: 거래일 {num_trading_days}일, " f"벤치마크 커버리지 {benchmark_pct:.1f}%" ) return result def _generate_rebalance_dates( self, start_date: date, end_date: date, period: RebalancePeriod, ) -> List[date]: """Generate list of rebalance dates.""" dates = [] current = start_date if period == RebalancePeriod.MONTHLY: delta = relativedelta(months=1) elif period == RebalancePeriod.QUARTERLY: delta = relativedelta(months=3) elif period == RebalancePeriod.SEMI_ANNUAL: delta = relativedelta(months=6) else: # ANNUAL delta = relativedelta(years=1) while current <= end_date: dates.append(current) current = current + delta return dates def _get_trading_days(self, start_date: date, end_date: date) -> List[date]: """Get list of trading days from price data.""" prices = ( self.db.query(Price.date) .filter(Price.date >= start_date) .filter(Price.date <= end_date) .distinct() .order_by(Price.date) .all() ) return [p[0] for p in prices] def _load_benchmark_prices( self, benchmark: str, start_date: date, end_date: date, ) -> Dict[date, Decimal]: """Load benchmark index prices.""" # For KOSPI, we'll use a representative ETF or index # Using KODEX 200 (069500) as KOSPI proxy benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500" prices = ( self.db.query(Price) .filter(Price.ticker == benchmark_ticker) .filter(Price.date >= start_date) .filter(Price.date <= end_date) .all() ) return {p.date: p.close for p in prices} def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]: """Get all stock prices for a specific date.""" prices = ( self.db.query(Price) .filter(Price.date == trading_date) .all() ) return {p.ticker: p.close for p in prices} def _get_stock_names(self) -> Dict[str, str]: """Get all stock names.""" stocks = self.db.query(Stock).all() return {s.ticker: s.name for s in stocks} def _create_strategy( self, strategy_type: str, strategy_params: dict, top_n: int, ): """Create strategy instance based on type.""" if strategy_type == "multi_factor": strategy = MultiFactorStrategy(self.db) strategy._weights = FactorWeights(**strategy_params.get("weights", {})) elif strategy_type == "quality": strategy = QualityStrategy(self.db) strategy._min_fscore = strategy_params.get("min_fscore", 7) elif strategy_type == "value_momentum": strategy = ValueMomentumStrategy(self.db) strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5))) strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5))) else: raise ValueError(f"Unknown strategy type: {strategy_type}") return strategy def _save_results( self, backtest_id: int, metrics, equity_curve_data: List[Dict], drawdowns: List[Decimal], holdings_history: List[Dict], transactions: List, ) -> None: """Save all backtest results to database.""" # Save metrics result = BacktestResult( backtest_id=backtest_id, total_return=metrics.total_return, cagr=metrics.cagr, mdd=metrics.mdd, sharpe_ratio=metrics.sharpe_ratio, volatility=metrics.volatility, benchmark_return=metrics.benchmark_return, excess_return=metrics.excess_return, ) self.db.add(result) # Save equity curve for i, point in enumerate(equity_curve_data): curve_point = BacktestEquityCurve( backtest_id=backtest_id, date=point['date'], portfolio_value=point['portfolio_value'], benchmark_value=point['benchmark_value'], drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"), ) self.db.add(curve_point) # Save holdings history for record in holdings_history: for holding in record['holdings']: h = BacktestHolding( backtest_id=backtest_id, rebalance_date=record['date'], ticker=holding.ticker, name=holding.name, weight=holding.weight, shares=holding.shares, price=holding.price, ) self.db.add(h) # Save transactions for trading_date, txn in transactions: t = BacktestTransaction( backtest_id=backtest_id, date=trading_date, ticker=txn.ticker, action=txn.action, shares=txn.shares, price=txn.price, commission=txn.commission, ) self.db.add(t) self.db.commit()