- VirtualPortfolio for portfolio simulation - BacktestEngine for strategy backtesting - Worker for async background execution Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
302 lines
10 KiB
Python
302 lines
10 KiB
Python
"""
|
|
Main backtest engine.
|
|
"""
|
|
from datetime import date, timedelta
|
|
from decimal import Decimal
|
|
from typing import List, Dict, Optional
|
|
from dateutil.relativedelta import relativedelta
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.backtest import (
|
|
Backtest, BacktestResult, BacktestEquityCurve,
|
|
BacktestHolding, BacktestTransaction, RebalancePeriod,
|
|
)
|
|
from app.models.stock import Stock, Price
|
|
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
|
|
from app.services.backtest.metrics import MetricsCalculator
|
|
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy
|
|
from app.schemas.strategy import UniverseFilter, FactorWeights
|
|
|
|
|
|
class BacktestEngine:
|
|
"""
|
|
Main backtest engine that simulates strategy performance over time.
|
|
"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def run(self, backtest_id: int) -> None:
|
|
"""Execute backtest and save results."""
|
|
backtest = self.db.query(Backtest).get(backtest_id)
|
|
if not backtest:
|
|
raise ValueError(f"Backtest {backtest_id} not found")
|
|
|
|
# Initialize portfolio
|
|
portfolio = VirtualPortfolio(backtest.initial_capital)
|
|
|
|
# Generate rebalance dates
|
|
rebalance_dates = self._generate_rebalance_dates(
|
|
backtest.start_date,
|
|
backtest.end_date,
|
|
backtest.rebalance_period,
|
|
)
|
|
|
|
# Get trading days
|
|
trading_days = self._get_trading_days(
|
|
backtest.start_date,
|
|
backtest.end_date,
|
|
)
|
|
|
|
if not trading_days:
|
|
raise ValueError("No trading days found in the specified period")
|
|
|
|
# Load benchmark data
|
|
benchmark_prices = self._load_benchmark_prices(
|
|
backtest.benchmark,
|
|
backtest.start_date,
|
|
backtest.end_date,
|
|
)
|
|
|
|
# Create strategy instance
|
|
strategy = self._create_strategy(
|
|
backtest.strategy_type,
|
|
backtest.strategy_params,
|
|
backtest.top_n,
|
|
)
|
|
|
|
# Simulation
|
|
equity_curve_data: List[Dict] = []
|
|
all_transactions: List[Transaction] = []
|
|
holdings_history: List[Dict] = []
|
|
|
|
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
|
if initial_benchmark == 0:
|
|
initial_benchmark = Decimal("1")
|
|
|
|
for trading_date in trading_days:
|
|
# Get prices for this date
|
|
prices = self._get_prices_for_date(trading_date)
|
|
names = self._get_stock_names()
|
|
|
|
# Rebalance if needed
|
|
if trading_date in rebalance_dates:
|
|
# Run strategy to get target stocks
|
|
target_stocks = strategy.run(
|
|
universe_filter=UniverseFilter(),
|
|
top_n=backtest.top_n,
|
|
base_date=trading_date,
|
|
)
|
|
|
|
target_tickers = [s.ticker for s in target_stocks.stocks]
|
|
|
|
# Execute rebalance
|
|
transactions = portfolio.rebalance(
|
|
target_tickers=target_tickers,
|
|
prices=prices,
|
|
names=names,
|
|
commission_rate=backtest.commission_rate,
|
|
slippage_rate=backtest.slippage_rate,
|
|
)
|
|
|
|
all_transactions.extend([
|
|
(trading_date, txn) for txn in transactions
|
|
])
|
|
|
|
# Record holdings
|
|
holdings = portfolio.get_holdings_with_weights(prices, names)
|
|
holdings_history.append({
|
|
'date': trading_date,
|
|
'holdings': holdings,
|
|
})
|
|
|
|
# Record daily value
|
|
portfolio_value = portfolio.get_value(prices)
|
|
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
|
|
|
# Normalize benchmark to initial capital
|
|
normalized_benchmark = (
|
|
benchmark_value / initial_benchmark * backtest.initial_capital
|
|
)
|
|
|
|
equity_curve_data.append({
|
|
'date': trading_date,
|
|
'portfolio_value': portfolio_value,
|
|
'benchmark_value': normalized_benchmark,
|
|
})
|
|
|
|
# Calculate metrics
|
|
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data]
|
|
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data]
|
|
|
|
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
|
|
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
|
|
|
|
# Save results
|
|
self._save_results(
|
|
backtest_id=backtest_id,
|
|
metrics=metrics,
|
|
equity_curve_data=equity_curve_data,
|
|
drawdowns=drawdowns,
|
|
holdings_history=holdings_history,
|
|
transactions=all_transactions,
|
|
)
|
|
|
|
def _generate_rebalance_dates(
|
|
self,
|
|
start_date: date,
|
|
end_date: date,
|
|
period: RebalancePeriod,
|
|
) -> List[date]:
|
|
"""Generate list of rebalance dates."""
|
|
dates = []
|
|
current = start_date
|
|
|
|
if period == RebalancePeriod.MONTHLY:
|
|
delta = relativedelta(months=1)
|
|
elif period == RebalancePeriod.QUARTERLY:
|
|
delta = relativedelta(months=3)
|
|
elif period == RebalancePeriod.SEMI_ANNUAL:
|
|
delta = relativedelta(months=6)
|
|
else: # ANNUAL
|
|
delta = relativedelta(years=1)
|
|
|
|
while current <= end_date:
|
|
dates.append(current)
|
|
current = current + delta
|
|
|
|
return dates
|
|
|
|
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
|
|
"""Get list of trading days from price data."""
|
|
prices = (
|
|
self.db.query(Price.date)
|
|
.filter(Price.date >= start_date)
|
|
.filter(Price.date <= end_date)
|
|
.distinct()
|
|
.order_by(Price.date)
|
|
.all()
|
|
)
|
|
return [p[0] for p in prices]
|
|
|
|
def _load_benchmark_prices(
|
|
self,
|
|
benchmark: str,
|
|
start_date: date,
|
|
end_date: date,
|
|
) -> Dict[date, Decimal]:
|
|
"""Load benchmark index prices."""
|
|
# For KOSPI, we'll use a representative ETF or index
|
|
# Using KODEX 200 (069500) as KOSPI proxy
|
|
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
|
|
|
|
prices = (
|
|
self.db.query(Price)
|
|
.filter(Price.ticker == benchmark_ticker)
|
|
.filter(Price.date >= start_date)
|
|
.filter(Price.date <= end_date)
|
|
.all()
|
|
)
|
|
|
|
return {p.date: p.close for p in prices}
|
|
|
|
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]:
|
|
"""Get all stock prices for a specific date."""
|
|
prices = (
|
|
self.db.query(Price)
|
|
.filter(Price.date == trading_date)
|
|
.all()
|
|
)
|
|
return {p.ticker: p.close for p in prices}
|
|
|
|
def _get_stock_names(self) -> Dict[str, str]:
|
|
"""Get all stock names."""
|
|
stocks = self.db.query(Stock).all()
|
|
return {s.ticker: s.name for s in stocks}
|
|
|
|
def _create_strategy(
|
|
self,
|
|
strategy_type: str,
|
|
strategy_params: dict,
|
|
top_n: int,
|
|
):
|
|
"""Create strategy instance based on type."""
|
|
if strategy_type == "multi_factor":
|
|
strategy = MultiFactorStrategy(self.db)
|
|
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
|
|
elif strategy_type == "quality":
|
|
strategy = QualityStrategy(self.db)
|
|
strategy._min_fscore = strategy_params.get("min_fscore", 7)
|
|
elif strategy_type == "value_momentum":
|
|
strategy = ValueMomentumStrategy(self.db)
|
|
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5)))
|
|
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5)))
|
|
else:
|
|
raise ValueError(f"Unknown strategy type: {strategy_type}")
|
|
|
|
return strategy
|
|
|
|
def _save_results(
|
|
self,
|
|
backtest_id: int,
|
|
metrics,
|
|
equity_curve_data: List[Dict],
|
|
drawdowns: List[Decimal],
|
|
holdings_history: List[Dict],
|
|
transactions: List,
|
|
) -> None:
|
|
"""Save all backtest results to database."""
|
|
# Save metrics
|
|
result = BacktestResult(
|
|
backtest_id=backtest_id,
|
|
total_return=metrics.total_return,
|
|
cagr=metrics.cagr,
|
|
mdd=metrics.mdd,
|
|
sharpe_ratio=metrics.sharpe_ratio,
|
|
volatility=metrics.volatility,
|
|
benchmark_return=metrics.benchmark_return,
|
|
excess_return=metrics.excess_return,
|
|
)
|
|
self.db.add(result)
|
|
|
|
# Save equity curve
|
|
for i, point in enumerate(equity_curve_data):
|
|
curve_point = BacktestEquityCurve(
|
|
backtest_id=backtest_id,
|
|
date=point['date'],
|
|
portfolio_value=point['portfolio_value'],
|
|
benchmark_value=point['benchmark_value'],
|
|
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
|
|
)
|
|
self.db.add(curve_point)
|
|
|
|
# Save holdings history
|
|
for record in holdings_history:
|
|
for holding in record['holdings']:
|
|
h = BacktestHolding(
|
|
backtest_id=backtest_id,
|
|
rebalance_date=record['date'],
|
|
ticker=holding.ticker,
|
|
name=holding.name,
|
|
weight=holding.weight,
|
|
shares=holding.shares,
|
|
price=holding.price,
|
|
)
|
|
self.db.add(h)
|
|
|
|
# Save transactions
|
|
for trading_date, txn in transactions:
|
|
t = BacktestTransaction(
|
|
backtest_id=backtest_id,
|
|
date=trading_date,
|
|
ticker=txn.ticker,
|
|
action=txn.action,
|
|
shares=txn.shares,
|
|
price=txn.price,
|
|
commission=txn.commission,
|
|
)
|
|
self.db.add(t)
|
|
|
|
self.db.commit()
|