""" Portfolio management API endpoints. """ from decimal import Decimal from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser from app.models.portfolio import Portfolio, PortfolioType, Target, Holding, Transaction, TransactionType from app.models.stock import ETF from app.schemas.portfolio import ( PortfolioCreate, PortfolioUpdate, PortfolioResponse, PortfolioDetail, TargetCreate, TargetResponse, HoldingCreate, HoldingResponse, HoldingWithValue, TransactionCreate, TransactionResponse, RebalanceResponse, RebalanceSimulationRequest, RebalanceSimulationResponse, RebalanceCalculateRequest, RebalanceCalculateResponse, RebalanceApplyRequest, RebalanceApplyResponse, PositionSizeResponse, ) from app.services.rebalance import RebalanceService router = APIRouter(prefix="/api/portfolios", tags=["portfolios"]) @router.get("", response_model=List[PortfolioResponse]) async def list_portfolios( current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all portfolios for current user.""" portfolios = ( db.query(Portfolio) .filter(Portfolio.user_id == current_user.id) .order_by(Portfolio.created_at.desc()) .all() ) return portfolios @router.post("", response_model=PortfolioResponse, status_code=status.HTTP_201_CREATED) async def create_portfolio( data: PortfolioCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create a new portfolio.""" portfolio_type = PortfolioType(data.portfolio_type) portfolio = Portfolio( user_id=current_user.id, name=data.name, portfolio_type=portfolio_type, ) db.add(portfolio) db.commit() db.refresh(portfolio) return portfolio @router.get("/{portfolio_id}", response_model=PortfolioResponse) async def get_portfolio( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get a portfolio by ID.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.put("/{portfolio_id}", response_model=PortfolioResponse) async def update_portfolio( portfolio_id: int, data: PortfolioUpdate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Update a portfolio.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") if data.name is not None: portfolio.name = data.name if data.portfolio_type is not None: portfolio.portfolio_type = PortfolioType(data.portfolio_type) db.commit() db.refresh(portfolio) return portfolio @router.delete("/{portfolio_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_portfolio( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a portfolio.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == current_user.id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") db.delete(portfolio) db.commit() return None def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: """Helper to get portfolio with ownership check.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == user_id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.get("/{portfolio_id}/targets", response_model=List[TargetResponse]) async def get_targets( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get target allocations for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) return portfolio.targets @router.put("/{portfolio_id}/targets", response_model=List[TargetResponse]) async def set_targets( portfolio_id: int, targets: List[TargetCreate], current_user: CurrentUser, db: Session = Depends(get_db), ): """Set target allocations for a portfolio (replaces all existing).""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Validate total ratio total_ratio = sum(t.target_ratio for t in targets) if total_ratio != 100: raise HTTPException( status_code=400, detail=f"Target ratios must sum to 100%, got {total_ratio}%" ) # Delete existing targets db.query(Target).filter(Target.portfolio_id == portfolio_id).delete() # Create new targets new_targets = [] for t in targets: target = Target( portfolio_id=portfolio_id, ticker=t.ticker, target_ratio=t.target_ratio, ) db.add(target) new_targets.append(target) db.commit() return new_targets @router.get("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) async def get_holdings( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get holdings for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) return portfolio.holdings @router.put("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) async def set_holdings( portfolio_id: int, holdings: List[HoldingCreate], current_user: CurrentUser, db: Session = Depends(get_db), ): """Set holdings for a portfolio (replaces all existing).""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Delete existing holdings db.query(Holding).filter(Holding.portfolio_id == portfolio_id).delete() # Create new holdings new_holdings = [] for h in holdings: holding = Holding( portfolio_id=portfolio_id, ticker=h.ticker, quantity=h.quantity, avg_price=h.avg_price, ) db.add(holding) new_holdings.append(holding) db.commit() return new_holdings @router.get("/{portfolio_id}/transactions", response_model=List[TransactionResponse]) async def get_transactions( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), limit: int = 50, ): """Get transaction history for a portfolio.""" _get_portfolio(db, portfolio_id, current_user.id) transactions = ( db.query(Transaction) .filter(Transaction.portfolio_id == portfolio_id) .order_by(Transaction.executed_at.desc()) .limit(limit) .all() ) # Resolve stock names tickers = list({tx.ticker for tx in transactions}) service = RebalanceService(db) names = service.get_stock_names(tickers) return [ TransactionResponse( id=tx.id, ticker=tx.ticker, name=names.get(tx.ticker), tx_type=tx.tx_type.value, quantity=tx.quantity, price=tx.price, executed_at=tx.executed_at, memo=tx.memo, realized_pnl=tx.realized_pnl, ) for tx in transactions ] @router.post("/{portfolio_id}/transactions", response_model=TransactionResponse, status_code=status.HTTP_201_CREATED) async def add_transaction( portfolio_id: int, data: TransactionCreate, current_user: CurrentUser, db: Session = Depends(get_db), ): """Add a transaction and update holdings accordingly.""" _get_portfolio(db, portfolio_id, current_user.id) tx_type = TransactionType(data.tx_type) # Create transaction transaction = Transaction( portfolio_id=portfolio_id, ticker=data.ticker, tx_type=tx_type, quantity=data.quantity, price=data.price, executed_at=data.executed_at, memo=data.memo, ) db.add(transaction) # Update holding holding = db.query(Holding).filter( Holding.portfolio_id == portfolio_id, Holding.ticker == data.ticker, ).first() if tx_type == TransactionType.BUY: if holding: # Update average price total_value = (holding.quantity * holding.avg_price) + (data.quantity * data.price) new_quantity = holding.quantity + data.quantity holding.quantity = new_quantity holding.avg_price = total_value / new_quantity if new_quantity > 0 else 0 else: # Create new holding holding = Holding( portfolio_id=portfolio_id, ticker=data.ticker, quantity=data.quantity, avg_price=data.price, ) db.add(holding) elif tx_type == TransactionType.SELL: if not holding or holding.quantity < data.quantity: raise HTTPException( status_code=400, detail=f"Insufficient quantity for {data.ticker}" ) # Calculate realized PnL: (sell_price - avg_price) * quantity transaction.realized_pnl = (data.price - holding.avg_price) * data.quantity holding.quantity -= data.quantity if holding.quantity == 0: db.delete(holding) db.commit() db.refresh(transaction) return transaction @router.get("/{portfolio_id}/rebalance", response_model=RebalanceResponse) async def calculate_rebalance( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Calculate rebalancing for a portfolio.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) service = RebalanceService(db) return service.calculate_rebalance(portfolio) @router.post("/{portfolio_id}/rebalance/simulate", response_model=RebalanceSimulationResponse) async def simulate_rebalance( portfolio_id: int, data: RebalanceSimulationRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """Simulate rebalancing with additional investment amount.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) service = RebalanceService(db) return service.calculate_rebalance(portfolio, additional_amount=data.additional_amount) @router.post("/{portfolio_id}/rebalance/calculate", response_model=RebalanceCalculateResponse) async def calculate_rebalance_manual( portfolio_id: int, data: RebalanceCalculateRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """Calculate rebalancing with manual prices and strategy selection.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) if data.strategy == "additional_buy" and not data.additional_amount: raise HTTPException( status_code=400, detail="additional_amount is required for additional_buy strategy" ) service = RebalanceService(db) return service.calculate_with_prices( portfolio, strategy=data.strategy, manual_prices=data.prices, additional_amount=data.additional_amount, min_trade_amount=data.min_trade_amount, ) @router.post("/{portfolio_id}/rebalance/apply", response_model=RebalanceApplyResponse, status_code=status.HTTP_201_CREATED) async def apply_rebalance( portfolio_id: int, data: RebalanceApplyRequest, current_user: CurrentUser, db: Session = Depends(get_db), ): """리밸런싱 결과를 적용하여 거래를 일괄 생성한다.""" from datetime import datetime, timezone portfolio = _get_portfolio(db, portfolio_id, current_user.id) transactions = [] service = RebalanceService(db) for item in data.items: tx_type = TransactionType(item.action) transaction = Transaction( portfolio_id=portfolio_id, ticker=item.ticker, tx_type=tx_type, quantity=item.quantity, price=item.price, executed_at=datetime.now(timezone.utc), memo="리밸런싱 적용", ) db.add(transaction) # Update holding holding = db.query(Holding).filter( Holding.portfolio_id == portfolio_id, Holding.ticker == item.ticker, ).first() if tx_type == TransactionType.BUY: if holding: total_value = (holding.quantity * holding.avg_price) + (item.quantity * item.price) new_quantity = holding.quantity + item.quantity holding.quantity = new_quantity holding.avg_price = total_value / new_quantity if new_quantity > 0 else 0 else: holding = Holding( portfolio_id=portfolio_id, ticker=item.ticker, quantity=item.quantity, avg_price=item.price, ) db.add(holding) elif tx_type == TransactionType.SELL: if not holding or holding.quantity < item.quantity: raise HTTPException(status_code=400, detail=f"Insufficient quantity for {item.ticker}") transaction.realized_pnl = (item.price - holding.avg_price) * item.quantity holding.quantity -= item.quantity if holding.quantity == 0: db.delete(holding) transactions.append(transaction) db.commit() for tx in transactions: db.refresh(tx) # Resolve stock names tickers = list({tx.ticker for tx in transactions}) names = service.get_stock_names(tickers) tx_responses = [ TransactionResponse( id=tx.id, ticker=tx.ticker, name=names.get(tx.ticker), tx_type=tx.tx_type.value, quantity=tx.quantity, price=tx.price, executed_at=tx.executed_at, memo=tx.memo, realized_pnl=tx.realized_pnl, ) for tx in transactions ] return RebalanceApplyResponse( transactions=tx_responses, holdings_updated=len(transactions), ) @router.get("/{portfolio_id}/detail", response_model=PortfolioDetail) async def get_portfolio_detail( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get portfolio with calculated values.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) # Get current prices and stock names tickers = [h.ticker for h in portfolio.holdings] service = RebalanceService(db) prices = service.get_current_prices(tickers) names = service.get_stock_names(tickers) # Calculate holding values holdings_with_value = [] total_value = Decimal("0") total_invested = Decimal("0") for holding in portfolio.holdings: current_price = prices.get(holding.ticker, Decimal("0")) value = current_price * holding.quantity invested = Decimal(str(holding.avg_price)) * holding.quantity profit_loss = value - invested profit_loss_ratio = (profit_loss / invested * 100) if invested > 0 else Decimal("0") total_value += value total_invested += invested holdings_with_value.append(HoldingWithValue( ticker=holding.ticker, name=names.get(holding.ticker), quantity=holding.quantity, avg_price=Decimal(str(holding.avg_price)), current_price=current_price, value=value, current_ratio=Decimal("0"), # Will be calculated after total profit_loss=profit_loss, profit_loss_ratio=profit_loss_ratio.quantize(Decimal("0.01")), )) # Calculate current ratios for h in holdings_with_value: if total_value > 0: h.current_ratio = (h.value / total_value * 100).quantize(Decimal("0.01")) # Calculate realized PnL (sum of all sell transactions with realized_pnl) from sqlalchemy import func total_realized_pnl_result = ( db.query(func.coalesce(func.sum(Transaction.realized_pnl), 0)) .filter( Transaction.portfolio_id == portfolio_id, Transaction.realized_pnl.isnot(None), ) .scalar() ) total_realized_pnl = Decimal(str(total_realized_pnl_result)) total_unrealized_pnl = (total_value - total_invested) # Calculate risk asset ratio for pension portfolios risk_asset_ratio = None if portfolio.portfolio_type == PortfolioType.PENSION and total_value > 0: # Look up ETF asset classes etf_tickers = [h.ticker for h in holdings_with_value] etfs = db.query(ETF).filter(ETF.ticker.in_(etf_tickers)).all() if etf_tickers else [] safe_classes = {"bond", "gold"} etf_class_map = {e.ticker: e.asset_class.value for e in etfs} risk_value = Decimal("0") for h in holdings_with_value: asset_class = etf_class_map.get(h.ticker) if asset_class not in safe_classes: risk_value += h.value or Decimal("0") risk_asset_ratio = (risk_value / total_value * 100).quantize(Decimal("0.01")) return PortfolioDetail( id=portfolio.id, user_id=portfolio.user_id, name=portfolio.name, portfolio_type=portfolio.portfolio_type.value, created_at=portfolio.created_at, updated_at=portfolio.updated_at, targets=portfolio.targets, holdings=holdings_with_value, total_value=total_value, total_invested=total_invested, total_profit_loss=total_value - total_invested, total_realized_pnl=total_realized_pnl, total_unrealized_pnl=total_unrealized_pnl, risk_asset_ratio=risk_asset_ratio, ) @router.get("/{portfolio_id}/position-size", response_model=PositionSizeResponse) async def get_position_size( portfolio_id: int, ticker: str, price: Decimal, current_user: CurrentUser, db: Session = Depends(get_db), ): """포지션 사이징 가이드: 추천 수량과 최대 수량을 계산한다.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) service = RebalanceService(db) # Calculate total portfolio value holding_tickers = [h.ticker for h in portfolio.holdings] prices = service.get_current_prices(holding_tickers) total_value = Decimal("0") for holding in portfolio.holdings: cp = prices.get(holding.ticker, Decimal("0")) total_value += cp * holding.quantity # Current holding for this ticker current_holding = db.query(Holding).filter( Holding.portfolio_id == portfolio_id, Holding.ticker == ticker, ).first() current_qty = current_holding.quantity if current_holding else 0 current_value = price * current_qty # Current ratio current_ratio = (current_value / total_value * 100) if total_value > 0 else Decimal("0") # Target ratio from portfolio targets target = db.query(Target).filter( Target.portfolio_id == portfolio_id, Target.ticker == ticker, ).first() target_ratio = Decimal(str(target.target_ratio)) if target else None # Max position: 20% of portfolio (or target ratio if set) max_ratio = target_ratio if target_ratio else Decimal("20") max_value = total_value * max_ratio / 100 max_additional_value = max(max_value - current_value, Decimal("0")) max_quantity = int(max_additional_value / price) if price > 0 else 0 # Recommended: equal-weight across targets, or 10% if no targets num_targets = len(portfolio.targets) or 1 equal_ratio = Decimal("100") / num_targets rec_ratio = target_ratio if target_ratio else min(equal_ratio, Decimal("10")) rec_value = total_value * rec_ratio / 100 rec_additional_value = max(rec_value - current_value, Decimal("0")) recommended_quantity = int(rec_additional_value / price) if price > 0 else 0 return PositionSizeResponse( ticker=ticker, price=price, total_portfolio_value=total_value, current_holding_quantity=current_qty, current_holding_value=current_value, current_ratio=current_ratio.quantize(Decimal("0.01")) if isinstance(current_ratio, Decimal) else current_ratio, target_ratio=target_ratio, recommended_quantity=recommended_quantity, max_quantity=max_quantity, recommended_value=rec_additional_value, max_value=max_additional_value, )