galaxis-po/backend/app/api/backtest.py

277 lines
8.3 KiB
Python
Raw Normal View History

"""
Backtest API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, BacktestStatus,
)
from app.schemas.backtest import (
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
)
from app.services.backtest import submit_backtest
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@router.post("", response_model=dict)
async def create_backtest(
request: BacktestCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create and start a new backtest."""
# Create backtest record
backtest = Backtest(
user_id=current_user.id,
strategy_type=request.strategy_type,
strategy_params=request.strategy_params,
start_date=request.start_date,
end_date=request.end_date,
rebalance_period=request.rebalance_period,
initial_capital=request.initial_capital,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
benchmark=request.benchmark,
top_n=request.top_n,
status=BacktestStatus.PENDING,
)
db.add(backtest)
db.commit()
db.refresh(backtest)
# Submit for background execution
submit_backtest(backtest.id)
return {"id": backtest.id, "status": "pending"}
@router.get("", response_model=List[BacktestListItem])
async def list_backtests(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""List all backtests for current user."""
backtests = (
db.query(Backtest)
.filter(Backtest.user_id == current_user.id)
.order_by(Backtest.created_at.desc())
.all()
)
result = []
for bt in backtests:
item = BacktestListItem(
id=bt.id,
strategy_type=bt.strategy_type,
start_date=bt.start_date,
end_date=bt.end_date,
rebalance_period=bt.rebalance_period.value,
status=bt.status.value,
created_at=bt.created_at,
total_return=bt.result.total_return if bt.result else None,
cagr=bt.result.cagr if bt.result else None,
mdd=bt.result.mdd if bt.result else None,
)
result.append(item)
return result
@router.get("/{backtest_id}", response_model=BacktestResponse)
async def get_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get backtest details and results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
result_metrics = None
if backtest.result:
result_metrics = BacktestMetrics(
total_return=backtest.result.total_return,
cagr=backtest.result.cagr,
mdd=backtest.result.mdd,
sharpe_ratio=backtest.result.sharpe_ratio,
volatility=backtest.result.volatility,
benchmark_return=backtest.result.benchmark_return,
excess_return=backtest.result.excess_return,
)
return BacktestResponse(
id=backtest.id,
user_id=backtest.user_id,
strategy_type=backtest.strategy_type,
strategy_params=backtest.strategy_params or {},
start_date=backtest.start_date,
end_date=backtest.end_date,
rebalance_period=backtest.rebalance_period.value,
initial_capital=backtest.initial_capital,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
benchmark=backtest.benchmark,
status=backtest.status.value,
created_at=backtest.created_at,
completed_at=backtest.completed_at,
error_message=backtest.error_message,
result=result_metrics,
)
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
async def get_equity_curve(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get equity curve data for chart."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
curve = (
db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id)
.order_by(BacktestEquityCurve.date)
.all()
)
return [
EquityCurvePoint(
date=point.date,
portfolio_value=point.portfolio_value,
benchmark_value=point.benchmark_value,
drawdown=point.drawdown,
)
for point in curve
]
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
async def get_holdings(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get holdings at each rebalance date."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
holdings = (
db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id)
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
.all()
)
# Group by rebalance date
grouped = {}
for h in holdings:
if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
))
return [
RebalanceHoldings(rebalance_date=date, holdings=items)
for date, items in sorted(grouped.items())
]
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
async def get_transactions(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all transactions."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
transactions = (
db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id)
.order_by(BacktestTransaction.date, BacktestTransaction.id)
.all()
)
return [
TransactionItem(
id=t.id,
date=t.date,
ticker=t.ticker,
action=t.action,
shares=t.shares,
price=t.price,
commission=t.commission,
)
for t in transactions
]
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a backtest and all its data."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
# Delete related data
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()
db.query(BacktestHolding).filter(
BacktestHolding.backtest_id == backtest_id
).delete()
db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(
BacktestResult.backtest_id == backtest_id
).delete()
db.delete(backtest)
db.commit()
return {"message": "Backtest deleted"}