""" Main backtest engine. """ from datetime import date, timedelta from decimal import Decimal from typing import List, Dict, Optional from dateutil.relativedelta import relativedelta from sqlalchemy.orm import Session from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, BacktestHolding, BacktestTransaction, RebalancePeriod, ) from app.models.stock import Stock, Price from app.services.backtest.portfolio import VirtualPortfolio, Transaction from app.services.backtest.metrics import MetricsCalculator from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy from app.schemas.strategy import UniverseFilter, FactorWeights class BacktestEngine: """ Main backtest engine that simulates strategy performance over time. """ def __init__(self, db: Session): self.db = db def run(self, backtest_id: int) -> None: """Execute backtest and save results.""" backtest = self.db.query(Backtest).get(backtest_id) if not backtest: raise ValueError(f"Backtest {backtest_id} not found") # Initialize portfolio portfolio = VirtualPortfolio(backtest.initial_capital) # Generate rebalance dates rebalance_dates = self._generate_rebalance_dates( backtest.start_date, backtest.end_date, backtest.rebalance_period, ) # Get trading days trading_days = self._get_trading_days( backtest.start_date, backtest.end_date, ) if not trading_days: raise ValueError("No trading days found in the specified period") # Load benchmark data benchmark_prices = self._load_benchmark_prices( backtest.benchmark, backtest.start_date, backtest.end_date, ) # Create strategy instance strategy = self._create_strategy( backtest.strategy_type, backtest.strategy_params, backtest.top_n, ) # Simulation equity_curve_data: List[Dict] = [] all_transactions: List[Transaction] = [] holdings_history: List[Dict] = [] initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) if initial_benchmark == 0: initial_benchmark = Decimal("1") for trading_date in trading_days: # Get prices for this date prices = self._get_prices_for_date(trading_date) names = self._get_stock_names() # Rebalance if needed if trading_date in rebalance_dates: # Run strategy to get target stocks target_stocks = strategy.run( universe_filter=UniverseFilter(), top_n=backtest.top_n, base_date=trading_date, ) target_tickers = [s.ticker for s in target_stocks.stocks] # Execute rebalance transactions = portfolio.rebalance( target_tickers=target_tickers, prices=prices, names=names, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, ) all_transactions.extend([ (trading_date, txn) for txn in transactions ]) # Record holdings holdings = portfolio.get_holdings_with_weights(prices, names) holdings_history.append({ 'date': trading_date, 'holdings': holdings, }) # Record daily value portfolio_value = portfolio.get_value(prices) benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) # Normalize benchmark to initial capital normalized_benchmark = ( benchmark_value / initial_benchmark * backtest.initial_capital ) equity_curve_data.append({ 'date': trading_date, 'portfolio_value': portfolio_value, 'benchmark_value': normalized_benchmark, }) # Calculate metrics portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data] benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data] metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values) # Save results self._save_results( backtest_id=backtest_id, metrics=metrics, equity_curve_data=equity_curve_data, drawdowns=drawdowns, holdings_history=holdings_history, transactions=all_transactions, ) def _generate_rebalance_dates( self, start_date: date, end_date: date, period: RebalancePeriod, ) -> List[date]: """Generate list of rebalance dates.""" dates = [] current = start_date if period == RebalancePeriod.MONTHLY: delta = relativedelta(months=1) elif period == RebalancePeriod.QUARTERLY: delta = relativedelta(months=3) elif period == RebalancePeriod.SEMI_ANNUAL: delta = relativedelta(months=6) else: # ANNUAL delta = relativedelta(years=1) while current <= end_date: dates.append(current) current = current + delta return dates def _get_trading_days(self, start_date: date, end_date: date) -> List[date]: """Get list of trading days from price data.""" prices = ( self.db.query(Price.date) .filter(Price.date >= start_date) .filter(Price.date <= end_date) .distinct() .order_by(Price.date) .all() ) return [p[0] for p in prices] def _load_benchmark_prices( self, benchmark: str, start_date: date, end_date: date, ) -> Dict[date, Decimal]: """Load benchmark index prices.""" # For KOSPI, we'll use a representative ETF or index # Using KODEX 200 (069500) as KOSPI proxy benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500" prices = ( self.db.query(Price) .filter(Price.ticker == benchmark_ticker) .filter(Price.date >= start_date) .filter(Price.date <= end_date) .all() ) return {p.date: p.close for p in prices} def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]: """Get all stock prices for a specific date.""" prices = ( self.db.query(Price) .filter(Price.date == trading_date) .all() ) return {p.ticker: p.close for p in prices} def _get_stock_names(self) -> Dict[str, str]: """Get all stock names.""" stocks = self.db.query(Stock).all() return {s.ticker: s.name for s in stocks} def _create_strategy( self, strategy_type: str, strategy_params: dict, top_n: int, ): """Create strategy instance based on type.""" if strategy_type == "multi_factor": strategy = MultiFactorStrategy(self.db) strategy._weights = FactorWeights(**strategy_params.get("weights", {})) elif strategy_type == "quality": strategy = QualityStrategy(self.db) strategy._min_fscore = strategy_params.get("min_fscore", 7) elif strategy_type == "value_momentum": strategy = ValueMomentumStrategy(self.db) strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5))) strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5))) else: raise ValueError(f"Unknown strategy type: {strategy_type}") return strategy def _save_results( self, backtest_id: int, metrics, equity_curve_data: List[Dict], drawdowns: List[Decimal], holdings_history: List[Dict], transactions: List, ) -> None: """Save all backtest results to database.""" # Save metrics result = BacktestResult( backtest_id=backtest_id, total_return=metrics.total_return, cagr=metrics.cagr, mdd=metrics.mdd, sharpe_ratio=metrics.sharpe_ratio, volatility=metrics.volatility, benchmark_return=metrics.benchmark_return, excess_return=metrics.excess_return, ) self.db.add(result) # Save equity curve for i, point in enumerate(equity_curve_data): curve_point = BacktestEquityCurve( backtest_id=backtest_id, date=point['date'], portfolio_value=point['portfolio_value'], benchmark_value=point['benchmark_value'], drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"), ) self.db.add(curve_point) # Save holdings history for record in holdings_history: for holding in record['holdings']: h = BacktestHolding( backtest_id=backtest_id, rebalance_date=record['date'], ticker=holding.ticker, name=holding.name, weight=holding.weight, shares=holding.shares, price=holding.price, ) self.db.add(h) # Save transactions for trading_date, txn in transactions: t = BacktestTransaction( backtest_id=backtest_id, date=trading_date, ticker=txn.ticker, action=txn.action, shares=txn.shares, price=txn.price, commission=txn.commission, ) self.db.add(t) self.db.commit()