feat: add backtest API endpoints
- POST /api/backtest (create and start)
- GET /api/backtest (list)
- GET /api/backtest/{id} (detail)
- GET /api/backtest/{id}/equity-curve
- GET /api/backtest/{id}/holdings
- GET /api/backtest/{id}/transactions
- DELETE /api/backtest/{id}
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c1ee879cb4
commit
99bd08c68a
@ -3,5 +3,6 @@ from app.api.admin import router as admin_router
|
||||
from app.api.portfolio import router as portfolio_router
|
||||
from app.api.strategy import router as strategy_router
|
||||
from app.api.market import router as market_router
|
||||
from app.api.backtest import router as backtest_router
|
||||
|
||||
__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router"]
|
||||
__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router", "backtest_router"]
|
||||
|
||||
276
backend/app/api/backtest.py
Normal file
276
backend/app/api/backtest.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Backtest API endpoints.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import CurrentUser
|
||||
from app.models.backtest import (
|
||||
Backtest, BacktestResult, BacktestEquityCurve,
|
||||
BacktestHolding, BacktestTransaction, BacktestStatus,
|
||||
)
|
||||
from app.schemas.backtest import (
|
||||
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
|
||||
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
|
||||
)
|
||||
from app.services.backtest import submit_backtest
|
||||
|
||||
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
|
||||
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
async def create_backtest(
|
||||
request: BacktestCreate,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create and start a new backtest."""
|
||||
# Create backtest record
|
||||
backtest = Backtest(
|
||||
user_id=current_user.id,
|
||||
strategy_type=request.strategy_type,
|
||||
strategy_params=request.strategy_params,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
rebalance_period=request.rebalance_period,
|
||||
initial_capital=request.initial_capital,
|
||||
commission_rate=request.commission_rate,
|
||||
slippage_rate=request.slippage_rate,
|
||||
benchmark=request.benchmark,
|
||||
top_n=request.top_n,
|
||||
status=BacktestStatus.PENDING,
|
||||
)
|
||||
db.add(backtest)
|
||||
db.commit()
|
||||
db.refresh(backtest)
|
||||
|
||||
# Submit for background execution
|
||||
submit_backtest(backtest.id)
|
||||
|
||||
return {"id": backtest.id, "status": "pending"}
|
||||
|
||||
|
||||
@router.get("", response_model=List[BacktestListItem])
|
||||
async def list_backtests(
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List all backtests for current user."""
|
||||
backtests = (
|
||||
db.query(Backtest)
|
||||
.filter(Backtest.user_id == current_user.id)
|
||||
.order_by(Backtest.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
result = []
|
||||
for bt in backtests:
|
||||
item = BacktestListItem(
|
||||
id=bt.id,
|
||||
strategy_type=bt.strategy_type,
|
||||
start_date=bt.start_date,
|
||||
end_date=bt.end_date,
|
||||
rebalance_period=bt.rebalance_period.value,
|
||||
status=bt.status.value,
|
||||
created_at=bt.created_at,
|
||||
total_return=bt.result.total_return if bt.result else None,
|
||||
cagr=bt.result.cagr if bt.result else None,
|
||||
mdd=bt.result.mdd if bt.result else None,
|
||||
)
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{backtest_id}", response_model=BacktestResponse)
|
||||
async def get_backtest(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get backtest details and results."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
result_metrics = None
|
||||
if backtest.result:
|
||||
result_metrics = BacktestMetrics(
|
||||
total_return=backtest.result.total_return,
|
||||
cagr=backtest.result.cagr,
|
||||
mdd=backtest.result.mdd,
|
||||
sharpe_ratio=backtest.result.sharpe_ratio,
|
||||
volatility=backtest.result.volatility,
|
||||
benchmark_return=backtest.result.benchmark_return,
|
||||
excess_return=backtest.result.excess_return,
|
||||
)
|
||||
|
||||
return BacktestResponse(
|
||||
id=backtest.id,
|
||||
user_id=backtest.user_id,
|
||||
strategy_type=backtest.strategy_type,
|
||||
strategy_params=backtest.strategy_params or {},
|
||||
start_date=backtest.start_date,
|
||||
end_date=backtest.end_date,
|
||||
rebalance_period=backtest.rebalance_period.value,
|
||||
initial_capital=backtest.initial_capital,
|
||||
commission_rate=backtest.commission_rate,
|
||||
slippage_rate=backtest.slippage_rate,
|
||||
benchmark=backtest.benchmark,
|
||||
status=backtest.status.value,
|
||||
created_at=backtest.created_at,
|
||||
completed_at=backtest.completed_at,
|
||||
error_message=backtest.error_message,
|
||||
result=result_metrics,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
|
||||
async def get_equity_curve(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get equity curve data for chart."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
curve = (
|
||||
db.query(BacktestEquityCurve)
|
||||
.filter(BacktestEquityCurve.backtest_id == backtest_id)
|
||||
.order_by(BacktestEquityCurve.date)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
EquityCurvePoint(
|
||||
date=point.date,
|
||||
portfolio_value=point.portfolio_value,
|
||||
benchmark_value=point.benchmark_value,
|
||||
drawdown=point.drawdown,
|
||||
)
|
||||
for point in curve
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
|
||||
async def get_holdings(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get holdings at each rebalance date."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
holdings = (
|
||||
db.query(BacktestHolding)
|
||||
.filter(BacktestHolding.backtest_id == backtest_id)
|
||||
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group by rebalance date
|
||||
grouped = {}
|
||||
for h in holdings:
|
||||
if h.rebalance_date not in grouped:
|
||||
grouped[h.rebalance_date] = []
|
||||
grouped[h.rebalance_date].append(HoldingItem(
|
||||
ticker=h.ticker,
|
||||
name=h.name,
|
||||
weight=h.weight,
|
||||
shares=h.shares,
|
||||
price=h.price,
|
||||
))
|
||||
|
||||
return [
|
||||
RebalanceHoldings(rebalance_date=date, holdings=items)
|
||||
for date, items in sorted(grouped.items())
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
|
||||
async def get_transactions(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all transactions."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
transactions = (
|
||||
db.query(BacktestTransaction)
|
||||
.filter(BacktestTransaction.backtest_id == backtest_id)
|
||||
.order_by(BacktestTransaction.date, BacktestTransaction.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
TransactionItem(
|
||||
id=t.id,
|
||||
date=t.date,
|
||||
ticker=t.ticker,
|
||||
action=t.action,
|
||||
shares=t.shares,
|
||||
price=t.price,
|
||||
commission=t.commission,
|
||||
)
|
||||
for t in transactions
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{backtest_id}")
|
||||
async def delete_backtest(
|
||||
backtest_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a backtest and all its data."""
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
|
||||
if not backtest:
|
||||
raise HTTPException(status_code=404, detail="Backtest not found")
|
||||
|
||||
if backtest.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Delete related data
|
||||
db.query(BacktestTransaction).filter(
|
||||
BacktestTransaction.backtest_id == backtest_id
|
||||
).delete()
|
||||
db.query(BacktestHolding).filter(
|
||||
BacktestHolding.backtest_id == backtest_id
|
||||
).delete()
|
||||
db.query(BacktestEquityCurve).filter(
|
||||
BacktestEquityCurve.backtest_id == backtest_id
|
||||
).delete()
|
||||
db.query(BacktestResult).filter(
|
||||
BacktestResult.backtest_id == backtest_id
|
||||
).delete()
|
||||
db.delete(backtest)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Backtest deleted"}
|
||||
@ -4,7 +4,7 @@ Galaxy-PO Backend API
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router
|
||||
from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router
|
||||
|
||||
app = FastAPI(
|
||||
title="Galaxy-PO API",
|
||||
@ -26,6 +26,7 @@ app.include_router(admin_router)
|
||||
app.include_router(portfolio_router)
|
||||
app.include_router(strategy_router)
|
||||
app.include_router(market_router)
|
||||
app.include_router(backtest_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user