""" Backtest API endpoints. """ from typing import List from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session, joinedload from app.core.database import get_db from app.api.deps import CurrentUser from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, BacktestHolding, BacktestTransaction, BacktestStatus, WalkForwardResult, ) from app.schemas.backtest import ( BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics, EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem, WalkForwardRequest, WalkForwardWindowResult, WalkForwardResponse, ) from app.services.backtest import submit_backtest from app.services.backtest.walkforward_engine import WalkForwardEngine from app.services.rebalance import RebalanceService router = APIRouter(prefix="/api/backtest", tags=["backtest"]) @router.post("", response_model=dict) async def create_backtest( request: BacktestCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create and start a new backtest.""" # Create backtest record backtest = Backtest( user_id=current_user.id, strategy_type=request.strategy_type, strategy_params=request.strategy_params, start_date=request.start_date, end_date=request.end_date, rebalance_period=request.rebalance_period, initial_capital=request.initial_capital, commission_rate=request.commission_rate, slippage_rate=request.slippage_rate, benchmark=request.benchmark, top_n=request.top_n, status=BacktestStatus.PENDING, ) db.add(backtest) db.commit() db.refresh(backtest) # Submit for background execution submit_backtest(backtest.id) return {"id": backtest.id, "status": "pending"} @router.get("", response_model=List[BacktestListItem]) async def list_backtests( current_user: CurrentUser, db: Session = Depends(get_db), ): """List all backtests for current user.""" backtests = ( db.query(Backtest) .options(joinedload(Backtest.result)) .filter(Backtest.user_id == current_user.id) .order_by(Backtest.created_at.desc()) .all() ) result = [] for bt in backtests: item = BacktestListItem( id=bt.id, strategy_type=bt.strategy_type, start_date=bt.start_date, end_date=bt.end_date, rebalance_period=bt.rebalance_period.value, status=bt.status.value, created_at=bt.created_at, total_return=bt.result.total_return if bt.result else None, cagr=bt.result.cagr if bt.result else None, mdd=bt.result.mdd if bt.result else None, ) result.append(item) return result @router.get("/{backtest_id}", response_model=BacktestResponse) async def get_backtest( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get backtest details and results.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") result_metrics = None if backtest.result: result_metrics = BacktestMetrics( total_return=backtest.result.total_return, cagr=backtest.result.cagr, mdd=backtest.result.mdd, sharpe_ratio=backtest.result.sharpe_ratio, volatility=backtest.result.volatility, benchmark_return=backtest.result.benchmark_return, excess_return=backtest.result.excess_return, ) return BacktestResponse( id=backtest.id, user_id=backtest.user_id, strategy_type=backtest.strategy_type, strategy_params=backtest.strategy_params or {}, start_date=backtest.start_date, end_date=backtest.end_date, rebalance_period=backtest.rebalance_period.value, initial_capital=backtest.initial_capital, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, benchmark=backtest.benchmark, status=backtest.status.value, created_at=backtest.created_at, completed_at=backtest.completed_at, error_message=backtest.error_message, result=result_metrics, ) @router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint]) async def get_equity_curve( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get equity curve data for chart.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") curve = ( db.query(BacktestEquityCurve) .filter(BacktestEquityCurve.backtest_id == backtest_id) .order_by(BacktestEquityCurve.date) .all() ) return [ EquityCurvePoint( date=point.date, portfolio_value=point.portfolio_value, benchmark_value=point.benchmark_value, drawdown=point.drawdown, ) for point in curve ] @router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings]) async def get_holdings( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get holdings at each rebalance date.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") holdings = ( db.query(BacktestHolding) .filter(BacktestHolding.backtest_id == backtest_id) .order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc()) .all() ) # Group by rebalance date grouped = {} for h in holdings: if h.rebalance_date not in grouped: grouped[h.rebalance_date] = [] grouped[h.rebalance_date].append( HoldingItem( ticker=h.ticker, name=h.name, weight=h.weight, shares=h.shares, price=h.price, ) ) return [ RebalanceHoldings(rebalance_date=date, holdings=items) for date, items in sorted(grouped.items()) ] @router.get("/{backtest_id}/transactions", response_model=List[TransactionItem]) async def get_transactions( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all transactions.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") transactions = ( db.query(BacktestTransaction) .filter(BacktestTransaction.backtest_id == backtest_id) .order_by(BacktestTransaction.date, BacktestTransaction.id) .all() ) # Resolve stock names tickers = list({t.ticker for t in transactions}) name_service = RebalanceService(db) names = name_service.get_stock_names(tickers) return [ TransactionItem( id=t.id, date=t.date, ticker=t.ticker, name=names.get(t.ticker), action=t.action, shares=t.shares, price=t.price, commission=t.commission, ) for t in transactions ] @router.post("/{backtest_id}/walkforward", response_model=dict) async def run_walkforward( backtest_id: int, request: WalkForwardRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """Run walk-forward analysis on a completed backtest.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.status != BacktestStatus.COMPLETED: raise HTTPException( status_code=400, detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다", ) engine = WalkForwardEngine(db) try: engine.run( backtest_id=backtest_id, train_months=request.train_months, test_months=request.test_months, step_months=request.step_months, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return {"status": "completed", "backtest_id": backtest_id} @router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse) async def get_walkforward( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get walk-forward analysis results.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") results = ( db.query(WalkForwardResult) .filter(WalkForwardResult.backtest_id == backtest_id) .order_by(WalkForwardResult.window_index) .all() ) return WalkForwardResponse( backtest_id=backtest_id, windows=[ WalkForwardWindowResult( window_index=r.window_index, train_start=r.train_start, train_end=r.train_end, test_start=r.test_start, test_end=r.test_end, test_return=r.test_return, test_sharpe=r.test_sharpe, test_mdd=r.test_mdd, ) for r in results ], ) @router.delete("/{backtest_id}") async def delete_backtest( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a backtest and all its data.""" backtest = ( db.query(Backtest) .filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id) .first() ) if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") # Delete related data db.query(WalkForwardResult).filter( WalkForwardResult.backtest_id == backtest_id ).delete() db.query(BacktestTransaction).filter( BacktestTransaction.backtest_id == backtest_id ).delete() db.query(BacktestHolding).filter( BacktestHolding.backtest_id == backtest_id ).delete() db.query(BacktestEquityCurve).filter( BacktestEquityCurve.backtest_id == backtest_id ).delete() db.query(BacktestResult).filter(BacktestResult.backtest_id == backtest_id).delete() db.delete(backtest) db.commit() return {"message": "Backtest deleted"}