feat: add backtest services (portfolio, engine, worker)
- VirtualPortfolio for portfolio simulation - BacktestEngine for strategy backtesting - Worker for async background execution Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a78c00ecbb
commit
c1ee879cb4
15
backend/app/services/backtest/__init__.py
Normal file
15
backend/app/services/backtest/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from app.services.backtest.engine import BacktestEngine
|
||||||
|
from app.services.backtest.portfolio import VirtualPortfolio, Transaction, HoldingInfo
|
||||||
|
from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics
|
||||||
|
from app.services.backtest.worker import submit_backtest, get_executor_status
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BacktestEngine",
|
||||||
|
"VirtualPortfolio",
|
||||||
|
"Transaction",
|
||||||
|
"HoldingInfo",
|
||||||
|
"MetricsCalculator",
|
||||||
|
"BacktestMetrics",
|
||||||
|
"submit_backtest",
|
||||||
|
"get_executor_status",
|
||||||
|
]
|
||||||
301
backend/app/services/backtest/engine.py
Normal file
301
backend/app/services/backtest/engine.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
Main backtest engine.
|
||||||
|
"""
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.backtest import (
|
||||||
|
Backtest, BacktestResult, BacktestEquityCurve,
|
||||||
|
BacktestHolding, BacktestTransaction, RebalancePeriod,
|
||||||
|
)
|
||||||
|
from app.models.stock import Stock, Price
|
||||||
|
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
|
||||||
|
from app.services.backtest.metrics import MetricsCalculator
|
||||||
|
from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy
|
||||||
|
from app.schemas.strategy import UniverseFilter, FactorWeights
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestEngine:
|
||||||
|
"""
|
||||||
|
Main backtest engine that simulates strategy performance over time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def run(self, backtest_id: int) -> None:
|
||||||
|
"""Execute backtest and save results."""
|
||||||
|
backtest = self.db.query(Backtest).get(backtest_id)
|
||||||
|
if not backtest:
|
||||||
|
raise ValueError(f"Backtest {backtest_id} not found")
|
||||||
|
|
||||||
|
# Initialize portfolio
|
||||||
|
portfolio = VirtualPortfolio(backtest.initial_capital)
|
||||||
|
|
||||||
|
# Generate rebalance dates
|
||||||
|
rebalance_dates = self._generate_rebalance_dates(
|
||||||
|
backtest.start_date,
|
||||||
|
backtest.end_date,
|
||||||
|
backtest.rebalance_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get trading days
|
||||||
|
trading_days = self._get_trading_days(
|
||||||
|
backtest.start_date,
|
||||||
|
backtest.end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not trading_days:
|
||||||
|
raise ValueError("No trading days found in the specified period")
|
||||||
|
|
||||||
|
# Load benchmark data
|
||||||
|
benchmark_prices = self._load_benchmark_prices(
|
||||||
|
backtest.benchmark,
|
||||||
|
backtest.start_date,
|
||||||
|
backtest.end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create strategy instance
|
||||||
|
strategy = self._create_strategy(
|
||||||
|
backtest.strategy_type,
|
||||||
|
backtest.strategy_params,
|
||||||
|
backtest.top_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulation
|
||||||
|
equity_curve_data: List[Dict] = []
|
||||||
|
all_transactions: List[Transaction] = []
|
||||||
|
holdings_history: List[Dict] = []
|
||||||
|
|
||||||
|
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
||||||
|
if initial_benchmark == 0:
|
||||||
|
initial_benchmark = Decimal("1")
|
||||||
|
|
||||||
|
for trading_date in trading_days:
|
||||||
|
# Get prices for this date
|
||||||
|
prices = self._get_prices_for_date(trading_date)
|
||||||
|
names = self._get_stock_names()
|
||||||
|
|
||||||
|
# Rebalance if needed
|
||||||
|
if trading_date in rebalance_dates:
|
||||||
|
# Run strategy to get target stocks
|
||||||
|
target_stocks = strategy.run(
|
||||||
|
universe_filter=UniverseFilter(),
|
||||||
|
top_n=backtest.top_n,
|
||||||
|
base_date=trading_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_tickers = [s.ticker for s in target_stocks.stocks]
|
||||||
|
|
||||||
|
# Execute rebalance
|
||||||
|
transactions = portfolio.rebalance(
|
||||||
|
target_tickers=target_tickers,
|
||||||
|
prices=prices,
|
||||||
|
names=names,
|
||||||
|
commission_rate=backtest.commission_rate,
|
||||||
|
slippage_rate=backtest.slippage_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_transactions.extend([
|
||||||
|
(trading_date, txn) for txn in transactions
|
||||||
|
])
|
||||||
|
|
||||||
|
# Record holdings
|
||||||
|
holdings = portfolio.get_holdings_with_weights(prices, names)
|
||||||
|
holdings_history.append({
|
||||||
|
'date': trading_date,
|
||||||
|
'holdings': holdings,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Record daily value
|
||||||
|
portfolio_value = portfolio.get_value(prices)
|
||||||
|
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
||||||
|
|
||||||
|
# Normalize benchmark to initial capital
|
||||||
|
normalized_benchmark = (
|
||||||
|
benchmark_value / initial_benchmark * backtest.initial_capital
|
||||||
|
)
|
||||||
|
|
||||||
|
equity_curve_data.append({
|
||||||
|
'date': trading_date,
|
||||||
|
'portfolio_value': portfolio_value,
|
||||||
|
'benchmark_value': normalized_benchmark,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data]
|
||||||
|
benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data]
|
||||||
|
|
||||||
|
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
|
||||||
|
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
self._save_results(
|
||||||
|
backtest_id=backtest_id,
|
||||||
|
metrics=metrics,
|
||||||
|
equity_curve_data=equity_curve_data,
|
||||||
|
drawdowns=drawdowns,
|
||||||
|
holdings_history=holdings_history,
|
||||||
|
transactions=all_transactions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_rebalance_dates(
|
||||||
|
self,
|
||||||
|
start_date: date,
|
||||||
|
end_date: date,
|
||||||
|
period: RebalancePeriod,
|
||||||
|
) -> List[date]:
|
||||||
|
"""Generate list of rebalance dates."""
|
||||||
|
dates = []
|
||||||
|
current = start_date
|
||||||
|
|
||||||
|
if period == RebalancePeriod.MONTHLY:
|
||||||
|
delta = relativedelta(months=1)
|
||||||
|
elif period == RebalancePeriod.QUARTERLY:
|
||||||
|
delta = relativedelta(months=3)
|
||||||
|
elif period == RebalancePeriod.SEMI_ANNUAL:
|
||||||
|
delta = relativedelta(months=6)
|
||||||
|
else: # ANNUAL
|
||||||
|
delta = relativedelta(years=1)
|
||||||
|
|
||||||
|
while current <= end_date:
|
||||||
|
dates.append(current)
|
||||||
|
current = current + delta
|
||||||
|
|
||||||
|
return dates
|
||||||
|
|
||||||
|
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
|
||||||
|
"""Get list of trading days from price data."""
|
||||||
|
prices = (
|
||||||
|
self.db.query(Price.date)
|
||||||
|
.filter(Price.date >= start_date)
|
||||||
|
.filter(Price.date <= end_date)
|
||||||
|
.distinct()
|
||||||
|
.order_by(Price.date)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [p[0] for p in prices]
|
||||||
|
|
||||||
|
def _load_benchmark_prices(
|
||||||
|
self,
|
||||||
|
benchmark: str,
|
||||||
|
start_date: date,
|
||||||
|
end_date: date,
|
||||||
|
) -> Dict[date, Decimal]:
|
||||||
|
"""Load benchmark index prices."""
|
||||||
|
# For KOSPI, we'll use a representative ETF or index
|
||||||
|
# Using KODEX 200 (069500) as KOSPI proxy
|
||||||
|
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
|
||||||
|
|
||||||
|
prices = (
|
||||||
|
self.db.query(Price)
|
||||||
|
.filter(Price.ticker == benchmark_ticker)
|
||||||
|
.filter(Price.date >= start_date)
|
||||||
|
.filter(Price.date <= end_date)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {p.date: p.close for p in prices}
|
||||||
|
|
||||||
|
def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]:
|
||||||
|
"""Get all stock prices for a specific date."""
|
||||||
|
prices = (
|
||||||
|
self.db.query(Price)
|
||||||
|
.filter(Price.date == trading_date)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return {p.ticker: p.close for p in prices}
|
||||||
|
|
||||||
|
def _get_stock_names(self) -> Dict[str, str]:
|
||||||
|
"""Get all stock names."""
|
||||||
|
stocks = self.db.query(Stock).all()
|
||||||
|
return {s.ticker: s.name for s in stocks}
|
||||||
|
|
||||||
|
def _create_strategy(
|
||||||
|
self,
|
||||||
|
strategy_type: str,
|
||||||
|
strategy_params: dict,
|
||||||
|
top_n: int,
|
||||||
|
):
|
||||||
|
"""Create strategy instance based on type."""
|
||||||
|
if strategy_type == "multi_factor":
|
||||||
|
strategy = MultiFactorStrategy(self.db)
|
||||||
|
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
|
||||||
|
elif strategy_type == "quality":
|
||||||
|
strategy = QualityStrategy(self.db)
|
||||||
|
strategy._min_fscore = strategy_params.get("min_fscore", 7)
|
||||||
|
elif strategy_type == "value_momentum":
|
||||||
|
strategy = ValueMomentumStrategy(self.db)
|
||||||
|
strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5)))
|
||||||
|
strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy type: {strategy_type}")
|
||||||
|
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
def _save_results(
|
||||||
|
self,
|
||||||
|
backtest_id: int,
|
||||||
|
metrics,
|
||||||
|
equity_curve_data: List[Dict],
|
||||||
|
drawdowns: List[Decimal],
|
||||||
|
holdings_history: List[Dict],
|
||||||
|
transactions: List,
|
||||||
|
) -> None:
|
||||||
|
"""Save all backtest results to database."""
|
||||||
|
# Save metrics
|
||||||
|
result = BacktestResult(
|
||||||
|
backtest_id=backtest_id,
|
||||||
|
total_return=metrics.total_return,
|
||||||
|
cagr=metrics.cagr,
|
||||||
|
mdd=metrics.mdd,
|
||||||
|
sharpe_ratio=metrics.sharpe_ratio,
|
||||||
|
volatility=metrics.volatility,
|
||||||
|
benchmark_return=metrics.benchmark_return,
|
||||||
|
excess_return=metrics.excess_return,
|
||||||
|
)
|
||||||
|
self.db.add(result)
|
||||||
|
|
||||||
|
# Save equity curve
|
||||||
|
for i, point in enumerate(equity_curve_data):
|
||||||
|
curve_point = BacktestEquityCurve(
|
||||||
|
backtest_id=backtest_id,
|
||||||
|
date=point['date'],
|
||||||
|
portfolio_value=point['portfolio_value'],
|
||||||
|
benchmark_value=point['benchmark_value'],
|
||||||
|
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
|
||||||
|
)
|
||||||
|
self.db.add(curve_point)
|
||||||
|
|
||||||
|
# Save holdings history
|
||||||
|
for record in holdings_history:
|
||||||
|
for holding in record['holdings']:
|
||||||
|
h = BacktestHolding(
|
||||||
|
backtest_id=backtest_id,
|
||||||
|
rebalance_date=record['date'],
|
||||||
|
ticker=holding.ticker,
|
||||||
|
name=holding.name,
|
||||||
|
weight=holding.weight,
|
||||||
|
shares=holding.shares,
|
||||||
|
price=holding.price,
|
||||||
|
)
|
||||||
|
self.db.add(h)
|
||||||
|
|
||||||
|
# Save transactions
|
||||||
|
for trading_date, txn in transactions:
|
||||||
|
t = BacktestTransaction(
|
||||||
|
backtest_id=backtest_id,
|
||||||
|
date=trading_date,
|
||||||
|
ticker=txn.ticker,
|
||||||
|
action=txn.action,
|
||||||
|
shares=txn.shares,
|
||||||
|
price=txn.price,
|
||||||
|
commission=txn.commission,
|
||||||
|
)
|
||||||
|
self.db.add(t)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
237
backend/app/services/backtest/portfolio.py
Normal file
237
backend/app/services/backtest/portfolio.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
"""
|
||||||
|
Virtual portfolio simulation for backtesting.
|
||||||
|
"""
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Transaction:
|
||||||
|
"""A single buy/sell transaction."""
|
||||||
|
ticker: str
|
||||||
|
action: str # 'buy' or 'sell'
|
||||||
|
shares: int
|
||||||
|
price: Decimal
|
||||||
|
commission: Decimal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HoldingInfo:
|
||||||
|
"""Information about a single holding."""
|
||||||
|
ticker: str
|
||||||
|
name: str
|
||||||
|
shares: int
|
||||||
|
price: Decimal
|
||||||
|
weight: Decimal
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualPortfolio:
|
||||||
|
"""
|
||||||
|
Simulates a portfolio for backtesting.
|
||||||
|
Handles cash management, buying/selling with transaction costs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_capital: Decimal):
|
||||||
|
self.cash: Decimal = initial_capital
|
||||||
|
self.holdings: Dict[str, int] = {} # ticker -> shares
|
||||||
|
self.initial_capital = initial_capital
|
||||||
|
|
||||||
|
def get_value(self, prices: Dict[str, Decimal]) -> Decimal:
|
||||||
|
"""Calculate total portfolio value."""
|
||||||
|
holdings_value = sum(
|
||||||
|
Decimal(str(shares)) * prices.get(ticker, Decimal("0"))
|
||||||
|
for ticker, shares in self.holdings.items()
|
||||||
|
)
|
||||||
|
return self.cash + holdings_value
|
||||||
|
|
||||||
|
def get_holdings_with_weights(
|
||||||
|
self,
|
||||||
|
prices: Dict[str, Decimal],
|
||||||
|
names: Dict[str, str],
|
||||||
|
) -> List[HoldingInfo]:
|
||||||
|
"""Get current holdings with weights."""
|
||||||
|
total_value = self.get_value(prices)
|
||||||
|
if total_value == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for ticker, shares in self.holdings.items():
|
||||||
|
if shares <= 0:
|
||||||
|
continue
|
||||||
|
price = prices.get(ticker, Decimal("0"))
|
||||||
|
value = Decimal(str(shares)) * price
|
||||||
|
weight = value / total_value * 100
|
||||||
|
|
||||||
|
result.append(HoldingInfo(
|
||||||
|
ticker=ticker,
|
||||||
|
name=names.get(ticker, ticker),
|
||||||
|
shares=shares,
|
||||||
|
price=price,
|
||||||
|
weight=weight,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Sort by weight descending
|
||||||
|
result.sort(key=lambda x: x.weight, reverse=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def rebalance(
|
||||||
|
self,
|
||||||
|
target_tickers: List[str],
|
||||||
|
prices: Dict[str, Decimal],
|
||||||
|
names: Dict[str, str],
|
||||||
|
commission_rate: Decimal,
|
||||||
|
slippage_rate: Decimal,
|
||||||
|
) -> List[Transaction]:
|
||||||
|
"""
|
||||||
|
Rebalance portfolio to equal-weight target tickers.
|
||||||
|
Returns list of transactions executed.
|
||||||
|
"""
|
||||||
|
transactions: List[Transaction] = []
|
||||||
|
target_set = set(target_tickers)
|
||||||
|
|
||||||
|
# Step 1: Sell holdings not in target
|
||||||
|
for ticker in list(self.holdings.keys()):
|
||||||
|
if ticker not in target_set and self.holdings[ticker] > 0:
|
||||||
|
txn = self._sell_all(ticker, prices, commission_rate, slippage_rate)
|
||||||
|
if txn:
|
||||||
|
transactions.append(txn)
|
||||||
|
|
||||||
|
# Step 2: Calculate target value per stock (equal weight)
|
||||||
|
total_value = self.get_value(prices)
|
||||||
|
if len(target_tickers) == 0:
|
||||||
|
return transactions
|
||||||
|
|
||||||
|
target_value_per_stock = total_value / Decimal(str(len(target_tickers)))
|
||||||
|
|
||||||
|
# Step 3: Sell excess from current holdings in target
|
||||||
|
for ticker in target_tickers:
|
||||||
|
if ticker in self.holdings:
|
||||||
|
current_shares = self.holdings[ticker]
|
||||||
|
price = prices.get(ticker, Decimal("0"))
|
||||||
|
if price <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_value = Decimal(str(current_shares)) * price
|
||||||
|
|
||||||
|
if current_value > target_value_per_stock * Decimal("1.05"):
|
||||||
|
# Sell excess
|
||||||
|
excess_value = current_value - target_value_per_stock
|
||||||
|
sell_price = price * (1 - slippage_rate)
|
||||||
|
shares_to_sell = int(excess_value / sell_price)
|
||||||
|
|
||||||
|
if shares_to_sell > 0:
|
||||||
|
txn = self._sell(ticker, shares_to_sell, sell_price, commission_rate)
|
||||||
|
if txn:
|
||||||
|
transactions.append(txn)
|
||||||
|
|
||||||
|
# Step 4: Buy to reach target weight
|
||||||
|
for ticker in target_tickers:
|
||||||
|
price = prices.get(ticker, Decimal("0"))
|
||||||
|
if price <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_shares = self.holdings.get(ticker, 0)
|
||||||
|
current_value = Decimal(str(current_shares)) * price
|
||||||
|
|
||||||
|
if current_value < target_value_per_stock * Decimal("0.95"):
|
||||||
|
# Buy more
|
||||||
|
buy_value = target_value_per_stock - current_value
|
||||||
|
buy_price = price * (1 + slippage_rate)
|
||||||
|
|
||||||
|
# Account for commission in available cash
|
||||||
|
max_buy_value = self.cash / (1 + commission_rate)
|
||||||
|
actual_buy_value = min(buy_value, max_buy_value)
|
||||||
|
|
||||||
|
shares_to_buy = int(actual_buy_value / buy_price)
|
||||||
|
|
||||||
|
if shares_to_buy > 0:
|
||||||
|
txn = self._buy(ticker, shares_to_buy, buy_price, commission_rate)
|
||||||
|
if txn:
|
||||||
|
transactions.append(txn)
|
||||||
|
|
||||||
|
return transactions
|
||||||
|
|
||||||
|
def _buy(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
shares: int,
|
||||||
|
price: Decimal,
|
||||||
|
commission_rate: Decimal,
|
||||||
|
) -> Optional[Transaction]:
|
||||||
|
"""Execute a buy order."""
|
||||||
|
cost = Decimal(str(shares)) * price
|
||||||
|
commission = cost * commission_rate
|
||||||
|
total_cost = cost + commission
|
||||||
|
|
||||||
|
if total_cost > self.cash:
|
||||||
|
# Reduce shares to fit budget
|
||||||
|
available = self.cash / (1 + commission_rate)
|
||||||
|
shares = int(available / price)
|
||||||
|
if shares <= 0:
|
||||||
|
return None
|
||||||
|
cost = Decimal(str(shares)) * price
|
||||||
|
commission = cost * commission_rate
|
||||||
|
total_cost = cost + commission
|
||||||
|
|
||||||
|
self.cash -= total_cost
|
||||||
|
self.holdings[ticker] = self.holdings.get(ticker, 0) + shares
|
||||||
|
|
||||||
|
return Transaction(
|
||||||
|
ticker=ticker,
|
||||||
|
action='buy',
|
||||||
|
shares=shares,
|
||||||
|
price=price,
|
||||||
|
commission=commission,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sell(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
shares: int,
|
||||||
|
price: Decimal,
|
||||||
|
commission_rate: Decimal,
|
||||||
|
) -> Optional[Transaction]:
|
||||||
|
"""Execute a sell order."""
|
||||||
|
current_shares = self.holdings.get(ticker, 0)
|
||||||
|
shares = min(shares, current_shares)
|
||||||
|
|
||||||
|
if shares <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
proceeds = Decimal(str(shares)) * price
|
||||||
|
commission = proceeds * commission_rate
|
||||||
|
net_proceeds = proceeds - commission
|
||||||
|
|
||||||
|
self.cash += net_proceeds
|
||||||
|
self.holdings[ticker] -= shares
|
||||||
|
|
||||||
|
if self.holdings[ticker] <= 0:
|
||||||
|
del self.holdings[ticker]
|
||||||
|
|
||||||
|
return Transaction(
|
||||||
|
ticker=ticker,
|
||||||
|
action='sell',
|
||||||
|
shares=shares,
|
||||||
|
price=price,
|
||||||
|
commission=commission,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sell_all(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
prices: Dict[str, Decimal],
|
||||||
|
commission_rate: Decimal,
|
||||||
|
slippage_rate: Decimal,
|
||||||
|
) -> Optional[Transaction]:
|
||||||
|
"""Sell all shares of a ticker."""
|
||||||
|
shares = self.holdings.get(ticker, 0)
|
||||||
|
if shares <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
price = prices.get(ticker, Decimal("0"))
|
||||||
|
if price <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sell_price = price * (1 - slippage_rate)
|
||||||
|
return self._sell(ticker, shares, sell_price, commission_rate)
|
||||||
80
backend/app/services/backtest/worker.py
Normal file
80
backend/app/services/backtest/worker.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
Background worker for backtest execution.
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.database import SessionLocal
|
||||||
|
from app.models.backtest import Backtest, BacktestStatus
|
||||||
|
from app.services.backtest.engine import BacktestEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Thread pool for background execution
|
||||||
|
executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
|
||||||
|
|
||||||
|
def submit_backtest(backtest_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Submit a backtest job for background execution.
|
||||||
|
Returns immediately, backtest runs in background thread.
|
||||||
|
"""
|
||||||
|
executor.submit(_run_backtest_job, backtest_id)
|
||||||
|
logger.info(f"Backtest {backtest_id} submitted for background execution")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_backtest_job(backtest_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Execute backtest in background thread.
|
||||||
|
Creates its own database session.
|
||||||
|
"""
|
||||||
|
db: Session = SessionLocal()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update status to running
|
||||||
|
backtest = db.query(Backtest).get(backtest_id)
|
||||||
|
if not backtest:
|
||||||
|
logger.error(f"Backtest {backtest_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
backtest.status = BacktestStatus.RUNNING
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Backtest {backtest_id} started")
|
||||||
|
|
||||||
|
# Run backtest
|
||||||
|
engine = BacktestEngine(db)
|
||||||
|
engine.run(backtest_id)
|
||||||
|
|
||||||
|
# Update status to completed
|
||||||
|
backtest.status = BacktestStatus.COMPLETED
|
||||||
|
backtest.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Backtest {backtest_id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Backtest {backtest_id} failed: {e}")
|
||||||
|
|
||||||
|
# Update status to failed
|
||||||
|
try:
|
||||||
|
backtest = db.query(Backtest).get(backtest_id)
|
||||||
|
if backtest:
|
||||||
|
backtest.status = BacktestStatus.FAILED
|
||||||
|
backtest.error_message = str(e)[:1000] # Limit error message length
|
||||||
|
backtest.completed_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
except Exception as commit_error:
|
||||||
|
logger.exception(f"Failed to update backtest status: {commit_error}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_executor_status() -> dict:
|
||||||
|
"""Get current executor status for monitoring."""
|
||||||
|
return {
|
||||||
|
"max_workers": executor._max_workers,
|
||||||
|
"pending_tasks": executor._work_queue.qsize() if hasattr(executor, '_work_queue') else 0,
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user