diff --git a/backend/app/api/portfolio.py b/backend/app/api/portfolio.py index 19770f6..60e5131 100644 --- a/backend/app/api/portfolio.py +++ b/backend/app/api/portfolio.py @@ -8,9 +8,11 @@ from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser -from app.models.portfolio import Portfolio, PortfolioType +from app.models.portfolio import Portfolio, PortfolioType, Target, Holding from app.schemas.portfolio import ( PortfolioCreate, PortfolioUpdate, PortfolioResponse, PortfolioDetail, + TargetCreate, TargetResponse, + HoldingCreate, HoldingResponse, ) router = APIRouter(prefix="/api/portfolios", tags=["portfolios"]) @@ -108,3 +110,101 @@ async def delete_portfolio( db.delete(portfolio) db.commit() return None + + +def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: + """Helper to get portfolio with ownership check.""" + portfolio = db.query(Portfolio).filter( + Portfolio.id == portfolio_id, + Portfolio.user_id == user_id, + ).first() + if not portfolio: + raise HTTPException(status_code=404, detail="Portfolio not found") + return portfolio + + +@router.get("/{portfolio_id}/targets", response_model=List[TargetResponse]) +async def get_targets( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get target allocations for a portfolio.""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + return portfolio.targets + + +@router.put("/{portfolio_id}/targets", response_model=List[TargetResponse]) +async def set_targets( + portfolio_id: int, + targets: List[TargetCreate], + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Set target allocations for a portfolio (replaces all existing).""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + + # Validate total ratio + total_ratio = sum(t.target_ratio for t in targets) + if total_ratio != 100: + raise HTTPException( + status_code=400, + detail=f"Target ratios must sum to 100%, got {total_ratio}%" + ) + + # Delete existing targets + db.query(Target).filter(Target.portfolio_id == portfolio_id).delete() + + # Create new targets + new_targets = [] + for t in targets: + target = Target( + portfolio_id=portfolio_id, + ticker=t.ticker, + target_ratio=t.target_ratio, + ) + db.add(target) + new_targets.append(target) + + db.commit() + return new_targets + + +@router.get("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) +async def get_holdings( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get holdings for a portfolio.""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + return portfolio.holdings + + +@router.put("/{portfolio_id}/holdings", response_model=List[HoldingResponse]) +async def set_holdings( + portfolio_id: int, + holdings: List[HoldingCreate], + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Set holdings for a portfolio (replaces all existing).""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + + # Delete existing holdings + db.query(Holding).filter(Holding.portfolio_id == portfolio_id).delete() + + # Create new holdings + new_holdings = [] + for h in holdings: + holding = Holding( + portfolio_id=portfolio_id, + ticker=h.ticker, + quantity=h.quantity, + avg_price=h.avg_price, + ) + db.add(holding) + new_holdings.append(holding) + + db.commit() + return new_holdings