penti/backend/app/backtest/portfolio.py

223 lines
6.0 KiB
Python
Raw Normal View History

2026-01-31 23:30:51 +09:00
"""Portfolio management for backtesting."""
from dataclasses import dataclass, field
from typing import Dict, List
from decimal import Decimal
from datetime import datetime
@dataclass
class Position:
"""포지션 정보."""
ticker: str
quantity: Decimal
avg_price: Decimal
current_price: Decimal = Decimal("0")
@property
def market_value(self) -> Decimal:
"""현재 시장가치."""
return self.quantity * self.current_price
@property
def pnl(self) -> Decimal:
"""손익."""
return (self.current_price - self.avg_price) * self.quantity
@property
def pnl_pct(self) -> Decimal:
"""수익률 (%)."""
if self.avg_price == 0:
return Decimal("0")
return (self.current_price - self.avg_price) / self.avg_price * Decimal("100")
@dataclass
class Trade:
"""거래 정보."""
ticker: str
trade_date: datetime
action: str # 'buy' or 'sell'
quantity: Decimal
price: Decimal
commission: Decimal = Decimal("0")
@property
def total_amount(self) -> Decimal:
"""총 금액 (수수료 포함)."""
amount = self.quantity * self.price
if self.action == 'buy':
return amount + self.commission
else:
return amount - self.commission
@dataclass
class PortfolioSnapshot:
"""포트폴리오 스냅샷."""
date: datetime
cash: Decimal
positions_value: Decimal
total_value: Decimal
positions: Dict[str, Position] = field(default_factory=dict)
class BacktestPortfolio:
"""백테스트용 포트폴리오 관리 클래스."""
def __init__(self, initial_capital: Decimal, commission_rate: Decimal = Decimal("0.0015")):
"""
초기화.
Args:
initial_capital: 초기 자본금
commission_rate: 수수료율 (기본 0.15%)
"""
self.initial_capital = initial_capital
self.cash = initial_capital
self.commission_rate = commission_rate
self.positions: Dict[str, Position] = {}
self.trades: List[Trade] = []
self.snapshots: List[PortfolioSnapshot] = []
def buy(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool:
"""
매수.
Args:
ticker: 종목코드
quantity: 수량
price: 가격
trade_date: 거래일
Returns:
매수 성공 여부
"""
commission = quantity * price * self.commission_rate
total_cost = quantity * price + commission
if total_cost > self.cash:
return False
# 포지션 업데이트
if ticker in self.positions:
existing = self.positions[ticker]
total_quantity = existing.quantity + quantity
total_cost_basis = (existing.avg_price * existing.quantity) + (price * quantity)
new_avg_price = total_cost_basis / total_quantity
existing.quantity = total_quantity
existing.avg_price = new_avg_price
else:
self.positions[ticker] = Position(
ticker=ticker,
quantity=quantity,
avg_price=price,
current_price=price
)
# 현금 차감
self.cash -= total_cost
# 거래 기록
trade = Trade(
ticker=ticker,
trade_date=trade_date,
action='buy',
quantity=quantity,
price=price,
commission=commission
)
self.trades.append(trade)
return True
def sell(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool:
"""
매도.
Args:
ticker: 종목코드
quantity: 수량
price: 가격
trade_date: 거래일
Returns:
매도 성공 여부
"""
if ticker not in self.positions:
return False
position = self.positions[ticker]
if position.quantity < quantity:
return False
commission = quantity * price * self.commission_rate
total_proceeds = quantity * price - commission
# 포지션 업데이트
position.quantity -= quantity
if position.quantity == 0:
del self.positions[ticker]
# 현금 추가
self.cash += total_proceeds
# 거래 기록
trade = Trade(
ticker=ticker,
trade_date=trade_date,
action='sell',
quantity=quantity,
price=price,
commission=commission
)
self.trades.append(trade)
return True
def update_prices(self, prices: Dict[str, Decimal]) -> None:
"""
포지션 가격 업데이트.
Args:
prices: {ticker: price} 딕셔너리
"""
for ticker, position in self.positions.items():
if ticker in prices:
position.current_price = prices[ticker]
def get_total_value(self) -> Decimal:
"""총 포트폴리오 가치."""
positions_value = sum(pos.market_value for pos in self.positions.values())
return self.cash + positions_value
def get_positions_value(self) -> Decimal:
"""포지션 총 가치."""
return sum(pos.market_value for pos in self.positions.values())
def take_snapshot(self, date: datetime) -> PortfolioSnapshot:
"""
포트폴리오 스냅샷 생성.
Args:
date: 스냅샷 날짜
Returns:
포트폴리오 스냅샷
"""
positions_value = self.get_positions_value()
total_value = self.get_total_value()
snapshot = PortfolioSnapshot(
date=date,
cash=self.cash,
positions_value=positions_value,
total_value=total_value,
positions=self.positions.copy()
)
self.snapshots.append(snapshot)
return snapshot