feat: add rebalancing calculation service
This commit is contained in:
parent
95f97eeef9
commit
a45c44740e
154
backend/app/services/rebalance.py
Normal file
154
backend/app/services/rebalance.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Rebalancing calculation service.
|
||||
"""
|
||||
from decimal import Decimal
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.portfolio import Portfolio
|
||||
from app.models.stock import Stock, ETF, ETFPrice
|
||||
from app.schemas.portfolio import RebalanceItem, RebalanceResponse, RebalanceSimulationResponse
|
||||
|
||||
|
||||
class RebalanceService:
|
||||
"""Service for calculating portfolio rebalancing."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]:
|
||||
"""Get current prices for tickers from database."""
|
||||
prices = {}
|
||||
|
||||
# Check stocks
|
||||
stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all()
|
||||
for stock in stocks:
|
||||
if stock.close_price:
|
||||
prices[stock.ticker] = Decimal(str(stock.close_price))
|
||||
|
||||
# Check ETFs for missing tickers
|
||||
missing = [t for t in tickers if t not in prices]
|
||||
if missing:
|
||||
# Get latest ETF prices
|
||||
from sqlalchemy import func
|
||||
subq = (
|
||||
self.db.query(
|
||||
ETFPrice.ticker,
|
||||
func.max(ETFPrice.date).label('max_date')
|
||||
)
|
||||
.filter(ETFPrice.ticker.in_(missing))
|
||||
.group_by(ETFPrice.ticker)
|
||||
.subquery()
|
||||
)
|
||||
etf_prices = (
|
||||
self.db.query(ETFPrice)
|
||||
.join(subq, (ETFPrice.ticker == subq.c.ticker) & (ETFPrice.date == subq.c.max_date))
|
||||
.all()
|
||||
)
|
||||
for ep in etf_prices:
|
||||
prices[ep.ticker] = Decimal(str(ep.close))
|
||||
|
||||
return prices
|
||||
|
||||
def get_stock_names(self, tickers: List[str]) -> Dict[str, str]:
|
||||
"""Get stock names for tickers."""
|
||||
names = {}
|
||||
stocks = self.db.query(Stock).filter(Stock.ticker.in_(tickers)).all()
|
||||
for stock in stocks:
|
||||
names[stock.ticker] = stock.name
|
||||
|
||||
# Also check ETFs
|
||||
missing = [t for t in tickers if t not in names]
|
||||
if missing:
|
||||
etfs = self.db.query(ETF).filter(ETF.ticker.in_(missing)).all()
|
||||
for etf in etfs:
|
||||
names[etf.ticker] = etf.name
|
||||
|
||||
return names
|
||||
|
||||
def calculate_rebalance(
|
||||
self,
|
||||
portfolio: Portfolio,
|
||||
additional_amount: Optional[Decimal] = None,
|
||||
) -> RebalanceResponse | RebalanceSimulationResponse:
|
||||
"""Calculate rebalancing for a portfolio."""
|
||||
targets = {t.ticker: Decimal(str(t.target_ratio)) for t in portfolio.targets}
|
||||
holdings = {h.ticker: (h.quantity, Decimal(str(h.avg_price))) for h in portfolio.holdings}
|
||||
|
||||
all_tickers = list(set(targets.keys()) | set(holdings.keys()))
|
||||
current_prices = self.get_current_prices(all_tickers)
|
||||
stock_names = self.get_stock_names(all_tickers)
|
||||
|
||||
# Calculate current values
|
||||
current_values = {}
|
||||
for ticker, (quantity, _) in holdings.items():
|
||||
price = current_prices.get(ticker, Decimal("0"))
|
||||
current_values[ticker] = price * quantity
|
||||
|
||||
current_total = sum(current_values.values())
|
||||
|
||||
if additional_amount:
|
||||
new_total = current_total + additional_amount
|
||||
else:
|
||||
new_total = current_total
|
||||
|
||||
# Calculate rebalance items
|
||||
items = []
|
||||
for ticker in all_tickers:
|
||||
target_ratio = targets.get(ticker, Decimal("0"))
|
||||
current_value = current_values.get(ticker, Decimal("0"))
|
||||
current_quantity = holdings.get(ticker, (0, Decimal("0")))[0]
|
||||
current_price = current_prices.get(ticker, Decimal("0"))
|
||||
|
||||
if new_total > 0:
|
||||
current_ratio = (current_value / new_total * 100).quantize(Decimal("0.01"))
|
||||
else:
|
||||
current_ratio = Decimal("0")
|
||||
|
||||
target_value = (new_total * target_ratio / 100).quantize(Decimal("0.01"))
|
||||
diff_value = target_value - current_value
|
||||
|
||||
if current_price > 0:
|
||||
diff_quantity = int(diff_value / current_price)
|
||||
else:
|
||||
diff_quantity = 0
|
||||
|
||||
if diff_quantity > 0:
|
||||
action = "buy"
|
||||
elif diff_quantity < 0:
|
||||
action = "sell"
|
||||
else:
|
||||
action = "hold"
|
||||
|
||||
items.append(RebalanceItem(
|
||||
ticker=ticker,
|
||||
name=stock_names.get(ticker),
|
||||
target_ratio=target_ratio,
|
||||
current_ratio=current_ratio,
|
||||
current_quantity=current_quantity,
|
||||
current_value=current_value,
|
||||
target_value=target_value,
|
||||
diff_value=diff_value,
|
||||
diff_quantity=diff_quantity,
|
||||
action=action,
|
||||
))
|
||||
|
||||
# Sort by action priority (buy first, then sell, then hold)
|
||||
action_order = {"buy": 0, "sell": 1, "hold": 2}
|
||||
items.sort(key=lambda x: (action_order.get(x.action, 3), -abs(x.diff_quantity)))
|
||||
|
||||
if additional_amount:
|
||||
return RebalanceSimulationResponse(
|
||||
portfolio_id=portfolio.id,
|
||||
current_total=current_total,
|
||||
additional_amount=additional_amount,
|
||||
new_total=new_total,
|
||||
items=items,
|
||||
)
|
||||
else:
|
||||
return RebalanceResponse(
|
||||
portfolio_id=portfolio.id,
|
||||
total_value=current_total,
|
||||
items=items,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user