머니페니 f6db08c9bd feat: improve security, performance, and add missing features
- Remove hardcoded database_url/jwt_secret defaults, require env vars
- Add DB indexes for stocks.market, market_cap, backtests.user_id
- Optimize backtest engine: preload all prices, move stock_names out of loop
- Fix backtest API auth: filter by user_id at query level (6 endpoints)
- Add manual transaction entry modal on portfolio detail page
- Replace console.error with toast.error in signals, backtest, data explorer
- Add backtest delete button with confirmation dialog
- Replace simulated sine chart with real snapshot data
- Add strategy-to-portfolio apply flow with dialog
- Add DC pension risk asset ratio >70% warning on rebalance page
- Add backtest comparison page with metrics table and overlay chart
2026-03-20 12:27:05 +09:00

458 lines
15 KiB
Python

"""
Main backtest engine.
"""
import logging
from dataclasses import dataclass, field
from datetime import date, timedelta
from decimal import Decimal
from typing import List, Dict, Optional
from dateutil.relativedelta import relativedelta
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.backtest import (
Backtest,
BacktestResult,
BacktestEquityCurve,
BacktestHolding,
BacktestTransaction,
RebalancePeriod,
)
from app.models.stock import Stock, Price
from app.services.backtest.portfolio import VirtualPortfolio, Transaction
from app.services.backtest.metrics import MetricsCalculator
from app.services.strategy import (
MultiFactorStrategy,
QualityStrategy,
ValueMomentumStrategy,
)
from app.schemas.strategy import UniverseFilter, FactorWeights
logger = logging.getLogger(__name__)
@dataclass
class DataValidationResult:
"""Result of pre-backtest data validation."""
is_valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
class BacktestEngine:
"""
Main backtest engine that simulates strategy performance over time.
"""
def __init__(self, db: Session):
self.db = db
def run(self, backtest_id: int) -> None:
"""Execute backtest and save results."""
backtest = self.db.query(Backtest).get(backtest_id)
if not backtest:
raise ValueError(f"Backtest {backtest_id} not found")
# Initialize portfolio
portfolio = VirtualPortfolio(backtest.initial_capital)
# Generate rebalance dates
rebalance_dates = self._generate_rebalance_dates(
backtest.start_date,
backtest.end_date,
backtest.rebalance_period,
)
# Get trading days
trading_days = self._get_trading_days(
backtest.start_date,
backtest.end_date,
)
if not trading_days:
raise ValueError("No trading days found in the specified period")
# Load benchmark data
benchmark_prices = self._load_benchmark_prices(
backtest.benchmark,
backtest.start_date,
backtest.end_date,
)
# Pre-backtest data validation
validation = self._validate_data(
trading_days=trading_days,
benchmark_prices=benchmark_prices,
benchmark=backtest.benchmark,
start_date=backtest.start_date,
end_date=backtest.end_date,
)
for warning in validation.warnings:
logger.warning(f"Backtest {backtest_id}: {warning}")
if not validation.is_valid:
raise ValueError("데이터 검증 실패:\n" + "\n".join(validation.errors))
# Create strategy instance
strategy = self._create_strategy(
backtest.strategy_type,
backtest.strategy_params,
backtest.top_n,
)
# Simulation
equity_curve_data: List[Dict] = []
all_transactions: List[Transaction] = []
holdings_history: List[Dict] = []
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
if initial_benchmark == 0:
initial_benchmark = Decimal("1")
names = self._get_stock_names()
all_date_prices = self._load_all_prices_by_date(
backtest.start_date,
backtest.end_date,
)
for trading_date in trading_days:
prices = all_date_prices.get(trading_date, {})
# Warn about holdings with missing prices
missing = [
t
for t in portfolio.holdings
if portfolio.holdings[t] > 0 and t not in prices
]
if missing:
logger.warning(
f"{trading_date}: 보유 종목 가격 누락 {missing} (0원으로 처리됨)"
)
# Rebalance if needed
if trading_date in rebalance_dates:
# Run strategy to get target stocks
target_stocks = strategy.run(
universe_filter=UniverseFilter(),
top_n=backtest.top_n,
base_date=trading_date,
)
target_tickers = [s.ticker for s in target_stocks.stocks]
# Execute rebalance
transactions = portfolio.rebalance(
target_tickers=target_tickers,
prices=prices,
names=names,
commission_rate=backtest.commission_rate,
slippage_rate=backtest.slippage_rate,
)
all_transactions.extend([(trading_date, txn) for txn in transactions])
# Record holdings
holdings = portfolio.get_holdings_with_weights(prices, names)
holdings_history.append(
{
"date": trading_date,
"holdings": holdings,
}
)
# Record daily value
portfolio_value = portfolio.get_value(prices)
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
# Normalize benchmark to initial capital
normalized_benchmark = (
benchmark_value / initial_benchmark * backtest.initial_capital
)
equity_curve_data.append(
{
"date": trading_date,
"portfolio_value": portfolio_value,
"benchmark_value": normalized_benchmark,
}
)
# Calculate metrics
portfolio_values = [
Decimal(str(e["portfolio_value"])) for e in equity_curve_data
]
benchmark_values = [
Decimal(str(e["benchmark_value"])) for e in equity_curve_data
]
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
drawdowns = MetricsCalculator.calculate_drawdown_series(portfolio_values)
# Save results
self._save_results(
backtest_id=backtest_id,
metrics=metrics,
equity_curve_data=equity_curve_data,
drawdowns=drawdowns,
holdings_history=holdings_history,
transactions=all_transactions,
)
def _validate_data(
self,
trading_days: List[date],
benchmark_prices: Dict[date, Decimal],
benchmark: str,
start_date: date,
end_date: date,
) -> DataValidationResult:
"""
Validate price data completeness before running backtest.
Checks:
1. Minimum trading days requirement
2. Benchmark data coverage
3. Overall price data density (tickers per trading day)
4. Large date gaps in trading days
"""
result = DataValidationResult()
total_days = trading_days
num_trading_days = len(total_days)
calendar_days = (end_date - start_date).days
# 1. Minimum trading days check
if calendar_days > 30:
# Expect at least 60% of calendar days to be trading days
# (weekends ~28%, holidays ~3% => ~69% expected)
expected_min = int(calendar_days * 0.5)
if num_trading_days < expected_min:
result.errors.append(
f"거래일 수 부족: {num_trading_days}"
f"(기간 {calendar_days}일 중 최소 {expected_min}일 필요)"
)
result.is_valid = False
# 2. Benchmark data coverage
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
benchmark_coverage = sum(1 for d in total_days if d in benchmark_prices)
benchmark_pct = (
benchmark_coverage / num_trading_days * 100 if num_trading_days > 0 else 0
)
if benchmark_coverage == 0:
result.errors.append(f"벤치마크({benchmark_ticker}) 가격 데이터 없음")
result.is_valid = False
elif benchmark_pct < 90:
result.warnings.append(
f"벤치마크({benchmark_ticker}) 데이터 커버리지 낮음: "
f"{benchmark_coverage}/{num_trading_days}일 ({benchmark_pct:.1f}%)"
)
# 3. Price data density per trading day (sample check)
# Check first, middle, last trading days
sample_dates = [
total_days[0],
total_days[num_trading_days // 2],
total_days[-1],
]
for sample_date in sample_dates:
ticker_count = (
self.db.query(func.count(Price.ticker))
.filter(Price.date == sample_date)
.scalar()
)
if ticker_count == 0:
result.errors.append(f"{sample_date} 가격 데이터 없음 (종목 0개)")
result.is_valid = False
elif ticker_count < 100:
result.warnings.append(f"{sample_date} 종목 수 적음: {ticker_count}")
# 4. Large gaps in trading days (> 7 calendar days excluding normal weekends)
for i in range(1, num_trading_days):
gap = (total_days[i] - total_days[i - 1]).days
if gap > 7:
result.warnings.append(
f"거래일 갭 발견: {total_days[i - 1]} ~ {total_days[i]} ({gap}일)"
)
if result.is_valid and not result.warnings:
logger.info(
f"데이터 검증 통과: 거래일 {num_trading_days}일, "
f"벤치마크 커버리지 {benchmark_pct:.1f}%"
)
return result
def _generate_rebalance_dates(
self,
start_date: date,
end_date: date,
period: RebalancePeriod,
) -> List[date]:
"""Generate list of rebalance dates."""
dates = []
current = start_date
if period == RebalancePeriod.MONTHLY:
delta = relativedelta(months=1)
elif period == RebalancePeriod.QUARTERLY:
delta = relativedelta(months=3)
elif period == RebalancePeriod.SEMI_ANNUAL:
delta = relativedelta(months=6)
else: # ANNUAL
delta = relativedelta(years=1)
while current <= end_date:
dates.append(current)
current = current + delta
return dates
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
"""Get list of trading days from price data."""
prices = (
self.db.query(Price.date)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.distinct()
.order_by(Price.date)
.all()
)
return [p[0] for p in prices]
def _load_benchmark_prices(
self,
benchmark: str,
start_date: date,
end_date: date,
) -> Dict[date, Decimal]:
"""Load benchmark index prices."""
# For KOSPI, we'll use a representative ETF or index
# Using KODEX 200 (069500) as KOSPI proxy
benchmark_ticker = "069500" if benchmark == "KOSPI" else "069500"
prices = (
self.db.query(Price)
.filter(Price.ticker == benchmark_ticker)
.filter(Price.date >= start_date)
.filter(Price.date <= end_date)
.all()
)
return {p.date: p.close for p in prices}
def _load_all_prices_by_date(
self,
start_date: date,
end_date: date,
) -> Dict[date, Dict[str, Decimal]]:
prices = (
self.db.query(Price)
.filter(Price.date >= start_date, Price.date <= end_date)
.all()
)
result: Dict[date, Dict[str, Decimal]] = {}
for p in prices:
if p.date not in result:
result[p.date] = {}
result[p.date][p.ticker] = p.close
return result
def _get_stock_names(self) -> Dict[str, str]:
stocks = self.db.query(Stock.ticker, Stock.name).all()
return {s.ticker: s.name for s in stocks}
def _create_strategy(
self,
strategy_type: str,
strategy_params: dict,
top_n: int,
):
"""Create strategy instance based on type."""
if strategy_type == "multi_factor":
strategy = MultiFactorStrategy(self.db)
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
elif strategy_type == "quality":
strategy = QualityStrategy(self.db)
strategy._min_fscore = strategy_params.get("min_fscore", 7)
elif strategy_type == "value_momentum":
strategy = ValueMomentumStrategy(self.db)
strategy._value_weight = Decimal(
str(strategy_params.get("value_weight", 0.5))
)
strategy._momentum_weight = Decimal(
str(strategy_params.get("momentum_weight", 0.5))
)
else:
raise ValueError(f"Unknown strategy type: {strategy_type}")
return strategy
def _save_results(
self,
backtest_id: int,
metrics,
equity_curve_data: List[Dict],
drawdowns: List[Decimal],
holdings_history: List[Dict],
transactions: List,
) -> None:
"""Save all backtest results to database."""
# Save metrics
result = BacktestResult(
backtest_id=backtest_id,
total_return=metrics.total_return,
cagr=metrics.cagr,
mdd=metrics.mdd,
sharpe_ratio=metrics.sharpe_ratio,
volatility=metrics.volatility,
benchmark_return=metrics.benchmark_return,
excess_return=metrics.excess_return,
)
self.db.add(result)
# Save equity curve
for i, point in enumerate(equity_curve_data):
curve_point = BacktestEquityCurve(
backtest_id=backtest_id,
date=point["date"],
portfolio_value=point["portfolio_value"],
benchmark_value=point["benchmark_value"],
drawdown=drawdowns[i] if i < len(drawdowns) else Decimal("0"),
)
self.db.add(curve_point)
# Save holdings history
for record in holdings_history:
for holding in record["holdings"]:
h = BacktestHolding(
backtest_id=backtest_id,
rebalance_date=record["date"],
ticker=holding.ticker,
name=holding.name,
weight=holding.weight,
shares=holding.shares,
price=holding.price,
)
self.db.add(h)
# Save transactions
for trading_date, txn in transactions:
t = BacktestTransaction(
backtest_id=backtest_id,
date=trading_date,
ticker=txn.ticker,
action=txn.action,
shares=txn.shares,
price=txn.price,
commission=txn.commission,
)
self.db.add(t)
self.db.commit()