zephyrdark c1ee879cb4 feat: add backtest services (portfolio, engine, worker)
- VirtualPortfolio for portfolio simulation
- BacktestEngine for strategy backtesting
- Worker for async background execution

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 11:34:48 +09:00

302 lines
10 KiB
Python

"""
Main backtest engine.
"""
from datetime import date, timedelta
from decimal import Decimal
from typing import List, Dict, Optional
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, RebalancePeriod,
)
from app.models.stock import Stock, Price
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
from app.services.backtest.metrics import MetricsCalculator
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy
from app.schemas.strategy import UniverseFilter, FactorWeights
class BacktestEngine:
"""
Main backtest engine that simulates strategy performance over time.
"""
def __init__(self, db: Session):
self.db = db
def run(self, backtest_id: int) -> None:
"""Execute backtest and save results."""
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
# Initialize portfolio
portfolio = VirtualPortfolio(backtest.initial_capital)
# Generate rebalance dates
rebalance_dates = self._generate_rebalance_dates(
backtest.start_date,
backtest.end_date,
backtest.rebalance_period,
)
# Get trading days
trading_days = self._get_trading_days(
backtest.start_date,
backtest.end_date,
)
if not trading_days:
raise ValueError("No trading days found in the specified period")
# Load benchmark data
benchmark_prices = self._load_benchmark_prices(
backtest.benchmark,
backtest.start_date,
backtest.end_date,
)
# Create strategy instance
strategy = self._create_strategy(
backtest.strategy_type,
backtest.strategy_params,
backtest.top_n,
)
# Simulation
equity_curve_data: List[Dict] = []
all_transactions: List[Transaction] = []
holdings_history: List[Dict] = []
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
for trading_date in trading_days:
# Get prices for this date
prices = self._get_prices_for_date(trading_date)
names = self._get_stock_names()
# Rebalance if needed
if trading_date in rebalance_dates:
# Run strategy to get target stocks
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
# Execute rebalance
transactions = portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
all_transactions.extend([
(trading_date, txn) for txn in transactions
])
# Record holdings
holdings = portfolio.get_holdings_with_weights(prices, names)
holdings_history.append({
'date': trading_date,
'holdings': holdings,
})
# Record daily value
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
# Normalize benchmark to initial capital
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve_data.append({
'date': trading_date,
'portfolio_value': portfolio_value,
'benchmark_value': normalized_benchmark,
})
# Calculate metrics
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data]
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
# Save results
self._save_results(
backtest_id=backtest_id,
metrics=metrics,
equity_curve_data=equity_curve_data,
drawdowns=drawdowns,
holdings_history=holdings_history,
transactions=all_transactions,
)
def _generate_rebalance_dates(
self,
start_date: date,
end_date: date,
period: RebalancePeriod,
) -> List[date]:
"""Generate list of rebalance dates."""
dates = []
current = start_date
if period == RebalancePeriod.MONTHLY:
delta = relativedelta(months=1)
elif period == RebalancePeriod.QUARTERLY:
delta = relativedelta(months=3)
elif period == RebalancePeriod.SEMI_ANNUAL:
delta = relativedelta(months=6)
else: # ANNUAL
delta = relativedelta(years=1)
while current <= end_date:
dates.append(current)
current = current + delta
return dates
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
"""Get list of trading days from price data."""
prices = (
self.db.query(Price.date)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.distinct()
.order_by(Price.date)
.all()
)
return [p[0] for p in prices]
def _load_benchmark_prices(
self,
benchmark: str,
start_date: date,
end_date: date,
) -> Dict[date, Decimal]:
"""Load benchmark index prices."""
# For KOSPI, we'll use a representative ETF or index
# Using KODEX 200 (069500) as KOSPI proxy
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
prices = (
self.db.query(Price)
.filter(Price.ticker == benchmark_ticker)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.all()
)
return {p.date: p.close for p in prices}
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]:
"""Get all stock prices for a specific date."""
prices = (
self.db.query(Price)
.filter(Price.date == trading_date)
.all()
)
return {p.ticker: p.close for p in prices}
def _get_stock_names(self) -> Dict[str, str]:
"""Get all stock names."""
stocks = self.db.query(Stock).all()
return {s.ticker: s.name for s in stocks}
def _create_strategy(
self,
strategy_type: str,
strategy_params: dict,
top_n: int,
):
"""Create strategy instance based on type."""
if strategy_type == "multi_factor":
strategy = MultiFactorStrategy(self.db)
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
elif strategy_type == "quality":
strategy = QualityStrategy(self.db)
strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5)))
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5)))
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
return strategy
def _save_results(
self,
backtest_id: int,
metrics,
equity_curve_data: List[Dict],
drawdowns: List[Decimal],
holdings_history: List[Dict],
transactions: List,
) -> None:
"""Save all backtest results to database."""
# Save metrics
result = BacktestResult(
backtest_id=backtest_id,
total_return=metrics.total_return,
cagr=metrics.cagr,
mdd=metrics.mdd,
sharpe_ratio=metrics.sharpe_ratio,
volatility=metrics.volatility,
benchmark_return=metrics.benchmark_return,
excess_return=metrics.excess_return,
)
self.db.add(result)
# Save equity curve
for i, point in enumerate(equity_curve_data):
curve_point = BacktestEquityCurve(
backtest_id=backtest_id,
date=point['date'],
portfolio_value=point['portfolio_value'],
benchmark_value=point['benchmark_value'],
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
)
self.db.add(curve_point)
# Save holdings history
for record in holdings_history:
for holding in record['holdings']:
h = BacktestHolding(
backtest_id=backtest_id,
rebalance_date=record['date'],
ticker=holding.ticker,
name=holding.name,
weight=holding.weight,
shares=holding.shares,
price=holding.price,
)
self.db.add(h)
# Save transactions
for trading_date, txn in transactions:
t = BacktestTransaction(
backtest_id=backtest_id,
date=trading_date,
ticker=txn.ticker,
action=txn.action,
shares=txn.shares,
price=txn.price,
commission=txn.commission,
)
self.db.add(t)
self.db.commit()