diff --git a/backend/app/services/rebalance.py b/backend/app/services/rebalance.py index d79f83a..ae2bc53 100644 --- a/backend/app/services/rebalance.py +++ b/backend/app/services/rebalance.py @@ -152,3 +152,249 @@ class RebalanceService: total_value=current_total, items=items, ) + + def _get_prev_month_prices(self, portfolio_id: int, tickers: List[str]) -> Dict[str, Decimal]: + """Get prices from the most recent snapshot for change calculation.""" + from app.models.portfolio import PortfolioSnapshot, SnapshotHolding + latest_snapshot = ( + self.db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date.desc()) + .first() + ) + if not latest_snapshot: + return {} + prices = {} + for sh in latest_snapshot.holdings: + if sh.ticker in tickers: + prices[sh.ticker] = Decimal(str(sh.price)) + return prices + + def _get_start_prices(self, portfolio_id: int, tickers: List[str]) -> Dict[str, Decimal]: + """Get prices from the earliest snapshot for change calculation.""" + from app.models.portfolio import PortfolioSnapshot, SnapshotHolding + earliest_snapshot = ( + self.db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date.asc()) + .first() + ) + if not earliest_snapshot: + return {} + prices = {} + for sh in earliest_snapshot.holdings: + if sh.ticker in tickers: + prices[sh.ticker] = Decimal(str(sh.price)) + return prices + + def calculate_with_prices( + self, + portfolio: "Portfolio", + strategy: str, + manual_prices: Optional[Dict[str, Decimal]] = None, + additional_amount: Optional[Decimal] = None, + ): + """Calculate rebalance with optional manual prices and strategy selection.""" + from app.schemas.portfolio import RebalanceCalculateItem, RebalanceCalculateResponse + + targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets} + holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings} + all_tickers = list(set(targets.keys()) | set(holdings.keys())) + + # Use manual prices if provided, else fall back to DB + if manual_prices: + current_prices = {t: manual_prices.get(t, Decimal("0")) for t in all_tickers} + else: + current_prices = self.get_current_prices(all_tickers) + + stock_names = self.get_stock_names(all_tickers) + + # Calculate current values + current_values = {} + for ticker in all_tickers: + qty = holdings.get(ticker, (0, Decimal("0")))[0] + price = current_prices.get(ticker, Decimal("0")) + current_values[ticker] = price * qty + + total_assets = sum(current_values.values()) + + # Get snapshot prices for change calculation + prev_prices = self._get_prev_month_prices(portfolio.id, all_tickers) + start_prices = self._get_start_prices(portfolio.id, all_tickers) + + if strategy == "full_rebalance": + items = self._calc_full_rebalance( + all_tickers, targets, holdings, current_prices, + current_values, total_assets, stock_names, + prev_prices, start_prices, + ) + return RebalanceCalculateResponse( + portfolio_id=portfolio.id, + total_assets=total_assets, + items=items, + ) + else: # additional_buy + items = self._calc_additional_buy( + all_tickers, targets, holdings, current_prices, + current_values, total_assets, additional_amount, + stock_names, prev_prices, start_prices, + ) + return RebalanceCalculateResponse( + portfolio_id=portfolio.id, + total_assets=total_assets, + available_to_buy=additional_amount, + items=items, + ) + + def _calc_change_pct( + self, current_price: Decimal, ref_price: Optional[Decimal] + ) -> Optional[Decimal]: + if ref_price and ref_price > 0: + return ((current_price - ref_price) / ref_price * 100).quantize(Decimal("0.01")) + return None + + def _calc_full_rebalance( + self, all_tickers, targets, holdings, current_prices, + current_values, total_assets, stock_names, + prev_prices, start_prices, + ): + from app.schemas.portfolio import RebalanceCalculateItem + + items = [] + for ticker in all_tickers: + target_ratio = targets.get(ticker, Decimal("0")) + current_value = current_values.get(ticker, Decimal("0")) + current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] + current_price = current_prices.get(ticker, Decimal("0")) + + if total_assets > 0: + current_ratio = (current_value / total_assets * 100).quantize(Decimal("0.01")) + else: + current_ratio = Decimal("0") + + target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) + diff_ratio = (target_ratio - current_ratio).quantize(Decimal("0.01")) + + if current_price > 0: + diff_quantity = int((target_value - current_value) / current_price) + else: + diff_quantity = 0 + + if diff_quantity > 0: + action = "buy" + elif diff_quantity < 0: + action = "sell" + else: + action = "hold" + + items.append(RebalanceCalculateItem( + ticker=ticker, + name=stock_names.get(ticker), + target_ratio=target_ratio, + current_ratio=current_ratio, + current_quantity=current_quantity, + current_value=current_value, + current_price=current_price, + target_value=target_value, + diff_ratio=diff_ratio, + diff_quantity=diff_quantity, + action=action, + change_vs_prev_month=self._calc_change_pct( + current_price, prev_prices.get(ticker) + ), + change_vs_start=self._calc_change_pct( + current_price, start_prices.get(ticker) + ), + )) + + action_order = {"buy": 0, "sell": 1, "hold": 2} + items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity))) + return items + + def _calc_additional_buy( + self, all_tickers, targets, holdings, current_prices, + current_values, total_assets, additional_amount, + stock_names, prev_prices, start_prices, + ): + from app.schemas.portfolio import RebalanceCalculateItem + + remaining = additional_amount or Decimal("0") + + # Calculate change vs prev month for sorting + ticker_changes = {} + for ticker in all_tickers: + cp = current_prices.get(ticker, Decimal("0")) + pp = prev_prices.get(ticker) + if pp and pp > 0: + ticker_changes[ticker] = ((cp - pp) / pp * 100).quantize(Decimal("0.01")) + else: + ticker_changes[ticker] = Decimal("0") + + # Sort by drop (most negative first) + sorted_tickers = sorted(all_tickers, key=lambda t: ticker_changes.get(t, Decimal("0"))) + + buy_quantities = {t: 0 for t in all_tickers} + + # Allocate additional amount to tickers sorted by drop + for ticker in sorted_tickers: + if remaining <= 0: + break + target_ratio = targets.get(ticker, Decimal("0")) + current_value = current_values.get(ticker, Decimal("0")) + current_price = current_prices.get(ticker, Decimal("0")) + if current_price <= 0: + continue + + target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) + deficit = target_value - current_value + if deficit <= 0: + continue + + buy_amount = min(deficit, remaining) + buy_qty = int(buy_amount / current_price) + if buy_qty <= 0: + continue + + actual_cost = current_price * buy_qty + buy_quantities[ticker] = buy_qty + remaining -= actual_cost + + # Build response items + items = [] + for ticker in all_tickers: + target_ratio = targets.get(ticker, Decimal("0")) + current_value = current_values.get(ticker, Decimal("0")) + current_quantity = holdings.get(ticker, (0, Decimal("0")))[0] + current_price = current_prices.get(ticker, Decimal("0")) + + if total_assets > 0: + current_ratio = (current_value / total_assets * 100).quantize(Decimal("0.01")) + else: + current_ratio = Decimal("0") + + target_value = (total_assets * target_ratio / 100).quantize(Decimal("0.01")) + diff_ratio = (target_ratio - current_ratio).quantize(Decimal("0.01")) + diff_quantity = buy_quantities.get(ticker, 0) + + items.append(RebalanceCalculateItem( + ticker=ticker, + name=stock_names.get(ticker), + target_ratio=target_ratio, + current_ratio=current_ratio, + current_quantity=current_quantity, + current_value=current_value, + current_price=current_price, + target_value=target_value, + diff_ratio=diff_ratio, + diff_quantity=diff_quantity, + action="buy" if diff_quantity > 0 else "hold", + change_vs_prev_month=self._calc_change_pct( + current_price, prev_prices.get(ticker) + ), + change_vs_start=self._calc_change_pct( + current_price, start_prices.get(ticker) + ), + )) + + items.sort(key=lambda x: (-x.diff_quantity, x.ticker)) + return items