"""Portfolio management for backtesting.""" from dataclasses import dataclass, field from typing import Dict, List from decimal import Decimal from datetime import datetime @dataclass class Position: """포지션 정보.""" ticker: str quantity: Decimal avg_price: Decimal current_price: Decimal = Decimal("0") @property def market_value(self) -> Decimal: """현재 시장가치.""" return self.quantity * self.current_price @property def pnl(self) -> Decimal: """손익.""" return (self.current_price - self.avg_price) * self.quantity @property def pnl_pct(self) -> Decimal: """수익률 (%).""" if self.avg_price == 0: return Decimal("0") return (self.current_price - self.avg_price) / self.avg_price * Decimal("100") @dataclass class Trade: """거래 정보.""" ticker: str trade_date: datetime action: str # 'buy' or 'sell' quantity: Decimal price: Decimal commission: Decimal = Decimal("0") @property def total_amount(self) -> Decimal: """총 금액 (수수료 포함).""" amount = self.quantity * self.price if self.action == 'buy': return amount + self.commission else: return amount - self.commission @dataclass class PortfolioSnapshot: """포트폴리오 스냅샷.""" date: datetime cash: Decimal positions_value: Decimal total_value: Decimal positions: Dict[str, Position] = field(default_factory=dict) class BacktestPortfolio: """백테스트용 포트폴리오 관리 클래스.""" def __init__(self, initial_capital: Decimal, commission_rate: Decimal = Decimal("0.0015")): """ 초기화. Args: initial_capital: 초기 자본금 commission_rate: 수수료율 (기본 0.15%) """ self.initial_capital = initial_capital self.cash = initial_capital self.commission_rate = commission_rate self.positions: Dict[str, Position] = {} self.trades: List[Trade] = [] self.snapshots: List[PortfolioSnapshot] = [] def buy(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool: """ 매수. Args: ticker: 종목코드 quantity: 수량 price: 가격 trade_date: 거래일 Returns: 매수 성공 여부 """ commission = quantity * price * self.commission_rate total_cost = quantity * price + commission if total_cost > self.cash: return False # 포지션 업데이트 if ticker in self.positions: existing = self.positions[ticker] total_quantity = existing.quantity + quantity total_cost_basis = (existing.avg_price * existing.quantity) + (price * quantity) new_avg_price = total_cost_basis / total_quantity existing.quantity = total_quantity existing.avg_price = new_avg_price else: self.positions[ticker] = Position( ticker=ticker, quantity=quantity, avg_price=price, current_price=price ) # 현금 차감 self.cash -= total_cost # 거래 기록 trade = Trade( ticker=ticker, trade_date=trade_date, action='buy', quantity=quantity, price=price, commission=commission ) self.trades.append(trade) return True def sell(self, ticker: str, quantity: Decimal, price: Decimal, trade_date: datetime) -> bool: """ 매도. Args: ticker: 종목코드 quantity: 수량 price: 가격 trade_date: 거래일 Returns: 매도 성공 여부 """ if ticker not in self.positions: return False position = self.positions[ticker] if position.quantity < quantity: return False commission = quantity * price * self.commission_rate total_proceeds = quantity * price - commission # 포지션 업데이트 position.quantity -= quantity if position.quantity == 0: del self.positions[ticker] # 현금 추가 self.cash += total_proceeds # 거래 기록 trade = Trade( ticker=ticker, trade_date=trade_date, action='sell', quantity=quantity, price=price, commission=commission ) self.trades.append(trade) return True def update_prices(self, prices: Dict[str, Decimal]) -> None: """ 포지션 가격 업데이트. Args: prices: {ticker: price} 딕셔너리 """ for ticker, position in self.positions.items(): if ticker in prices: position.current_price = prices[ticker] def get_total_value(self) -> Decimal: """총 포트폴리오 가치.""" positions_value = sum(pos.market_value for pos in self.positions.values()) return self.cash + positions_value def get_positions_value(self) -> Decimal: """포지션 총 가치.""" return sum(pos.market_value for pos in self.positions.values()) def take_snapshot(self, date: datetime) -> PortfolioSnapshot: """ 포트폴리오 스냅샷 생성. Args: date: 스냅샷 날짜 Returns: 포트폴리오 스냅샷 """ positions_value = self.get_positions_value() total_value = self.get_total_value() snapshot = PortfolioSnapshot( date=date, cash=self.cash, positions_value=positions_value, total_value=total_value, positions=self.positions.copy() ) self.snapshots.append(snapshot) return snapshot