diff --git a/backend/app/services/backtest/__init__.py b/backend/app/services/backtest/__init__.py new file mode 100644 index 0000000..3826f3d --- /dev/null +++ b/backend/app/services/backtest/__init__.py @@ -0,0 +1,15 @@ +from app.services.backtest.engine import BacktestEngine +from app.services.backtest.portfolio import VirtualPortfolio, Transaction, HoldingInfo +from app.services.backtest.metrics import MetricsCalculator, BacktestMetrics +from app.services.backtest.worker import submit_backtest, get_executor_status + +__all__ = [ + "BacktestEngine", + "VirtualPortfolio", + "Transaction", + "HoldingInfo", + "MetricsCalculator", + "BacktestMetrics", + "submit_backtest", + "get_executor_status", +] diff --git a/backend/app/services/backtest/engine.py b/backend/app/services/backtest/engine.py new file mode 100644 index 0000000..e9f822c --- /dev/null +++ b/backend/app/services/backtest/engine.py @@ -0,0 +1,301 @@ +""" +Main backtest engine. +""" +from datetime import date, timedelta +from decimal import Decimal +from typing import List, Dict, Optional +from dateutil.relativedelta import relativedelta + +from sqlalchemy.orm import Session + +from app.models.backtest import ( + Backtest, BacktestResult, BacktestEquityCurve, + BacktestHolding, BacktestTransaction, RebalancePeriod, +) +from app.models.stock import Stock, Price +from app.services.backtest.portfolio import VirtualPortfolio, Transaction +from app.services.backtest.metrics import MetricsCalculator +from app.services.strategy import MultiFactorStrategy, QualityStrategy, ValueMomentumStrategy +from app.schemas.strategy import UniverseFilter, FactorWeights + + +class BacktestEngine: + """ + Main backtest engine that simulates strategy performance over time. + """ + + def __init__(self, db: Session): + self.db = db + + def run(self, backtest_id: int) -> None: + """Execute backtest and save results.""" + backtest = self.db.query(Backtest).get(backtest_id) + if not backtest: + raise ValueError(f"Backtest {backtest_id} not found") + + # Initialize portfolio + portfolio = VirtualPortfolio(backtest.initial_capital) + + # Generate rebalance dates + rebalance_dates = self._generate_rebalance_dates( + backtest.start_date, + backtest.end_date, + backtest.rebalance_period, + ) + + # Get trading days + trading_days = self._get_trading_days( + backtest.start_date, + backtest.end_date, + ) + + if not trading_days: + raise ValueError("No trading days found in the specified period") + + # Load benchmark data + benchmark_prices = self._load_benchmark_prices( + backtest.benchmark, + backtest.start_date, + backtest.end_date, + ) + + # Create strategy instance + strategy = self._create_strategy( + backtest.strategy_type, + backtest.strategy_params, + backtest.top_n, + ) + + # Simulation + equity_curve_data: List[Dict] = [] + all_transactions: List[Transaction] = [] + holdings_history: List[Dict] = [] + + initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) + if initial_benchmark == 0: + initial_benchmark = Decimal("1") + + for trading_date in trading_days: + # Get prices for this date + prices = self._get_prices_for_date(trading_date) + names = self._get_stock_names() + + # Rebalance if needed + if trading_date in rebalance_dates: + # Run strategy to get target stocks + target_stocks = strategy.run( + universe_filter=UniverseFilter(), + top_n=backtest.top_n, + base_date=trading_date, + ) + + target_tickers = [s.ticker for s in target_stocks.stocks] + + # Execute rebalance + transactions = portfolio.rebalance( + target_tickers=target_tickers, + prices=prices, + names=names, + commission_rate=backtest.commission_rate, + slippage_rate=backtest.slippage_rate, + ) + + all_transactions.extend([ + (trading_date, txn) for txn in transactions + ]) + + # Record holdings + holdings = portfolio.get_holdings_with_weights(prices, names) + holdings_history.append({ + 'date': trading_date, + 'holdings': holdings, + }) + + # Record daily value + portfolio_value = portfolio.get_value(prices) + benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) + + # Normalize benchmark to initial capital + normalized_benchmark = ( + benchmark_value / initial_benchmark * backtest.initial_capital + ) + + equity_curve_data.append({ + 'date': trading_date, + 'portfolio_value': portfolio_value, + 'benchmark_value': normalized_benchmark, + }) + + # Calculate metrics + portfolio_values = [Decimal(str(e['portfolio_value'])) for e in equity_curve_data] + benchmark_values = [Decimal(str(e['benchmark_value'])) for e in equity_curve_data] + + metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) + drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values) + + # Save results + self._save_results( + backtest_id=backtest_id, + metrics=metrics, + equity_curve_data=equity_curve_data, + drawdowns=drawdowns, + holdings_history=holdings_history, + transactions=all_transactions, + ) + + def _generate_rebalance_dates( + self, + start_date: date, + end_date: date, + period: RebalancePeriod, + ) -> List[date]: + """Generate list of rebalance dates.""" + dates = [] + current = start_date + + if period == RebalancePeriod.MONTHLY: + delta = relativedelta(months=1) + elif period == RebalancePeriod.QUARTERLY: + delta = relativedelta(months=3) + elif period == RebalancePeriod.SEMI_ANNUAL: + delta = relativedelta(months=6) + else: # ANNUAL + delta = relativedelta(years=1) + + while current <= end_date: + dates.append(current) + current = current + delta + + return dates + + def _get_trading_days(self, start_date: date, end_date: date) -> List[date]: + """Get list of trading days from price data.""" + prices = ( + self.db.query(Price.date) + .filter(Price.date >= start_date) + .filter(Price.date <= end_date) + .distinct() + .order_by(Price.date) + .all() + ) + return [p[0] for p in prices] + + def _load_benchmark_prices( + self, + benchmark: str, + start_date: date, + end_date: date, + ) -> Dict[date, Decimal]: + """Load benchmark index prices.""" + # For KOSPI, we'll use a representative ETF or index + # Using KODEX 200 (069500) as KOSPI proxy + benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500" + + prices = ( + self.db.query(Price) + .filter(Price.ticker == benchmark_ticker) + .filter(Price.date >= start_date) + .filter(Price.date <= end_date) + .all() + ) + + return {p.date: p.close for p in prices} + + def _get_prices_for_date(self, trading_date: date) -> Dict[str, Decimal]: + """Get all stock prices for a specific date.""" + prices = ( + self.db.query(Price) + .filter(Price.date == trading_date) + .all() + ) + return {p.ticker: p.close for p in prices} + + def _get_stock_names(self) -> Dict[str, str]: + """Get all stock names.""" + stocks = self.db.query(Stock).all() + return {s.ticker: s.name for s in stocks} + + def _create_strategy( + self, + strategy_type: str, + strategy_params: dict, + top_n: int, + ): + """Create strategy instance based on type.""" + if strategy_type == "multi_factor": + strategy = MultiFactorStrategy(self.db) + strategy._weights = FactorWeights(**strategy_params.get("weights", {})) + elif strategy_type == "quality": + strategy = QualityStrategy(self.db) + strategy._min_fscore = strategy_params.get("min_fscore", 7) + elif strategy_type == "value_momentum": + strategy = ValueMomentumStrategy(self.db) + strategy._value_weight = Decimal(str(strategy_params.get("value_weight", 0.5))) + strategy._momentum_weight = Decimal(str(strategy_params.get("momentum_weight", 0.5))) + else: + raise ValueError(f"Unknown strategy type: {strategy_type}") + + return strategy + + def _save_results( + self, + backtest_id: int, + metrics, + equity_curve_data: List[Dict], + drawdowns: List[Decimal], + holdings_history: List[Dict], + transactions: List, + ) -> None: + """Save all backtest results to database.""" + # Save metrics + result = BacktestResult( + backtest_id=backtest_id, + total_return=metrics.total_return, + cagr=metrics.cagr, + mdd=metrics.mdd, + sharpe_ratio=metrics.sharpe_ratio, + volatility=metrics.volatility, + benchmark_return=metrics.benchmark_return, + excess_return=metrics.excess_return, + ) + self.db.add(result) + + # Save equity curve + for i, point in enumerate(equity_curve_data): + curve_point = BacktestEquityCurve( + backtest_id=backtest_id, + date=point['date'], + portfolio_value=point['portfolio_value'], + benchmark_value=point['benchmark_value'], + drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"), + ) + self.db.add(curve_point) + + # Save holdings history + for record in holdings_history: + for holding in record['holdings']: + h = BacktestHolding( + backtest_id=backtest_id, + rebalance_date=record['date'], + ticker=holding.ticker, + name=holding.name, + weight=holding.weight, + shares=holding.shares, + price=holding.price, + ) + self.db.add(h) + + # Save transactions + for trading_date, txn in transactions: + t = BacktestTransaction( + backtest_id=backtest_id, + date=trading_date, + ticker=txn.ticker, + action=txn.action, + shares=txn.shares, + price=txn.price, + commission=txn.commission, + ) + self.db.add(t) + + self.db.commit() diff --git a/backend/app/services/backtest/portfolio.py b/backend/app/services/backtest/portfolio.py new file mode 100644 index 0000000..3888d1e --- /dev/null +++ b/backend/app/services/backtest/portfolio.py @@ -0,0 +1,237 @@ +""" +Virtual portfolio simulation for backtesting. +""" +from decimal import Decimal +from typing import Dict, List, Optional +from dataclasses import dataclass, field + + +@dataclass +class Transaction: + """A single buy/sell transaction.""" + ticker: str + action: str # 'buy' or 'sell' + shares: int + price: Decimal + commission: Decimal + + +@dataclass +class HoldingInfo: + """Information about a single holding.""" + ticker: str + name: str + shares: int + price: Decimal + weight: Decimal + + +class VirtualPortfolio: + """ + Simulates a portfolio for backtesting. + Handles cash management, buying/selling with transaction costs. + """ + + def __init__(self, initial_capital: Decimal): + self.cash: Decimal = initial_capital + self.holdings: Dict[str, int] = {} # ticker -> shares + self.initial_capital = initial_capital + + def get_value(self, prices: Dict[str, Decimal]) -> Decimal: + """Calculate total portfolio value.""" + holdings_value = sum( + Decimal(str(shares)) * prices.get(ticker, Decimal("0")) + for ticker, shares in self.holdings.items() + ) + return self.cash + holdings_value + + def get_holdings_with_weights( + self, + prices: Dict[str, Decimal], + names: Dict[str, str], + ) -> List[HoldingInfo]: + """Get current holdings with weights.""" + total_value = self.get_value(prices) + if total_value == 0: + return [] + + result = [] + for ticker, shares in self.holdings.items(): + if shares <= 0: + continue + price = prices.get(ticker, Decimal("0")) + value = Decimal(str(shares)) * price + weight = value / total_value * 100 + + result.append(HoldingInfo( + ticker=ticker, + name=names.get(ticker, ticker), + shares=shares, + price=price, + weight=weight, + )) + + # Sort by weight descending + result.sort(key=lambda x: x.weight, reverse=True) + return result + + def rebalance( + self, + target_tickers: List[str], + prices: Dict[str, Decimal], + names: Dict[str, str], + commission_rate: Decimal, + slippage_rate: Decimal, + ) -> List[Transaction]: + """ + Rebalance portfolio to equal-weight target tickers. + Returns list of transactions executed. + """ + transactions: List[Transaction] = [] + target_set = set(target_tickers) + + # Step 1: Sell holdings not in target + for ticker in list(self.holdings.keys()): + if ticker not in target_set and self.holdings[ticker] > 0: + txn = self._sell_all(ticker, prices, commission_rate, slippage_rate) + if txn: + transactions.append(txn) + + # Step 2: Calculate target value per stock (equal weight) + total_value = self.get_value(prices) + if len(target_tickers) == 0: + return transactions + + target_value_per_stock = total_value / Decimal(str(len(target_tickers))) + + # Step 3: Sell excess from current holdings in target + for ticker in target_tickers: + if ticker in self.holdings: + current_shares = self.holdings[ticker] + price = prices.get(ticker, Decimal("0")) + if price <= 0: + continue + + current_value = Decimal(str(current_shares)) * price + + if current_value > target_value_per_stock * Decimal("1.05"): + # Sell excess + excess_value = current_value - target_value_per_stock + sell_price = price * (1 - slippage_rate) + shares_to_sell = int(excess_value / sell_price) + + if shares_to_sell > 0: + txn = self._sell(ticker, shares_to_sell, sell_price, commission_rate) + if txn: + transactions.append(txn) + + # Step 4: Buy to reach target weight + for ticker in target_tickers: + price = prices.get(ticker, Decimal("0")) + if price <= 0: + continue + + current_shares = self.holdings.get(ticker, 0) + current_value = Decimal(str(current_shares)) * price + + if current_value < target_value_per_stock * Decimal("0.95"): + # Buy more + buy_value = target_value_per_stock - current_value + buy_price = price * (1 + slippage_rate) + + # Account for commission in available cash + max_buy_value = self.cash / (1 + commission_rate) + actual_buy_value = min(buy_value, max_buy_value) + + shares_to_buy = int(actual_buy_value / buy_price) + + if shares_to_buy > 0: + txn = self._buy(ticker, shares_to_buy, buy_price, commission_rate) + if txn: + transactions.append(txn) + + return transactions + + def _buy( + self, + ticker: str, + shares: int, + price: Decimal, + commission_rate: Decimal, + ) -> Optional[Transaction]: + """Execute a buy order.""" + cost = Decimal(str(shares)) * price + commission = cost * commission_rate + total_cost = cost + commission + + if total_cost > self.cash: + # Reduce shares to fit budget + available = self.cash / (1 + commission_rate) + shares = int(available / price) + if shares <= 0: + return None + cost = Decimal(str(shares)) * price + commission = cost * commission_rate + total_cost = cost + commission + + self.cash -= total_cost + self.holdings[ticker] = self.holdings.get(ticker, 0) + shares + + return Transaction( + ticker=ticker, + action='buy', + shares=shares, + price=price, + commission=commission, + ) + + def _sell( + self, + ticker: str, + shares: int, + price: Decimal, + commission_rate: Decimal, + ) -> Optional[Transaction]: + """Execute a sell order.""" + current_shares = self.holdings.get(ticker, 0) + shares = min(shares, current_shares) + + if shares <= 0: + return None + + proceeds = Decimal(str(shares)) * price + commission = proceeds * commission_rate + net_proceeds = proceeds - commission + + self.cash += net_proceeds + self.holdings[ticker] -= shares + + if self.holdings[ticker] <= 0: + del self.holdings[ticker] + + return Transaction( + ticker=ticker, + action='sell', + shares=shares, + price=price, + commission=commission, + ) + + def _sell_all( + self, + ticker: str, + prices: Dict[str, Decimal], + commission_rate: Decimal, + slippage_rate: Decimal, + ) -> Optional[Transaction]: + """Sell all shares of a ticker.""" + shares = self.holdings.get(ticker, 0) + if shares <= 0: + return None + + price = prices.get(ticker, Decimal("0")) + if price <= 0: + return None + + sell_price = price * (1 - slippage_rate) + return self._sell(ticker, shares, sell_price, commission_rate) diff --git a/backend/app/services/backtest/worker.py b/backend/app/services/backtest/worker.py new file mode 100644 index 0000000..d8a228e --- /dev/null +++ b/backend/app/services/backtest/worker.py @@ -0,0 +1,80 @@ +""" +Background worker for backtest execution. +""" +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor +import logging + +from sqlalchemy.orm import Session + +from app.core.database import SessionLocal +from app.models.backtest import Backtest, BacktestStatus +from app.services.backtest.engine import BacktestEngine + +logger = logging.getLogger(__name__) + +# Thread pool for background execution +executor = ThreadPoolExecutor(max_workers=2) + + +def submit_backtest(backtest_id: int) -> None: + """ + Submit a backtest job for background execution. + Returns immediately, backtest runs in background thread. + """ + executor.submit(_run_backtest_job, backtest_id) + logger.info(f"Backtest {backtest_id} submitted for background execution") + + +def _run_backtest_job(backtest_id: int) -> None: + """ + Execute backtest in background thread. + Creates its own database session. + """ + db: Session = SessionLocal() + + try: + # Update status to running + backtest = db.query(Backtest).get(backtest_id) + if not backtest: + logger.error(f"Backtest {backtest_id} not found") + return + + backtest.status = BacktestStatus.RUNNING + db.commit() + logger.info(f"Backtest {backtest_id} started") + + # Run backtest + engine = BacktestEngine(db) + engine.run(backtest_id) + + # Update status to completed + backtest.status = BacktestStatus.COMPLETED + backtest.completed_at = datetime.utcnow() + db.commit() + logger.info(f"Backtest {backtest_id} completed successfully") + + except Exception as e: + logger.exception(f"Backtest {backtest_id} failed: {e}") + + # Update status to failed + try: + backtest = db.query(Backtest).get(backtest_id) + if backtest: + backtest.status = BacktestStatus.FAILED + backtest.error_message = str(e)[:1000] # Limit error message length + backtest.completed_at = datetime.utcnow() + db.commit() + except Exception as commit_error: + logger.exception(f"Failed to update backtest status: {commit_error}") + + finally: + db.close() + + +def get_executor_status() -> dict: + """Get current executor status for monitoring.""" + return { + "max_workers": executor._max_workers, + "pending_tasks": executor._work_queue.qsize() if hasattr(executor, '_work_queue') else 0, + }