galaxis-po/backend/app/services/backtest/walkforward_engine.py

236 lines
7.7 KiB
Python
Raw Normal View History

"""
Walk-forward analysis engine.
Splits backtest period into rolling train/test windows and runs
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
Train window results are used for validation only (no parameter optimisation yet).
"""
import logging
from dataclasses import dataclass
from datetime import date
from decimal import Decimal
from typing import List
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
WalkForwardResult,
)
from app.services.backtest.engine import BacktestEngine
from app.services.backtest.metrics import MetricsCalculator
logger = logging.getLogger(__name__)
@dataclass
class Window:
index: int
train_start: date
train_end: date
test_start: date
test_end: date
class WalkForwardEngine:
"""
Walk-forward analysis using existing BacktestEngine.
Parameters
----------
train_months : int length of in-sample (training) window
test_months : int length of out-of-sample (test) window
step_months : int how far the window slides each iteration
"""
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# public
# ------------------------------------------------------------------
def run(
self,
backtest_id: int,
train_months: int = 12,
test_months: int = 3,
step_months: int = 3,
) -> List[WalkForwardResult]:
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
windows = self._generate_windows(
start=backtest.start_date,
end=backtest.end_date,
train_months=train_months,
test_months=test_months,
step_months=step_months,
)
if not windows:
raise ValueError(
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
f"최소 {train_months + test_months}개월 필요"
)
# Delete previous walk-forward results for this backtest
self.db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
self.db.flush()
engine = BacktestEngine(self.db)
results: List[WalkForwardResult] = []
for win in windows:
logger.info(
f"Walk-forward window {win.index}: "
f"test {win.test_start} ~ {win.test_end}"
)
test_return, test_sharpe, test_mdd = self._run_window(
engine, backtest, win
)
wf = WalkForwardResult(
backtest_id=backtest_id,
window_index=win.index,
train_start=win.train_start,
train_end=win.train_end,
test_start=win.test_start,
test_end=win.test_end,
test_return=test_return,
test_sharpe=test_sharpe,
test_mdd=test_mdd,
)
self.db.add(wf)
results.append(wf)
self.db.commit()
return results
# ------------------------------------------------------------------
# window generation
# ------------------------------------------------------------------
@staticmethod
def _generate_windows(
start: date,
end: date,
train_months: int,
test_months: int,
step_months: int,
) -> List[Window]:
windows: List[Window] = []
idx = 0
cursor = start
while True:
train_start = cursor
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
test_start = train_end + relativedelta(days=1)
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
if test_end > end:
# Allow partial last window if test_start is before end
if test_start <= end:
test_end = end
else:
break
windows.append(Window(
index=idx,
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
))
idx += 1
cursor += relativedelta(months=step_months)
return windows
# ------------------------------------------------------------------
# single window execution
# ------------------------------------------------------------------
def _run_window(
self,
engine: BacktestEngine,
backtest: Backtest,
win: Window,
) -> tuple:
"""Run backtest on the test window and return (return, sharpe, mdd)."""
try:
trading_days = engine._get_trading_days(win.test_start, win.test_end)
if not trading_days:
return (Decimal("0"), Decimal("0"), Decimal("0"))
benchmark_prices = engine._load_benchmark_prices(
backtest.benchmark, win.test_start, win.test_end
)
strategy = engine._create_strategy(
backtest.strategy_type,
backtest.strategy_params or {},
backtest.top_n,
)
from app.services.backtest.portfolio import VirtualPortfolio
from app.schemas.strategy import UniverseFilter
portfolio = VirtualPortfolio(backtest.initial_capital)
rebalance_dates = engine._generate_rebalance_dates(
win.test_start, win.test_end, backtest.rebalance_period,
)
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
equity_curve: List[Decimal] = []
benchmark_curve: List[Decimal] = []
for trading_date in trading_days:
prices = engine._get_prices_for_date(trading_date)
names = engine._get_stock_names()
if trading_date in rebalance_dates:
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve.append(Decimal(str(portfolio_value)))
benchmark_curve.append(Decimal(str(normalized_benchmark)))
if len(equity_curve) < 2:
return (Decimal("0"), Decimal("0"), Decimal("0"))
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
except Exception as e:
logger.warning(f"Walk-forward window {win.index} failed: {e}")
return (Decimal("0"), Decimal("0"), Decimal("0"))