236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
|
|
"""
|
|||
|
|
Walk-forward analysis engine.
|
|||
|
|
|
|||
|
|
Splits backtest period into rolling train/test windows and runs
|
|||
|
|
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
|
|||
|
|
Train window results are used for validation only (no parameter optimisation yet).
|
|||
|
|
"""
|
|||
|
|
import logging
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from datetime import date
|
|||
|
|
from decimal import Decimal
|
|||
|
|
from typing import List
|
|||
|
|
|
|||
|
|
from dateutil.relativedelta import relativedelta
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.models.backtest import (
|
|||
|
|
Backtest,
|
|||
|
|
BacktestResult,
|
|||
|
|
BacktestEquityCurve,
|
|||
|
|
WalkForwardResult,
|
|||
|
|
)
|
|||
|
|
from app.services.backtest.engine import BacktestEngine
|
|||
|
|
from app.services.backtest.metrics import MetricsCalculator
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Window:
|
|||
|
|
index: int
|
|||
|
|
train_start: date
|
|||
|
|
train_end: date
|
|||
|
|
test_start: date
|
|||
|
|
test_end: date
|
|||
|
|
|
|||
|
|
|
|||
|
|
class WalkForwardEngine:
|
|||
|
|
"""
|
|||
|
|
Walk-forward analysis using existing BacktestEngine.
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
train_months : int – length of in-sample (training) window
|
|||
|
|
test_months : int – length of out-of-sample (test) window
|
|||
|
|
step_months : int – how far the window slides each iteration
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, db: Session):
|
|||
|
|
self.db = db
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# public
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def run(
|
|||
|
|
self,
|
|||
|
|
backtest_id: int,
|
|||
|
|
train_months: int = 12,
|
|||
|
|
test_months: int = 3,
|
|||
|
|
step_months: int = 3,
|
|||
|
|
) -> List[WalkForwardResult]:
|
|||
|
|
backtest = self.db.query(Backtest).get(backtest_id)
|
|||
|
|
if not backtest:
|
|||
|
|
raise ValueError(f"Backtest {backtest_id} not found")
|
|||
|
|
|
|||
|
|
windows = self._generate_windows(
|
|||
|
|
start=backtest.start_date,
|
|||
|
|
end=backtest.end_date,
|
|||
|
|
train_months=train_months,
|
|||
|
|
test_months=test_months,
|
|||
|
|
step_months=step_months,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not windows:
|
|||
|
|
raise ValueError(
|
|||
|
|
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
|
|||
|
|
f"최소 {train_months + test_months}개월 필요"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Delete previous walk-forward results for this backtest
|
|||
|
|
self.db.query(WalkForwardResult).filter(
|
|||
|
|
WalkForwardResult.backtest_id == backtest_id
|
|||
|
|
).delete()
|
|||
|
|
self.db.flush()
|
|||
|
|
|
|||
|
|
engine = BacktestEngine(self.db)
|
|||
|
|
results: List[WalkForwardResult] = []
|
|||
|
|
|
|||
|
|
for win in windows:
|
|||
|
|
logger.info(
|
|||
|
|
f"Walk-forward window {win.index}: "
|
|||
|
|
f"test {win.test_start} ~ {win.test_end}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
test_return, test_sharpe, test_mdd = self._run_window(
|
|||
|
|
engine, backtest, win
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
wf = WalkForwardResult(
|
|||
|
|
backtest_id=backtest_id,
|
|||
|
|
window_index=win.index,
|
|||
|
|
train_start=win.train_start,
|
|||
|
|
train_end=win.train_end,
|
|||
|
|
test_start=win.test_start,
|
|||
|
|
test_end=win.test_end,
|
|||
|
|
test_return=test_return,
|
|||
|
|
test_sharpe=test_sharpe,
|
|||
|
|
test_mdd=test_mdd,
|
|||
|
|
)
|
|||
|
|
self.db.add(wf)
|
|||
|
|
results.append(wf)
|
|||
|
|
|
|||
|
|
self.db.commit()
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# window generation
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _generate_windows(
|
|||
|
|
start: date,
|
|||
|
|
end: date,
|
|||
|
|
train_months: int,
|
|||
|
|
test_months: int,
|
|||
|
|
step_months: int,
|
|||
|
|
) -> List[Window]:
|
|||
|
|
windows: List[Window] = []
|
|||
|
|
idx = 0
|
|||
|
|
cursor = start
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
train_start = cursor
|
|||
|
|
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
|
|||
|
|
test_start = train_end + relativedelta(days=1)
|
|||
|
|
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
|
|||
|
|
|
|||
|
|
if test_end > end:
|
|||
|
|
# Allow partial last window if test_start is before end
|
|||
|
|
if test_start <= end:
|
|||
|
|
test_end = end
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
windows.append(Window(
|
|||
|
|
index=idx,
|
|||
|
|
train_start=train_start,
|
|||
|
|
train_end=train_end,
|
|||
|
|
test_start=test_start,
|
|||
|
|
test_end=test_end,
|
|||
|
|
))
|
|||
|
|
idx += 1
|
|||
|
|
cursor += relativedelta(months=step_months)
|
|||
|
|
|
|||
|
|
return windows
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# single window execution
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _run_window(
|
|||
|
|
self,
|
|||
|
|
engine: BacktestEngine,
|
|||
|
|
backtest: Backtest,
|
|||
|
|
win: Window,
|
|||
|
|
) -> tuple:
|
|||
|
|
"""Run backtest on the test window and return (return, sharpe, mdd)."""
|
|||
|
|
try:
|
|||
|
|
trading_days = engine._get_trading_days(win.test_start, win.test_end)
|
|||
|
|
if not trading_days:
|
|||
|
|
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
|||
|
|
|
|||
|
|
benchmark_prices = engine._load_benchmark_prices(
|
|||
|
|
backtest.benchmark, win.test_start, win.test_end
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
strategy = engine._create_strategy(
|
|||
|
|
backtest.strategy_type,
|
|||
|
|
backtest.strategy_params or {},
|
|||
|
|
backtest.top_n,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
from app.services.backtest.portfolio import VirtualPortfolio
|
|||
|
|
from app.schemas.strategy import UniverseFilter
|
|||
|
|
|
|||
|
|
portfolio = VirtualPortfolio(backtest.initial_capital)
|
|||
|
|
|
|||
|
|
rebalance_dates = engine._generate_rebalance_dates(
|
|||
|
|
win.test_start, win.test_end, backtest.rebalance_period,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
|||
|
|
if initial_benchmark == 0:
|
|||
|
|
initial_benchmark = Decimal("1")
|
|||
|
|
|
|||
|
|
equity_curve: List[Decimal] = []
|
|||
|
|
benchmark_curve: List[Decimal] = []
|
|||
|
|
|
|||
|
|
for trading_date in trading_days:
|
|||
|
|
prices = engine._get_prices_for_date(trading_date)
|
|||
|
|
names = engine._get_stock_names()
|
|||
|
|
|
|||
|
|
if trading_date in rebalance_dates:
|
|||
|
|
target_stocks = strategy.run(
|
|||
|
|
universe_filter=UniverseFilter(),
|
|||
|
|
top_n=backtest.top_n,
|
|||
|
|
base_date=trading_date,
|
|||
|
|
)
|
|||
|
|
target_tickers = [s.ticker for s in target_stocks.stocks]
|
|||
|
|
portfolio.rebalance(
|
|||
|
|
target_tickers=target_tickers,
|
|||
|
|
prices=prices,
|
|||
|
|
names=names,
|
|||
|
|
commission_rate=backtest.commission_rate,
|
|||
|
|
slippage_rate=backtest.slippage_rate,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
portfolio_value = portfolio.get_value(prices)
|
|||
|
|
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
|||
|
|
normalized_benchmark = (
|
|||
|
|
benchmark_value / initial_benchmark * backtest.initial_capital
|
|||
|
|
)
|
|||
|
|
equity_curve.append(Decimal(str(portfolio_value)))
|
|||
|
|
benchmark_curve.append(Decimal(str(normalized_benchmark)))
|
|||
|
|
|
|||
|
|
if len(equity_curve) < 2:
|
|||
|
|
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
|||
|
|
|
|||
|
|
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
|
|||
|
|
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"Walk-forward window {win.index} failed: {e}")
|
|||
|
|
return (Decimal("0"), Decimal("0"), Decimal("0"))
|