diff --git a/backend/app/services/rebalance.py b/backend/app/services/rebalance.py new file mode 100644 index 0000000..d79f83a --- /dev/null +++ b/backend/app/services/rebalance.py @@ -0,0 +1,154 @@ +""" +Rebalancing calculation service. +""" +from decimal import Decimal +from typing import List, Dict, Optional + +from sqlalchemy.orm import Session + +from app.models.portfolio import Portfolio +from app.models.stock import Stock, ETF, ETFPrice +from app.schemas.portfolio import RebalanceItem, RebalanceResponse, RebalanceSimulationResponse + + +class RebalanceService: + """Service for calculating portfolio rebalancing.""" + + def __init__(self, db: Session): + self.db = db + + def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]: + """Get current prices for tickers from database.""" + prices = {} + + # Check stocks + stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() + for stock in stocks: + if stock.close_price: + prices[stock.ticker] = Decimal(str(stock.close_price)) + + # Check ETFs for missing tickers + missing = [t for t in tickers if t not in prices] + if missing: + # Get latest ETF prices + from sqlalchemy import func + subq = ( + self.db.query( + ETFPrice.ticker, + func.max(ETFPrice.date).label('max_date') + ) + .filter(ETFPrice.ticker.in_(missing)) + .group_by(ETFPrice.ticker) + .subquery() + ) + etf_prices = ( + self.db.query(ETFPrice) + .join(subq, (ETFPrice.ticker == subq.c.ticker) & (ETFPrice.date == subq.c.max_date)) + .all() + ) + for ep in etf_prices: + prices[ep.ticker] = Decimal(str(ep.close)) + + return prices + + def get_stock_names(self, tickers: List[str]) -> Dict[str, str]: + """Get stock names for tickers.""" + names = {} + stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all() + for stock in stocks: + names[stock.ticker] = stock.name + + # Also check ETFs + missing = [t for t in tickers if t not in names] + if missing: + etfs = self.db.query(ETF).filter(ETF.ticker.in_(missing)).all() + for etf in etfs: + names[etf.ticker] = etf.name + + return names + + def calculate_rebalance( + self, + portfolio: Portfolio, + additional_amount: Optional[Decimal] = None, + ) -> RebalanceResponse | RebalanceSimulationResponse: + """Calculate rebalancing for a portfolio.""" + targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets} + holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings} + + all_tickers = list(set(targets.keys()) | set(holdings.keys())) + current_prices = self.get_current_prices(all_tickers) + stock_names = self.get_stock_names(all_tickers) + + # Calculate current values + current_values = {} + for ticker, (quantity, _) in holdings.items(): + price = current_prices.get(ticker, Decimal("0")) + current_values[ticker] = price * quantity + + current_total = sum(current_values.values()) + + if additional_amount: + new_total = current_total + additional_amount + else: + new_total = current_total + + # Calculate rebalance items + items = [] + for ticker in all_tickers: + target_ratio = targets.get(ticker, Decimal("0")) + current_value = current_values.get(ticker, Decimal("0")) + current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] + current_price = current_prices.get(ticker, Decimal("0")) + + if new_total > 0: + current_ratio = (current_value / new_total * 100).quantize(Decimal("0.01")) + else: + current_ratio = Decimal("0") + + target_value = (new_total * target_ratio / 100).quantize(Decimal("0.01")) + diff_value = target_value - current_value + + if current_price > 0: + diff_quantity = int(diff_value / current_price) + else: + diff_quantity = 0 + + if diff_quantity > 0: + action = "buy" + elif diff_quantity < 0: + action = "sell" + else: + action = "hold" + + items.append(RebalanceItem( + ticker=ticker, + name=stock_names.get(ticker), + target_ratio=target_ratio, + current_ratio=current_ratio, + current_quantity=current_quantity, + current_value=current_value, + target_value=target_value, + diff_value=diff_value, + diff_quantity=diff_quantity, + action=action, + )) + + # Sort by action priority (buy first, then sell, then hold) + action_order = {"buy": 0, "sell": 1, "hold": 2} + items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity))) + + if additional_amount: + return RebalanceSimulationResponse( + portfolio_id=portfolio.id, + current_total=current_total, + additional_amount=additional_amount, + new_total=new_total, + items=items, + ) + else: + return RebalanceResponse( + portfolio_id=portfolio.id, + total_value=current_total, + items=items, + )