galaxis-po/backend/app/api/backtest.py

387 lines
11 KiB
Python
Raw Normal View History

"""
Backtest API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session, joinedload
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
BacktestStatus,
WalkForwardResult,
)
from app.schemas.backtest import (
BacktestCreate,
BacktestResponse,
BacktestListItem,
BacktestMetrics,
EquityCurvePoint,
RebalanceHoldings,
HoldingItem,
TransactionItem,
WalkForwardRequest,
WalkForwardWindowResult,
WalkForwardResponse,
)
from app.services.backtest import submit_backtest
from app.services.backtest.walkforward_engine import WalkForwardEngine
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@router.post("", response_model=dict)
async def create_backtest(
request: BacktestCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create and start a new backtest."""
# Create backtest record
backtest = Backtest(
user_id=current_user.id,
strategy_type=request.strategy_type,
strategy_params=request.strategy_params,
start_date=request.start_date,
end_date=request.end_date,
rebalance_period=request.rebalance_period,
initial_capital=request.initial_capital,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
benchmark=request.benchmark,
top_n=request.top_n,
status=BacktestStatus.PENDING,
)
db.add(backtest)
db.commit()
db.refresh(backtest)
# Submit for background execution
submit_backtest(backtest.id)
return {"id": backtest.id, "status": "pending"}
@router.get("", response_model=List[BacktestListItem])
async def list_backtests(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""List all backtests for current user."""
backtests = (
db.query(Backtest)
.options(joinedload(Backtest.result))
.filter(Backtest.user_id == current_user.id)
.order_by(Backtest.created_at.desc())
.all()
)
result = []
for bt in backtests:
item = BacktestListItem(
id=bt.id,
strategy_type=bt.strategy_type,
start_date=bt.start_date,
end_date=bt.end_date,
rebalance_period=bt.rebalance_period.value,
status=bt.status.value,
created_at=bt.created_at,
total_return=bt.result.total_return if bt.result else None,
cagr=bt.result.cagr if bt.result else None,
mdd=bt.result.mdd if bt.result else None,
)
result.append(item)
return result
@router.get("/{backtest_id}", response_model=BacktestResponse)
async def get_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get backtest details and results."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
result_metrics = None
if backtest.result:
result_metrics = BacktestMetrics(
total_return=backtest.result.total_return,
cagr=backtest.result.cagr,
mdd=backtest.result.mdd,
sharpe_ratio=backtest.result.sharpe_ratio,
volatility=backtest.result.volatility,
benchmark_return=backtest.result.benchmark_return,
excess_return=backtest.result.excess_return,
)
return BacktestResponse(
id=backtest.id,
user_id=backtest.user_id,
strategy_type=backtest.strategy_type,
strategy_params=backtest.strategy_params or {},
start_date=backtest.start_date,
end_date=backtest.end_date,
rebalance_period=backtest.rebalance_period.value,
initial_capital=backtest.initial_capital,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
benchmark=backtest.benchmark,
status=backtest.status.value,
created_at=backtest.created_at,
completed_at=backtest.completed_at,
error_message=backtest.error_message,
result=result_metrics,
)
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
async def get_equity_curve(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get equity curve data for chart."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
curve = (
db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id)
.order_by(BacktestEquityCurve.date)
.all()
)
return [
EquityCurvePoint(
date=point.date,
portfolio_value=point.portfolio_value,
benchmark_value=point.benchmark_value,
drawdown=point.drawdown,
)
for point in curve
]
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
async def get_holdings(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get holdings at each rebalance date."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
holdings = (
db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id)
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
.all()
)
# Group by rebalance date
grouped = {}
for h in holdings:
if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(
HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
)
)
return [
RebalanceHoldings(rebalance_date=date, holdings=items)
for date, items in sorted(grouped.items())
]
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
async def get_transactions(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all transactions."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
transactions = (
db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id)
.order_by(BacktestTransaction.date, BacktestTransaction.id)
.all()
)
# Resolve stock names
tickers = list({t.ticker for t in transactions})
name_service = RebalanceService(db)
names = name_service.get_stock_names(tickers)
return [
TransactionItem(
id=t.id,
date=t.date,
ticker=t.ticker,
name=names.get(t.ticker),
action=t.action,
shares=t.shares,
price=t.price,
commission=t.commission,
)
for t in transactions
]
@router.post("/{backtest_id}/walkforward", response_model=dict)
async def run_walkforward(
backtest_id: int,
request: WalkForwardRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Run walk-forward analysis on a completed backtest."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.status != BacktestStatus.COMPLETED:
raise HTTPException(
status_code=400,
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
)
engine = WalkForwardEngine(db)
try:
engine.run(
backtest_id=backtest_id,
train_months=request.train_months,
test_months=request.test_months,
step_months=request.step_months,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "completed", "backtest_id": backtest_id}
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
async def get_walkforward(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get walk-forward analysis results."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
results = (
db.query(WalkForwardResult)
.filter(WalkForwardResult.backtest_id == backtest_id)
.order_by(WalkForwardResult.window_index)
.all()
)
return WalkForwardResponse(
backtest_id=backtest_id,
windows=[
WalkForwardWindowResult(
window_index=r.window_index,
train_start=r.train_start,
train_end=r.train_end,
test_start=r.test_start,
test_end=r.test_end,
test_return=r.test_return,
test_sharpe=r.test_sharpe,
test_mdd=r.test_mdd,
)
for r in results
],
)
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a backtest and all its data."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
# Delete related data
db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()
db.query(BacktestHolding).filter(
BacktestHolding.backtest_id == backtest_id
).delete()
db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(BacktestResult.backtest_id == backtest_id).delete()
db.delete(backtest)
db.commit()
return {"message": "Backtest deleted"}