135 lines
4.2 KiB
Python
Raw Normal View History

2026-01-31 23:30:51 +09:00
"""Momentum Strategy (12M Return + K-Ratio)."""
from typing import List, Dict
from decimal import Decimal
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
import pandas as pd
import numpy as np
import statsmodels.api as sm
from app.strategies.base import BaseStrategy
from app.utils.data_helpers import (
get_ticker_list,
get_price_data,
get_prices_on_date
)
class MomentumStrategy(BaseStrategy):
"""
모멘텀 전략.
- 12개월 수익률
- K-Ratio (모멘텀의 꾸준함)
"""
def __init__(self, config: Dict = None):
"""
초기화.
Args:
config: 전략 설정
- count: 선정 종목 (기본 20)
- use_k_ratio: K-Ratio 사용 여부 (기본 True)
"""
super().__init__(config)
self.count = config.get('count', 20)
self.use_k_ratio = config.get('use_k_ratio', True)
def select_stocks(self, rebal_date: datetime, db_session: Session) -> List[str]:
"""
종목 선정.
Args:
rebal_date: 리밸런싱 날짜
db_session: 데이터베이스 세션
Returns:
선정된 종목 코드 리스트
"""
try:
# 1. 종목 리스트 조회
ticker_list = get_ticker_list(db_session)
if ticker_list.empty:
return []
tickers = ticker_list['종목코드'].tolist()
# 2. 12개월 가격 데이터 조회
start_date = rebal_date - timedelta(days=365)
price_list = get_price_data(db_session, tickers, start_date, rebal_date)
if price_list.empty:
return []
price_pivot = price_list.pivot(index='날짜', columns='종목코드', values='종가')
# 3. 12개월 수익률 계산
ret_list = pd.DataFrame(
data=(price_pivot.iloc[-1] / price_pivot.iloc[0]) - 1,
columns=['return']
)
data_bind = ticker_list[['종목코드', '종목명']].merge(
ret_list, how='left', on='종목코드'
)
if self.use_k_ratio:
# 4. K-Ratio 계산
ret = price_pivot.pct_change().iloc[1:]
ret_cum = np.log(1 + ret).cumsum()
x = np.array(range(len(ret)))
k_ratio = {}
for ticker in tickers:
try:
if ticker in price_pivot.columns:
y = ret_cum[ticker]
reg = sm.OLS(y, x).fit()
res = float(reg.params / reg.bse)
k_ratio[ticker] = res
except:
k_ratio[ticker] = np.nan
k_ratio_bind = pd.DataFrame.from_dict(
k_ratio, orient='index'
).reset_index()
k_ratio_bind.columns = ['종목코드', 'K_ratio']
# 5. K-Ratio 병합 및 상위 종목 선정
data_bind = data_bind.merge(k_ratio_bind, how='left', on='종목코드')
k_ratio_rank = data_bind['K_ratio'].rank(axis=0, ascending=False)
momentum_top = data_bind[k_ratio_rank <= self.count]
return momentum_top['종목코드'].tolist()
else:
# 단순 12개월 수익률 기준 상위 종목
momentum_rank = data_bind['return'].rank(axis=0, ascending=False)
momentum_top = data_bind[momentum_rank <= self.count]
return momentum_top['종목코드'].tolist()
except Exception as e:
print(f"Momentum 종목 선정 오류: {e}")
return []
def get_prices(
self,
tickers: List[str],
date: datetime,
db_session: Session
) -> Dict[str, Decimal]:
"""
종목 가격 조회.
Args:
tickers: 종목 코드 리스트
date: 조회 날짜
db_session: 데이터베이스 세션
Returns:
{ticker: price} 딕셔너리
"""
return get_prices_on_date(db_session, tickers, date)