galaxis-po/backend/app/api/backtest.py
머니페니 f6db08c9bd feat: improve security, performance, and add missing features
- Remove hardcoded database_url/jwt_secret defaults, require env vars
- Add DB indexes for stocks.market, market_cap, backtests.user_id
- Optimize backtest engine: preload all prices, move stock_names out of loop
- Fix backtest API auth: filter by user_id at query level (6 endpoints)
- Add manual transaction entry modal on portfolio detail page
- Replace console.error with toast.error in signals, backtest, data explorer
- Add backtest delete button with confirmation dialog
- Replace simulated sine chart with real snapshot data
- Add strategy-to-portfolio apply flow with dialog
- Add DC pension risk asset ratio >70% warning on rebalance page
- Add backtest comparison page with metrics table and overlay chart
2026-03-20 12:27:05 +09:00

387 lines
11 KiB
Python

"""
Backtest API endpoints.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session, joinedload
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
BacktestStatus,
WalkForwardResult,
)
from app.schemas.backtest import (
BacktestCreate,
BacktestResponse,
BacktestListItem,
BacktestMetrics,
EquityCurvePoint,
RebalanceHoldings,
HoldingItem,
TransactionItem,
WalkForwardRequest,
WalkForwardWindowResult,
WalkForwardResponse,
)
from app.services.backtest import submit_backtest
from app.services.backtest.walkforward_engine import WalkForwardEngine
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
@router.post("", response_model=dict)
async def create_backtest(
request: BacktestCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create and start a new backtest."""
# Create backtest record
backtest = Backtest(
user_id=current_user.id,
strategy_type=request.strategy_type,
strategy_params=request.strategy_params,
start_date=request.start_date,
end_date=request.end_date,
rebalance_period=request.rebalance_period,
initial_capital=request.initial_capital,
commission_rate=request.commission_rate,
slippage_rate=request.slippage_rate,
benchmark=request.benchmark,
top_n=request.top_n,
status=BacktestStatus.PENDING,
)
db.add(backtest)
db.commit()
db.refresh(backtest)
# Submit for background execution
submit_backtest(backtest.id)
return {"id": backtest.id, "status": "pending"}
@router.get("", response_model=List[BacktestListItem])
async def list_backtests(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""List all backtests for current user."""
backtests = (
db.query(Backtest)
.options(joinedload(Backtest.result))
.filter(Backtest.user_id == current_user.id)
.order_by(Backtest.created_at.desc())
.all()
)
result = []
for bt in backtests:
item = BacktestListItem(
id=bt.id,
strategy_type=bt.strategy_type,
start_date=bt.start_date,
end_date=bt.end_date,
rebalance_period=bt.rebalance_period.value,
status=bt.status.value,
created_at=bt.created_at,
total_return=bt.result.total_return if bt.result else None,
cagr=bt.result.cagr if bt.result else None,
mdd=bt.result.mdd if bt.result else None,
)
result.append(item)
return result
@router.get("/{backtest_id}", response_model=BacktestResponse)
async def get_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get backtest details and results."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
result_metrics = None
if backtest.result:
result_metrics = BacktestMetrics(
total_return=backtest.result.total_return,
cagr=backtest.result.cagr,
mdd=backtest.result.mdd,
sharpe_ratio=backtest.result.sharpe_ratio,
volatility=backtest.result.volatility,
benchmark_return=backtest.result.benchmark_return,
excess_return=backtest.result.excess_return,
)
return BacktestResponse(
id=backtest.id,
user_id=backtest.user_id,
strategy_type=backtest.strategy_type,
strategy_params=backtest.strategy_params or {},
start_date=backtest.start_date,
end_date=backtest.end_date,
rebalance_period=backtest.rebalance_period.value,
initial_capital=backtest.initial_capital,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
benchmark=backtest.benchmark,
status=backtest.status.value,
created_at=backtest.created_at,
completed_at=backtest.completed_at,
error_message=backtest.error_message,
result=result_metrics,
)
@router.get("/{backtest_id}/equity-curve", response_model=List[EquityCurvePoint])
async def get_equity_curve(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get equity curve data for chart."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
curve = (
db.query(BacktestEquityCurve)
.filter(BacktestEquityCurve.backtest_id == backtest_id)
.order_by(BacktestEquityCurve.date)
.all()
)
return [
EquityCurvePoint(
date=point.date,
portfolio_value=point.portfolio_value,
benchmark_value=point.benchmark_value,
drawdown=point.drawdown,
)
for point in curve
]
@router.get("/{backtest_id}/holdings", response_model=List[RebalanceHoldings])
async def get_holdings(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get holdings at each rebalance date."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
holdings = (
db.query(BacktestHolding)
.filter(BacktestHolding.backtest_id == backtest_id)
.order_by(BacktestHolding.rebalance_date, BacktestHolding.weight.desc())
.all()
)
# Group by rebalance date
grouped = {}
for h in holdings:
if h.rebalance_date not in grouped:
grouped[h.rebalance_date] = []
grouped[h.rebalance_date].append(
HoldingItem(
ticker=h.ticker,
name=h.name,
weight=h.weight,
shares=h.shares,
price=h.price,
)
)
return [
RebalanceHoldings(rebalance_date=date, holdings=items)
for date, items in sorted(grouped.items())
]
@router.get("/{backtest_id}/transactions", response_model=List[TransactionItem])
async def get_transactions(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all transactions."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
transactions = (
db.query(BacktestTransaction)
.filter(BacktestTransaction.backtest_id == backtest_id)
.order_by(BacktestTransaction.date, BacktestTransaction.id)
.all()
)
# Resolve stock names
tickers = list({t.ticker for t in transactions})
name_service = RebalanceService(db)
names = name_service.get_stock_names(tickers)
return [
TransactionItem(
id=t.id,
date=t.date,
ticker=t.ticker,
name=names.get(t.ticker),
action=t.action,
shares=t.shares,
price=t.price,
commission=t.commission,
)
for t in transactions
]
@router.post("/{backtest_id}/walkforward", response_model=dict)
async def run_walkforward(
backtest_id: int,
request: WalkForwardRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Run walk-forward analysis on a completed backtest."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
if backtest.status != BacktestStatus.COMPLETED:
raise HTTPException(
status_code=400,
detail="백테스트가 완료된 상태에서만 walk-forward 분석이 가능합니다",
)
engine = WalkForwardEngine(db)
try:
engine.run(
backtest_id=backtest_id,
train_months=request.train_months,
test_months=request.test_months,
step_months=request.step_months,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"status": "completed", "backtest_id": backtest_id}
@router.get("/{backtest_id}/walkforward", response_model=WalkForwardResponse)
async def get_walkforward(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get walk-forward analysis results."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
results = (
db.query(WalkForwardResult)
.filter(WalkForwardResult.backtest_id == backtest_id)
.order_by(WalkForwardResult.window_index)
.all()
)
return WalkForwardResponse(
backtest_id=backtest_id,
windows=[
WalkForwardWindowResult(
window_index=r.window_index,
train_start=r.train_start,
train_end=r.train_end,
test_start=r.test_start,
test_end=r.test_end,
test_return=r.test_return,
test_sharpe=r.test_sharpe,
test_mdd=r.test_mdd,
)
for r in results
],
)
@router.delete("/{backtest_id}")
async def delete_backtest(
backtest_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a backtest and all its data."""
backtest = (
db.query(Backtest)
.filter(Backtest.id == backtest_id, Backtest.user_id == current_user.id)
.first()
)
if not backtest:
raise HTTPException(status_code=404, detail="Backtest not found")
# Delete related data
db.query(WalkForwardResult).filter(
WalkForwardResult.backtest_id == backtest_id
).delete()
db.query(BacktestTransaction).filter(
BacktestTransaction.backtest_id == backtest_id
).delete()
db.query(BacktestHolding).filter(
BacktestHolding.backtest_id == backtest_id
).delete()
db.query(BacktestEquityCurve).filter(
BacktestEquityCurve.backtest_id == backtest_id
).delete()
db.query(BacktestResult).filter(BacktestResult.backtest_id == backtest_id).delete()
db.delete(backtest)
db.commit()
return {"message": "Backtest deleted"}