galaxis-po/backend/app/api/backtest.py

284 lines
8.6 KiB
Python

"""
Backtest API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, BacktestStatus,
)
from app.schemas.backtest import (
BacktestCreate, BacktestResponse, BacktestListItem, BacktestMetrics,
EquityCurvePoint, RebalanceHoldings, HoldingItem, TransactionItem,
)
from app.services.backtest import submit_backtest
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@router.post("", response_model=dict)
async def create_backtest(
request: BacktestCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create and start a new backtest."""
# Create backtest record
backtest = Backtest(
user_id=current_user.id,
strategy_type=request.strategy_type,
strategy_params=request.strategy_params,
start_date=request.start_date,
end_date=request.end_date,
rebalance_period=request.rebalance_period,
initial_capital=request.initial_capital,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
benchmark=request.benchmark,
top_n=request.top_n,
status=BacktestStatus.PENDING,
)
db.add(backtest)
db.commit()
db.refresh(backtest)
# Submit for background execution
submit_backtest(backtest.id)
return {"id": backtest.id, "status": "pending"}
@router.get("", response_model=List[BacktestListItem])
async def list_backtests(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""List all backtests for current user."""
backtests = (
db.query(Backtest)
.filter(Backtest.user_id == current_user.id)
.order_by(Backtest.created_at.desc())
.all()
)
result = []
for bt in backtests:
item = BacktestListItem(
id=bt.id,
strategy_type=bt.strategy_type,
start_date=bt.start_date,
end_date=bt.end_date,
rebalance_period=bt.rebalance_period.value,
status=bt.status.value,
created_at=bt.created_at,
total_return=bt.result.total_return if bt.result else None,
cagr=bt.result.cagr if bt.result else None,
mdd=bt.result.mdd if bt.result else None,
)
result.append(item)
return result
@router.get("/{backtest_id}", response_model=BacktestResponse)
async def get_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get backtest details and results."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
result_metrics = None
if backtest.result:
result_metrics = BacktestMetrics(
total_return=backtest.result.total_return,
cagr=backtest.result.cagr,
mdd=backtest.result.mdd,
sharpe_ratio=backtest.result.sharpe_ratio,
volatility=backtest.result.volatility,
benchmark_return=backtest.result.benchmark_return,
excess_return=backtest.result.excess_return,
)
return BacktestResponse(
id=backtest.id,
user_id=backtest.user_id,
strategy_type=backtest.strategy_type,
strategy_params=backtest.strategy_params or {},
start_date=backtest.start_date,
end_date=backtest.end_date,
rebalance_period=backtest.rebalance_period.value,
initial_capital=backtest.initial_capital,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
benchmark=backtest.benchmark,
status=backtest.status.value,
created_at=backtest.created_at,
completed_at=backtest.completed_at,
error_message=backtest.error_message,
result=result_metrics,
)
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
async def get_equity_curve(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get equity curve data for chart."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
curve = (
db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id)
.order_by(BacktestEquityCurve.date)
.all()
)
return [
EquityCurvePoint(
date=point.date,
portfolio_value=point.portfolio_value,
benchmark_value=point.benchmark_value,
drawdown=point.drawdown,
)
for point in curve
]
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
async def get_holdings(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get holdings at each rebalance date."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
holdings = (
db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id)
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
.all()
)
# Group by rebalance date
grouped = {}
for h in holdings:
if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
))
return [
RebalanceHoldings(rebalance_date=date, holdings=items)
for date, items in sorted(grouped.items())
]
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
async def get_transactions(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all transactions."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
transactions = (
db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id)
.order_by(BacktestTransaction.date, BacktestTransaction.id)
.all()
)
# Resolve stock names
tickers = list({t.ticker for t in transactions})
name_service = RebalanceService(db)
names = name_service.get_stock_names(tickers)
return [
TransactionItem(
id=t.id,
date=t.date,
ticker=t.ticker,
name=names.get(t.ticker),
action=t.action,
shares=t.shares,
price=t.price,
commission=t.commission,
)
for t in transactions
]
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a backtest and all its data."""
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
# Delete related data
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()
db.query(BacktestHolding).filter(
BacktestHolding.backtest_id == backtest_id
).delete()
db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(
BacktestResult.backtest_id == backtest_id
).delete()
db.delete(backtest)
db.commit()
return {"message": "Backtest deleted"}