""" Rebalancing calculation service. """ from decimal import Decimal from typing import List, Dict, Optional from sqlalchemy.orm import Session from app.models.portfolio import Portfolio from app.models.stock import Stock, ETF, ETFPrice from app.schemas.portfolio import RebalanceItem, RebalanceResponse, RebalanceSimulationResponse class RebalanceService: """Service for calculating portfolio rebalancing.""" def __init__(self, db: Session): self.db = db def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]: """Get current prices for tickers from database.""" prices = {} # Check stocks stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() for stock in stocks: if stock.close_price: prices[stock.ticker] = Decimal(str(stock.close_price)) # Check ETFs for missing tickers missing = [t for t in tickers if t not in prices] if missing: # Get latest ETF prices from sqlalchemy import func subq = ( self.db.query( ETFPrice.ticker, func.max(ETFPrice.date).label('max_date') ) .filter(ETFPrice.ticker.in_(missing)) .group_by(ETFPrice.ticker) .subquery() ) etf_prices = ( self.db.query(ETFPrice) .join(subq, (ETFPrice.ticker == subq.c.ticker) & (ETFPrice.date == subq.c.max_date)) .all() ) for ep in etf_prices: prices[ep.ticker] = Decimal(str(ep.close)) return prices def get_stock_names(self, tickers: List[str]) -> Dict[str, str]: """Get stock names for tickers.""" names = {} stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() for stock in stocks: names[stock.ticker] = stock.name # Also check ETFs missing = [t for t in tickers if t not in names] if missing: etfs = self.db.query(ETF).filter(ETF.ticker.in_(missing)).all() for etf in etfs: names[etf.ticker] = etf.name return names def calculate_rebalance( self, portfolio: Portfolio, additional_amount: Optional[Decimal] = None, ) -> RebalanceResponse | RebalanceSimulationResponse: """Calculate rebalancing for a portfolio.""" targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets} holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings} all_tickers = list(set(targets.keys()) | set(holdings.keys())) current_prices = self.get_current_prices(all_tickers) stock_names = self.get_stock_names(all_tickers) # Calculate current values current_values = {} for ticker, (quantity, _) in holdings.items(): price = current_prices.get(ticker, Decimal("0")) current_values[ticker] = price * quantity current_total = sum(current_values.values()) if additional_amount: new_total = current_total + additional_amount else: new_total = current_total # Calculate rebalance items items = [] for ticker in all_tickers: target_ratio = targets.get(ticker, Decimal("0")) current_value = current_values.get(ticker, Decimal("0")) current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] current_price = current_prices.get(ticker, Decimal("0")) if new_total > 0: current_ratio = (current_value / new_total * 100).quantize(Decimal("0.01")) else: current_ratio = Decimal("0") target_value = (new_total * target_ratio / 100).quantize(Decimal("0.01")) diff_value = target_value - current_value if current_price > 0: diff_quantity = int(diff_value / current_price) else: diff_quantity = 0 if diff_quantity > 0: action = "buy" elif diff_quantity < 0: action = "sell" else: action = "hold" items.append(RebalanceItem( ticker=ticker, name=stock_names.get(ticker), target_ratio=target_ratio, current_ratio=current_ratio, current_quantity=current_quantity, current_value=current_value, target_value=target_value, diff_value=diff_value, diff_quantity=diff_quantity, action=action, )) # Sort by action priority (buy first, then sell, then hold) action_order = {"buy": 0, "sell": 1, "hold": 2} items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity))) if additional_amount: return RebalanceSimulationResponse( portfolio_id=portfolio.id, current_total=current_total, additional_amount=additional_amount, new_total=new_total, items=items, ) else: return RebalanceResponse( portfolio_id=portfolio.id, total_value=current_total, items=items, ) def _get_prev_month_prices(self, portfolio_id: int, tickers: List[str]) -> Dict[str, Decimal]: """Get prices from the most recent snapshot for change calculation.""" from app.models.portfolio import PortfolioSnapshot, SnapshotHolding latest_snapshot = ( self.db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date.desc()) .first() ) if not latest_snapshot: return {} prices = {} for sh in latest_snapshot.holdings: if sh.ticker in tickers: prices[sh.ticker] = Decimal(str(sh.price)) return prices def _get_start_prices(self, portfolio_id: int, tickers: List[str]) -> Dict[str, Decimal]: """Get prices from the earliest snapshot for change calculation.""" from app.models.portfolio import PortfolioSnapshot, SnapshotHolding earliest_snapshot = ( self.db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date.asc()) .first() ) if not earliest_snapshot: return {} prices = {} for sh in earliest_snapshot.holdings: if sh.ticker in tickers: prices[sh.ticker] = Decimal(str(sh.price)) return prices def calculate_with_prices( self, portfolio: "Portfolio", strategy: str, manual_prices: Optional[Dict[str, Decimal]] = None, additional_amount: Optional[Decimal] = None, min_trade_amount: Optional[Decimal] = None, ): """Calculate rebalance with optional manual prices and strategy selection.""" from app.schemas.portfolio import RebalanceCalculateItem, RebalanceCalculateResponse targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets} holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings} all_tickers = list(set(targets.keys()) | set(holdings.keys())) # Use manual prices if provided, else fall back to DB if manual_prices: current_prices = {t: manual_prices.get(t, Decimal("0")) for t in all_tickers} else: current_prices = self.get_current_prices(all_tickers) stock_names = self.get_stock_names(all_tickers) # Calculate current values current_values = {} for ticker in all_tickers: qty = holdings.get(ticker, (0, Decimal("0")))[0] price = current_prices.get(ticker, Decimal("0")) current_values[ticker] = price * qty total_assets = sum(current_values.values()) # Get snapshot prices for change calculation prev_prices = self._get_prev_month_prices(portfolio.id, all_tickers) start_prices = self._get_start_prices(portfolio.id, all_tickers) if strategy == "full_rebalance": items = self._calc_full_rebalance( all_tickers, targets, holdings, current_prices, current_values, total_assets, stock_names, prev_prices, start_prices, ) else: # additional_buy items = self._calc_additional_buy( all_tickers, targets, holdings, current_prices, current_values, total_assets, additional_amount, stock_names, prev_prices, start_prices, ) # Filter out trades below min_trade_amount if min_trade_amount and min_trade_amount > 0: for item in items: if item.action != "hold": trade_value = abs(item.diff_quantity) * item.current_price if trade_value < min_trade_amount: item.diff_quantity = 0 item.action = "hold" if strategy == "full_rebalance": return RebalanceCalculateResponse( portfolio_id=portfolio.id, total_assets=total_assets, items=items, ) else: return RebalanceCalculateResponse( portfolio_id=portfolio.id, total_assets=total_assets, available_to_buy=additional_amount, items=items, ) def _calc_change_pct( self, current_price: Decimal, ref_price: Optional[Decimal] ) -> Optional[Decimal]: if ref_price and ref_price > 0: return ((current_price - ref_price) / ref_price * 100).quantize(Decimal("0.01")) return None def _calc_full_rebalance( self, all_tickers, targets, holdings, current_prices, current_values, total_assets, stock_names, prev_prices, start_prices, ): from app.schemas.portfolio import RebalanceCalculateItem items = [] for ticker in all_tickers: target_ratio = targets.get(ticker, Decimal("0")) current_value = current_values.get(ticker, Decimal("0")) current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] current_price = current_prices.get(ticker, Decimal("0")) if total_assets > 0: current_ratio = (current_value / total_assets * 100).quantize(Decimal("0.01")) else: current_ratio = Decimal("0") target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) diff_ratio = (target_ratio - current_ratio).quantize(Decimal("0.01")) if current_price > 0: diff_quantity = int((target_value - current_value) / current_price) else: diff_quantity = 0 if diff_quantity > 0: action = "buy" elif diff_quantity < 0: action = "sell" else: action = "hold" items.append(RebalanceCalculateItem( ticker=ticker, name=stock_names.get(ticker), target_ratio=target_ratio, current_ratio=current_ratio, current_quantity=current_quantity, current_value=current_value, current_price=current_price, target_value=target_value, diff_ratio=diff_ratio, diff_quantity=diff_quantity, action=action, change_vs_prev_month=self._calc_change_pct( current_price, prev_prices.get(ticker) ), change_vs_start=self._calc_change_pct( current_price, start_prices.get(ticker) ), )) action_order = {"buy": 0, "sell": 1, "hold": 2} items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity))) return items def _calc_additional_buy( self, all_tickers, targets, holdings, current_prices, current_values, total_assets, additional_amount, stock_names, prev_prices, start_prices, ): from app.schemas.portfolio import RebalanceCalculateItem remaining = additional_amount or Decimal("0") # Calculate change vs prev month for sorting ticker_changes = {} for ticker in all_tickers: cp = current_prices.get(ticker, Decimal("0")) pp = prev_prices.get(ticker) if pp and pp > 0: ticker_changes[ticker] = ((cp - pp) / pp * 100).quantize(Decimal("0.01")) else: ticker_changes[ticker] = Decimal("0") # Sort by drop (most negative first) sorted_tickers = sorted(all_tickers, key=lambda t: ticker_changes.get(t, Decimal("0"))) buy_quantities = {t: 0 for t in all_tickers} # Allocate additional amount to tickers sorted by drop for ticker in sorted_tickers: if remaining <= 0: break target_ratio = targets.get(ticker, Decimal("0")) current_value = current_values.get(ticker, Decimal("0")) current_price = current_prices.get(ticker, Decimal("0")) if current_price <= 0: continue target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) deficit = target_value - current_value if deficit <= 0: continue buy_amount = min(deficit, remaining) buy_qty = int(buy_amount / current_price) if buy_qty <= 0: continue actual_cost = current_price * buy_qty buy_quantities[ticker] = buy_qty remaining -= actual_cost # Build response items items = [] for ticker in all_tickers: target_ratio = targets.get(ticker, Decimal("0")) current_value = current_values.get(ticker, Decimal("0")) current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] current_price = current_prices.get(ticker, Decimal("0")) if total_assets > 0: current_ratio = (current_value / total_assets * 100).quantize(Decimal("0.01")) else: current_ratio = Decimal("0") target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) diff_ratio = (target_ratio - current_ratio).quantize(Decimal("0.01")) diff_quantity = buy_quantities.get(ticker, 0) items.append(RebalanceCalculateItem( ticker=ticker, name=stock_names.get(ticker), target_ratio=target_ratio, current_ratio=current_ratio, current_quantity=current_quantity, current_value=current_value, current_price=current_price, target_value=target_value, diff_ratio=diff_ratio, diff_quantity=diff_quantity, action="buy" if diff_quantity > 0 else "hold", change_vs_prev_month=self._calc_change_pct( current_price, prev_prices.get(ticker) ), change_vs_start=self._calc_change_pct( current_price, start_prices.get(ticker) ), )) items.sort(key=lambda x: (-x.diff_quantity, x.ticker)) return items