""" Walk-forward analysis engine. Splits backtest period into rolling train/test windows and runs the existing BacktestEngine (or DailyBacktestEngine) on each test window. Train window results are used for validation only (no parameter optimisation yet). """ import logging from dataclasses import dataclass from datetime import date from decimal import Decimal from typing import List from dateutil.relativedelta import relativedelta from sqlalchemy.orm import Session from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, WalkForwardResult, ) from app.services.backtest.engine import BacktestEngine from app.services.backtest.metrics import MetricsCalculator logger = logging.getLogger(__name__) @dataclass class Window: index: int train_start: date train_end: date test_start: date test_end: date class WalkForwardEngine: """ Walk-forward analysis using existing BacktestEngine. Parameters ---------- train_months : int – length of in-sample (training) window test_months : int – length of out-of-sample (test) window step_months : int – how far the window slides each iteration """ def __init__(self, db: Session): self.db = db # ------------------------------------------------------------------ # public # ------------------------------------------------------------------ def run( self, backtest_id: int, train_months: int = 12, test_months: int = 3, step_months: int = 3, ) -> List[WalkForwardResult]: backtest = self.db.query(Backtest).get(backtest_id) if not backtest: raise ValueError(f"Backtest {backtest_id} not found") windows = self._generate_windows( start=backtest.start_date, end=backtest.end_date, train_months=train_months, test_months=test_months, step_months=step_months, ) if not windows: raise ValueError( "기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. " f"최소 {train_months + test_months}개월 필요" ) # Delete previous walk-forward results for this backtest self.db.query(WalkForwardResult).filter( WalkForwardResult.backtest_id == backtest_id ).delete() self.db.flush() engine = BacktestEngine(self.db) results: List[WalkForwardResult] = [] for win in windows: logger.info( f"Walk-forward window {win.index}: " f"test {win.test_start} ~ {win.test_end}" ) test_return, test_sharpe, test_mdd = self._run_window( engine, backtest, win ) wf = WalkForwardResult( backtest_id=backtest_id, window_index=win.index, train_start=win.train_start, train_end=win.train_end, test_start=win.test_start, test_end=win.test_end, test_return=test_return, test_sharpe=test_sharpe, test_mdd=test_mdd, ) self.db.add(wf) results.append(wf) self.db.commit() return results # ------------------------------------------------------------------ # window generation # ------------------------------------------------------------------ @staticmethod def _generate_windows( start: date, end: date, train_months: int, test_months: int, step_months: int, ) -> List[Window]: windows: List[Window] = [] idx = 0 cursor = start while True: train_start = cursor train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1) test_start = train_end + relativedelta(days=1) test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1) if test_end > end: # Allow partial last window if test_start is before end if test_start <= end: test_end = end else: break windows.append(Window( index=idx, train_start=train_start, train_end=train_end, test_start=test_start, test_end=test_end, )) idx += 1 cursor += relativedelta(months=step_months) return windows # ------------------------------------------------------------------ # single window execution # ------------------------------------------------------------------ def _run_window( self, engine: BacktestEngine, backtest: Backtest, win: Window, ) -> tuple: """Run backtest on the test window and return (return, sharpe, mdd).""" try: trading_days = engine._get_trading_days(win.test_start, win.test_end) if not trading_days: return (Decimal("0"), Decimal("0"), Decimal("0")) benchmark_prices = engine._load_benchmark_prices( backtest.benchmark, win.test_start, win.test_end ) strategy = engine._create_strategy( backtest.strategy_type, backtest.strategy_params or {}, backtest.top_n, ) from app.services.backtest.portfolio import VirtualPortfolio from app.schemas.strategy import UniverseFilter portfolio = VirtualPortfolio(backtest.initial_capital) rebalance_dates = engine._generate_rebalance_dates( win.test_start, win.test_end, backtest.rebalance_period, ) initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) if initial_benchmark == 0: initial_benchmark = Decimal("1") equity_curve: List[Decimal] = [] benchmark_curve: List[Decimal] = [] for trading_date in trading_days: prices = engine._get_prices_for_date(trading_date) names = engine._get_stock_names() if trading_date in rebalance_dates: target_stocks = strategy.run( universe_filter=UniverseFilter(), top_n=backtest.top_n, base_date=trading_date, ) target_tickers = [s.ticker for s in target_stocks.stocks] portfolio.rebalance( target_tickers=target_tickers, prices=prices, names=names, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, ) portfolio_value = portfolio.get_value(prices) benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) normalized_benchmark = ( benchmark_value / initial_benchmark * backtest.initial_capital ) equity_curve.append(Decimal(str(portfolio_value))) benchmark_curve.append(Decimal(str(normalized_benchmark))) if len(equity_curve) < 2: return (Decimal("0"), Decimal("0"), Decimal("0")) metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve) return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd) except Exception as e: logger.warning(f"Walk-forward window {win.index} failed: {e}") return (Decimal("0"), Decimal("0"), Decimal("0"))