2026-02-03 11:51:29 +09:00
|
|
|
"""
|
|
|
|
|
Backtest API endpoints.
|
|
|
|
|
"""
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
2026-03-18 22:20:29 +09:00
|
|
|
from sqlalchemy.orm import Session, joinedload
|
2026-02-03 11:51:29 +09:00
|
|
|
|
|
|
|
|
from app.core.database import get_db
|
|
|
|
|
from app.api.deps import CurrentUser
|
|
|
|
|
from app.models.backtest import (
|
|
|
|
|
Backtest, BacktestResult, BacktestEquityCurve,
|
|
|
|
|
BacktestHolding, BacktestTransaction, BacktestStatus,
|
2026-03-18 22:33:41 +09:00
|
|
|
WalkForwardResult,
|
2026-02-03 11:51:29 +09:00
|
|
|
)
|
|
|
|
|
from app.schemas.backtest import (
|
|
|
|
|
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
|
|
|
|
|
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
|
2026-03-18 22:33:41 +09:00
|
|
|
WalkForwardRequest, WalkForwardWindowResult, WalkForwardResponse,
|
2026-02-03 11:51:29 +09:00
|
|
|
)
|
|
|
|
|
from app.services.backtest import submit_backtest
|
2026-03-18 22:33:41 +09:00
|
|
|
from app.services.backtest.walkforward_engine import WalkForwardEngine
|
2026-02-16 12:50:21 +09:00
|
|
|
from app.services.rebalance import RebalanceService
|
2026-02-03 11:51:29 +09:00
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("", response_model=dict)
|
|
|
|
|
async def create_backtest(
|
|
|
|
|
request: BacktestCreate,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Create and start a new backtest."""
|
|
|
|
|
# Create backtest record
|
|
|
|
|
backtest = Backtest(
|
|
|
|
|
user_id=current_user.id,
|
|
|
|
|
strategy_type=request.strategy_type,
|
|
|
|
|
strategy_params=request.strategy_params,
|
|
|
|
|
start_date=request.start_date,
|
|
|
|
|
end_date=request.end_date,
|
|
|
|
|
rebalance_period=request.rebalance_period,
|
|
|
|
|
initial_capital=request.initial_capital,
|
|
|
|
|
commission_rate=request.commission_rate,
|
|
|
|
|
slippage_rate=request.slippage_rate,
|
|
|
|
|
benchmark=request.benchmark,
|
|
|
|
|
top_n=request.top_n,
|
|
|
|
|
status=BacktestStatus.PENDING,
|
|
|
|
|
)
|
|
|
|
|
db.add(backtest)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(backtest)
|
|
|
|
|
|
|
|
|
|
# Submit for background execution
|
|
|
|
|
submit_backtest(backtest.id)
|
|
|
|
|
|
|
|
|
|
return {"id": backtest.id, "status": "pending"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("", response_model=List[BacktestListItem])
|
|
|
|
|
async def list_backtests(
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""List all backtests for current user."""
|
|
|
|
|
backtests = (
|
|
|
|
|
db.query(Backtest)
|
2026-03-18 22:20:29 +09:00
|
|
|
.options(joinedload(Backtest.result))
|
2026-02-03 11:51:29 +09:00
|
|
|
.filter(Backtest.user_id == current_user.id)
|
|
|
|
|
.order_by(Backtest.created_at.desc())
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = []
|
|
|
|
|
for bt in backtests:
|
|
|
|
|
item = BacktestListItem(
|
|
|
|
|
id=bt.id,
|
|
|
|
|
strategy_type=bt.strategy_type,
|
|
|
|
|
start_date=bt.start_date,
|
|
|
|
|
end_date=bt.end_date,
|
|
|
|
|
rebalance_period=bt.rebalance_period.value,
|
|
|
|
|
status=bt.status.value,
|
|
|
|
|
created_at=bt.created_at,
|
|
|
|
|
total_return=bt.result.total_return if bt.result else None,
|
|
|
|
|
cagr=bt.result.cagr if bt.result else None,
|
|
|
|
|
mdd=bt.result.mdd if bt.result else None,
|
|
|
|
|
)
|
|
|
|
|
result.append(item)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{backtest_id}", response_model=BacktestResponse)
|
|
|
|
|
async def get_backtest(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Get backtest details and results."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
result_metrics = None
|
|
|
|
|
if backtest.result:
|
|
|
|
|
result_metrics = BacktestMetrics(
|
|
|
|
|
total_return=backtest.result.total_return,
|
|
|
|
|
cagr=backtest.result.cagr,
|
|
|
|
|
mdd=backtest.result.mdd,
|
|
|
|
|
sharpe_ratio=backtest.result.sharpe_ratio,
|
|
|
|
|
volatility=backtest.result.volatility,
|
|
|
|
|
benchmark_return=backtest.result.benchmark_return,
|
|
|
|
|
excess_return=backtest.result.excess_return,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return BacktestResponse(
|
|
|
|
|
id=backtest.id,
|
|
|
|
|
user_id=backtest.user_id,
|
|
|
|
|
strategy_type=backtest.strategy_type,
|
|
|
|
|
strategy_params=backtest.strategy_params or {},
|
|
|
|
|
start_date=backtest.start_date,
|
|
|
|
|
end_date=backtest.end_date,
|
|
|
|
|
rebalance_period=backtest.rebalance_period.value,
|
|
|
|
|
initial_capital=backtest.initial_capital,
|
|
|
|
|
commission_rate=backtest.commission_rate,
|
|
|
|
|
slippage_rate=backtest.slippage_rate,
|
|
|
|
|
benchmark=backtest.benchmark,
|
|
|
|
|
status=backtest.status.value,
|
|
|
|
|
created_at=backtest.created_at,
|
|
|
|
|
completed_at=backtest.completed_at,
|
|
|
|
|
error_message=backtest.error_message,
|
|
|
|
|
result=result_metrics,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
|
|
|
|
|
async def get_equity_curve(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Get equity curve data for chart."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
curve = (
|
|
|
|
|
db.query(BacktestEquityCurve)
|
|
|
|
|
.filter(BacktestEquityCurve.backtest_id == backtest_id)
|
|
|
|
|
.order_by(BacktestEquityCurve.date)
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
EquityCurvePoint(
|
|
|
|
|
date=point.date,
|
|
|
|
|
portfolio_value=point.portfolio_value,
|
|
|
|
|
benchmark_value=point.benchmark_value,
|
|
|
|
|
drawdown=point.drawdown,
|
|
|
|
|
)
|
|
|
|
|
for point in curve
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
|
|
|
|
|
async def get_holdings(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Get holdings at each rebalance date."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
holdings = (
|
|
|
|
|
db.query(BacktestHolding)
|
|
|
|
|
.filter(BacktestHolding.backtest_id == backtest_id)
|
|
|
|
|
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Group by rebalance date
|
|
|
|
|
grouped = {}
|
|
|
|
|
for h in holdings:
|
|
|
|
|
if h.rebalance_date not in grouped:
|
|
|
|
|
grouped[h.rebalance_date] = []
|
|
|
|
|
grouped[h.rebalance_date].append(HoldingItem(
|
|
|
|
|
ticker=h.ticker,
|
|
|
|
|
name=h.name,
|
|
|
|
|
weight=h.weight,
|
|
|
|
|
shares=h.shares,
|
|
|
|
|
price=h.price,
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
RebalanceHoldings(rebalance_date=date, holdings=items)
|
|
|
|
|
for date, items in sorted(grouped.items())
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
|
|
|
|
|
async def get_transactions(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Get all transactions."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
transactions = (
|
|
|
|
|
db.query(BacktestTransaction)
|
|
|
|
|
.filter(BacktestTransaction.backtest_id == backtest_id)
|
|
|
|
|
.order_by(BacktestTransaction.date, BacktestTransaction.id)
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
2026-02-16 12:50:21 +09:00
|
|
|
# Resolve stock names
|
|
|
|
|
tickers = list({t.ticker for t in transactions})
|
|
|
|
|
name_service = RebalanceService(db)
|
|
|
|
|
names = name_service.get_stock_names(tickers)
|
|
|
|
|
|
2026-02-03 11:51:29 +09:00
|
|
|
return [
|
|
|
|
|
TransactionItem(
|
|
|
|
|
id=t.id,
|
|
|
|
|
date=t.date,
|
|
|
|
|
ticker=t.ticker,
|
2026-02-16 12:50:21 +09:00
|
|
|
name=names.get(t.ticker),
|
2026-02-03 11:51:29 +09:00
|
|
|
action=t.action,
|
|
|
|
|
shares=t.shares,
|
|
|
|
|
price=t.price,
|
|
|
|
|
commission=t.commission,
|
|
|
|
|
)
|
|
|
|
|
for t in transactions
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
2026-03-18 22:33:41 +09:00
|
|
|
@router.post("/{backtest_id}/walkforward", response_model=dict)
|
|
|
|
|
async def run_walkforward(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
request: WalkForwardRequest,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Run walk-forward analysis on a completed backtest."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
if backtest.status != BacktestStatus.COMPLETED:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
engine = WalkForwardEngine(db)
|
|
|
|
|
try:
|
|
|
|
|
engine.run(
|
|
|
|
|
backtest_id=backtest_id,
|
|
|
|
|
train_months=request.train_months,
|
|
|
|
|
test_months=request.test_months,
|
|
|
|
|
step_months=request.step_months,
|
|
|
|
|
)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
|
|
return {"status": "completed", "backtest_id": backtest_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
|
|
|
|
|
async def get_walkforward(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Get walk-forward analysis results."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
results = (
|
|
|
|
|
db.query(WalkForwardResult)
|
|
|
|
|
.filter(WalkForwardResult.backtest_id == backtest_id)
|
|
|
|
|
.order_by(WalkForwardResult.window_index)
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return WalkForwardResponse(
|
|
|
|
|
backtest_id=backtest_id,
|
|
|
|
|
windows=[
|
|
|
|
|
WalkForwardWindowResult(
|
|
|
|
|
window_index=r.window_index,
|
|
|
|
|
train_start=r.train_start,
|
|
|
|
|
train_end=r.train_end,
|
|
|
|
|
test_start=r.test_start,
|
|
|
|
|
test_end=r.test_end,
|
|
|
|
|
test_return=r.test_return,
|
|
|
|
|
test_sharpe=r.test_sharpe,
|
|
|
|
|
test_mdd=r.test_mdd,
|
|
|
|
|
)
|
|
|
|
|
for r in results
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-02-03 11:51:29 +09:00
|
|
|
@router.delete("/{backtest_id}")
|
|
|
|
|
async def delete_backtest(
|
|
|
|
|
backtest_id: int,
|
|
|
|
|
current_user: CurrentUser,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Delete a backtest and all its data."""
|
|
|
|
|
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
|
|
|
|
|
|
|
|
|
if not backtest:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Backtest not found")
|
|
|
|
|
|
|
|
|
|
if backtest.user_id != current_user.id:
|
|
|
|
|
raise HTTPException(status_code=403, detail="Not authorized")
|
|
|
|
|
|
|
|
|
|
# Delete related data
|
2026-03-18 22:33:41 +09:00
|
|
|
db.query(WalkForwardResult).filter(
|
|
|
|
|
WalkForwardResult.backtest_id == backtest_id
|
|
|
|
|
).delete()
|
2026-02-03 11:51:29 +09:00
|
|
|
db.query(BacktestTransaction).filter(
|
|
|
|
|
BacktestTransaction.backtest_id == backtest_id
|
|
|
|
|
).delete()
|
|
|
|
|
db.query(BacktestHolding).filter(
|
|
|
|
|
BacktestHolding.backtest_id == backtest_id
|
|
|
|
|
).delete()
|
|
|
|
|
db.query(BacktestEquityCurve).filter(
|
|
|
|
|
BacktestEquityCurve.backtest_id == backtest_id
|
|
|
|
|
).delete()
|
|
|
|
|
db.query(BacktestResult).filter(
|
|
|
|
|
BacktestResult.backtest_id == backtest_id
|
|
|
|
|
).delete()
|
|
|
|
|
db.delete(backtest)
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
return {"message": "Backtest deleted"}
|