223 lines
6.0 KiB
Python
223 lines
6.0 KiB
Python
|
|
"""Portfolio management for backtesting."""
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Dict, List
|
||
|
|
from decimal import Decimal
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Position:
|
||
|
|
"""포지션 정보."""
|
||
|
|
|
||
|
|
ticker: str
|
||
|
|
quantity: Decimal
|
||
|
|
avg_price: Decimal
|
||
|
|
current_price: Decimal = Decimal("0")
|
||
|
|
|
||
|
|
@property
|
||
|
|
def market_value(self) -> Decimal:
|
||
|
|
"""현재 시장가치."""
|
||
|
|
return self.quantity * self.current_price
|
||
|
|
|
||
|
|
@property
|
||
|
|
def pnl(self) -> Decimal:
|
||
|
|
"""손익."""
|
||
|
|
return (self.current_price - self.avg_price) * self.quantity
|
||
|
|
|
||
|
|
@property
|
||
|
|
def pnl_pct(self) -> Decimal:
|
||
|
|
"""수익률 (%)."""
|
||
|
|
if self.avg_price == 0:
|
||
|
|
return Decimal("0")
|
||
|
|
return (self.current_price - self.avg_price) / self.avg_price * Decimal("100")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Trade:
|
||
|
|
"""거래 정보."""
|
||
|
|
|
||
|
|
ticker: str
|
||
|
|
trade_date: datetime
|
||
|
|
action: str # 'buy' or 'sell'
|
||
|
|
quantity: Decimal
|
||
|
|
price: Decimal
|
||
|
|
commission: Decimal = Decimal("0")
|
||
|
|
|
||
|
|
@property
|
||
|
|
def total_amount(self) -> Decimal:
|
||
|
|
"""총 금액 (수수료 포함)."""
|
||
|
|
amount = self.quantity * self.price
|
||
|
|
if self.action == 'buy':
|
||
|
|
return amount + self.commission
|
||
|
|
else:
|
||
|
|
return amount - self.commission
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class PortfolioSnapshot:
|
||
|
|
"""포트폴리오 스냅샷."""
|
||
|
|
|
||
|
|
date: datetime
|
||
|
|
cash: Decimal
|
||
|
|
positions_value: Decimal
|
||
|
|
total_value: Decimal
|
||
|
|
positions: Dict[str, Position] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
class BacktestPortfolio:
|
||
|
|
"""백테스트용 포트폴리오 관리 클래스."""
|
||
|
|
|
||
|
|
def __init__(self, initial_capital: Decimal, commission_rate: Decimal = Decimal("0.0015")):
|
||
|
|
"""
|
||
|
|
초기화.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
initial_capital: 초기 자본금
|
||
|
|
commission_rate: 수수료율 (기본 0.15%)
|
||
|
|
"""
|
||
|
|
self.initial_capital = initial_capital
|
||
|
|
self.cash = initial_capital
|
||
|
|
self.commission_rate = commission_rate
|
||
|
|
self.positions: Dict[str, Position] = {}
|
||
|
|
self.trades: List[Trade] = []
|
||
|
|
self.snapshots: List[PortfolioSnapshot] = []
|
||
|
|
|
||
|
|
def buy(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool:
|
||
|
|
"""
|
||
|
|
매수.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ticker: 종목코드
|
||
|
|
quantity: 수량
|
||
|
|
price: 가격
|
||
|
|
trade_date: 거래일
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
매수 성공 여부
|
||
|
|
"""
|
||
|
|
commission = quantity * price * self.commission_rate
|
||
|
|
total_cost = quantity * price + commission
|
||
|
|
|
||
|
|
if total_cost > self.cash:
|
||
|
|
return False
|
||
|
|
|
||
|
|
# 포지션 업데이트
|
||
|
|
if ticker in self.positions:
|
||
|
|
existing = self.positions[ticker]
|
||
|
|
total_quantity = existing.quantity + quantity
|
||
|
|
total_cost_basis = (existing.avg_price * existing.quantity) + (price * quantity)
|
||
|
|
new_avg_price = total_cost_basis / total_quantity
|
||
|
|
|
||
|
|
existing.quantity = total_quantity
|
||
|
|
existing.avg_price = new_avg_price
|
||
|
|
else:
|
||
|
|
self.positions[ticker] = Position(
|
||
|
|
ticker=ticker,
|
||
|
|
quantity=quantity,
|
||
|
|
avg_price=price,
|
||
|
|
current_price=price
|
||
|
|
)
|
||
|
|
|
||
|
|
# 현금 차감
|
||
|
|
self.cash -= total_cost
|
||
|
|
|
||
|
|
# 거래 기록
|
||
|
|
trade = Trade(
|
||
|
|
ticker=ticker,
|
||
|
|
trade_date=trade_date,
|
||
|
|
action='buy',
|
||
|
|
quantity=quantity,
|
||
|
|
price=price,
|
||
|
|
commission=commission
|
||
|
|
)
|
||
|
|
self.trades.append(trade)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def sell(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool:
|
||
|
|
"""
|
||
|
|
매도.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ticker: 종목코드
|
||
|
|
quantity: 수량
|
||
|
|
price: 가격
|
||
|
|
trade_date: 거래일
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
매도 성공 여부
|
||
|
|
"""
|
||
|
|
if ticker not in self.positions:
|
||
|
|
return False
|
||
|
|
|
||
|
|
position = self.positions[ticker]
|
||
|
|
if position.quantity < quantity:
|
||
|
|
return False
|
||
|
|
|
||
|
|
commission = quantity * price * self.commission_rate
|
||
|
|
total_proceeds = quantity * price - commission
|
||
|
|
|
||
|
|
# 포지션 업데이트
|
||
|
|
position.quantity -= quantity
|
||
|
|
if position.quantity == 0:
|
||
|
|
del self.positions[ticker]
|
||
|
|
|
||
|
|
# 현금 추가
|
||
|
|
self.cash += total_proceeds
|
||
|
|
|
||
|
|
# 거래 기록
|
||
|
|
trade = Trade(
|
||
|
|
ticker=ticker,
|
||
|
|
trade_date=trade_date,
|
||
|
|
action='sell',
|
||
|
|
quantity=quantity,
|
||
|
|
price=price,
|
||
|
|
commission=commission
|
||
|
|
)
|
||
|
|
self.trades.append(trade)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def update_prices(self, prices: Dict[str, Decimal]) -> None:
|
||
|
|
"""
|
||
|
|
포지션 가격 업데이트.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
prices: {ticker: price} 딕셔너리
|
||
|
|
"""
|
||
|
|
for ticker, position in self.positions.items():
|
||
|
|
if ticker in prices:
|
||
|
|
position.current_price = prices[ticker]
|
||
|
|
|
||
|
|
def get_total_value(self) -> Decimal:
|
||
|
|
"""총 포트폴리오 가치."""
|
||
|
|
positions_value = sum(pos.market_value for pos in self.positions.values())
|
||
|
|
return self.cash + positions_value
|
||
|
|
|
||
|
|
def get_positions_value(self) -> Decimal:
|
||
|
|
"""포지션 총 가치."""
|
||
|
|
return sum(pos.market_value for pos in self.positions.values())
|
||
|
|
|
||
|
|
def take_snapshot(self, date: datetime) -> PortfolioSnapshot:
|
||
|
|
"""
|
||
|
|
포트폴리오 스냅샷 생성.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
date: 스냅샷 날짜
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
포트폴리오 스냅샷
|
||
|
|
"""
|
||
|
|
positions_value = self.get_positions_value()
|
||
|
|
total_value = self.get_total_value()
|
||
|
|
|
||
|
|
snapshot = PortfolioSnapshot(
|
||
|
|
date=date,
|
||
|
|
cash=self.cash,
|
||
|
|
positions_value=positions_value,
|
||
|
|
total_value=total_value,
|
||
|
|
positions=self.positions.copy()
|
||
|
|
)
|
||
|
|
self.snapshots.append(snapshot)
|
||
|
|
return snapshot
|