""" Rebalancing calculation service. """ from decimal import Decimal from typing import List, Dict, Optional from sqlalchemy.orm import Session from app.models.portfolio import Portfolio from app.models.stock import Stock, ETF, ETFPrice from app.schemas.portfolio import RebalanceItem, RebalanceResponse, RebalanceSimulationResponse class RebalanceService: """Service for calculating portfolio rebalancing.""" def __init__(self, db: Session): self.db = db def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]: """Get current prices for tickers from database.""" prices = {} # Check stocks stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() for stock in stocks: if stock.close_price: prices[stock.ticker] = Decimal(str(stock.close_price)) # Check ETFs for missing tickers missing = [t for t in tickers if t not in prices] if missing: # Get latest ETF prices from sqlalchemy import func subq = ( self.db.query( ETFPrice.ticker, func.max(ETFPrice.date).label('max_date') ) .filter(ETFPrice.ticker.in_(missing)) .group_by(ETFPrice.ticker) .subquery() ) etf_prices = ( self.db.query(ETFPrice) .join(subq, (ETFPrice.ticker == subq.c.ticker) & (ETFPrice.date == subq.c.max_date)) .all() ) for ep in etf_prices: prices[ep.ticker] = Decimal(str(ep.close)) return prices def get_stock_names(self, tickers: List[str]) -> Dict[str, str]: """Get stock names for tickers.""" names = {} stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() for stock in stocks: names[stock.ticker] = stock.name # Also check ETFs missing = [t for t in tickers if t not in names] if missing: etfs = self.db.query(ETF).filter(ETF.ticker.in_(missing)).all() for etf in etfs: names[etf.ticker] = etf.name return names def calculate_rebalance( self, portfolio: Portfolio, additional_amount: Optional[Decimal] = None, ) -> RebalanceResponse | RebalanceSimulationResponse: """Calculate rebalancing for a portfolio.""" targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets} holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings} all_tickers = list(set(targets.keys()) | set(holdings.keys())) current_prices = self.get_current_prices(all_tickers) stock_names = self.get_stock_names(all_tickers) # Calculate current values current_values = {} for ticker, (quantity, _) in holdings.items(): price = current_prices.get(ticker, Decimal("0")) current_values[ticker] = price * quantity current_total = sum(current_values.values()) if additional_amount: new_total = current_total + additional_amount else: new_total = current_total # Calculate rebalance items items = [] for ticker in all_tickers: target_ratio = targets.get(ticker, Decimal("0")) current_value = current_values.get(ticker, Decimal("0")) current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] current_price = current_prices.get(ticker, Decimal("0")) if new_total > 0: current_ratio = (current_value / new_total * 100).quantize(Decimal("0.01")) else: current_ratio = Decimal("0") target_value = (new_total * target_ratio / 100).quantize(Decimal("0.01")) diff_value = target_value - current_value if current_price > 0: diff_quantity = int(diff_value / current_price) else: diff_quantity = 0 if diff_quantity > 0: action = "buy" elif diff_quantity < 0: action = "sell" else: action = "hold" items.append(RebalanceItem( ticker=ticker, name=stock_names.get(ticker), target_ratio=target_ratio, current_ratio=current_ratio, current_quantity=current_quantity, current_value=current_value, target_value=target_value, diff_value=diff_value, diff_quantity=diff_quantity, action=action, )) # Sort by action priority (buy first, then sell, then hold) action_order = {"buy": 0, "sell": 1, "hold": 2} items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity))) if additional_amount: return RebalanceSimulationResponse( portfolio_id=portfolio.id, current_total=current_total, additional_amount=additional_amount, new_total=new_total, items=items, ) else: return RebalanceResponse( portfolio_id=portfolio.id, total_value=current_total, items=items, )