feat: add walk-forward analysis for backtests

- Add WalkForwardResult model with train/test window metrics
- Create WalkForwardEngine that reuses existing BacktestEngine
  with rolling train/test window splits
- Add POST/GET /api/backtest/{id}/walkforward endpoints
- Add Walk-forward tab to backtest detail page with parameter
  controls, cumulative return chart, and window results table
- Add Alembic migration for walkforward_results table
- Add 8 unit tests for window generation logic (100 total passed)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
머니페니 2026-03-18 22:33:41 +09:00
parent 741b7fa7dd
commit f818bd3290
8 changed files with 698 additions and 1 deletions

View File

@ -0,0 +1,45 @@
"""add walkforward_results table
Revision ID: 59807c4e84ee
Revises: b7c8d9e0f1a2
Create Date: 2026-03-18 22:28:53.955519
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '59807c4e84ee'
down_revision: Union[str, None] = 'b7c8d9e0f1a2'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('walkforward_results',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('backtest_id', sa.Integer(), nullable=False),
sa.Column('window_index', sa.Integer(), nullable=False),
sa.Column('train_start', sa.Date(), nullable=False),
sa.Column('train_end', sa.Date(), nullable=False),
sa.Column('test_start', sa.Date(), nullable=False),
sa.Column('test_end', sa.Date(), nullable=False),
sa.Column('test_return', sa.Numeric(precision=10, scale=4), nullable=True),
sa.Column('test_sharpe', sa.Numeric(precision=10, scale=4), nullable=True),
sa.Column('test_mdd', sa.Numeric(precision=10, scale=4), nullable=True),
sa.ForeignKeyConstraint(['backtest_id'], ['backtests.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_walkforward_results_id'), 'walkforward_results', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_walkforward_results_id'), table_name='walkforward_results')
op.drop_table('walkforward_results')
# ### end Alembic commands ###

View File

@ -11,12 +11,15 @@ from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, BacktestStatus,
WalkForwardResult,
)
from app.schemas.backtest import (
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
WalkForwardRequest, WalkForwardWindowResult, WalkForwardResponse,
)
from app.services.backtest import submit_backtest
from app.services.backtest.walkforward_engine import WalkForwardEngine
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@ -250,6 +253,82 @@ async def get_transactions(
]
@router.post("/{backtest_id}/walkforward", response_model=dict)
async def run_walkforward(
backtest_id: int,
request: WalkForwardRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Run walk-forward analysis on a completed backtest."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
if backtest.status != BacktestStatus.COMPLETED:
raise HTTPException(
status_code=400,
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
)
engine = WalkForwardEngine(db)
try:
engine.run(
backtest_id=backtest_id,
train_months=request.train_months,
test_months=request.test_months,
step_months=request.step_months,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "completed", "backtest_id": backtest_id}
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
async def get_walkforward(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get walk-forward analysis results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
results = (
db.query(WalkForwardResult)
.filter(WalkForwardResult.backtest_id == backtest_id)
.order_by(WalkForwardResult.window_index)
.all()
)
return WalkForwardResponse(
backtest_id=backtest_id,
windows=[
WalkForwardWindowResult(
window_index=r.window_index,
train_start=r.train_start,
train_end=r.train_end,
test_start=r.test_start,
test_end=r.test_end,
test_return=r.test_return,
test_sharpe=r.test_sharpe,
test_mdd=r.test_mdd,
)
for r in results
],
)
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
@ -266,6 +345,9 @@ async def delete_backtest(
raise HTTPException(status_code=403, detail="Not authorized")
# Delete related data
db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()

View File

@ -51,6 +51,7 @@ class Backtest(Base):
equity_curve = relationship("BacktestEquityCurve", back_populates="backtest")
holdings = relationship("BacktestHolding", back_populates="backtest")
transactions = relationship("BacktestTransaction", back_populates="backtest")
walkforward_results = relationship("WalkForwardResult", back_populates="backtest")
class BacktestResult(Base):
@ -107,3 +108,20 @@ class BacktestTransaction(Base):
commission = Column(Numeric(12, 2), nullable=False)
backtest = relationship("Backtest", back_populates="transactions")
class WalkForwardResult(Base):
__tablename__ = "walkforward_results"
id = Column(Integer, primary_key=True, index=True)
backtest_id = Column(Integer, ForeignKey("backtests.id"), nullable=False)
window_index = Column(Integer, nullable=False)
train_start = Column(Date, nullable=False)
train_end = Column(Date, nullable=False)
test_start = Column(Date, nullable=False)
test_end = Column(Date, nullable=False)
test_return = Column(Numeric(10, 4), nullable=True)
test_sharpe = Column(Numeric(10, 4), nullable=True)
test_mdd = Column(Numeric(10, 4), nullable=True)
backtest = relationship("Backtest", back_populates="walkforward_results")

View File

@ -129,3 +129,31 @@ class TransactionItem(BaseModel):
class Config:
from_attributes = True
class WalkForwardRequest(BaseModel):
"""Request to run walk-forward analysis."""
train_months: int = Field(default=12, ge=3, le=60)
test_months: int = Field(default=3, ge=1, le=24)
step_months: int = Field(default=3, ge=1, le=24)
class WalkForwardWindowResult(BaseModel):
"""Single walk-forward window result."""
window_index: int
train_start: date
train_end: date
test_start: date
test_end: date
test_return: FloatDecimal | None = None
test_sharpe: FloatDecimal | None = None
test_mdd: FloatDecimal | None = None
class Config:
from_attributes = True
class WalkForwardResponse(BaseModel):
"""Walk-forward analysis results."""
backtest_id: int
windows: List[WalkForwardWindowResult]

View File

@ -4,6 +4,7 @@ from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics
from app.services.backtest.worker import submit_backtest, get_executor_status
from app.services.backtest.daily_engine import DailyBacktestEngine
from app.services.backtest.trading_portfolio import TradingPortfolio, TradingTransaction
from app.services.backtest.walkforward_engine import WalkForwardEngine
__all__ = [
"BacktestEngine",
@ -18,4 +19,5 @@ __all__ = [
"DailyBacktestEngine",
"TradingPortfolio",
"TradingTransaction",
"WalkForwardEngine",
]

View File

@ -0,0 +1,235 @@
"""
Walk-forward analysis engine.
Splits backtest period into rolling train/test windows and runs
the existing BacktestEngine (or DailyBacktestEngine) on each test window.
Train window results are used for validation only (no parameter optimisation yet).
"""
import logging
from dataclasses import dataclass
from datetime import date
from decimal import Decimal
from typing import List
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
WalkForwardResult,
)
from app.services.backtest.engine import BacktestEngine
from app.services.backtest.metrics import MetricsCalculator
logger = logging.getLogger(__name__)
@dataclass
class Window:
index: int
train_start: date
train_end: date
test_start: date
test_end: date
class WalkForwardEngine:
"""
Walk-forward analysis using existing BacktestEngine.
Parameters
----------
train_months : int length of in-sample (training) window
test_months : int length of out-of-sample (test) window
step_months : int how far the window slides each iteration
"""
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# public
# ------------------------------------------------------------------
def run(
self,
backtest_id: int,
train_months: int = 12,
test_months: int = 3,
step_months: int = 3,
) -> List[WalkForwardResult]:
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
windows = self._generate_windows(
start=backtest.start_date,
end=backtest.end_date,
train_months=train_months,
test_months=test_months,
step_months=step_months,
)
if not windows:
raise ValueError(
"기간이 너무 짧아 walk-forward 윈도우를 생성할 수 없습니다. "
f"최소 {train_months + test_months}개월 필요"
)
# Delete previous walk-forward results for this backtest
self.db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
self.db.flush()
engine = BacktestEngine(self.db)
results: List[WalkForwardResult] = []
for win in windows:
logger.info(
f"Walk-forward window {win.index}: "
f"test {win.test_start} ~ {win.test_end}"
)
test_return, test_sharpe, test_mdd = self._run_window(
engine, backtest, win
)
wf = WalkForwardResult(
backtest_id=backtest_id,
window_index=win.index,
train_start=win.train_start,
train_end=win.train_end,
test_start=win.test_start,
test_end=win.test_end,
test_return=test_return,
test_sharpe=test_sharpe,
test_mdd=test_mdd,
)
self.db.add(wf)
results.append(wf)
self.db.commit()
return results
# ------------------------------------------------------------------
# window generation
# ------------------------------------------------------------------
@staticmethod
def _generate_windows(
start: date,
end: date,
train_months: int,
test_months: int,
step_months: int,
) -> List[Window]:
windows: List[Window] = []
idx = 0
cursor = start
while True:
train_start = cursor
train_end = train_start + relativedelta(months=train_months) - relativedelta(days=1)
test_start = train_end + relativedelta(days=1)
test_end = test_start + relativedelta(months=test_months) - relativedelta(days=1)
if test_end > end:
# Allow partial last window if test_start is before end
if test_start <= end:
test_end = end
else:
break
windows.append(Window(
index=idx,
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
))
idx += 1
cursor += relativedelta(months=step_months)
return windows
# ------------------------------------------------------------------
# single window execution
# ------------------------------------------------------------------
def _run_window(
self,
engine: BacktestEngine,
backtest: Backtest,
win: Window,
) -> tuple:
"""Run backtest on the test window and return (return, sharpe, mdd)."""
try:
trading_days = engine._get_trading_days(win.test_start, win.test_end)
if not trading_days:
return (Decimal("0"), Decimal("0"), Decimal("0"))
benchmark_prices = engine._load_benchmark_prices(
backtest.benchmark, win.test_start, win.test_end
)
strategy = engine._create_strategy(
backtest.strategy_type,
backtest.strategy_params or {},
backtest.top_n,
)
from app.services.backtest.portfolio import VirtualPortfolio
from app.schemas.strategy import UniverseFilter
portfolio = VirtualPortfolio(backtest.initial_capital)
rebalance_dates = engine._generate_rebalance_dates(
win.test_start, win.test_end, backtest.rebalance_period,
)
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
equity_curve: List[Decimal] = []
benchmark_curve: List[Decimal] = []
for trading_date in trading_days:
prices = engine._get_prices_for_date(trading_date)
names = engine._get_stock_names()
if trading_date in rebalance_dates:
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve.append(Decimal(str(portfolio_value)))
benchmark_curve.append(Decimal(str(normalized_benchmark)))
if len(equity_curve) < 2:
return (Decimal("0"), Decimal("0"), Decimal("0"))
metrics = MetricsCalculator.calculate_all(equity_curve, benchmark_curve)
return (metrics.total_return, metrics.sharpe_ratio, metrics.mdd)
except Exception as e:
logger.warning(f"Walk-forward window {win.index} failed: {e}")
return (Decimal("0"), Decimal("0"), Decimal("0"))

View File

@ -0,0 +1,115 @@
"""
Unit tests for WalkForwardEngine window generation logic.
"""
from datetime import date
import pytest
from app.services.backtest.walkforward_engine import WalkForwardEngine, Window
class TestGenerateWindows:
"""Test _generate_windows static method."""
def test_basic_windows(self):
"""2-year period with 12m train, 3m test, 3m step -> 4 windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 4
assert windows[0].index == 0
assert windows[0].train_start == date(2020, 1, 1)
assert windows[0].train_end == date(2020, 12, 31)
assert windows[0].test_start == date(2021, 1, 1)
assert windows[0].test_end == date(2021, 3, 31)
def test_single_window(self):
"""Exactly 15 months -> 1 window with 12m train + 3m test."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 3, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 1
assert windows[0].train_start == date(2020, 1, 1)
assert windows[0].test_end == date(2021, 3, 31)
def test_no_windows_period_too_short(self):
"""Period shorter than train + test -> 0 windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2020, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 0
def test_partial_last_window(self):
"""Last window with partial test period is included."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 2, 15),
train_months=12,
test_months=3,
step_months=3,
)
assert len(windows) == 1
assert windows[0].test_end == date(2021, 2, 15)
def test_step_larger_than_test(self):
"""step_months > test_months creates non-overlapping test windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2022, 12, 31),
train_months=12,
test_months=3,
step_months=6,
)
assert len(windows) >= 2
# test windows should not overlap
for i in range(1, len(windows)):
assert windows[i].test_start > windows[i - 1].test_end
def test_monthly_step(self):
"""step_months=1 creates many overlapping windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2020, 1, 1),
end=date(2021, 6, 30),
train_months=6,
test_months=3,
step_months=1,
)
assert len(windows) >= 9
def test_window_indices_sequential(self):
"""Window indices should be sequential starting from 0."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2022, 12, 31),
train_months=12,
test_months=3,
step_months=3,
)
for i, w in enumerate(windows):
assert w.index == i
def test_window_dates_consistent(self):
"""train_end < test_start and test_start <= test_end for all windows."""
windows = WalkForwardEngine._generate_windows(
start=date(2019, 1, 1),
end=date(2023, 12, 31),
train_months=12,
test_months=6,
step_months=3,
)
for w in windows:
assert w.train_start < w.train_end
assert w.train_end < w.test_start
assert w.test_start <= w.test_end

View File

@ -67,6 +67,22 @@ interface TransactionItem {
commission: number;
}
interface WalkForwardWindow {
window_index: number;
train_start: string;
train_end: string;
test_start: string;
test_end: string;
test_return: number | null;
test_sharpe: number | null;
test_mdd: number | null;
}
interface WalkForwardResponse {
backtest_id: number;
windows: WalkForwardWindow[];
}
const strategyLabels: Record<string, string> = {
multi_factor: '멀티 팩터',
quality: '슈퍼 퀄리티',
@ -90,8 +106,13 @@ export default function BacktestDetailPage() {
const [equityCurve, setEquityCurve] = useState<EquityCurvePoint[]>([]);
const [holdings, setHoldings] = useState<RebalanceHoldings[]>([]);
const [transactions, setTransactions] = useState<TransactionItem[]>([]);
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions'>('holdings');
const [activeTab, setActiveTab] = useState<'holdings' | 'transactions' | 'walkforward'>('holdings');
const [selectedRebalance, setSelectedRebalance] = useState<string | null>(null);
const [wfWindows, setWfWindows] = useState<WalkForwardWindow[]>([]);
const [wfLoading, setWfLoading] = useState(false);
const [wfTrainMonths, setWfTrainMonths] = useState(12);
const [wfTestMonths, setWfTestMonths] = useState(3);
const [wfStepMonths, setWfStepMonths] = useState(3);
const fetchBacktest = useCallback(async () => {
try {
@ -141,6 +162,37 @@ export default function BacktestDetailPage() {
}
}, [backtest, fetchBacktest]);
const fetchWalkForward = useCallback(async () => {
try {
const data = await api.get<WalkForwardResponse>(`/api/backtest/${backtestId}/walkforward`);
setWfWindows(data.windows);
} catch {
setWfWindows([]);
}
}, [backtestId]);
useEffect(() => {
if (activeTab === 'walkforward' && backtest?.status === 'completed') {
fetchWalkForward();
}
}, [activeTab, backtest?.status, fetchWalkForward]);
const runWalkForward = async () => {
setWfLoading(true);
try {
await api.post(`/api/backtest/${backtestId}/walkforward`, {
train_months: wfTrainMonths,
test_months: wfTestMonths,
step_months: wfStepMonths,
});
await fetchWalkForward();
} catch (err) {
console.error('Walk-forward failed:', err);
} finally {
setWfLoading(false);
}
};
const formatNumber = (value: number | null | undefined, decimals: number = 2) => {
if (value === null || value === undefined) return '-';
return value.toFixed(decimals);
@ -356,6 +408,16 @@ export default function BacktestDetailPage() {
>
</button>
<button
onClick={() => setActiveTab('walkforward')}
className={`px-6 py-3 text-sm font-medium ${
activeTab === 'walkforward'
? 'border-b-2 border-primary text-primary'
: 'text-muted-foreground hover:text-foreground'
}`}
>
Walk-forward
</button>
</nav>
</div>
@ -448,6 +510,116 @@ export default function BacktestDetailPage() {
</table>
</div>
)}
{/* Walk-forward Tab */}
{activeTab === 'walkforward' && (
<CardContent className="p-4">
<div className="flex flex-wrap gap-4 items-end mb-6">
<div className="space-y-1">
<Label htmlFor="wf-train"> ()</Label>
<input
id="wf-train"
type="number"
min={3}
max={60}
value={wfTrainMonths}
onChange={(e) => setWfTrainMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-1">
<Label htmlFor="wf-test"> ()</Label>
<input
id="wf-test"
type="number"
min={1}
max={24}
value={wfTestMonths}
onChange={(e) => setWfTestMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<div className="space-y-1">
<Label htmlFor="wf-step"> ()</Label>
<input
id="wf-step"
type="number"
min={1}
max={24}
value={wfStepMonths}
onChange={(e) => setWfStepMonths(Number(e.target.value))}
className="flex h-10 w-24 rounded-md border border-input bg-background px-3 py-2 text-sm"
/>
</div>
<button
onClick={runWalkForward}
disabled={wfLoading}
className="h-10 px-4 rounded-md bg-primary text-primary-foreground text-sm font-medium hover:bg-primary/90 disabled:opacity-50"
>
{wfLoading ? '분석 중...' : '분석 실행'}
</button>
</div>
{wfWindows.length > 0 && (
<>
<div className="mb-6">
<h3 className="text-sm font-medium text-muted-foreground mb-2"> </h3>
<AreaChart
data={wfWindows.map((w) => {
const cumReturn = wfWindows
.filter((x) => x.window_index <= w.window_index)
.reduce((acc, x) => acc * (1 + (x.test_return ?? 0) / 100), 1);
return {
date: w.test_end,
value: (cumReturn - 1) * 100,
};
})}
height={200}
color="#3b82f6"
showLegend={false}
formatValue={(v) => `${v.toFixed(2)}%`}
formatXAxis={(v) => v.slice(5)}
/>
</div>
<div className="overflow-x-auto">
<table className="w-full">
<thead className="bg-muted">
<tr>
<th scope="col" className="px-4 py-3 text-center text-sm font-medium text-muted-foreground">#</th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"> </th>
<th scope="col" className="px-4 py-3 text-left text-sm font-medium text-muted-foreground"> </th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground"></th>
<th scope="col" className="px-4 py-3 text-right text-sm font-medium text-muted-foreground">MDD</th>
</tr>
</thead>
<tbody className="divide-y divide-border">
{wfWindows.map((w) => (
<tr key={w.window_index} className="hover:bg-muted/50">
<td className="px-4 py-3 text-sm text-center">{w.window_index + 1}</td>
<td className="px-4 py-3 text-sm">{w.train_start} ~ {w.train_end}</td>
<td className="px-4 py-3 text-sm">{w.test_start} ~ {w.test_end}</td>
<td className={`px-4 py-3 text-sm text-right font-medium ${(w.test_return ?? 0) >= 0 ? 'text-green-600' : 'text-red-600'}`}>
{formatNumber(w.test_return)}%
</td>
<td className="px-4 py-3 text-sm text-right">{formatNumber(w.test_sharpe)}</td>
<td className="px-4 py-3 text-sm text-right text-red-600">{formatNumber(w.test_mdd)}%</td>
</tr>
))}
</tbody>
</table>
</div>
</>
)}
{wfWindows.length === 0 && !wfLoading && (
<div className="py-8 text-center text-muted-foreground">
.
</div>
)}
</CardContent>
)}
</Card>
</>
)}