""" Portfolio management API endpoints. """ from decimal import Decimal from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser from app.models.portfolio import Portfolio, PortfolioType, Target, Holding, Transaction, TransactionType from app.schemas.portfolio import ( PortfolioCreate, PortfolioUpdate, PortfolioResponse, PortfolioDetail, TargetCreate, TargetResponse, HoldingCreate, HoldingResponse, HoldingWithValue, TransactionCreate, TransactionResponse, RebalanceResponse, RebalanceSimulationRequest, RebalanceSimulationResponse, RebalanceCalculateRequest, RebalanceCalculateResponse, ) from app.services.rebalance import RebalanceService router = APIRouter(prefix="/api/portfolios", tags=["portfolios"]) @router.get("", response_model=List[PortfolioResponse]) async def list_portfolios( current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all portfolios for current user.""" portfolios = ( db.query(Portfolio) .filter(Portfolio.user_id == current_user.id) .order_by(Portfolio.created_at.desc()) .all() ) return portfolios @router.post("", response_model=PortfolioResponse, status_code=status.HTTP_201_CREATED) async def create_portfolio( data: PortfolioCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create a new portfolio.""" portfolio_type = PortfolioType(data.portfolio_type) portfolio = Portfolio( user_id=current_user.id, name=data.name, portfolio_type=portfolio_type, ) db.add(portfolio) db.commit() db.refresh(portfolio) return portfolio @router.get("/{portfolio_id}", response_model=PortfolioResponse) async def get_portfolio( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get a portfolio by ID.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.put("/{portfolio_id}", response_model=PortfolioResponse) async def update_portfolio( portfolio_id: int, data: PortfolioUpdate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Update a portfolio.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") if data.name is not None: portfolio.name = data.name if data.portfolio_type is not None: portfolio.portfolio_type = PortfolioType(data.portfolio_type) db.commit() db.refresh(portfolio) return portfolio @router.delete("/{portfolio_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_portfolio( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a portfolio.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") db.delete(portfolio) db.commit() return None def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: """Helper to get portfolio with ownership check.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == user_id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.get("/{portfolio_id}/targets", response_model=List[TargetResponse]) async def get_targets( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get target allocations for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) return portfolio.targets @router.put("/{portfolio_id}/targets", response_model=List[TargetResponse]) async def set_targets( portfolio_id: int, targets: List[TargetCreate], current_user: CurrentUser, db: Session = Depends(get_db), ): """Set target allocations for a portfolio (replaces all existing).""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Validate total ratio total_ratio = sum(t.target_ratio for t in targets) if total_ratio != 100: raise HTTPException( status_code=400, detail=f"Target ratios must sum to 100%, got {total_ratio}%" ) # Delete existing targets db.query(Target).filter(Target.portfolio_id == portfolio_id).delete() # Create new targets new_targets = [] for t in targets: target = Target( portfolio_id=portfolio_id, ticker=t.ticker, target_ratio=t.target_ratio, ) db.add(target) new_targets.append(target) db.commit() return new_targets @router.get("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) async def get_holdings( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get holdings for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) return portfolio.holdings @router.put("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) async def set_holdings( portfolio_id: int, holdings: List[HoldingCreate], current_user: CurrentUser, db: Session = Depends(get_db), ): """Set holdings for a portfolio (replaces all existing).""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Delete existing holdings db.query(Holding).filter(Holding.portfolio_id == portfolio_id).delete() # Create new holdings new_holdings = [] for h in holdings: holding = Holding( portfolio_id=portfolio_id, ticker=h.ticker, quantity=h.quantity, avg_price=h.avg_price, ) db.add(holding) new_holdings.append(holding) db.commit() return new_holdings @router.get("/{portfolio_id}/transactions", response_model=List[TransactionResponse]) async def get_transactions( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), limit: int = 50, ): """Get transaction history for a portfolio.""" _get_portfolio(db, portfolio_id, current_user.id) transactions = ( db.query(Transaction) .filter(Transaction.portfolio_id == portfolio_id) .order_by(Transaction.executed_at.desc()) .limit(limit) .all() ) return transactions @router.post("/{portfolio_id}/transactions", response_model=TransactionResponse, status_code=status.HTTP_201_CREATED) async def add_transaction( portfolio_id: int, data: TransactionCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Add a transaction and update holdings accordingly.""" _get_portfolio(db, portfolio_id, current_user.id) tx_type = TransactionType(data.tx_type) # Create transaction transaction = Transaction( portfolio_id=portfolio_id, ticker=data.ticker, tx_type=tx_type, quantity=data.quantity, price=data.price, executed_at=data.executed_at, memo=data.memo, ) db.add(transaction) # Update holding holding = db.query(Holding).filter( Holding.portfolio_id == portfolio_id, Holding.ticker == data.ticker, ).first() if tx_type == TransactionType.BUY: if holding: # Update average price total_value = (holding.quantity * holding.avg_price) + (data.quantity * data.price) new_quantity = holding.quantity + data.quantity holding.quantity = new_quantity holding.avg_price = total_value / new_quantity if new_quantity > 0 else 0 else: # Create new holding holding = Holding( portfolio_id=portfolio_id, ticker=data.ticker, quantity=data.quantity, avg_price=data.price, ) db.add(holding) elif tx_type == TransactionType.SELL: if not holding or holding.quantity < data.quantity: raise HTTPException( status_code=400, detail=f"Insufficient quantity for {data.ticker}" ) holding.quantity -= data.quantity if holding.quantity == 0: db.delete(holding) db.commit() db.refresh(transaction) return transaction @router.get("/{portfolio_id}/rebalance", response_model=RebalanceResponse) async def calculate_rebalance( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Calculate rebalancing for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) service = RebalanceService(db) return service.calculate_rebalance(portfolio) @router.post("/{portfolio_id}/rebalance/simulate", response_model=RebalanceSimulationResponse) async def simulate_rebalance( portfolio_id: int, data: RebalanceSimulationRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """Simulate rebalancing with additional investment amount.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) service = RebalanceService(db) return service.calculate_rebalance(portfolio, additional_amount=data.additional_amount) @router.post("/{portfolio_id}/rebalance/calculate", response_model=RebalanceCalculateResponse) async def calculate_rebalance_manual( portfolio_id: int, data: RebalanceCalculateRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """Calculate rebalancing with manual prices and strategy selection.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) if data.strategy == "additional_buy" and not data.additional_amount: raise HTTPException( status_code=400, detail="additional_amount is required for additional_buy strategy" ) service = RebalanceService(db) return service.calculate_with_prices( portfolio, strategy=data.strategy, manual_prices=data.prices, additional_amount=data.additional_amount, ) @router.get("/{portfolio_id}/detail", response_model=PortfolioDetail) async def get_portfolio_detail( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get portfolio with calculated values.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Get current prices and stock names tickers = [h.ticker for h in portfolio.holdings] service = RebalanceService(db) prices = service.get_current_prices(tickers) names = service.get_stock_names(tickers) # Calculate holding values holdings_with_value = [] total_value = Decimal("0") total_invested = Decimal("0") for holding in portfolio.holdings: current_price = prices.get(holding.ticker, Decimal("0")) value = current_price * holding.quantity invested = Decimal(str(holding.avg_price)) * holding.quantity profit_loss = value - invested profit_loss_ratio = (profit_loss / invested * 100) if invested > 0 else Decimal("0") total_value += value total_invested += invested holdings_with_value.append(HoldingWithValue( ticker=holding.ticker, name=names.get(holding.ticker), quantity=holding.quantity, avg_price=Decimal(str(holding.avg_price)), current_price=current_price, value=value, current_ratio=Decimal("0"), # Will be calculated after total profit_loss=profit_loss, profit_loss_ratio=profit_loss_ratio.quantize(Decimal("0.01")), )) # Calculate current ratios for h in holdings_with_value: if total_value > 0: h.current_ratio = (h.value / total_value * 100).quantize(Decimal("0.01")) return PortfolioDetail( id=portfolio.id, user_id=portfolio.user_id, name=portfolio.name, portfolio_type=portfolio.portfolio_type.value, created_at=portfolio.created_at, updated_at=portfolio.updated_at, targets=portfolio.targets, holdings=holdings_with_value, total_value=total_value, total_invested=total_invested, total_profit_loss=total_value - total_invested, )