300 lines
8.8 KiB
Python
300 lines
8.8 KiB
Python
"""
|
|
Snapshot API endpoints for portfolio history tracking.
|
|
"""
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.database import get_db
|
|
from app.api.deps import CurrentUser
|
|
from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding
|
|
from app.schemas.portfolio import (
|
|
SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse,
|
|
ReturnsResponse, ReturnDataPoint,
|
|
)
|
|
from app.services.price_service import PriceService
|
|
from app.services.rebalance import RebalanceService
|
|
|
|
router = APIRouter(prefix="/api/portfolios", tags=["snapshots"])
|
|
|
|
|
|
def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
|
|
"""Helper to get portfolio with ownership check."""
|
|
portfolio = db.query(Portfolio).filter(
|
|
Portfolio.id == portfolio_id,
|
|
Portfolio.user_id == user_id,
|
|
).first()
|
|
if not portfolio:
|
|
raise HTTPException(status_code=404, detail="Portfolio not found")
|
|
return portfolio
|
|
|
|
|
|
@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem])
|
|
async def list_snapshots(
|
|
portfolio_id: int,
|
|
current_user: CurrentUser,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Get all snapshots for a portfolio."""
|
|
_get_portfolio(db, portfolio_id, current_user.id)
|
|
|
|
snapshots = (
|
|
db.query(PortfolioSnapshot)
|
|
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
|
|
.order_by(PortfolioSnapshot.snapshot_date.desc())
|
|
.all()
|
|
)
|
|
|
|
return snapshots
|
|
|
|
|
|
@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_snapshot(
|
|
portfolio_id: int,
|
|
current_user: CurrentUser,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Create a snapshot of current portfolio state."""
|
|
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
|
|
|
|
if not portfolio.holdings:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot create snapshot for empty portfolio"
|
|
)
|
|
|
|
# Get current prices
|
|
price_service = PriceService(db)
|
|
tickers = [h.ticker for h in portfolio.holdings]
|
|
prices = price_service.get_current_prices(tickers)
|
|
|
|
# Calculate total value
|
|
total_value = Decimal("0")
|
|
holding_values = []
|
|
|
|
for holding in portfolio.holdings:
|
|
price = prices.get(holding.ticker, Decimal("0"))
|
|
value = price * holding.quantity
|
|
total_value += value
|
|
holding_values.append({
|
|
"ticker": holding.ticker,
|
|
"quantity": holding.quantity,
|
|
"price": price,
|
|
"value": value,
|
|
})
|
|
|
|
# Create snapshot
|
|
snapshot = PortfolioSnapshot(
|
|
portfolio_id=portfolio_id,
|
|
total_value=total_value,
|
|
snapshot_date=date.today(),
|
|
)
|
|
db.add(snapshot)
|
|
db.flush() # Get snapshot ID
|
|
|
|
# Create snapshot holdings
|
|
for hv in holding_values:
|
|
ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0")
|
|
snapshot_holding = SnapshotHolding(
|
|
snapshot_id=snapshot.id,
|
|
ticker=hv["ticker"],
|
|
quantity=hv["quantity"],
|
|
price=hv["price"],
|
|
value=hv["value"],
|
|
current_ratio=ratio.quantize(Decimal("0.01")),
|
|
)
|
|
db.add(snapshot_holding)
|
|
|
|
db.commit()
|
|
db.refresh(snapshot)
|
|
|
|
# Get stock names
|
|
name_service = RebalanceService(db)
|
|
names = name_service.get_stock_names(tickers)
|
|
|
|
return SnapshotResponse(
|
|
id=snapshot.id,
|
|
portfolio_id=snapshot.portfolio_id,
|
|
total_value=snapshot.total_value,
|
|
snapshot_date=snapshot.snapshot_date,
|
|
holdings=[
|
|
SnapshotHoldingResponse(
|
|
ticker=h.ticker,
|
|
name=names.get(h.ticker),
|
|
quantity=h.quantity,
|
|
price=h.price,
|
|
value=h.value,
|
|
current_ratio=h.current_ratio,
|
|
)
|
|
for h in snapshot.holdings
|
|
],
|
|
)
|
|
|
|
|
|
@router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse)
|
|
async def get_snapshot(
|
|
portfolio_id: int,
|
|
snapshot_id: int,
|
|
current_user: CurrentUser,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Get a specific snapshot with holdings."""
|
|
_get_portfolio(db, portfolio_id, current_user.id)
|
|
|
|
snapshot = (
|
|
db.query(PortfolioSnapshot)
|
|
.filter(
|
|
PortfolioSnapshot.id == snapshot_id,
|
|
PortfolioSnapshot.portfolio_id == portfolio_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
|
|
# Get stock names
|
|
tickers = [h.ticker for h in snapshot.holdings]
|
|
name_service = RebalanceService(db)
|
|
names = name_service.get_stock_names(tickers)
|
|
|
|
return SnapshotResponse(
|
|
id=snapshot.id,
|
|
portfolio_id=snapshot.portfolio_id,
|
|
total_value=snapshot.total_value,
|
|
snapshot_date=snapshot.snapshot_date,
|
|
holdings=[
|
|
SnapshotHoldingResponse(
|
|
ticker=h.ticker,
|
|
name=names.get(h.ticker),
|
|
quantity=h.quantity,
|
|
price=h.price,
|
|
value=h.value,
|
|
current_ratio=h.current_ratio,
|
|
)
|
|
for h in snapshot.holdings
|
|
],
|
|
)
|
|
|
|
|
|
@router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_snapshot(
|
|
portfolio_id: int,
|
|
snapshot_id: int,
|
|
current_user: CurrentUser,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Delete a snapshot."""
|
|
_get_portfolio(db, portfolio_id, current_user.id)
|
|
|
|
snapshot = (
|
|
db.query(PortfolioSnapshot)
|
|
.filter(
|
|
PortfolioSnapshot.id == snapshot_id,
|
|
PortfolioSnapshot.portfolio_id == portfolio_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
|
|
# Delete holdings first (cascade should handle this, but being explicit)
|
|
db.query(SnapshotHolding).filter(
|
|
SnapshotHolding.snapshot_id == snapshot_id
|
|
).delete()
|
|
|
|
db.delete(snapshot)
|
|
db.commit()
|
|
|
|
return None
|
|
|
|
|
|
@router.get("/{portfolio_id}/returns", response_model=ReturnsResponse)
|
|
async def get_returns(
|
|
portfolio_id: int,
|
|
current_user: CurrentUser,
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""Get portfolio returns over time based on snapshots."""
|
|
_get_portfolio(db, portfolio_id, current_user.id)
|
|
|
|
snapshots = (
|
|
db.query(PortfolioSnapshot)
|
|
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
|
|
.order_by(PortfolioSnapshot.snapshot_date)
|
|
.all()
|
|
)
|
|
|
|
if not snapshots:
|
|
return ReturnsResponse(
|
|
portfolio_id=portfolio_id,
|
|
start_date=None,
|
|
end_date=None,
|
|
total_return=None,
|
|
cagr=None,
|
|
data=[],
|
|
)
|
|
|
|
# Calculate returns
|
|
data_points = []
|
|
first_value = Decimal(str(snapshots[0].total_value))
|
|
prev_value = first_value
|
|
|
|
for snapshot in snapshots:
|
|
current_value = Decimal(str(snapshot.total_value))
|
|
|
|
# Daily return (vs previous snapshot)
|
|
if prev_value > 0:
|
|
daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01"))
|
|
else:
|
|
daily_return = Decimal("0")
|
|
|
|
# Cumulative return (vs first snapshot)
|
|
if first_value > 0:
|
|
cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
|
|
else:
|
|
cumulative_return = Decimal("0")
|
|
|
|
data_points.append(ReturnDataPoint(
|
|
date=snapshot.snapshot_date,
|
|
total_value=current_value,
|
|
daily_return=daily_return,
|
|
cumulative_return=cumulative_return,
|
|
))
|
|
|
|
prev_value = current_value
|
|
|
|
# Calculate total return and CAGR
|
|
start_date = snapshots[0].snapshot_date
|
|
end_date = snapshots[-1].snapshot_date
|
|
last_value = Decimal(str(snapshots[-1].total_value))
|
|
|
|
total_return = None
|
|
cagr = None
|
|
|
|
if first_value > 0:
|
|
total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
|
|
|
|
# CAGR calculation
|
|
days = (end_date - start_date).days
|
|
if days > 0:
|
|
years = Decimal(str(days)) / Decimal("365")
|
|
if years > 0:
|
|
ratio = last_value / first_value
|
|
# CAGR = (ending/beginning)^(1/years) - 1
|
|
cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100
|
|
cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01"))
|
|
|
|
return ReturnsResponse(
|
|
portfolio_id=portfolio_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
total_return=total_return,
|
|
cagr=cagr,
|
|
data=data_points,
|
|
)
|