""" Snapshot API endpoints for portfolio history tracking. """ from datetime import date from decimal import Decimal from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding from app.schemas.portfolio import ( SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse, ReturnsResponse, ReturnDataPoint, ) from app.services.price_service import PriceService router = APIRouter(prefix="/api/portfolios", tags=["snapshots"]) def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: """Helper to get portfolio with ownership check.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == user_id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem]) async def list_snapshots( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all snapshots for a portfolio.""" _get_portfolio(db, portfolio_id, current_user.id) snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date.desc()) .all() ) return snapshots @router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED) async def create_snapshot( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create a snapshot of current portfolio state.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) if not portfolio.holdings: raise HTTPException( status_code=400, detail="Cannot create snapshot for empty portfolio" ) # Get current prices price_service = PriceService(db) tickers = [h.ticker for h in portfolio.holdings] prices = price_service.get_current_prices(tickers) # Calculate total value total_value = Decimal("0") holding_values = [] for holding in portfolio.holdings: price = prices.get(holding.ticker, Decimal("0")) value = price * holding.quantity total_value += value holding_values.append({ "ticker": holding.ticker, "quantity": holding.quantity, "price": price, "value": value, }) # Create snapshot snapshot = PortfolioSnapshot( portfolio_id=portfolio_id, total_value=total_value, snapshot_date=date.today(), ) db.add(snapshot) db.flush() # Get snapshot ID # Create snapshot holdings for hv in holding_values: ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0") snapshot_holding = SnapshotHolding( snapshot_id=snapshot.id, ticker=hv["ticker"], quantity=hv["quantity"], price=hv["price"], value=hv["value"], current_ratio=ratio.quantize(Decimal("0.01")), ) db.add(snapshot_holding) db.commit() db.refresh(snapshot) return SnapshotResponse( id=snapshot.id, portfolio_id=snapshot.portfolio_id, total_value=snapshot.total_value, snapshot_date=snapshot.snapshot_date, holdings=[ SnapshotHoldingResponse( ticker=h.ticker, quantity=h.quantity, price=h.price, value=h.value, current_ratio=h.current_ratio, ) for h in snapshot.holdings ], ) @router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse) async def get_snapshot( portfolio_id: int, snapshot_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get a specific snapshot with holdings.""" _get_portfolio(db, portfolio_id, current_user.id) snapshot = ( db.query(PortfolioSnapshot) .filter( PortfolioSnapshot.id == snapshot_id, PortfolioSnapshot.portfolio_id == portfolio_id, ) .first() ) if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") return SnapshotResponse( id=snapshot.id, portfolio_id=snapshot.portfolio_id, total_value=snapshot.total_value, snapshot_date=snapshot.snapshot_date, holdings=[ SnapshotHoldingResponse( ticker=h.ticker, quantity=h.quantity, price=h.price, value=h.value, current_ratio=h.current_ratio, ) for h in snapshot.holdings ], ) @router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_snapshot( portfolio_id: int, snapshot_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a snapshot.""" _get_portfolio(db, portfolio_id, current_user.id) snapshot = ( db.query(PortfolioSnapshot) .filter( PortfolioSnapshot.id == snapshot_id, PortfolioSnapshot.portfolio_id == portfolio_id, ) .first() ) if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") # Delete holdings first (cascade should handle this, but being explicit) db.query(SnapshotHolding).filter( SnapshotHolding.snapshot_id == snapshot_id ).delete() db.delete(snapshot) db.commit() return None @router.get("/{portfolio_id}/returns", response_model=ReturnsResponse) async def get_returns( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get portfolio returns over time based on snapshots.""" _get_portfolio(db, portfolio_id, current_user.id) snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date) .all() ) if not snapshots: return ReturnsResponse( portfolio_id=portfolio_id, start_date=None, end_date=None, total_return=None, cagr=None, data=[], ) # Calculate returns data_points = [] first_value = Decimal(str(snapshots[0].total_value)) prev_value = first_value for snapshot in snapshots: current_value = Decimal(str(snapshot.total_value)) # Daily return (vs previous snapshot) if prev_value > 0: daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01")) else: daily_return = Decimal("0") # Cumulative return (vs first snapshot) if first_value > 0: cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01")) else: cumulative_return = Decimal("0") data_points.append(ReturnDataPoint( date=snapshot.snapshot_date, total_value=current_value, daily_return=daily_return, cumulative_return=cumulative_return, )) prev_value = current_value # Calculate total return and CAGR start_date = snapshots[0].snapshot_date end_date = snapshots[-1].snapshot_date last_value = Decimal(str(snapshots[-1].total_value)) total_return = None cagr = None if first_value > 0: total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01")) # CAGR calculation days = (end_date - start_date).days if days > 0: years = Decimal(str(days)) / Decimal("365") if years > 0: ratio = last_value / first_value # CAGR = (ending/beginning)^(1/years) - 1 cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100 cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01")) return ReturnsResponse( portfolio_id=portfolio_id, start_date=start_date, end_date=end_date, total_return=total_return, cagr=cagr, data=data_points, )