Phase 1: - Real-time signal alerts (Discord/Telegram webhook) - Trading journal with entry/exit tracking - Position sizing calculator (Fixed/Kelly/ATR) Phase 2: - Pension asset allocation (DC/IRP 70% risk limit) - Drawdown monitoring with SVG gauge - Benchmark dashboard (portfolio vs KOSPI vs deposit) Phase 3: - Tax benefit simulation (Korean pension tax rules) - Correlation matrix heatmap - Parameter optimizer with grid search + overfit detection
447 lines
17 KiB
Python
447 lines
17 KiB
Python
"""
|
|
Grid-search strategy optimizer.
|
|
|
|
Runs backtests across parameter combinations and ranks by selected metric.
|
|
Reuses existing BacktestEngine / DailyBacktestEngine logic without DB persistence.
|
|
"""
|
|
import itertools
|
|
import logging
|
|
from dataclasses import asdict
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.schemas.optimizer import (
|
|
DEFAULT_GRIDS,
|
|
OptimizeRequest,
|
|
OptimizeResponse,
|
|
OptimizeResultItem,
|
|
)
|
|
from app.services.backtest.metrics import MetricsCalculator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _expand_grid(param_grid: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
|
|
"""Expand parameter grid into list of parameter dicts."""
|
|
keys = list(param_grid.keys())
|
|
values = list(param_grid.values())
|
|
combos = []
|
|
for combo in itertools.product(*values):
|
|
combos.append(dict(zip(keys, combo)))
|
|
return combos
|
|
|
|
|
|
def _build_strategy_params(strategy_type: str, flat_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Convert flat grid params to nested strategy_params dict."""
|
|
result: Dict[str, Any] = {}
|
|
for key, value in flat_params.items():
|
|
parts = key.split(".")
|
|
target = result
|
|
for part in parts[:-1]:
|
|
if part not in target:
|
|
target[part] = {}
|
|
target = target[part]
|
|
target[parts[-1]] = value
|
|
return result
|
|
|
|
|
|
class OptimizerService:
|
|
"""Grid-search optimizer that runs backtests across parameter combinations."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def optimize(self, request: OptimizeRequest) -> OptimizeResponse:
|
|
grid = request.param_grid or DEFAULT_GRIDS.get(request.strategy_type, {})
|
|
if not grid:
|
|
raise ValueError(f"No parameter grid for strategy type: {request.strategy_type}")
|
|
|
|
combinations = _expand_grid(grid)
|
|
logger.info(
|
|
f"Optimizer: {request.strategy_type}, {len(combinations)} combinations"
|
|
)
|
|
|
|
results: List[Tuple[Dict[str, Any], Dict[str, float]]] = []
|
|
|
|
for combo in combinations:
|
|
try:
|
|
metrics = self._run_single(request, combo)
|
|
results.append((combo, metrics))
|
|
except Exception as e:
|
|
logger.warning(f"Optimizer: failed for params {combo}: {e}")
|
|
|
|
if not results:
|
|
return OptimizeResponse(
|
|
strategy_type=request.strategy_type,
|
|
total_combinations=len(combinations),
|
|
results=[],
|
|
best_params={},
|
|
)
|
|
|
|
# Sort by rank_by metric (descending, except mdd which is negative so also desc)
|
|
rank_by = request.rank_by
|
|
results.sort(key=lambda x: x[1].get(rank_by, 0), reverse=True)
|
|
|
|
items = []
|
|
for i, (combo, metrics) in enumerate(results, 1):
|
|
items.append(OptimizeResultItem(
|
|
rank=i,
|
|
params=combo,
|
|
total_return=Decimal(str(metrics["total_return"])),
|
|
cagr=Decimal(str(metrics["cagr"])),
|
|
mdd=Decimal(str(metrics["mdd"])),
|
|
sharpe_ratio=Decimal(str(metrics["sharpe_ratio"])),
|
|
volatility=Decimal(str(metrics["volatility"])),
|
|
benchmark_return=Decimal(str(metrics["benchmark_return"])),
|
|
excess_return=Decimal(str(metrics["excess_return"])),
|
|
))
|
|
|
|
return OptimizeResponse(
|
|
strategy_type=request.strategy_type,
|
|
total_combinations=len(combinations),
|
|
results=items,
|
|
best_params=items[0].params if items else {},
|
|
)
|
|
|
|
def _run_single(
|
|
self, request: OptimizeRequest, flat_params: Dict[str, Any]
|
|
) -> Dict[str, float]:
|
|
"""Run a single backtest with given params, return metrics dict."""
|
|
strategy_params = _build_strategy_params(request.strategy_type, flat_params)
|
|
|
|
if request.strategy_type == "kjb":
|
|
return self._run_kjb(request, strategy_params, flat_params)
|
|
else:
|
|
return self._run_factor(request, strategy_params)
|
|
|
|
def _run_kjb(
|
|
self,
|
|
request: OptimizeRequest,
|
|
strategy_params: Dict[str, Any],
|
|
flat_params: Dict[str, Any],
|
|
) -> Dict[str, float]:
|
|
"""Run KJB daily backtest in-memory."""
|
|
import pandas as pd
|
|
from app.models.stock import Stock, Price
|
|
from app.services.backtest.trading_portfolio import TradingPortfolio
|
|
from app.services.strategy.kjb import KJBSignalGenerator
|
|
|
|
signal_gen = KJBSignalGenerator()
|
|
|
|
portfolio = TradingPortfolio(
|
|
initial_capital=request.initial_capital,
|
|
max_positions=strategy_params.get("max_positions", 10),
|
|
cash_reserve_ratio=Decimal(str(strategy_params.get("cash_reserve_ratio", 0.3))),
|
|
stop_loss_pct=Decimal(str(flat_params.get("stop_loss_pct", 0.03))),
|
|
target1_pct=Decimal(str(flat_params.get("target1_pct", 0.05))),
|
|
target2_pct=Decimal(str(flat_params.get("target2_pct", 0.10))),
|
|
)
|
|
|
|
rs_lookback = flat_params.get("rs_lookback", 10)
|
|
breakout_lookback = flat_params.get("breakout_lookback", 20)
|
|
|
|
trading_days = self._get_trading_days(request.start_date, request.end_date)
|
|
if not trading_days:
|
|
raise ValueError("No trading days found")
|
|
|
|
universe_tickers = self._get_universe_tickers()
|
|
all_prices = self._load_all_prices(universe_tickers, request.start_date, request.end_date)
|
|
stock_dfs = self._build_stock_dfs(all_prices, universe_tickers)
|
|
kospi_df = self._load_kospi_df(request.start_date, request.end_date)
|
|
benchmark_prices = self._load_benchmark_prices(request.benchmark, request.start_date, request.end_date)
|
|
|
|
day_prices_map: Dict[date, Dict[str, Decimal]] = {}
|
|
for p in all_prices:
|
|
if p.date not in day_prices_map:
|
|
day_prices_map[p.date] = {}
|
|
day_prices_map[p.date][p.ticker] = p.close
|
|
|
|
equity_curve: List[Dict] = []
|
|
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
|
if initial_benchmark == 0:
|
|
initial_benchmark = Decimal("1")
|
|
|
|
for trading_date in trading_days:
|
|
day_prices = day_prices_map.get(trading_date, {})
|
|
|
|
portfolio.check_exits(
|
|
date=trading_date,
|
|
prices=day_prices,
|
|
commission_rate=request.commission_rate,
|
|
slippage_rate=request.slippage_rate,
|
|
)
|
|
|
|
for ticker in universe_tickers:
|
|
if ticker in portfolio.positions:
|
|
continue
|
|
if ticker not in stock_dfs or ticker not in day_prices:
|
|
continue
|
|
stock_df = stock_dfs[ticker]
|
|
if trading_date not in stock_df.index:
|
|
continue
|
|
hist = stock_df.loc[stock_df.index <= trading_date]
|
|
if len(hist) < 21:
|
|
continue
|
|
kospi_hist = kospi_df.loc[kospi_df.index <= trading_date]
|
|
if len(kospi_hist) < 11:
|
|
continue
|
|
|
|
signals = signal_gen.generate_signals(
|
|
hist, kospi_hist,
|
|
rs_lookback=rs_lookback,
|
|
breakout_lookback=breakout_lookback,
|
|
)
|
|
|
|
if trading_date in signals.index and signals.loc[trading_date, "buy"]:
|
|
portfolio.enter_position(
|
|
ticker=ticker,
|
|
price=day_prices[ticker],
|
|
date=trading_date,
|
|
commission_rate=request.commission_rate,
|
|
slippage_rate=request.slippage_rate,
|
|
)
|
|
|
|
portfolio_value = portfolio.get_value(day_prices)
|
|
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
|
normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital
|
|
|
|
equity_curve.append({
|
|
"portfolio_value": portfolio_value,
|
|
"benchmark_value": normalized_benchmark,
|
|
})
|
|
|
|
return self._compute_metrics(equity_curve)
|
|
|
|
def _run_factor(
|
|
self,
|
|
request: OptimizeRequest,
|
|
strategy_params: Dict[str, Any],
|
|
) -> Dict[str, float]:
|
|
"""Run factor-based backtest in-memory (multi_factor, quality, value_momentum)."""
|
|
from dateutil.relativedelta import relativedelta
|
|
from app.models.backtest import RebalancePeriod
|
|
from app.services.backtest.portfolio import VirtualPortfolio
|
|
from app.services.strategy import (
|
|
MultiFactorStrategy,
|
|
QualityStrategy,
|
|
ValueMomentumStrategy,
|
|
)
|
|
from app.schemas.strategy import UniverseFilter, FactorWeights
|
|
|
|
strategy = self._create_strategy(
|
|
request.strategy_type, strategy_params, request.top_n,
|
|
)
|
|
|
|
portfolio = VirtualPortfolio(request.initial_capital)
|
|
|
|
trading_days = self._get_trading_days(request.start_date, request.end_date)
|
|
if not trading_days:
|
|
raise ValueError("No trading days found")
|
|
|
|
rebalance_dates = self._generate_rebalance_dates(
|
|
request.start_date, request.end_date, RebalancePeriod.QUARTERLY,
|
|
)
|
|
|
|
benchmark_prices = self._load_benchmark_prices(
|
|
request.benchmark, request.start_date, request.end_date,
|
|
)
|
|
all_date_prices = self._load_all_prices_by_date(
|
|
request.start_date, request.end_date,
|
|
)
|
|
names = self._get_stock_names()
|
|
|
|
initial_benchmark = benchmark_prices.get(trading_days[0], Decimal("1"))
|
|
if initial_benchmark == 0:
|
|
initial_benchmark = Decimal("1")
|
|
|
|
equity_curve: List[Dict] = []
|
|
|
|
for trading_date in trading_days:
|
|
prices = all_date_prices.get(trading_date, {})
|
|
|
|
if trading_date in rebalance_dates:
|
|
target_stocks = strategy.run(
|
|
universe_filter=UniverseFilter(),
|
|
top_n=request.top_n,
|
|
base_date=trading_date,
|
|
)
|
|
target_tickers = [s.ticker for s in target_stocks.stocks]
|
|
|
|
portfolio.rebalance(
|
|
target_tickers=target_tickers,
|
|
prices=prices,
|
|
names=names,
|
|
commission_rate=request.commission_rate,
|
|
slippage_rate=request.slippage_rate,
|
|
)
|
|
|
|
portfolio_value = portfolio.get_value(prices)
|
|
benchmark_value = benchmark_prices.get(trading_date, initial_benchmark)
|
|
normalized_benchmark = benchmark_value / initial_benchmark * request.initial_capital
|
|
|
|
equity_curve.append({
|
|
"portfolio_value": portfolio_value,
|
|
"benchmark_value": normalized_benchmark,
|
|
})
|
|
|
|
return self._compute_metrics(equity_curve)
|
|
|
|
def _compute_metrics(self, equity_curve: List[Dict]) -> Dict[str, float]:
|
|
portfolio_values = [Decimal(str(e["portfolio_value"])) for e in equity_curve]
|
|
benchmark_values = [Decimal(str(e["benchmark_value"])) for e in equity_curve]
|
|
metrics = MetricsCalculator.calculate_all(portfolio_values, benchmark_values)
|
|
return {
|
|
"total_return": float(metrics.total_return),
|
|
"cagr": float(metrics.cagr),
|
|
"mdd": float(metrics.mdd),
|
|
"sharpe_ratio": float(metrics.sharpe_ratio),
|
|
"volatility": float(metrics.volatility),
|
|
"benchmark_return": float(metrics.benchmark_return),
|
|
"excess_return": float(metrics.excess_return),
|
|
}
|
|
|
|
def _create_strategy(self, strategy_type: str, strategy_params: dict, top_n: int):
|
|
from app.services.strategy import (
|
|
MultiFactorStrategy,
|
|
QualityStrategy,
|
|
ValueMomentumStrategy,
|
|
)
|
|
from app.schemas.strategy import FactorWeights
|
|
|
|
if strategy_type == "multi_factor":
|
|
strategy = MultiFactorStrategy(self.db)
|
|
strategy._weights = FactorWeights(**strategy_params.get("weights", {}))
|
|
elif strategy_type == "quality":
|
|
strategy = QualityStrategy(self.db)
|
|
strategy._min_fscore = strategy_params.get("min_fscore", 7)
|
|
elif strategy_type == "value_momentum":
|
|
strategy = ValueMomentumStrategy(self.db)
|
|
strategy._value_weight = Decimal(
|
|
str(strategy_params.get("value_weight", 0.5))
|
|
)
|
|
strategy._momentum_weight = Decimal(
|
|
str(strategy_params.get("momentum_weight", 0.5))
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown strategy type: {strategy_type}")
|
|
return strategy
|
|
|
|
# --- Data loading helpers (mirrored from engines) ---
|
|
|
|
def _get_trading_days(self, start_date: date, end_date: date) -> List[date]:
|
|
from app.models.stock import Price
|
|
prices = (
|
|
self.db.query(Price.date)
|
|
.filter(Price.date >= start_date, Price.date <= end_date)
|
|
.distinct()
|
|
.order_by(Price.date)
|
|
.all()
|
|
)
|
|
return [p[0] for p in prices]
|
|
|
|
def _get_universe_tickers(self) -> List[str]:
|
|
from app.models.stock import Stock
|
|
stocks = (
|
|
self.db.query(Stock)
|
|
.filter(Stock.market_cap.isnot(None))
|
|
.order_by(Stock.market_cap.desc())
|
|
.limit(30)
|
|
.all()
|
|
)
|
|
return [s.ticker for s in stocks]
|
|
|
|
def _load_all_prices(self, tickers, start_date, end_date):
|
|
from app.models.stock import Price
|
|
return (
|
|
self.db.query(Price)
|
|
.filter(Price.ticker.in_(tickers))
|
|
.filter(Price.date >= start_date, Price.date <= end_date)
|
|
.all()
|
|
)
|
|
|
|
def _load_kospi_df(self, start_date, end_date):
|
|
import pandas as pd
|
|
from app.models.stock import Price
|
|
prices = (
|
|
self.db.query(Price)
|
|
.filter(Price.ticker == "069500")
|
|
.filter(Price.date >= start_date, Price.date <= end_date)
|
|
.order_by(Price.date)
|
|
.all()
|
|
)
|
|
if not prices:
|
|
return pd.DataFrame(columns=["close"])
|
|
data = [{"date": p.date, "close": float(p.close)} for p in prices]
|
|
return pd.DataFrame(data).set_index("date")
|
|
|
|
def _load_benchmark_prices(self, benchmark, start_date, end_date):
|
|
from app.models.stock import Price
|
|
prices = (
|
|
self.db.query(Price)
|
|
.filter(Price.ticker == "069500")
|
|
.filter(Price.date >= start_date, Price.date <= end_date)
|
|
.all()
|
|
)
|
|
return {p.date: p.close for p in prices}
|
|
|
|
def _build_stock_dfs(self, price_data, tickers):
|
|
import pandas as pd
|
|
ticker_rows = {t: [] for t in tickers}
|
|
for p in price_data:
|
|
if p.ticker in ticker_rows:
|
|
ticker_rows[p.ticker].append({
|
|
"date": p.date,
|
|
"open": float(p.open),
|
|
"high": float(p.high),
|
|
"low": float(p.low),
|
|
"close": float(p.close),
|
|
"volume": int(p.volume),
|
|
})
|
|
result = {}
|
|
for ticker, rows in ticker_rows.items():
|
|
if rows:
|
|
df = pd.DataFrame(rows).set_index("date").sort_index()
|
|
result[ticker] = df
|
|
return result
|
|
|
|
def _load_all_prices_by_date(self, start_date, end_date):
|
|
from app.models.stock import Price
|
|
prices = (
|
|
self.db.query(Price)
|
|
.filter(Price.date >= start_date, Price.date <= end_date)
|
|
.all()
|
|
)
|
|
result = {}
|
|
for p in prices:
|
|
if p.date not in result:
|
|
result[p.date] = {}
|
|
result[p.date][p.ticker] = p.close
|
|
return result
|
|
|
|
def _get_stock_names(self):
|
|
from app.models.stock import Stock
|
|
stocks = self.db.query(Stock.ticker, Stock.name).all()
|
|
return {s.ticker: s.name for s in stocks}
|
|
|
|
def _generate_rebalance_dates(self, start_date, end_date, period):
|
|
from dateutil.relativedelta import relativedelta
|
|
from app.models.backtest import RebalancePeriod
|
|
|
|
dates = []
|
|
current = start_date
|
|
if period == RebalancePeriod.MONTHLY:
|
|
delta = relativedelta(months=1)
|
|
elif period == RebalancePeriod.QUARTERLY:
|
|
delta = relativedelta(months=3)
|
|
elif period == RebalancePeriod.SEMI_ANNUAL:
|
|
delta = relativedelta(months=6)
|
|
else:
|
|
delta = relativedelta(years=1)
|
|
while current <= end_date:
|
|
dates.append(current)
|
|
current = current + delta
|
|
return dates
|