galaxis-po/backend/app/api/snapshot.py

331 lines
10 KiB
Python
Raw Normal View History

"""
Snapshot API endpoints for portfolio history tracking.
"""
from datetime import date
from decimal import Decimal
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding
from app.models.stock import ETFPrice
from app.schemas.portfolio import (
SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse,
ReturnsResponse, ReturnDataPoint,
)
from app.services.price_service import PriceService
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/portfolios", tags=["snapshots"])
def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
"""Helper to get portfolio with ownership check."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == user_id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
return portfolio
@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem])
async def list_snapshots(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all snapshots for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.all()
)
return snapshots
@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED)
async def create_snapshot(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create a snapshot of current portfolio state."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
if not portfolio.holdings:
raise HTTPException(
status_code=400,
detail="Cannot create snapshot for empty portfolio"
)
# Get current prices
price_service = PriceService(db)
tickers = [h.ticker for h in portfolio.holdings]
prices = price_service.get_current_prices(tickers)
# Calculate total value
total_value = Decimal("0")
holding_values = []
for holding in portfolio.holdings:
price = prices.get(holding.ticker, Decimal("0"))
value = price * holding.quantity
total_value += value
holding_values.append({
"ticker": holding.ticker,
"quantity": holding.quantity,
"price": price,
"value": value,
})
# Create snapshot
snapshot = PortfolioSnapshot(
portfolio_id=portfolio_id,
total_value=total_value,
snapshot_date=date.today(),
)
db.add(snapshot)
db.flush() # Get snapshot ID
# Create snapshot holdings
for hv in holding_values:
ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0")
snapshot_holding = SnapshotHolding(
snapshot_id=snapshot.id,
ticker=hv["ticker"],
quantity=hv["quantity"],
price=hv["price"],
value=hv["value"],
current_ratio=ratio.quantize(Decimal("0.01")),
)
db.add(snapshot_holding)
db.commit()
db.refresh(snapshot)
# Get stock names
name_service = RebalanceService(db)
names = name_service.get_stock_names(tickers)
return SnapshotResponse(
id=snapshot.id,
portfolio_id=snapshot.portfolio_id,
total_value=snapshot.total_value,
snapshot_date=snapshot.snapshot_date,
holdings=[
SnapshotHoldingResponse(
ticker=h.ticker,
name=names.get(h.ticker),
quantity=h.quantity,
price=h.price,
value=h.value,
current_ratio=h.current_ratio,
)
for h in snapshot.holdings
],
)
@router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse)
async def get_snapshot(
portfolio_id: int,
snapshot_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get a specific snapshot with holdings."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshot = (
db.query(PortfolioSnapshot)
.filter(
PortfolioSnapshot.id == snapshot_id,
PortfolioSnapshot.portfolio_id == portfolio_id,
)
.first()
)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
# Get stock names
tickers = [h.ticker for h in snapshot.holdings]
name_service = RebalanceService(db)
names = name_service.get_stock_names(tickers)
return SnapshotResponse(
id=snapshot.id,
portfolio_id=snapshot.portfolio_id,
total_value=snapshot.total_value,
snapshot_date=snapshot.snapshot_date,
holdings=[
SnapshotHoldingResponse(
ticker=h.ticker,
name=names.get(h.ticker),
quantity=h.quantity,
price=h.price,
value=h.value,
current_ratio=h.current_ratio,
)
for h in snapshot.holdings
],
)
@router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_snapshot(
portfolio_id: int,
snapshot_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a snapshot."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshot = (
db.query(PortfolioSnapshot)
.filter(
PortfolioSnapshot.id == snapshot_id,
PortfolioSnapshot.portfolio_id == portfolio_id,
)
.first()
)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
# Delete holdings first (cascade should handle this, but being explicit)
db.query(SnapshotHolding).filter(
SnapshotHolding.snapshot_id == snapshot_id
).delete()
db.delete(snapshot)
db.commit()
return None
@router.get("/{portfolio_id}/returns", response_model=ReturnsResponse)
async def get_returns(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get portfolio returns over time based on snapshots."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date)
.all()
)
if not snapshots:
return ReturnsResponse(
portfolio_id=portfolio_id,
start_date=None,
end_date=None,
total_return=None,
cagr=None,
data=[],
)
# Get benchmark (KOSPI ETF 069500) prices for the same date range
snapshot_dates = [s.snapshot_date for s in snapshots]
benchmark_ticker = "069500" # KODEX 200 (KOSPI benchmark)
benchmark_prices = (
db.query(ETFPrice)
.filter(
ETFPrice.ticker == benchmark_ticker,
ETFPrice.date.in_(snapshot_dates),
)
.all()
)
benchmark_map = {bp.date: Decimal(str(bp.close)) for bp in benchmark_prices}
# Get first benchmark price for cumulative calculation
first_benchmark = benchmark_map.get(snapshots[0].snapshot_date)
# Calculate returns
data_points = []
first_value = Decimal(str(snapshots[0].total_value))
prev_value = first_value
for snapshot in snapshots:
current_value = Decimal(str(snapshot.total_value))
# Daily return (vs previous snapshot)
if prev_value > 0:
daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01"))
else:
daily_return = Decimal("0")
# Cumulative return (vs first snapshot)
if first_value > 0:
cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
else:
cumulative_return = Decimal("0")
# Benchmark cumulative return
benchmark_return = None
bench_price = benchmark_map.get(snapshot.snapshot_date)
if bench_price and first_benchmark and first_benchmark > 0:
benchmark_return = ((bench_price - first_benchmark) / first_benchmark * 100).quantize(Decimal("0.01"))
data_points.append(ReturnDataPoint(
date=snapshot.snapshot_date,
total_value=current_value,
daily_return=daily_return,
cumulative_return=cumulative_return,
benchmark_return=benchmark_return,
))
prev_value = current_value
# Calculate total return and CAGR
start_date = snapshots[0].snapshot_date
end_date = snapshots[-1].snapshot_date
last_value = Decimal(str(snapshots[-1].total_value))
total_return = None
cagr = None
benchmark_total_return = None
if first_value > 0:
total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
# CAGR calculation
days = (end_date - start_date).days
if days > 0:
years = Decimal(str(days)) / Decimal("365")
if years > 0:
ratio = last_value / first_value
# CAGR = (ending/beginning)^(1/years) - 1
cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100
cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01"))
# Benchmark total return
last_benchmark = benchmark_map.get(end_date)
if first_benchmark and last_benchmark and first_benchmark > 0:
benchmark_total_return = ((last_benchmark - first_benchmark) / first_benchmark * 100).quantize(Decimal("0.01"))
return ReturnsResponse(
portfolio_id=portfolio_id,
start_date=start_date,
end_date=end_date,
total_return=total_return,
cagr=cagr,
benchmark_total_return=benchmark_total_return,
data=data_points,
)