""" Snapshot API endpoints for portfolio history tracking. """ from datetime import date from decimal import Decimal from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.core.database import get_db from app.api.deps import CurrentUser from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding from app.models.stock import ETFPrice from app.schemas.portfolio import ( SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse, ReturnsResponse, ReturnDataPoint, ) from app.services.price_service import PriceService from app.services.rebalance import RebalanceService router = APIRouter(prefix="/api/portfolios", tags=["snapshots"]) def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: """Helper to get portfolio with ownership check.""" portfolio = db.query(Portfolio).filter( Portfolio.id == portfolio_id, Portfolio.user_id == user_id, ).first() if not portfolio: raise HTTPException(status_code=404, detail="Portfolio not found") return portfolio @router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem]) async def list_snapshots( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get all snapshots for a portfolio.""" _get_portfolio(db, portfolio_id, current_user.id) snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date.desc()) .all() ) return snapshots @router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED) async def create_snapshot( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Create a snapshot of current portfolio state.""" portfolio = _get_portfolio(db, portfolio_id, current_user.id) if not portfolio.holdings: raise HTTPException( status_code=400, detail="Cannot create snapshot for empty portfolio" ) # Get current prices price_service = PriceService(db) tickers = [h.ticker for h in portfolio.holdings] prices = price_service.get_current_prices(tickers) # Calculate total value total_value = Decimal("0") holding_values = [] for holding in portfolio.holdings: price = prices.get(holding.ticker, Decimal("0")) value = price * holding.quantity total_value += value holding_values.append({ "ticker": holding.ticker, "quantity": holding.quantity, "price": price, "value": value, }) # Create snapshot snapshot = PortfolioSnapshot( portfolio_id=portfolio_id, total_value=total_value, snapshot_date=date.today(), ) db.add(snapshot) db.flush() # Get snapshot ID # Create snapshot holdings for hv in holding_values: ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0") snapshot_holding = SnapshotHolding( snapshot_id=snapshot.id, ticker=hv["ticker"], quantity=hv["quantity"], price=hv["price"], value=hv["value"], current_ratio=ratio.quantize(Decimal("0.01")), ) db.add(snapshot_holding) db.commit() db.refresh(snapshot) # Get stock names name_service = RebalanceService(db) names = name_service.get_stock_names(tickers) return SnapshotResponse( id=snapshot.id, portfolio_id=snapshot.portfolio_id, total_value=snapshot.total_value, snapshot_date=snapshot.snapshot_date, holdings=[ SnapshotHoldingResponse( ticker=h.ticker, name=names.get(h.ticker), quantity=h.quantity, price=h.price, value=h.value, current_ratio=h.current_ratio, ) for h in snapshot.holdings ], ) @router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse) async def get_snapshot( portfolio_id: int, snapshot_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get a specific snapshot with holdings.""" _get_portfolio(db, portfolio_id, current_user.id) snapshot = ( db.query(PortfolioSnapshot) .filter( PortfolioSnapshot.id == snapshot_id, PortfolioSnapshot.portfolio_id == portfolio_id, ) .first() ) if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") # Get stock names tickers = [h.ticker for h in snapshot.holdings] name_service = RebalanceService(db) names = name_service.get_stock_names(tickers) return SnapshotResponse( id=snapshot.id, portfolio_id=snapshot.portfolio_id, total_value=snapshot.total_value, snapshot_date=snapshot.snapshot_date, holdings=[ SnapshotHoldingResponse( ticker=h.ticker, name=names.get(h.ticker), quantity=h.quantity, price=h.price, value=h.value, current_ratio=h.current_ratio, ) for h in snapshot.holdings ], ) @router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_snapshot( portfolio_id: int, snapshot_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Delete a snapshot.""" _get_portfolio(db, portfolio_id, current_user.id) snapshot = ( db.query(PortfolioSnapshot) .filter( PortfolioSnapshot.id == snapshot_id, PortfolioSnapshot.portfolio_id == portfolio_id, ) .first() ) if not snapshot: raise HTTPException(status_code=404, detail="Snapshot not found") # Delete holdings first (cascade should handle this, but being explicit) db.query(SnapshotHolding).filter( SnapshotHolding.snapshot_id == snapshot_id ).delete() db.delete(snapshot) db.commit() return None @router.get("/{portfolio_id}/returns", response_model=ReturnsResponse) async def get_returns( portfolio_id: int, current_user: CurrentUser, db: Session = Depends(get_db), ): """Get portfolio returns over time based on snapshots.""" _get_portfolio(db, portfolio_id, current_user.id) snapshots = ( db.query(PortfolioSnapshot) .filter(PortfolioSnapshot.portfolio_id == portfolio_id) .order_by(PortfolioSnapshot.snapshot_date) .all() ) if not snapshots: return ReturnsResponse( portfolio_id=portfolio_id, start_date=None, end_date=None, total_return=None, cagr=None, data=[], ) # Get benchmark (KOSPI ETF 069500) prices for the same date range snapshot_dates = [s.snapshot_date for s in snapshots] benchmark_ticker = "069500" # KODEX 200 (KOSPI benchmark) benchmark_prices = ( db.query(ETFPrice) .filter( ETFPrice.ticker == benchmark_ticker, ETFPrice.date.in_(snapshot_dates), ) .all() ) benchmark_map = {bp.date: Decimal(str(bp.close)) for bp in benchmark_prices} # Get first benchmark price for cumulative calculation first_benchmark = benchmark_map.get(snapshots[0].snapshot_date) # Calculate returns data_points = [] first_value = Decimal(str(snapshots[0].total_value)) prev_value = first_value for snapshot in snapshots: current_value = Decimal(str(snapshot.total_value)) # Daily return (vs previous snapshot) if prev_value > 0: daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01")) else: daily_return = Decimal("0") # Cumulative return (vs first snapshot) if first_value > 0: cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01")) else: cumulative_return = Decimal("0") # Benchmark cumulative return benchmark_return = None bench_price = benchmark_map.get(snapshot.snapshot_date) if bench_price and first_benchmark and first_benchmark > 0: benchmark_return = ((bench_price - first_benchmark) / first_benchmark * 100).quantize(Decimal("0.01")) data_points.append(ReturnDataPoint( date=snapshot.snapshot_date, total_value=current_value, daily_return=daily_return, cumulative_return=cumulative_return, benchmark_return=benchmark_return, )) prev_value = current_value # Calculate total return and CAGR start_date = snapshots[0].snapshot_date end_date = snapshots[-1].snapshot_date last_value = Decimal(str(snapshots[-1].total_value)) total_return = None cagr = None benchmark_total_return = None if first_value > 0: total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01")) # CAGR calculation days = (end_date - start_date).days if days > 0: years = Decimal(str(days)) / Decimal("365") if years > 0: ratio = last_value / first_value # CAGR = (ending/beginning)^(1/years) - 1 cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100 cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01")) # Benchmark total return last_benchmark = benchmark_map.get(end_date) if first_benchmark and last_benchmark and first_benchmark > 0: benchmark_total_return = ((last_benchmark - first_benchmark) / first_benchmark * 100).quantize(Decimal("0.01")) return ReturnsResponse( portfolio_id=portfolio_id, start_date=start_date, end_date=end_date, total_return=total_return, cagr=cagr, benchmark_total_return=benchmark_total_return, data=data_points, )