feat: add walk-forward analysis for backtests
- Add WalkForwardResult model with train/test window metrics
- Create WalkForwardEngine that reuses existing BacktestEngine
with rolling train/test window splits
- Add POST/GET /api/backtest/{id}/walkforward endpoints
- Add Walk-forward tab to backtest detail page with parameter
controls, cumulative return chart, and window results table
- Add Alembic migration for walkforward_results table
- Add 8 unit tests for window generation logic (100 total passed)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
741b7fa7dd
commit
f818bd3290
@ -0,0 +1,45 @@
|
||||
"""add walkforward_results table
|
||||
|
||||
Revision ID: 59807c4e84ee
|
||||
Revises: b7c8d9e0f1a2
|
||||
Create Date: 2026-03-18 22:28:53.955519
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '59807c4e84ee'
|
||||
down_revision: Union[str, None] = 'b7c8d9e0f1a2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('walkforward_results',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('backtest_id', sa.Integer(), nullable=False),
|
||||
sa.Column('window_index', sa.Integer(), nullable=False),
|
||||
sa.Column('train_start', sa.Date(), nullable=False),
|
||||
sa.Column('train_end', sa.Date(), nullable=False),
|
||||
sa.Column('test_start', sa.Date(), nullable=False),
|
||||
sa.Column('test_end', sa.Date(), nullable=False),
|
||||
sa.Column('test_return', sa.Numeric(precision=10, scale=4), nullable=True),
|
||||
sa.Column('test_sharpe', sa.Numeric(precision=10, scale=4), nullable=True),
|
||||
sa.Column('test_mdd', sa.Numeric(precision=10, scale=4), nullable=True),
|
||||
sa.ForeignKeyConstraint(['backtest_id'], ['backtests.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_walkforward_results_id'), 'walkforward_results', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_walkforward_results_id'), table_name='walkforward_results')
|
||||
op.drop_table('walkforward_results')
|
||||
# ### end Alembic commands ###
|
||||
@ -11,12 +11,15 @@ from app.api.deps import CurrentUser
|
||||
from app.models.backtest import (
|
||||
Backtest, BacktestResult, BacktestEquityCurve,
|
||||
BacktestHolding, BacktestTransaction, BacktestStatus,
|
||||
WalkForwardResult,
|
||||
)
|
||||
from app.schemas.backtest import (
|
||||
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
|
||||
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
|
||||
WalkForwardRequest, WalkForwardWindowResult, WalkForwardResponse,
|
||||
)
|
||||
from app.services.backtest import submit_backtest
|
||||
from app.services.backtest.walkforward_engine import WalkForwardEngine
|
||||
from app.services.rebalance import RebalanceService
|
||||
|
||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||
@ -250,6 +253,82 @@ async def get_transactions(
|
||||
]
|
||||
|
||||
|
||||
@router.post("/{backtest_id}/walkforward", response_model=dict)
|
||||
async def run_walkforward(
|
||||
backtest_id: int,
|
||||
request: WalkForwardRequest,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Run walk-forward analysis on a completed backtest."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
if backtest.status != BacktestStatus.COMPLETED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
|
||||
)
|
||||
|
||||
engine = WalkForwardEngine(db)
|
||||
try:
|
||||
engine.run(
|
||||
backtest_id=backtest_id,
|
||||
train_months=request.train_months,
|
||||
test_months=request.test_months,
|
||||
step_months=request.step_months,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return {"status": "completed", "backtest_id": backtest_id}
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
|
||||
async def get_walkforward(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get walk-forward analysis results."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
results = (
|
||||
db.query(WalkForwardResult)
|
||||
.filter(WalkForwardResult.backtest_id == backtest_id)
|
||||
.order_by(WalkForwardResult.window_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
return WalkForwardResponse(
|
||||
backtest_id=backtest_id,
|
||||
windows=[
|
||||
WalkForwardWindowResult(
|
||||
window_index=r.window_index,
|
||||
train_start=r.train_start,
|
||||
train_end=r.train_end,
|
||||
test_start=r.test_start,
|
||||
test_end=r.test_end,
|
||||
test_return=r.test_return,
|
||||
test_sharpe=r.test_sharpe,
|
||||
test_mdd=r.test_mdd,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{backtest_id}")
|
||||
async def delete_backtest(
|
||||
backtest_id: int,
|
||||
@ -266,6 +345,9 @@ async def delete_backtest(
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete related data
|
||||
db.query(WalkForwardResult).filter(
|
||||
WalkForwardResult.backtest_id == backtest_id
|
||||
).delete()
|
||||
db.query(BacktestTransaction).filter(
|
||||
BacktestTransaction.backtest_id == backtest_id
|
||||
).delete()
|
||||
|
||||
@ -51,6 +51,7 @@ class Backtest(Base):
|
||||
equity_curve = relationship("BacktestEquityCurve", back_populates="backtest")
|
||||
holdings = relationship("BacktestHolding", back_populates="backtest")
|
||||
transactions = relationship("BacktestTransaction", back_populates="backtest")
|
||||
walkforward_results = relationship("WalkForwardResult", back_populates="backtest")
|
||||
|
||||
|
||||
class BacktestResult(Base):
|
||||
@ -107,3 +108,20 @@ class BacktestTransaction(Base):
|
||||
commission = Column(Numeric(12, 2), nullable=False)
|
||||
|
||||
backtest = relationship("Backtest", back_populates="transactions")
|
||||
|
||||
|
||||
class WalkForwardResult(Base):
|
||||
__tablename__ = "walkforward_results"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
backtest_id = Column(Integer, ForeignKey("backtests.id"), nullable=False)
|
||||
window_index = Column(Integer, nullable=False)
|
||||
train_start = Column(Date, nullable=False)
|
||||
train_end = Column(Date, nullable=False)
|
||||
test_start = Column(Date, nullable=False)
|
||||
test_end = Column(Date, nullable=False)
|
||||
test_return = Column(Numeric(10, 4), nullable=True)
|
||||
test_sharpe = Column(Numeric(10, 4), nullable=True)
|
||||
test_mdd = Column(Numeric(10, 4), nullable=True)
|
||||
|
||||
backtest = relationship("Backtest", back_populates="walkforward_results")
|
||||
|
||||
@ -129,3 +129,31 @@ class TransactionItem(BaseModel):
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WalkForwardRequest(BaseModel):
|
||||
"""Request to run walk-forward analysis."""
|
||||
train_months: int = Field(default=12, ge=3, le=60)
|
||||
test_months: int = Field(default=3, ge=1, le=24)
|
||||
step_months: int = Field(default=3, ge=1, le=24)
|
||||
|
||||
|
||||
class WalkForwardWindowResult(BaseModel):
|
||||
"""Single walk-forward window result."""
|
||||
window_index: int
|
||||
train_start: date
|
||||
train_end: date
|
||||
test_start: date
|
||||
test_end: date
|
||||
test_return: FloatDecimal | None = None
|
||||
test_sharpe: FloatDecimal | None = None
|
||||
test_mdd: FloatDecimal | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WalkForwardResponse(BaseModel):
|
||||
"""Walk-forward analysis results."""
|
||||
backtest_id: int
|
||||
windows: List[WalkForwardWindowResult]
|
||||
|
||||
@ -4,6 +4,7 @@ from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics
|
||||
from app.services.backtest.worker import submit_backtest, get_executor_status
|
||||
from app.services.backtest.daily_engine import DailyBacktestEngine
|
||||
from app.services.backtest.trading_portfolio import TradingPortfolio, TradingTransaction
|
||||
from app.services.backtest.walkforward_engine import WalkForwardEngine
|
||||
|
||||
__all__ = [
|
||||
"BacktestEngine",
|
||||
@ -18,4 +19,5 @@ __all__ = [
|
||||
"DailyBacktestEngine",
|
||||
"TradingPortfolio",
|
||||
"TradingTransaction",
|
||||
"WalkForwardEngine",
|
||||
]
|
||||
|
||||
235
backend/app/services/backtest/walkforward_engine.py
Normal file
235
backend/app/services/backtest/walkforward_engine.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""
|
||||
Walk-forward analysis engine.
|
||||
|
||||
Splits backtest period into rolling train/test windows and runs
|
||||
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
|
||||
Train window results are used for validation only (no parameter optimisation yet).
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.backtest import (
|
||||
Backtest,
|
||||
BacktestResult,
|
||||
BacktestEquityCurve,
|
||||
WalkForwardResult,
|
||||
)
|
||||
from app.services.backtest.engine import BacktestEngine
|
||||
from app.services.backtest.metrics import MetricsCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Window:
|
||||
index: int
|
||||
train_start: date
|
||||
train_end: date
|
||||
test_start: date
|
||||
test_end: date
|
||||
|
||||
|
||||
class WalkForwardEngine:
|
||||
"""
|
||||
Walk-forward analysis using existing BacktestEngine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train_months : int – length of in-sample (training) window
|
||||
test_months : int – length of out-of-sample (test) window
|
||||
step_months : int – how far the window slides each iteration
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(
|
||||
self,
|
||||
backtest_id: int,
|
||||
train_months: int = 12,
|
||||
test_months: int = 3,
|
||||
step_months: int = 3,
|
||||
) -> List[WalkForwardResult]:
|
||||
backtest = self.db.query(Backtest).get(backtest_id)
|
||||
if not backtest:
|
||||
raise ValueError(f"Backtest {backtest_id} not found")
|
||||
|
||||
windows = self._generate_windows(
|
||||
start=backtest.start_date,
|
||||
end=backtest.end_date,
|
||||
train_months=train_months,
|
||||
test_months=test_months,
|
||||
step_months=step_months,
|
||||
)
|
||||
|
||||
if not windows:
|
||||
raise ValueError(
|
||||
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
|
||||
f"최소 {train_months + test_months}개월 필요"
|
||||
)
|
||||
|
||||
# Delete previous walk-forward results for this backtest
|
||||
self.db.query(WalkForwardResult).filter(
|
||||
WalkForwardResult.backtest_id == backtest_id
|
||||
).delete()
|
||||
self.db.flush()
|
||||
|
||||
engine = BacktestEngine(self.db)
|
||||
results: List[WalkForwardResult] = []
|
||||
|
||||
for win in windows:
|
||||
logger.info(
|
||||
f"Walk-forward window {win.index}: "
|
||||
f"test {win.test_start} ~ {win.test_end}"
|
||||
)
|
||||
|
||||
test_return, test_sharpe, test_mdd = self._run_window(
|
||||
engine, backtest, win
|
||||
)
|
||||
|
||||
wf = WalkForwardResult(
|
||||
backtest_id=backtest_id,
|
||||
window_index=win.index,
|
||||
train_start=win.train_start,
|
||||
train_end=win.train_end,
|
||||
test_start=win.test_start,
|
||||
test_end=win.test_end,
|
||||
test_return=test_return,
|
||||
test_sharpe=test_sharpe,
|
||||
test_mdd=test_mdd,
|
||||
)
|
||||
self.db.add(wf)
|
||||
results.append(wf)
|
||||
|
||||
self.db.commit()
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# window generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _generate_windows(
|
||||
start: date,
|
||||
end: date,
|
||||
train_months: int,
|
||||
test_months: int,
|
||||
step_months: int,
|
||||
) -> List[Window]:
|
||||
windows: List[Window] = []
|
||||
idx = 0
|
||||
cursor = start
|
||||
|
||||
while True:
|
||||
train_start = cursor
|
||||
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
|
||||
test_start = train_end + relativedelta(days=1)
|
||||
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
|
||||
|
||||
if test_end > end:
|
||||
# Allow partial last window if test_start is before end
|
||||
if test_start <= end:
|
||||
test_end = end
|
||||
else:
|
||||
break
|
||||
|
||||
windows.append(Window(
|
||||
index=idx,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
))
|
||||
idx += 1
|
||||
cursor += relativedelta(months=step_months)
|
||||
|
||||
return windows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# single window execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_window(
|
||||
self,
|
||||
engine: BacktestEngine,
|
||||
backtest: Backtest,
|
||||
win: Window,
|
||||
) -> tuple:
|
||||
"""Run backtest on the test window and return (return, sharpe, mdd)."""
|
||||
try:
|
||||
trading_days = engine._get_trading_days(win.test_start, win.test_end)
|
||||
if not trading_days:
|
||||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
||||
|
||||
benchmark_prices = engine._load_benchmark_prices(
|
||||
backtest.benchmark, win.test_start, win.test_end
|
||||
)
|
||||
|
||||
strategy = engine._create_strategy(
|
||||
backtest.strategy_type,
|
||||
backtest.strategy_params or {},
|
||||
backtest.top_n,
|
||||
)
|
||||
|
||||
from app.services.backtest.portfolio import VirtualPortfolio
|
||||
from app.schemas.strategy import UniverseFilter
|
||||
|
||||
portfolio = VirtualPortfolio(backtest.initial_capital)
|
||||
|
||||
rebalance_dates = engine._generate_rebalance_dates(
|
||||
win.test_start, win.test_end, backtest.rebalance_period,
|
||||
)
|
||||
|
||||
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
||||
if initial_benchmark == 0:
|
||||
initial_benchmark = Decimal("1")
|
||||
|
||||
equity_curve: List[Decimal] = []
|
||||
benchmark_curve: List[Decimal] = []
|
||||
|
||||
for trading_date in trading_days:
|
||||
prices = engine._get_prices_for_date(trading_date)
|
||||
names = engine._get_stock_names()
|
||||
|
||||
if trading_date in rebalance_dates:
|
||||
target_stocks = strategy.run(
|
||||
universe_filter=UniverseFilter(),
|
||||
top_n=backtest.top_n,
|
||||
base_date=trading_date,
|
||||
)
|
||||
target_tickers = [s.ticker for s in target_stocks.stocks]
|
||||
portfolio.rebalance(
|
||||
target_tickers=target_tickers,
|
||||
prices=prices,
|
||||
names=names,
|
||||
commission_rate=backtest.commission_rate,
|
||||
slippage_rate=backtest.slippage_rate,
|
||||
)
|
||||
|
||||
portfolio_value = portfolio.get_value(prices)
|
||||
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
||||
normalized_benchmark = (
|
||||
benchmark_value / initial_benchmark * backtest.initial_capital
|
||||
)
|
||||
equity_curve.append(Decimal(str(portfolio_value)))
|
||||
benchmark_curve.append(Decimal(str(normalized_benchmark)))
|
||||
|
||||
if len(equity_curve) < 2:
|
||||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
||||
|
||||
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
|
||||
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Walk-forward window {win.index} failed: {e}")
|
||||
return (Decimal("0"), Decimal("0"), Decimal("0"))
|
||||
115
backend/tests/unit/test_walkforward.py
Normal file
115
backend/tests/unit/test_walkforward.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""
|
||||
Unit tests for WalkForwardEngine window generation logic.
|
||||
"""
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.backtest.walkforward_engine import WalkForwardEngine, Window
|
||||
|
||||
|
||||
class TestGenerateWindows:
|
||||
"""Test _generate_windows static method."""
|
||||
|
||||
def test_basic_windows(self):
|
||||
"""2-year period with 12m train, 3m test, 3m step -> 4 windows."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2020, 1, 1),
|
||||
end=date(2021, 12, 31),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=3,
|
||||
)
|
||||
assert len(windows) == 4
|
||||
assert windows[0].index == 0
|
||||
assert windows[0].train_start == date(2020, 1, 1)
|
||||
assert windows[0].train_end == date(2020, 12, 31)
|
||||
assert windows[0].test_start == date(2021, 1, 1)
|
||||
assert windows[0].test_end == date(2021, 3, 31)
|
||||
|
||||
def test_single_window(self):
|
||||
"""Exactly 15 months -> 1 window with 12m train + 3m test."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2020, 1, 1),
|
||||
end=date(2021, 3, 31),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=3,
|
||||
)
|
||||
assert len(windows) == 1
|
||||
assert windows[0].train_start == date(2020, 1, 1)
|
||||
assert windows[0].test_end == date(2021, 3, 31)
|
||||
|
||||
def test_no_windows_period_too_short(self):
|
||||
"""Period shorter than train + test -> 0 windows."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2020, 1, 1),
|
||||
end=date(2020, 12, 31),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=3,
|
||||
)
|
||||
assert len(windows) == 0
|
||||
|
||||
def test_partial_last_window(self):
|
||||
"""Last window with partial test period is included."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2020, 1, 1),
|
||||
end=date(2021, 2, 15),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=3,
|
||||
)
|
||||
assert len(windows) == 1
|
||||
assert windows[0].test_end == date(2021, 2, 15)
|
||||
|
||||
def test_step_larger_than_test(self):
|
||||
"""step_months > test_months creates non-overlapping test windows."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2019, 1, 1),
|
||||
end=date(2022, 12, 31),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=6,
|
||||
)
|
||||
assert len(windows) >= 2
|
||||
# test windows should not overlap
|
||||
for i in range(1, len(windows)):
|
||||
assert windows[i].test_start > windows[i - 1].test_end
|
||||
|
||||
def test_monthly_step(self):
|
||||
"""step_months=1 creates many overlapping windows."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2020, 1, 1),
|
||||
end=date(2021, 6, 30),
|
||||
train_months=6,
|
||||
test_months=3,
|
||||
step_months=1,
|
||||
)
|
||||
assert len(windows) >= 9
|
||||
|
||||
def test_window_indices_sequential(self):
|
||||
"""Window indices should be sequential starting from 0."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2019, 1, 1),
|
||||
end=date(2022, 12, 31),
|
||||
train_months=12,
|
||||
test_months=3,
|
||||
step_months=3,
|
||||
)
|
||||
for i, w in enumerate(windows):
|
||||
assert w.index == i
|
||||
|
||||
def test_window_dates_consistent(self):
|
||||
"""train_end < test_start and test_start <= test_end for all windows."""
|
||||
windows = WalkForwardEngine._generate_windows(
|
||||
start=date(2019, 1, 1),
|
||||
end=date(2023, 12, 31),
|
||||
train_months=12,
|
||||
test_months=6,
|
||||
step_months=3,
|
||||
)
|
||||
for w in windows:
|
||||
assert w.train_start < w.train_end
|
||||
assert w.train_end < w.test_start
|
||||
assert w.test_start <= w.test_end
|
||||
@ -67,6 +67,22 @@ interface TransactionItem {
|
||||
commission: number;
|
||||
}
|
||||
|
||||
interface WalkForwardWindow {
|
||||
window_index: number;
|
||||
train_start: string;
|
||||
train_end: string;
|
||||
test_start: string;
|
||||
test_end: string;
|
||||
test_return: number | null;
|
||||
test_sharpe: number | null;
|
||||
test_mdd: number | null;
|
||||
}
|
||||
|
||||
interface WalkForwardResponse {
|
||||
backtest_id: number;
|
||||
windows: WalkForwardWindow[];
|
||||
}
|
||||
|
||||
const strategyLabels: Record<string, string> = {
|
||||
multi_factor: '멀티 팩터',
|
||||
quality: '슈퍼 퀄리티',
|
||||
@ -90,8 +106,13 @@ export default function BacktestDetailPage() {
|
||||
const [equityCurve, setEquityCurve] = useState<EquityCurvePoint[]>([]);
|
||||
const [holdings, setHoldings] = useState<RebalanceHoldings[]>([]);
|
||||
const [transactions, setTransactions] = useState<TransactionItem[]>([]);
|
||||
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions'>('holdings');
|
||||
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions' | 'walkforward'>('holdings');
|
||||
const [selectedRebalance, setSelectedRebalance] = useState<string | null>(null);
|
||||
const [wfWindows, setWfWindows] = useState<WalkForwardWindow[]>([]);
|
||||
const [wfLoading, setWfLoading] = useState(false);
|
||||
const [wfTrainMonths, setWfTrainMonths] = useState(12);
|
||||
const [wfTestMonths, setWfTestMonths] = useState(3);
|
||||
const [wfStepMonths, setWfStepMonths] = useState(3);
|
||||
|
||||
const fetchBacktest = useCallback(async () => {
|
||||
try {
|
||||
@ -141,6 +162,37 @@ export default function BacktestDetailPage() {
|
||||
}
|
||||
}, [backtest, fetchBacktest]);
|
||||
|
||||
const fetchWalkForward = useCallback(async () => {
|
||||
try {
|
||||
const data = await api.get<WalkForwardResponse>(`/api/backtest/${backtestId}/walkforward`);
|
||||
setWfWindows(data.windows);
|
||||
} catch {
|
||||
setWfWindows([]);
|
||||
}
|
||||
}, [backtestId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activeTab === 'walkforward' && backtest?.status === 'completed') {
|
||||
fetchWalkForward();
|
||||
}
|
||||
}, [activeTab, backtest?.status, fetchWalkForward]);
|
||||
|
||||
const runWalkForward = async () => {
|
||||
setWfLoading(true);
|
||||
try {
|
||||
await api.post(`/api/backtest/${backtestId}/walkforward`, {
|
||||
train_months: wfTrainMonths,
|
||||
test_months: wfTestMonths,
|
||||
step_months: wfStepMonths,
|
||||
});
|
||||
await fetchWalkForward();
|
||||
} catch (err) {
|
||||
console.error('Walk-forward failed:', err);
|
||||
} finally {
|
||||
setWfLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const formatNumber = (value: number | null | undefined, decimals: number = 2) => {
|
||||
if (value === null || value === undefined) return '-';
|
||||
return value.toFixed(decimals);
|
||||
@ -356,6 +408,16 @@ export default function BacktestDetailPage() {
|
||||
>
|
||||
거래 내역
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setActiveTab('walkforward')}
|
||||
className={`px-6 py-3 text-sm font-medium ${
|
||||
activeTab === 'walkforward'
|
||||
? 'border-b-2 border-primary text-primary'
|
||||
: 'text-muted-foreground hover:text-foreground'
|
||||
}`}
|
||||
>
|
||||
Walk-forward
|
||||
</button>
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
@ -448,6 +510,116 @@ export default function BacktestDetailPage() {
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Walk-forward Tab */}
|
||||
{activeTab === 'walkforward' && (
|
||||
<CardContent className="p-4">
|
||||
<div className="flex flex-wrap gap-4 items-end mb-6">
|
||||
<div className="space-y-1">
|
||||
<Label htmlFor="wf-train">학습 기간 (월)</Label>
|
||||
<input
|
||||
id="wf-train"
|
||||
type="number"
|
||||
min={3}
|
||||
max={60}
|
||||
value={wfTrainMonths}
|
||||
onChange={(e) => setWfTrainMonths(Number(e.target.value))}
|
||||
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<Label htmlFor="wf-test">검증 기간 (월)</Label>
|
||||
<input
|
||||
id="wf-test"
|
||||
type="number"
|
||||
min={1}
|
||||
max={24}
|
||||
value={wfTestMonths}
|
||||
onChange={(e) => setWfTestMonths(Number(e.target.value))}
|
||||
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<Label htmlFor="wf-step">스텝 (월)</Label>
|
||||
<input
|
||||
id="wf-step"
|
||||
type="number"
|
||||
min={1}
|
||||
max={24}
|
||||
value={wfStepMonths}
|
||||
onChange={(e) => setWfStepMonths(Number(e.target.value))}
|
||||
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={runWalkForward}
|
||||
disabled={wfLoading}
|
||||
className="h-10 px-4 rounded-md bg-primary text-primary-foreground text-sm font-medium hover:bg-primary/90 disabled:opacity-50"
|
||||
>
|
||||
{wfLoading ? '분석 중...' : '분석 실행'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{wfWindows.length > 0 && (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
<h3 className="text-sm font-medium text-muted-foreground mb-2">누적 수익률 추이</h3>
|
||||
<AreaChart
|
||||
data={wfWindows.map((w) => {
|
||||
const cumReturn = wfWindows
|
||||
.filter((x) => x.window_index <= w.window_index)
|
||||
.reduce((acc, x) => acc * (1 + (x.test_return ?? 0) / 100), 1);
|
||||
return {
|
||||
date: w.test_end,
|
||||
value: (cumReturn - 1) * 100,
|
||||
};
|
||||
})}
|
||||
height={200}
|
||||
color="#3b82f6"
|
||||
showLegend={false}
|
||||
formatValue={(v) => `${v.toFixed(2)}%`}
|
||||
formatXAxis={(v) => v.slice(5)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full">
|
||||
<thead className="bg-muted">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground">#</th>
|
||||
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground">학습 기간</th>
|
||||
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground">검증 기간</th>
|
||||
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">수익률</th>
|
||||
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">샤프</th>
|
||||
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y divide-border">
|
||||
{wfWindows.map((w) => (
|
||||
<tr key={w.window_index} className="hover:bg-muted/50">
|
||||
<td className="px-4 py-3 text-sm text-center">{w.window_index + 1}</td>
|
||||
<td className="px-4 py-3 text-sm">{w.train_start} ~ {w.train_end}</td>
|
||||
<td className="px-4 py-3 text-sm">{w.test_start} ~ {w.test_end}</td>
|
||||
<td className={`px-4 py-3 text-sm text-right font-medium ${(w.test_return ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}`}>
|
||||
{formatNumber(w.test_return)}%
|
||||
</td>
|
||||
<td className="px-4 py-3 text-sm text-right">{formatNumber(w.test_sharpe)}</td>
|
||||
<td className="px-4 py-3 text-sm text-right text-red-600">{formatNumber(w.test_mdd)}%</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{wfWindows.length === 0 && !wfLoading && (
|
||||
<div className="py-8 text-center text-muted-foreground">
|
||||
파라미터를 설정하고 분석을 실행해주세요.
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
</>
|
||||
)}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user