- Add WalkForwardResult model with train/test window metrics
- Create WalkForwardEngine that reuses existing BacktestEngine
with rolling train/test window splits
- Add POST/GET /api/backtest/{id}/walkforward endpoints
- Add Walk-forward tab to backtest detail page with parameter
controls, cumulative return chart, and window results table
- Add Alembic migration for walkforward_results table
- Add 8 unit tests for window generation logic (100 total passed)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
"""
|
||
Walk-forward analysis engine.
|
||
|
||
Splits backtest period into rolling train/test windows and runs
|
||
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
|
||
Train window results are used for validation only (no parameter optimisation yet).
|
||
"""
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from datetime import date
|
||
from decimal import Decimal
|
||
from typing import List
|
||
|
||
from dateutil.relativedelta import relativedelta
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.models.backtest import (
|
||
Backtest,
|
||
BacktestResult,
|
||
BacktestEquityCurve,
|
||
WalkForwardResult,
|
||
)
|
||
from app.services.backtest.engine import BacktestEngine
|
||
from app.services.backtest.metrics import MetricsCalculator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class Window:
|
||
index: int
|
||
train_start: date
|
||
train_end: date
|
||
test_start: date
|
||
test_end: date
|
||
|
||
|
||
class WalkForwardEngine:
|
||
"""
|
||
Walk-forward analysis using existing BacktestEngine.
|
||
|
||
Parameters
|
||
----------
|
||
train_months : int – length of in-sample (training) window
|
||
test_months : int – length of out-of-sample (test) window
|
||
step_months : int – how far the window slides each iteration
|
||
"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
# ------------------------------------------------------------------
|
||
# public
|
||
# ------------------------------------------------------------------
|
||
|
||
def run(
|
||
self,
|
||
backtest_id: int,
|
||
train_months: int = 12,
|
||
test_months: int = 3,
|
||
step_months: int = 3,
|
||
) -> List[WalkForwardResult]:
|
||
backtest = self.db.query(Backtest).get(backtest_id)
|
||
if not backtest:
|
||
raise ValueError(f"Backtest {backtest_id} not found")
|
||
|
||
windows = self._generate_windows(
|
||
start=backtest.start_date,
|
||
end=backtest.end_date,
|
||
train_months=train_months,
|
||
test_months=test_months,
|
||
step_months=step_months,
|
||
)
|
||
|
||
if not windows:
|
||
raise ValueError(
|
||
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
|
||
f"최소 {train_months + test_months}개월 필요"
|
||
)
|
||
|
||
# Delete previous walk-forward results for this backtest
|
||
self.db.query(WalkForwardResult).filter(
|
||
WalkForwardResult.backtest_id == backtest_id
|
||
).delete()
|
||
self.db.flush()
|
||
|
||
engine = BacktestEngine(self.db)
|
||
results: List[WalkForwardResult] = []
|
||
|
||
for win in windows:
|
||
logger.info(
|
||
f"Walk-forward window {win.index}: "
|
||
f"test {win.test_start} ~ {win.test_end}"
|
||
)
|
||
|
||
test_return, test_sharpe, test_mdd = self._run_window(
|
||
engine, backtest, win
|
||
)
|
||
|
||
wf = WalkForwardResult(
|
||
backtest_id=backtest_id,
|
||
window_index=win.index,
|
||
train_start=win.train_start,
|
||
train_end=win.train_end,
|
||
test_start=win.test_start,
|
||
test_end=win.test_end,
|
||
test_return=test_return,
|
||
test_sharpe=test_sharpe,
|
||
test_mdd=test_mdd,
|
||
)
|
||
self.db.add(wf)
|
||
results.append(wf)
|
||
|
||
self.db.commit()
|
||
return results
|
||
|
||
# ------------------------------------------------------------------
|
||
# window generation
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _generate_windows(
|
||
start: date,
|
||
end: date,
|
||
train_months: int,
|
||
test_months: int,
|
||
step_months: int,
|
||
) -> List[Window]:
|
||
windows: List[Window] = []
|
||
idx = 0
|
||
cursor = start
|
||
|
||
while True:
|
||
train_start = cursor
|
||
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
|
||
test_start = train_end + relativedelta(days=1)
|
||
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
|
||
|
||
if test_end > end:
|
||
# Allow partial last window if test_start is before end
|
||
if test_start <= end:
|
||
test_end = end
|
||
else:
|
||
break
|
||
|
||
windows.append(Window(
|
||
index=idx,
|
||
train_start=train_start,
|
||
train_end=train_end,
|
||
test_start=test_start,
|
||
test_end=test_end,
|
||
))
|
||
idx += 1
|
||
cursor += relativedelta(months=step_months)
|
||
|
||
return windows
|
||
|
||
# ------------------------------------------------------------------
|
||
# single window execution
|
||
# ------------------------------------------------------------------
|
||
|
||
def _run_window(
|
||
self,
|
||
engine: BacktestEngine,
|
||
backtest: Backtest,
|
||
win: Window,
|
||
) -> tuple:
|
||
"""Run backtest on the test window and return (return, sharpe, mdd)."""
|
||
try:
|
||
trading_days = engine._get_trading_days(win.test_start, win.test_end)
|
||
if not trading_days:
|
||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
||
|
||
benchmark_prices = engine._load_benchmark_prices(
|
||
backtest.benchmark, win.test_start, win.test_end
|
||
)
|
||
|
||
strategy = engine._create_strategy(
|
||
backtest.strategy_type,
|
||
backtest.strategy_params or {},
|
||
backtest.top_n,
|
||
)
|
||
|
||
from app.services.backtest.portfolio import VirtualPortfolio
|
||
from app.schemas.strategy import UniverseFilter
|
||
|
||
portfolio = VirtualPortfolio(backtest.initial_capital)
|
||
|
||
rebalance_dates = engine._generate_rebalance_dates(
|
||
win.test_start, win.test_end, backtest.rebalance_period,
|
||
)
|
||
|
||
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
||
if initial_benchmark == 0:
|
||
initial_benchmark = Decimal("1")
|
||
|
||
equity_curve: List[Decimal] = []
|
||
benchmark_curve: List[Decimal] = []
|
||
|
||
for trading_date in trading_days:
|
||
prices = engine._get_prices_for_date(trading_date)
|
||
names = engine._get_stock_names()
|
||
|
||
if trading_date in rebalance_dates:
|
||
target_stocks = strategy.run(
|
||
universe_filter=UniverseFilter(),
|
||
top_n=backtest.top_n,
|
||
base_date=trading_date,
|
||
)
|
||
target_tickers = [s.ticker for s in target_stocks.stocks]
|
||
portfolio.rebalance(
|
||
target_tickers=target_tickers,
|
||
prices=prices,
|
||
names=names,
|
||
commission_rate=backtest.commission_rate,
|
||
slippage_rate=backtest.slippage_rate,
|
||
)
|
||
|
||
portfolio_value = portfolio.get_value(prices)
|
||
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
||
normalized_benchmark = (
|
||
benchmark_value / initial_benchmark * backtest.initial_capital
|
||
)
|
||
equity_curve.append(Decimal(str(portfolio_value)))
|
||
benchmark_curve.append(Decimal(str(normalized_benchmark)))
|
||
|
||
if len(equity_curve) < 2:
|
||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
||
|
||
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
|
||
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Walk-forward window {win.index} failed: {e}")
|
||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|