""" Backtest API endpoints. """ from typing import List from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session, joinedload from app.core.database import get_db from app.api.deps import CurrentUser from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, BacktestHolding, BacktestTransaction, BacktestStatus, ) from app.schemas.backtest import ( BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics, EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem, ) from app.services.backtest import submit_backtest from app.services.rebalance import RebalanceService router = APIRouter(prefix="/api/backtest", tags=["backtest"]) @router.post("", response_model=dict) async def create_backtest( request: BacktestCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create and start a new backtest.""" # Create backtest record backtest = Backtest( user_id=current_user.id, strategy_type=request.strategy_type, strategy_params=request.strategy_params, start_date=request.start_date, end_date=request.end_date, rebalance_period=request.rebalance_period, initial_capital=request.initial_capital, commission_rate=request.commission_rate, slippage_rate=request.slippage_rate, benchmark=request.benchmark, top_n=request.top_n, status=BacktestStatus.PENDING, ) db.add(backtest) db.commit() db.refresh(backtest) # Submit for background execution submit_backtest(backtest.id) return {"id": backtest.id, "status": "pending"} @router.get("", response_model=List[BacktestListItem]) async def list_backtests( current_user: CurrentUser, db: Session = Depends(get_db), ): """List all backtests for current user.""" backtests = ( db.query(Backtest) .options(joinedload(Backtest.result)) .filter(Backtest.user_id == current_user.id) .order_by(Backtest.created_at.desc()) .all() ) result = [] for bt in backtests: item = BacktestListItem( id=bt.id, strategy_type=bt.strategy_type, start_date=bt.start_date, end_date=bt.end_date, rebalance_period=bt.rebalance_period.value, status=bt.status.value, created_at=bt.created_at, total_return=bt.result.total_return if bt.result else None, cagr=bt.result.cagr if bt.result else None, mdd=bt.result.mdd if bt.result else None, ) result.append(item) return result @router.get("/{backtest_id}", response_model=BacktestResponse) async def get_backtest( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get backtest details and results.""" backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") result_metrics = None if backtest.result: result_metrics = BacktestMetrics( total_return=backtest.result.total_return, cagr=backtest.result.cagr, mdd=backtest.result.mdd, sharpe_ratio=backtest.result.sharpe_ratio, volatility=backtest.result.volatility, benchmark_return=backtest.result.benchmark_return, excess_return=backtest.result.excess_return, ) return BacktestResponse( id=backtest.id, user_id=backtest.user_id, strategy_type=backtest.strategy_type, strategy_params=backtest.strategy_params or {}, start_date=backtest.start_date, end_date=backtest.end_date, rebalance_period=backtest.rebalance_period.value, initial_capital=backtest.initial_capital, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, benchmark=backtest.benchmark, status=backtest.status.value, created_at=backtest.created_at, completed_at=backtest.completed_at, error_message=backtest.error_message, result=result_metrics, ) @router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint]) async def get_equity_curve( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get equity curve data for chart.""" backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") curve = ( db.query(BacktestEquityCurve) .filter(BacktestEquityCurve.backtest_id == backtest_id) .order_by(BacktestEquityCurve.date) .all() ) return [ EquityCurvePoint( date=point.date, portfolio_value=point.portfolio_value, benchmark_value=point.benchmark_value, drawdown=point.drawdown, ) for point in curve ] @router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings]) async def get_holdings( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get holdings at each rebalance date.""" backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") holdings = ( db.query(BacktestHolding) .filter(BacktestHolding.backtest_id == backtest_id) .order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc()) .all() ) # Group by rebalance date grouped = {} for h in holdings: if h.rebalance_date not in grouped: grouped[h.rebalance_date] = [] grouped[h.rebalance_date].append(HoldingItem( ticker=h.ticker, name=h.name, weight=h.weight, shares=h.shares, price=h.price, )) return [ RebalanceHoldings(rebalance_date=date, holdings=items) for date, items in sorted(grouped.items()) ] @router.get("/{backtest_id}/transactions", response_model=List[TransactionItem]) async def get_transactions( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all transactions.""" backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") transactions = ( db.query(BacktestTransaction) .filter(BacktestTransaction.backtest_id == backtest_id) .order_by(BacktestTransaction.date, BacktestTransaction.id) .all() ) # Resolve stock names tickers = list({t.ticker for t in transactions}) name_service = RebalanceService(db) names = name_service.get_stock_names(tickers) return [ TransactionItem( id=t.id, date=t.date, ticker=t.ticker, name=names.get(t.ticker), action=t.action, shares=t.shares, price=t.price, commission=t.commission, ) for t in transactions ] @router.delete("/{backtest_id}") async def delete_backtest( backtest_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a backtest and all its data.""" backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() if not backtest: raise HTTPException(status_code=404, detail="Backtest not found") if backtest.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") # Delete related data db.query(BacktestTransaction).filter( BacktestTransaction.backtest_id == backtest_id ).delete() db.query(BacktestHolding).filter( BacktestHolding.backtest_id == backtest_id ).delete() db.query(BacktestEquityCurve).filter( BacktestEquityCurve.backtest_id == backtest_id ).delete() db.query(BacktestResult).filter( BacktestResult.backtest_id == backtest_id ).delete() db.delete(backtest) db.commit() return {"message": "Backtest deleted"}