galaxis-po/backend/app/services/backtest/walkforward_engine.py
머니페니 f818bd3290 feat: add walk-forward analysis for backtests
- Add WalkForwardResult model with train/test window metrics
- Create WalkForwardEngine that reuses existing BacktestEngine
  with rolling train/test window splits
- Add POST/GET /api/backtest/{id}/walkforward endpoints
- Add Walk-forward tab to backtest detail page with parameter
  controls, cumulative return chart, and window results table
- Add Alembic migration for walkforward_results table
- Add 8 unit tests for window generation logic (100 total passed)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 22:33:41 +09:00

236 lines
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Walk-forward analysis engine.
Splits backtest period into rolling train/test windows and runs
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
Train window results are used for validation only (no parameter optimisation yet).
"""
import logging
from dataclasses import dataclass
from datetime import date
from decimal import Decimal
from typing import List
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
WalkForwardResult,
)
from app.services.backtest.engine import BacktestEngine
from app.services.backtest.metrics import MetricsCalculator
logger = logging.getLogger(__name__)
@dataclass
class Window:
index: int
train_start: date
train_end: date
test_start: date
test_end: date
class WalkForwardEngine:
"""
Walk-forward analysis using existing BacktestEngine.
Parameters
----------
train_months : int length of in-sample (training) window
test_months : int length of out-of-sample (test) window
step_months : int how far the window slides each iteration
"""
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# public
# ------------------------------------------------------------------
def run(
self,
backtest_id: int,
train_months: int = 12,
test_months: int = 3,
step_months: int = 3,
) -> List[WalkForwardResult]:
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
windows = self._generate_windows(
start=backtest.start_date,
end=backtest.end_date,
train_months=train_months,
test_months=test_months,
step_months=step_months,
)
if not windows:
raise ValueError(
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
f"최소 {train_months + test_months}개월 필요"
)
# Delete previous walk-forward results for this backtest
self.db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
self.db.flush()
engine = BacktestEngine(self.db)
results: List[WalkForwardResult] = []
for win in windows:
logger.info(
f"Walk-forward window {win.index}: "
f"test {win.test_start} ~ {win.test_end}"
)
test_return, test_sharpe, test_mdd = self._run_window(
engine, backtest, win
)
wf = WalkForwardResult(
backtest_id=backtest_id,
window_index=win.index,
train_start=win.train_start,
train_end=win.train_end,
test_start=win.test_start,
test_end=win.test_end,
test_return=test_return,
test_sharpe=test_sharpe,
test_mdd=test_mdd,
)
self.db.add(wf)
results.append(wf)
self.db.commit()
return results
# ------------------------------------------------------------------
# window generation
# ------------------------------------------------------------------
@staticmethod
def _generate_windows(
start: date,
end: date,
train_months: int,
test_months: int,
step_months: int,
) -> List[Window]:
windows: List[Window] = []
idx = 0
cursor = start
while True:
train_start = cursor
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
test_start = train_end + relativedelta(days=1)
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
if test_end > end:
# Allow partial last window if test_start is before end
if test_start <= end:
test_end = end
else:
break
windows.append(Window(
index=idx,
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
))
idx += 1
cursor += relativedelta(months=step_months)
return windows
# ------------------------------------------------------------------
# single window execution
# ------------------------------------------------------------------
def _run_window(
self,
engine: BacktestEngine,
backtest: Backtest,
win: Window,
) -> tuple:
"""Run backtest on the test window and return (return, sharpe, mdd)."""
try:
trading_days = engine._get_trading_days(win.test_start, win.test_end)
if not trading_days:
return (Decimal("0"), Decimal("0"), Decimal("0"))
benchmark_prices = engine._load_benchmark_prices(
backtest.benchmark, win.test_start, win.test_end
)
strategy = engine._create_strategy(
backtest.strategy_type,
backtest.strategy_params or {},
backtest.top_n,
)
from app.services.backtest.portfolio import VirtualPortfolio
from app.schemas.strategy import UniverseFilter
portfolio = VirtualPortfolio(backtest.initial_capital)
rebalance_dates = engine._generate_rebalance_dates(
win.test_start, win.test_end, backtest.rebalance_period,
)
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
equity_curve: List[Decimal] = []
benchmark_curve: List[Decimal] = []
for trading_date in trading_days:
prices = engine._get_prices_for_date(trading_date)
names = engine._get_stock_names()
if trading_date in rebalance_dates:
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve.append(Decimal(str(portfolio_value)))
benchmark_curve.append(Decimal(str(normalized_benchmark)))
if len(equity_curve) < 2:
return (Decimal("0"), Decimal("0"), Decimal("0"))
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
except Exception as e:
logger.warning(f"Walk-forward window {win.index} failed: {e}")
return (Decimal("0"), Decimal("0"), Decimal("0"))