feat: add backtest services (portfolio, engine, worker)

- VirtualPortfolio for portfolio simulation
- BacktestEngine for strategy backtesting
- Worker for async background execution

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
zephyrdark 2026-02-03 11:34:48 +09:00
parent a78c00ecbb
commit c1ee879cb4
4 changed files with 633 additions and 0 deletions

View File

@ -0,0 +1,15 @@
from app.services.backtest.engine import BacktestEngine
from app.services.backtest.portfolio import VirtualPortfolio, Transaction, HoldingInfo
from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics
from app.services.backtest.worker import submit_backtest, get_executor_status
__all__ = [
"BacktestEngine",
"VirtualPortfolio",
"Transaction",
"HoldingInfo",
"MetricsCalculator",
"BacktestMetrics",
"submit_backtest",
"get_executor_status",
]

View File

@ -0,0 +1,301 @@
"""
Main backtest engine.
"""
from datetime import date, timedelta
from decimal import Decimal
from typing import List, Dict, Optional
from dateutil.relativedelta import relativedelta
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest, BacktestResult, BacktestEquityCurve,
BacktestHolding, BacktestTransaction, RebalancePeriod,
)
from app.models.stock import Stock, Price
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
from app.services.backtest.metrics import MetricsCalculator
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy
from app.schemas.strategy import UniverseFilter, FactorWeights
class BacktestEngine:
"""
Main backtest engine that simulates strategy performance over time.
"""
def __init__(self, db: Session):
self.db = db
def run(self, backtest_id: int) -> None:
"""Execute backtest and save results."""
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
# Initialize portfolio
portfolio = VirtualPortfolio(backtest.initial_capital)
# Generate rebalance dates
rebalance_dates = self._generate_rebalance_dates(
backtest.start_date,
backtest.end_date,
backtest.rebalance_period,
)
# Get trading days
trading_days = self._get_trading_days(
backtest.start_date,
backtest.end_date,
)
if not trading_days:
raise ValueError("No trading days found in the specified period")
# Load benchmark data
benchmark_prices = self._load_benchmark_prices(
backtest.benchmark,
backtest.start_date,
backtest.end_date,
)
# Create strategy instance
strategy = self._create_strategy(
backtest.strategy_type,
backtest.strategy_params,
backtest.top_n,
)
# Simulation
equity_curve_data: List[Dict] = []
all_transactions: List[Transaction] = []
holdings_history: List[Dict] = []
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
for trading_date in trading_days:
# Get prices for this date
prices = self._get_prices_for_date(trading_date)
names = self._get_stock_names()
# Rebalance if needed
if trading_date in rebalance_dates:
# Run strategy to get target stocks
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
# Execute rebalance
transactions = portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
all_transactions.extend([
(trading_date, txn) for txn in transactions
])
# Record holdings
holdings = portfolio.get_holdings_with_weights(prices, names)
holdings_history.append({
'date': trading_date,
'holdings': holdings,
})
# Record daily value
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
# Normalize benchmark to initial capital
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve_data.append({
'date': trading_date,
'portfolio_value': portfolio_value,
'benchmark_value': normalized_benchmark,
})
# Calculate metrics
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data]
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
# Save results
self._save_results(
backtest_id=backtest_id,
metrics=metrics,
equity_curve_data=equity_curve_data,
drawdowns=drawdowns,
holdings_history=holdings_history,
transactions=all_transactions,
)
def _generate_rebalance_dates(
self,
start_date: date,
end_date: date,
period: RebalancePeriod,
) -> List[date]:
"""Generate list of rebalance dates."""
dates = []
current = start_date
if period == RebalancePeriod.MONTHLY:
delta = relativedelta(months=1)
elif period == RebalancePeriod.QUARTERLY:
delta = relativedelta(months=3)
elif period == RebalancePeriod.SEMI_ANNUAL:
delta = relativedelta(months=6)
else: # ANNUAL
delta = relativedelta(years=1)
while current <= end_date:
dates.append(current)
current = current + delta
return dates
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
"""Get list of trading days from price data."""
prices = (
self.db.query(Price.date)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.distinct()
.order_by(Price.date)
.all()
)
return [p[0] for p in prices]
def _load_benchmark_prices(
self,
benchmark: str,
start_date: date,
end_date: date,
) -> Dict[date, Decimal]:
"""Load benchmark index prices."""
# For KOSPI, we'll use a representative ETF or index
# Using KODEX 200 (069500) as KOSPI proxy
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
prices = (
self.db.query(Price)
.filter(Price.ticker == benchmark_ticker)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.all()
)
return {p.date: p.close for p in prices}
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]:
"""Get all stock prices for a specific date."""
prices = (
self.db.query(Price)
.filter(Price.date == trading_date)
.all()
)
return {p.ticker: p.close for p in prices}
def _get_stock_names(self) -> Dict[str, str]:
"""Get all stock names."""
stocks = self.db.query(Stock).all()
return {s.ticker: s.name for s in stocks}
def _create_strategy(
self,
strategy_type: str,
strategy_params: dict,
top_n: int,
):
"""Create strategy instance based on type."""
if strategy_type == "multi_factor":
strategy = MultiFactorStrategy(self.db)
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
elif strategy_type == "quality":
strategy = QualityStrategy(self.db)
strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5)))
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5)))
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
return strategy
def _save_results(
self,
backtest_id: int,
metrics,
equity_curve_data: List[Dict],
drawdowns: List[Decimal],
holdings_history: List[Dict],
transactions: List,
) -> None:
"""Save all backtest results to database."""
# Save metrics
result = BacktestResult(
backtest_id=backtest_id,
total_return=metrics.total_return,
cagr=metrics.cagr,
mdd=metrics.mdd,
sharpe_ratio=metrics.sharpe_ratio,
volatility=metrics.volatility,
benchmark_return=metrics.benchmark_return,
excess_return=metrics.excess_return,
)
self.db.add(result)
# Save equity curve
for i, point in enumerate(equity_curve_data):
curve_point = BacktestEquityCurve(
backtest_id=backtest_id,
date=point['date'],
portfolio_value=point['portfolio_value'],
benchmark_value=point['benchmark_value'],
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
)
self.db.add(curve_point)
# Save holdings history
for record in holdings_history:
for holding in record['holdings']:
h = BacktestHolding(
backtest_id=backtest_id,
rebalance_date=record['date'],
ticker=holding.ticker,
name=holding.name,
weight=holding.weight,
shares=holding.shares,
price=holding.price,
)
self.db.add(h)
# Save transactions
for trading_date, txn in transactions:
t = BacktestTransaction(
backtest_id=backtest_id,
date=trading_date,
ticker=txn.ticker,
action=txn.action,
shares=txn.shares,
price=txn.price,
commission=txn.commission,
)
self.db.add(t)
self.db.commit()

View File

@ -0,0 +1,237 @@
"""
Virtual portfolio simulation for backtesting.
"""
from decimal import Decimal
from typing import Dict, List, Optional
from dataclasses import dataclass, field
@dataclass
class Transaction:
"""A single buy/sell transaction."""
ticker: str
action: str # 'buy' or 'sell'
shares: int
price: Decimal
commission: Decimal
@dataclass
class HoldingInfo:
"""Information about a single holding."""
ticker: str
name: str
shares: int
price: Decimal
weight: Decimal
class VirtualPortfolio:
"""
Simulates a portfolio for backtesting.
Handles cash management, buying/selling with transaction costs.
"""
def __init__(self, initial_capital: Decimal):
self.cash: Decimal = initial_capital
self.holdings: Dict[str, int] = {} # ticker -> shares
self.initial_capital = initial_capital
def get_value(self, prices: Dict[str, Decimal]) -> Decimal:
"""Calculate total portfolio value."""
holdings_value = sum(
Decimal(str(shares)) * prices.get(ticker, Decimal("0"))
for ticker, shares in self.holdings.items()
)
return self.cash + holdings_value
def get_holdings_with_weights(
self,
prices: Dict[str, Decimal],
names: Dict[str, str],
) -> List[HoldingInfo]:
"""Get current holdings with weights."""
total_value = self.get_value(prices)
if total_value == 0:
return []
result = []
for ticker, shares in self.holdings.items():
if shares <= 0:
continue
price = prices.get(ticker, Decimal("0"))
value = Decimal(str(shares)) * price
weight = value / total_value * 100
result.append(HoldingInfo(
ticker=ticker,
name=names.get(ticker, ticker),
shares=shares,
price=price,
weight=weight,
))
# Sort by weight descending
result.sort(key=lambda x: x.weight, reverse=True)
return result
def rebalance(
self,
target_tickers: List[str],
prices: Dict[str, Decimal],
names: Dict[str, str],
commission_rate: Decimal,
slippage_rate: Decimal,
) -> List[Transaction]:
"""
Rebalance portfolio to equal-weight target tickers.
Returns list of transactions executed.
"""
transactions: List[Transaction] = []
target_set = set(target_tickers)
# Step 1: Sell holdings not in target
for ticker in list(self.holdings.keys()):
if ticker not in target_set and self.holdings[ticker] > 0:
txn = self._sell_all(ticker, prices, commission_rate, slippage_rate)
if txn:
transactions.append(txn)
# Step 2: Calculate target value per stock (equal weight)
total_value = self.get_value(prices)
if len(target_tickers) == 0:
return transactions
target_value_per_stock = total_value / Decimal(str(len(target_tickers)))
# Step 3: Sell excess from current holdings in target
for ticker in target_tickers:
if ticker in self.holdings:
current_shares = self.holdings[ticker]
price = prices.get(ticker, Decimal("0"))
if price <= 0:
continue
current_value = Decimal(str(current_shares)) * price
if current_value > target_value_per_stock * Decimal("1.05"):
# Sell excess
excess_value = current_value - target_value_per_stock
sell_price = price * (1 - slippage_rate)
shares_to_sell = int(excess_value / sell_price)
if shares_to_sell > 0:
txn = self._sell(ticker, shares_to_sell, sell_price, commission_rate)
if txn:
transactions.append(txn)
# Step 4: Buy to reach target weight
for ticker in target_tickers:
price = prices.get(ticker, Decimal("0"))
if price <= 0:
continue
current_shares = self.holdings.get(ticker, 0)
current_value = Decimal(str(current_shares)) * price
if current_value < target_value_per_stock * Decimal("0.95"):
# Buy more
buy_value = target_value_per_stock - current_value
buy_price = price * (1 + slippage_rate)
# Account for commission in available cash
max_buy_value = self.cash / (1 + commission_rate)
actual_buy_value = min(buy_value, max_buy_value)
shares_to_buy = int(actual_buy_value / buy_price)
if shares_to_buy > 0:
txn = self._buy(ticker, shares_to_buy, buy_price, commission_rate)
if txn:
transactions.append(txn)
return transactions
def _buy(
self,
ticker: str,
shares: int,
price: Decimal,
commission_rate: Decimal,
) -> Optional[Transaction]:
"""Execute a buy order."""
cost = Decimal(str(shares)) * price
commission = cost * commission_rate
total_cost = cost + commission
if total_cost > self.cash:
# Reduce shares to fit budget
available = self.cash / (1 + commission_rate)
shares = int(available / price)
if shares <= 0:
return None
cost = Decimal(str(shares)) * price
commission = cost * commission_rate
total_cost = cost + commission
self.cash -= total_cost
self.holdings[ticker] = self.holdings.get(ticker, 0) + shares
return Transaction(
ticker=ticker,
action='buy',
shares=shares,
price=price,
commission=commission,
)
def _sell(
self,
ticker: str,
shares: int,
price: Decimal,
commission_rate: Decimal,
) -> Optional[Transaction]:
"""Execute a sell order."""
current_shares = self.holdings.get(ticker, 0)
shares = min(shares, current_shares)
if shares <= 0:
return None
proceeds = Decimal(str(shares)) * price
commission = proceeds * commission_rate
net_proceeds = proceeds - commission
self.cash += net_proceeds
self.holdings[ticker] -= shares
if self.holdings[ticker] <= 0:
del self.holdings[ticker]
return Transaction(
ticker=ticker,
action='sell',
shares=shares,
price=price,
commission=commission,
)
def _sell_all(
self,
ticker: str,
prices: Dict[str, Decimal],
commission_rate: Decimal,
slippage_rate: Decimal,
) -> Optional[Transaction]:
"""Sell all shares of a ticker."""
shares = self.holdings.get(ticker, 0)
if shares <= 0:
return None
price = prices.get(ticker, Decimal("0"))
if price <= 0:
return None
sell_price = price * (1 - slippage_rate)
return self._sell(ticker, shares, sell_price, commission_rate)

View File

@ -0,0 +1,80 @@
"""
Background worker for backtest execution.
"""
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import logging
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models.backtest import Backtest, BacktestStatus
from app.services.backtest.engine import BacktestEngine
logger = logging.getLogger(__name__)
# Thread pool for background execution
executor = ThreadPoolExecutor(max_workers=2)
def submit_backtest(backtest_id: int) -> None:
"""
Submit a backtest job for background execution.
Returns immediately, backtest runs in background thread.
"""
executor.submit(_run_backtest_job, backtest_id)
logger.info(f"Backtest {backtest_id} submitted for background execution")
def _run_backtest_job(backtest_id: int) -> None:
"""
Execute backtest in background thread.
Creates its own database session.
"""
db: Session = SessionLocal()
try:
# Update status to running
backtest = db.query(Backtest).get(backtest_id)
if not backtest:
logger.error(f"Backtest {backtest_id} not found")
return
backtest.status = BacktestStatus.RUNNING
db.commit()
logger.info(f"Backtest {backtest_id} started")
# Run backtest
engine = BacktestEngine(db)
engine.run(backtest_id)
# Update status to completed
backtest.status = BacktestStatus.COMPLETED
backtest.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Backtest {backtest_id} completed successfully")
except Exception as e:
logger.exception(f"Backtest {backtest_id} failed: {e}")
# Update status to failed
try:
backtest = db.query(Backtest).get(backtest_id)
if backtest:
backtest.status = BacktestStatus.FAILED
backtest.error_message = str(e)[:1000] # Limit error message length
backtest.completed_at = datetime.utcnow()
db.commit()
except Exception as commit_error:
logger.exception(f"Failed to update backtest status: {commit_error}")
finally:
db.close()
def get_executor_status() -> dict:
"""Get current executor status for monitoring."""
return {
"max_workers": executor._max_workers,
"pending_tasks": executor._work_queue.qsize() if hasattr(executor, '_work_queue') else 0,
}