feat: add PriceService and snapshot API endpoints
- PriceService: Mock implementation using DB prices - Snapshot schemas: SnapshotListItem, ReturnsResponse, ReturnDataPoint - Snapshot API: list, create, get, delete snapshots - Returns API: portfolio returns calculation with CAGR Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
63ffe2439e
commit
8842928363
@ -4,5 +4,14 @@ from app.api.portfolio import router as portfolio_router
|
||||
from app.api.strategy import router as strategy_router
|
||||
from app.api.market import router as market_router
|
||||
from app.api.backtest import router as backtest_router
|
||||
from app.api.snapshot import router as snapshot_router
|
||||
|
||||
__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router", "backtest_router"]
|
||||
__all__ = [
|
||||
"auth_router",
|
||||
"admin_router",
|
||||
"portfolio_router",
|
||||
"strategy_router",
|
||||
"market_router",
|
||||
"backtest_router",
|
||||
"snapshot_router",
|
||||
]
|
||||
|
||||
287
backend/app/api/snapshot.py
Normal file
287
backend/app/api/snapshot.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Snapshot API endpoints for portfolio history tracking.
|
||||
"""
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import CurrentUser
|
||||
from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding
|
||||
from app.schemas.portfolio import (
|
||||
SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse,
|
||||
ReturnsResponse, ReturnDataPoint,
|
||||
)
|
||||
from app.services.price_service import PriceService
|
||||
|
||||
router = APIRouter(prefix="/api/portfolios", tags=["snapshots"])
|
||||
|
||||
|
||||
def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
|
||||
"""Helper to get portfolio with ownership check."""
|
||||
portfolio = db.query(Portfolio).filter(
|
||||
Portfolio.id == portfolio_id,
|
||||
Portfolio.user_id == user_id,
|
||||
).first()
|
||||
if not portfolio:
|
||||
raise HTTPException(status_code=404, detail="Portfolio not found")
|
||||
return portfolio
|
||||
|
||||
|
||||
@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem])
|
||||
async def list_snapshots(
|
||||
portfolio_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all snapshots for a portfolio."""
|
||||
_get_portfolio(db, portfolio_id, current_user.id)
|
||||
|
||||
snapshots = (
|
||||
db.query(PortfolioSnapshot)
|
||||
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
|
||||
.order_by(PortfolioSnapshot.snapshot_date.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return snapshots
|
||||
|
||||
|
||||
@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_snapshot(
|
||||
portfolio_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a snapshot of current portfolio state."""
|
||||
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
|
||||
|
||||
if not portfolio.holdings:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot create snapshot for empty portfolio"
|
||||
)
|
||||
|
||||
# Get current prices
|
||||
price_service = PriceService(db)
|
||||
tickers = [h.ticker for h in portfolio.holdings]
|
||||
prices = price_service.get_current_prices(tickers)
|
||||
|
||||
# Calculate total value
|
||||
total_value = Decimal("0")
|
||||
holding_values = []
|
||||
|
||||
for holding in portfolio.holdings:
|
||||
price = prices.get(holding.ticker, Decimal("0"))
|
||||
value = price * holding.quantity
|
||||
total_value += value
|
||||
holding_values.append({
|
||||
"ticker": holding.ticker,
|
||||
"quantity": holding.quantity,
|
||||
"price": price,
|
||||
"value": value,
|
||||
})
|
||||
|
||||
# Create snapshot
|
||||
snapshot = PortfolioSnapshot(
|
||||
portfolio_id=portfolio_id,
|
||||
total_value=total_value,
|
||||
snapshot_date=date.today(),
|
||||
)
|
||||
db.add(snapshot)
|
||||
db.flush() # Get snapshot ID
|
||||
|
||||
# Create snapshot holdings
|
||||
for hv in holding_values:
|
||||
ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0")
|
||||
snapshot_holding = SnapshotHolding(
|
||||
snapshot_id=snapshot.id,
|
||||
ticker=hv["ticker"],
|
||||
quantity=hv["quantity"],
|
||||
price=hv["price"],
|
||||
value=hv["value"],
|
||||
current_ratio=ratio.quantize(Decimal("0.01")),
|
||||
)
|
||||
db.add(snapshot_holding)
|
||||
|
||||
db.commit()
|
||||
db.refresh(snapshot)
|
||||
|
||||
return SnapshotResponse(
|
||||
id=snapshot.id,
|
||||
portfolio_id=snapshot.portfolio_id,
|
||||
total_value=snapshot.total_value,
|
||||
snapshot_date=snapshot.snapshot_date,
|
||||
holdings=[
|
||||
SnapshotHoldingResponse(
|
||||
ticker=h.ticker,
|
||||
quantity=h.quantity,
|
||||
price=h.price,
|
||||
value=h.value,
|
||||
current_ratio=h.current_ratio,
|
||||
)
|
||||
for h in snapshot.holdings
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse)
|
||||
async def get_snapshot(
|
||||
portfolio_id: int,
|
||||
snapshot_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a specific snapshot with holdings."""
|
||||
_get_portfolio(db, portfolio_id, current_user.id)
|
||||
|
||||
snapshot = (
|
||||
db.query(PortfolioSnapshot)
|
||||
.filter(
|
||||
PortfolioSnapshot.id == snapshot_id,
|
||||
PortfolioSnapshot.portfolio_id == portfolio_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
return SnapshotResponse(
|
||||
id=snapshot.id,
|
||||
portfolio_id=snapshot.portfolio_id,
|
||||
total_value=snapshot.total_value,
|
||||
snapshot_date=snapshot.snapshot_date,
|
||||
holdings=[
|
||||
SnapshotHoldingResponse(
|
||||
ticker=h.ticker,
|
||||
quantity=h.quantity,
|
||||
price=h.price,
|
||||
value=h.value,
|
||||
current_ratio=h.current_ratio,
|
||||
)
|
||||
for h in snapshot.holdings
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_snapshot(
|
||||
portfolio_id: int,
|
||||
snapshot_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a snapshot."""
|
||||
_get_portfolio(db, portfolio_id, current_user.id)
|
||||
|
||||
snapshot = (
|
||||
db.query(PortfolioSnapshot)
|
||||
.filter(
|
||||
PortfolioSnapshot.id == snapshot_id,
|
||||
PortfolioSnapshot.portfolio_id == portfolio_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
# Delete holdings first (cascade should handle this, but being explicit)
|
||||
db.query(SnapshotHolding).filter(
|
||||
SnapshotHolding.snapshot_id == snapshot_id
|
||||
).delete()
|
||||
|
||||
db.delete(snapshot)
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{portfolio_id}/returns", response_model=ReturnsResponse)
|
||||
async def get_returns(
|
||||
portfolio_id: int,
|
||||
current_user: CurrentUser,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get portfolio returns over time based on snapshots."""
|
||||
_get_portfolio(db, portfolio_id, current_user.id)
|
||||
|
||||
snapshots = (
|
||||
db.query(PortfolioSnapshot)
|
||||
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
|
||||
.order_by(PortfolioSnapshot.snapshot_date)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not snapshots:
|
||||
return ReturnsResponse(
|
||||
portfolio_id=portfolio_id,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
total_return=None,
|
||||
cagr=None,
|
||||
data=[],
|
||||
)
|
||||
|
||||
# Calculate returns
|
||||
data_points = []
|
||||
first_value = Decimal(str(snapshots[0].total_value))
|
||||
prev_value = first_value
|
||||
|
||||
for snapshot in snapshots:
|
||||
current_value = Decimal(str(snapshot.total_value))
|
||||
|
||||
# Daily return (vs previous snapshot)
|
||||
if prev_value > 0:
|
||||
daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01"))
|
||||
else:
|
||||
daily_return = Decimal("0")
|
||||
|
||||
# Cumulative return (vs first snapshot)
|
||||
if first_value > 0:
|
||||
cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
|
||||
else:
|
||||
cumulative_return = Decimal("0")
|
||||
|
||||
data_points.append(ReturnDataPoint(
|
||||
date=snapshot.snapshot_date,
|
||||
total_value=current_value,
|
||||
daily_return=daily_return,
|
||||
cumulative_return=cumulative_return,
|
||||
))
|
||||
|
||||
prev_value = current_value
|
||||
|
||||
# Calculate total return and CAGR
|
||||
start_date = snapshots[0].snapshot_date
|
||||
end_date = snapshots[-1].snapshot_date
|
||||
last_value = Decimal(str(snapshots[-1].total_value))
|
||||
|
||||
total_return = None
|
||||
cagr = None
|
||||
|
||||
if first_value > 0:
|
||||
total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
|
||||
|
||||
# CAGR calculation
|
||||
days = (end_date - start_date).days
|
||||
if days > 0:
|
||||
years = Decimal(str(days)) / Decimal("365")
|
||||
if years > 0:
|
||||
ratio = last_value / first_value
|
||||
# CAGR = (ending/beginning)^(1/years) - 1
|
||||
cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100
|
||||
cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01"))
|
||||
|
||||
return ReturnsResponse(
|
||||
portfolio_id=portfolio_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
total_return=total_return,
|
||||
cagr=cagr,
|
||||
data=data_points,
|
||||
)
|
||||
@ -4,7 +4,10 @@ Galaxy-PO Backend API
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router
|
||||
from app.api import (
|
||||
auth_router, admin_router, portfolio_router, strategy_router,
|
||||
market_router, backtest_router, snapshot_router,
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="Galaxy-PO API",
|
||||
@ -27,6 +30,7 @@ app.include_router(portfolio_router)
|
||||
app.include_router(strategy_router)
|
||||
app.include_router(market_router)
|
||||
app.include_router(backtest_router)
|
||||
app.include_router(snapshot_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@ -115,7 +115,19 @@ class SnapshotHoldingResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SnapshotListItem(BaseModel):
|
||||
"""Snapshot list item (without holdings)."""
|
||||
id: int
|
||||
portfolio_id: int
|
||||
total_value: Decimal
|
||||
snapshot_date: date
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SnapshotResponse(BaseModel):
|
||||
"""Snapshot detail with holdings."""
|
||||
id: int
|
||||
portfolio_id: int
|
||||
total_value: Decimal
|
||||
@ -126,6 +138,24 @@ class SnapshotResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ReturnDataPoint(BaseModel):
|
||||
"""Single data point for returns chart."""
|
||||
date: date
|
||||
total_value: Decimal
|
||||
daily_return: Decimal | None = None
|
||||
cumulative_return: Decimal | None = None
|
||||
|
||||
|
||||
class ReturnsResponse(BaseModel):
|
||||
"""Portfolio returns over time."""
|
||||
portfolio_id: int
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
total_return: Decimal | None = None
|
||||
cagr: Decimal | None = None
|
||||
data: List[ReturnDataPoint] = []
|
||||
|
||||
|
||||
# Rebalancing schemas
|
||||
class RebalanceItem(BaseModel):
|
||||
ticker: str
|
||||
|
||||
183
backend/app/services/price_service.py
Normal file
183
backend/app/services/price_service.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
Price service for fetching stock prices.
|
||||
|
||||
This is a mock implementation that uses DB data.
|
||||
Can be replaced with real OpenAPI implementation later.
|
||||
"""
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models.stock import Price
|
||||
|
||||
|
||||
class PriceData:
|
||||
"""Price data point."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ticker: str,
|
||||
date: date,
|
||||
open: Decimal,
|
||||
high: Decimal,
|
||||
low: Decimal,
|
||||
close: Decimal,
|
||||
volume: int,
|
||||
):
|
||||
self.ticker = ticker
|
||||
self.date = date
|
||||
self.open = open
|
||||
self.high = high
|
||||
self.low = low
|
||||
self.close = close
|
||||
self.volume = volume
|
||||
|
||||
|
||||
class PriceService:
|
||||
"""
|
||||
Service for fetching stock prices.
|
||||
|
||||
Current implementation uses DB data (prices table).
|
||||
Can be extended to use real-time OpenAPI in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_current_price(self, ticker: str) -> Optional[Decimal]:
|
||||
"""
|
||||
Get current price for a single ticker.
|
||||
|
||||
Returns the most recent closing price from DB.
|
||||
"""
|
||||
result = (
|
||||
self.db.query(Price.close)
|
||||
.filter(Price.ticker == ticker)
|
||||
.order_by(Price.date.desc())
|
||||
.first()
|
||||
)
|
||||
return Decimal(str(result[0])) if result else None
|
||||
|
||||
def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Get current prices for multiple tickers.
|
||||
|
||||
Returns a dict mapping ticker to most recent closing price.
|
||||
"""
|
||||
if not tickers:
|
||||
return {}
|
||||
|
||||
# Subquery to get max date for each ticker
|
||||
subquery = (
|
||||
self.db.query(
|
||||
Price.ticker,
|
||||
func.max(Price.date).label('max_date')
|
||||
)
|
||||
.filter(Price.ticker.in_(tickers))
|
||||
.group_by(Price.ticker)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Get prices at max date
|
||||
results = (
|
||||
self.db.query(Price.ticker, Price.close)
|
||||
.join(
|
||||
subquery,
|
||||
(Price.ticker == subquery.c.ticker) &
|
||||
(Price.date == subquery.c.max_date)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {ticker: Decimal(str(close)) for ticker, close in results}
|
||||
|
||||
def get_price_history(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> List[PriceData]:
|
||||
"""
|
||||
Get price history for a ticker within date range.
|
||||
"""
|
||||
results = (
|
||||
self.db.query(Price)
|
||||
.filter(
|
||||
Price.ticker == ticker,
|
||||
Price.date >= start_date,
|
||||
Price.date <= end_date,
|
||||
)
|
||||
.order_by(Price.date)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
PriceData(
|
||||
ticker=p.ticker,
|
||||
date=p.date,
|
||||
open=Decimal(str(p.open)),
|
||||
high=Decimal(str(p.high)),
|
||||
low=Decimal(str(p.low)),
|
||||
close=Decimal(str(p.close)),
|
||||
volume=p.volume,
|
||||
)
|
||||
for p in results
|
||||
]
|
||||
|
||||
def get_price_at_date(self, ticker: str, target_date: date) -> Optional[Decimal]:
|
||||
"""
|
||||
Get closing price at specific date.
|
||||
|
||||
If no price on exact date, returns most recent price before that date.
|
||||
"""
|
||||
result = (
|
||||
self.db.query(Price.close)
|
||||
.filter(
|
||||
Price.ticker == ticker,
|
||||
Price.date <= target_date,
|
||||
)
|
||||
.order_by(Price.date.desc())
|
||||
.first()
|
||||
)
|
||||
return Decimal(str(result[0])) if result else None
|
||||
|
||||
def get_prices_at_date(
|
||||
self,
|
||||
tickers: List[str],
|
||||
target_date: date,
|
||||
) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Get closing prices for multiple tickers at specific date.
|
||||
"""
|
||||
if not tickers:
|
||||
return {}
|
||||
|
||||
# Subquery to get max date <= target_date for each ticker
|
||||
subquery = (
|
||||
self.db.query(
|
||||
Price.ticker,
|
||||
func.max(Price.date).label('max_date')
|
||||
)
|
||||
.filter(
|
||||
Price.ticker.in_(tickers),
|
||||
Price.date <= target_date,
|
||||
)
|
||||
.group_by(Price.ticker)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Get prices at those dates
|
||||
results = (
|
||||
self.db.query(Price.ticker, Price.close)
|
||||
.join(
|
||||
subquery,
|
||||
(Price.ticker == subquery.c.ticker) &
|
||||
(Price.date == subquery.c.max_date)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {ticker: Decimal(str(close)) for ticker, close in results}
|
||||
Loading…
x
Reference in New Issue
Block a user