From 99bd08c68a0db6cc992e0a8af32bb24923cb69b0 Mon Sep 17 00:00:00 2001 From: zephyrdark Date: Tue, 3 Feb 2026 11:51:29 +0900 Subject: [PATCH] feat: add backtest API endpoints - POST /api/backtest (create and start) - GET /api/backtest (list) - GET /api/backtest/{id} (detail) - GET /api/backtest/{id}/equity-curve - GET /api/backtest/{id}/holdings - GET /api/backtest/{id}/transactions - DELETE /api/backtest/{id} Co-Authored-By: Claude Opus 4.5 --- backend/app/api/__init__.py | 3 +- backend/app/api/backtest.py | 276 ++++++++++++++++++++++++++++++++++++ backend/app/main.py | 3 +- 3 files changed, 280 insertions(+), 2 deletions(-) create mode 100644 backend/app/api/backtest.py diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 90d929b..4d98935 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -3,5 +3,6 @@ from app.api.admin import router as admin_router from app.api.portfolio import router as portfolio_router from app.api.strategy import router as strategy_router from app.api.market import router as market_router +from app.api.backtest import router as backtest_router -__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router"] +__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router", "backtest_router"] diff --git a/backend/app/api/backtest.py b/backend/app/api/backtest.py new file mode 100644 index 0000000..e636370 --- /dev/null +++ b/backend/app/api/backtest.py @@ -0,0 +1,276 @@ +""" +Backtest API endpoints. +""" +from typing import List + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.models.backtest import ( + Backtest, BacktestResult, BacktestEquityCurve, + BacktestHolding, BacktestTransaction, BacktestStatus, +) +from app.schemas.backtest import ( + BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics, + EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem, +) +from app.services.backtest import submit_backtest + +router = APIRouter(prefix="/api/backtest", tags=["backtest"]) + + +@router.post("", response_model=dict) +async def create_backtest( + request: BacktestCreate, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Create and start a new backtest.""" + # Create backtest record + backtest = Backtest( + user_id=current_user.id, + strategy_type=request.strategy_type, + strategy_params=request.strategy_params, + start_date=request.start_date, + end_date=request.end_date, + rebalance_period=request.rebalance_period, + initial_capital=request.initial_capital, + commission_rate=request.commission_rate, + slippage_rate=request.slippage_rate, + benchmark=request.benchmark, + top_n=request.top_n, + status=BacktestStatus.PENDING, + ) + db.add(backtest) + db.commit() + db.refresh(backtest) + + # Submit for background execution + submit_backtest(backtest.id) + + return {"id": backtest.id, "status": "pending"} + + +@router.get("", response_model=List[BacktestListItem]) +async def list_backtests( + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """List all backtests for current user.""" + backtests = ( + db.query(Backtest) + .filter(Backtest.user_id == current_user.id) + .order_by(Backtest.created_at.desc()) + .all() + ) + + result = [] + for bt in backtests: + item = BacktestListItem( + id=bt.id, + strategy_type=bt.strategy_type, + start_date=bt.start_date, + end_date=bt.end_date, + rebalance_period=bt.rebalance_period.value, + status=bt.status.value, + created_at=bt.created_at, + total_return=bt.result.total_return if bt.result else None, + cagr=bt.result.cagr if bt.result else None, + mdd=bt.result.mdd if bt.result else None, + ) + result.append(item) + + return result + + +@router.get("/{backtest_id}", response_model=BacktestResponse) +async def get_backtest( + backtest_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get backtest details and results.""" + backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() + + if not backtest: + raise HTTPException(status_code=404, detail="Backtest not found") + + if backtest.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + result_metrics = None + if backtest.result: + result_metrics = BacktestMetrics( + total_return=backtest.result.total_return, + cagr=backtest.result.cagr, + mdd=backtest.result.mdd, + sharpe_ratio=backtest.result.sharpe_ratio, + volatility=backtest.result.volatility, + benchmark_return=backtest.result.benchmark_return, + excess_return=backtest.result.excess_return, + ) + + return BacktestResponse( + id=backtest.id, + user_id=backtest.user_id, + strategy_type=backtest.strategy_type, + strategy_params=backtest.strategy_params or {}, + start_date=backtest.start_date, + end_date=backtest.end_date, + rebalance_period=backtest.rebalance_period.value, + initial_capital=backtest.initial_capital, + commission_rate=backtest.commission_rate, + slippage_rate=backtest.slippage_rate, + benchmark=backtest.benchmark, + status=backtest.status.value, + created_at=backtest.created_at, + completed_at=backtest.completed_at, + error_message=backtest.error_message, + result=result_metrics, + ) + + +@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint]) +async def get_equity_curve( + backtest_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get equity curve data for chart.""" + backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() + + if not backtest: + raise HTTPException(status_code=404, detail="Backtest not found") + + if backtest.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + curve = ( + db.query(BacktestEquityCurve) + .filter(BacktestEquityCurve.backtest_id == backtest_id) + .order_by(BacktestEquityCurve.date) + .all() + ) + + return [ + EquityCurvePoint( + date=point.date, + portfolio_value=point.portfolio_value, + benchmark_value=point.benchmark_value, + drawdown=point.drawdown, + ) + for point in curve + ] + + +@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings]) +async def get_holdings( + backtest_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get holdings at each rebalance date.""" + backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() + + if not backtest: + raise HTTPException(status_code=404, detail="Backtest not found") + + if backtest.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + holdings = ( + db.query(BacktestHolding) + .filter(BacktestHolding.backtest_id == backtest_id) + .order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc()) + .all() + ) + + # Group by rebalance date + grouped = {} + for h in holdings: + if h.rebalance_date not in grouped: + grouped[h.rebalance_date] = [] + grouped[h.rebalance_date].append(HoldingItem( + ticker=h.ticker, + name=h.name, + weight=h.weight, + shares=h.shares, + price=h.price, + )) + + return [ + RebalanceHoldings(rebalance_date=date, holdings=items) + for date, items in sorted(grouped.items()) + ] + + +@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem]) +async def get_transactions( + backtest_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get all transactions.""" + backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() + + if not backtest: + raise HTTPException(status_code=404, detail="Backtest not found") + + if backtest.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + transactions = ( + db.query(BacktestTransaction) + .filter(BacktestTransaction.backtest_id == backtest_id) + .order_by(BacktestTransaction.date, BacktestTransaction.id) + .all() + ) + + return [ + TransactionItem( + id=t.id, + date=t.date, + ticker=t.ticker, + action=t.action, + shares=t.shares, + price=t.price, + commission=t.commission, + ) + for t in transactions + ] + + +@router.delete("/{backtest_id}") +async def delete_backtest( + backtest_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Delete a backtest and all its data.""" + backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first() + + if not backtest: + raise HTTPException(status_code=404, detail="Backtest not found") + + if backtest.user_id != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + # Delete related data + db.query(BacktestTransaction).filter( + BacktestTransaction.backtest_id == backtest_id + ).delete() + db.query(BacktestHolding).filter( + BacktestHolding.backtest_id == backtest_id + ).delete() + db.query(BacktestEquityCurve).filter( + BacktestEquityCurve.backtest_id == backtest_id + ).delete() + db.query(BacktestResult).filter( + BacktestResult.backtest_id == backtest_id + ).delete() + db.delete(backtest) + db.commit() + + return {"message": "Backtest deleted"} diff --git a/backend/app/main.py b/backend/app/main.py index 3be8cdf..bd1f959 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,7 +4,7 @@ Galaxy-PO Backend API from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router +from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router app = FastAPI( title="Galaxy-PO API", @@ -26,6 +26,7 @@ app.include_router(admin_router) app.include_router(portfolio_router) app.include_router(strategy_router) app.include_router(market_router) +app.include_router(backtest_router) @app.get("/health")