""" Daily simulation backtest engine for signal-based strategies (KJB). """ import logging from datetime import date from decimal import Decimal from typing import Dict, List import pandas as pd from sqlalchemy.orm import Session from app.models.backtest import ( Backtest, BacktestResult, BacktestEquityCurve, BacktestTransaction, ) from app.models.stock import Stock, Price from app.services.backtest.trading_portfolio import TradingPortfolio, TradingTransaction from app.services.backtest.metrics import MetricsCalculator from app.services.strategy.kjb import KJBSignalGenerator logger = logging.getLogger(__name__) class DailyBacktestEngine: """ Backtest engine for KJB signal-based strategy. Runs daily simulation with individual position management. """ def __init__(self, db: Session): self.db = db self.signal_gen = KJBSignalGenerator() def run(self, backtest_id: int) -> None: backtest = self.db.query(Backtest).get(backtest_id) if not backtest: raise ValueError(f"Backtest {backtest_id} not found") params = backtest.strategy_params or {} portfolio = TradingPortfolio( initial_capital=backtest.initial_capital, max_positions=params.get("max_positions", 10), cash_reserve_ratio=Decimal(str(params.get("cash_reserve_ratio", 0.3))), stop_loss_pct=Decimal(str(params.get("stop_loss_pct", 0.03))), target1_pct=Decimal(str(params.get("target1_pct", 0.05))), target2_pct=Decimal(str(params.get("target2_pct", 0.10))), ) trading_days = self._get_trading_days(backtest.start_date, backtest.end_date) if not trading_days: raise ValueError("No trading days found") universe_tickers = self._get_universe_tickers() # Load all data upfront for performance all_prices = self._load_all_prices(universe_tickers, backtest.start_date, backtest.end_date) stock_dfs = self._build_stock_dfs(all_prices, universe_tickers) kospi_df = self._load_kospi_df(backtest.start_date, backtest.end_date) benchmark_prices = self._load_benchmark_prices(backtest.benchmark, backtest.start_date, backtest.end_date) # Build day -> prices lookup day_prices_map: Dict[date, Dict[str, Decimal]] = {} for p in all_prices: if p.date not in day_prices_map: day_prices_map[p.date] = {} day_prices_map[p.date][p.ticker] = p.close equity_curve_data: List[Dict] = [] all_transactions: List[tuple] = [] initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1")) if initial_benchmark == 0: initial_benchmark = Decimal("1") for trading_date in trading_days: day_prices = day_prices_map.get(trading_date, {}) # 1. Check exits first exit_txns = portfolio.check_exits( date=trading_date, prices=day_prices, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, ) for txn in exit_txns: all_transactions.append((trading_date, txn)) # 2. Check entry signals for ticker in universe_tickers: if ticker in portfolio.positions: continue if ticker not in stock_dfs or ticker not in day_prices: continue stock_df = stock_dfs[ticker] if trading_date not in stock_df.index: continue hist = stock_df.loc[stock_df.index <= trading_date] if len(hist) < 21: continue kospi_hist = kospi_df.loc[kospi_df.index <= trading_date] if len(kospi_hist) < 11: continue signals = self.signal_gen.generate_signals(hist, kospi_hist) if trading_date in signals.index and signals.loc[trading_date, "buy"]: txn = portfolio.enter_position( ticker=ticker, price=day_prices[ticker], date=trading_date, commission_rate=backtest.commission_rate, slippage_rate=backtest.slippage_rate, ) if txn: all_transactions.append((trading_date, txn)) # 3. Record daily value portfolio_value = portfolio.get_value(day_prices) benchmark_value = benchmark_prices.get(trading_date, initial_benchmark) normalized_benchmark = benchmark_value / initial_benchmark * backtest.initial_capital equity_curve_data.append({ "date": trading_date, "portfolio_value": portfolio_value, "benchmark_value": normalized_benchmark, }) # Calculate and save portfolio_values = [Decimal(str(e["portfolio_value"])) for e in equity_curve_data] benchmark_values = [Decimal(str(e["benchmark_value"])) for e in equity_curve_data] metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values) drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values) self._save_results(backtest_id, metrics, equity_curve_data, drawdowns, all_transactions) def _get_trading_days(self, start_date: date, end_date: date) -> List[date]: prices = ( self.db.query(Price.date) .filter(Price.date >= start_date, Price.date <= end_date) .distinct() .order_by(Price.date) .all() ) return [p[0] for p in prices] def _get_universe_tickers(self) -> List[str]: stocks = ( self.db.query(Stock) .filter(Stock.market_cap.isnot(None)) .order_by(Stock.market_cap.desc()) .limit(30) .all() ) return [s.ticker for s in stocks] def _load_all_prices(self, tickers: List[str], start_date: date, end_date: date) -> List: return ( self.db.query(Price) .filter(Price.ticker.in_(tickers)) .filter(Price.date >= start_date, Price.date <= end_date) .all() ) def _load_kospi_df(self, start_date: date, end_date: date) -> pd.DataFrame: prices = ( self.db.query(Price) .filter(Price.ticker == "069500") .filter(Price.date >= start_date, Price.date <= end_date) .order_by(Price.date) .all() ) if not prices: return pd.DataFrame(columns=["close"]) data = [{"date": p.date, "close": float(p.close)} for p in prices] return pd.DataFrame(data).set_index("date") def _load_benchmark_prices(self, benchmark: str, start_date: date, end_date: date) -> Dict[date, Decimal]: prices = ( self.db.query(Price) .filter(Price.ticker == "069500") .filter(Price.date >= start_date, Price.date <= end_date) .all() ) return {p.date: p.close for p in prices} def _build_stock_dfs(self, price_data: List, tickers: List[str]) -> Dict[str, pd.DataFrame]: ticker_rows: Dict[str, list] = {t: [] for t in tickers} for p in price_data: if p.ticker in ticker_rows: ticker_rows[p.ticker].append({ "date": p.date, "open": float(p.open), "high": float(p.high), "low": float(p.low), "close": float(p.close), "volume": int(p.volume), }) result = {} for ticker, rows in ticker_rows.items(): if rows: df = pd.DataFrame(rows).set_index("date").sort_index() result[ticker] = df return result def _save_results(self, backtest_id, metrics, equity_curve_data, drawdowns, transactions): result = BacktestResult( backtest_id=backtest_id, total_return=metrics.total_return, cagr=metrics.cagr, mdd=metrics.mdd, sharpe_ratio=metrics.sharpe_ratio, volatility=metrics.volatility, benchmark_return=metrics.benchmark_return, excess_return=metrics.excess_return, ) self.db.add(result) for i, point in enumerate(equity_curve_data): curve_point = BacktestEquityCurve( backtest_id=backtest_id, date=point["date"], portfolio_value=point["portfolio_value"], benchmark_value=point["benchmark_value"], drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"), ) self.db.add(curve_point) for trading_date, txn in transactions: t = BacktestTransaction( backtest_id=backtest_id, date=trading_date, ticker=txn.ticker, action=txn.action, shares=txn.shares, price=txn.price, commission=txn.commission, ) self.db.add(t) self.db.commit()