feat: add PriceService and snapshot API endpoints

- PriceService: Mock implementation using DB prices
- Snapshot schemas: SnapshotListItem, ReturnsResponse, ReturnDataPoint
- Snapshot API: list, create, get, delete snapshots
- Returns API: portfolio returns calculation with CAGR

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
zephyrdark 2026-02-03 12:23:56 +09:00
parent 63ffe2439e
commit 8842928363
5 changed files with 515 additions and 2 deletions

View File

@ -4,5 +4,14 @@ from app.api.portfolio import router as portfolio_router
from app.api.strategy import router as strategy_router
from app.api.market import router as market_router
from app.api.backtest import router as backtest_router
from app.api.snapshot import router as snapshot_router
__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router", "backtest_router"]
__all__ = [
"auth_router",
"admin_router",
"portfolio_router",
"strategy_router",
"market_router",
"backtest_router",
"snapshot_router",
]

287
backend/app/api/snapshot.py Normal file
View File

@ -0,0 +1,287 @@
"""
Snapshot API endpoints for portfolio history tracking.
"""
from datetime import date
from decimal import Decimal
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding
from app.schemas.portfolio import (
SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse,
ReturnsResponse, ReturnDataPoint,
)
from app.services.price_service import PriceService
router = APIRouter(prefix="/api/portfolios", tags=["snapshots"])
def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
"""Helper to get portfolio with ownership check."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == user_id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
return portfolio
@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem])
async def list_snapshots(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all snapshots for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date.desc())
.all()
)
return snapshots
@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED)
async def create_snapshot(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create a snapshot of current portfolio state."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
if not portfolio.holdings:
raise HTTPException(
status_code=400,
detail="Cannot create snapshot for empty portfolio"
)
# Get current prices
price_service = PriceService(db)
tickers = [h.ticker for h in portfolio.holdings]
prices = price_service.get_current_prices(tickers)
# Calculate total value
total_value = Decimal("0")
holding_values = []
for holding in portfolio.holdings:
price = prices.get(holding.ticker, Decimal("0"))
value = price * holding.quantity
total_value += value
holding_values.append({
"ticker": holding.ticker,
"quantity": holding.quantity,
"price": price,
"value": value,
})
# Create snapshot
snapshot = PortfolioSnapshot(
portfolio_id=portfolio_id,
total_value=total_value,
snapshot_date=date.today(),
)
db.add(snapshot)
db.flush() # Get snapshot ID
# Create snapshot holdings
for hv in holding_values:
ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0")
snapshot_holding = SnapshotHolding(
snapshot_id=snapshot.id,
ticker=hv["ticker"],
quantity=hv["quantity"],
price=hv["price"],
value=hv["value"],
current_ratio=ratio.quantize(Decimal("0.01")),
)
db.add(snapshot_holding)
db.commit()
db.refresh(snapshot)
return SnapshotResponse(
id=snapshot.id,
portfolio_id=snapshot.portfolio_id,
total_value=snapshot.total_value,
snapshot_date=snapshot.snapshot_date,
holdings=[
SnapshotHoldingResponse(
ticker=h.ticker,
quantity=h.quantity,
price=h.price,
value=h.value,
current_ratio=h.current_ratio,
)
for h in snapshot.holdings
],
)
@router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse)
async def get_snapshot(
portfolio_id: int,
snapshot_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get a specific snapshot with holdings."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshot = (
db.query(PortfolioSnapshot)
.filter(
PortfolioSnapshot.id == snapshot_id,
PortfolioSnapshot.portfolio_id == portfolio_id,
)
.first()
)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return SnapshotResponse(
id=snapshot.id,
portfolio_id=snapshot.portfolio_id,
total_value=snapshot.total_value,
snapshot_date=snapshot.snapshot_date,
holdings=[
SnapshotHoldingResponse(
ticker=h.ticker,
quantity=h.quantity,
price=h.price,
value=h.value,
current_ratio=h.current_ratio,
)
for h in snapshot.holdings
],
)
@router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_snapshot(
portfolio_id: int,
snapshot_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a snapshot."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshot = (
db.query(PortfolioSnapshot)
.filter(
PortfolioSnapshot.id == snapshot_id,
PortfolioSnapshot.portfolio_id == portfolio_id,
)
.first()
)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
# Delete holdings first (cascade should handle this, but being explicit)
db.query(SnapshotHolding).filter(
SnapshotHolding.snapshot_id == snapshot_id
).delete()
db.delete(snapshot)
db.commit()
return None
@router.get("/{portfolio_id}/returns", response_model=ReturnsResponse)
async def get_returns(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get portfolio returns over time based on snapshots."""
_get_portfolio(db, portfolio_id, current_user.id)
snapshots = (
db.query(PortfolioSnapshot)
.filter(PortfolioSnapshot.portfolio_id == portfolio_id)
.order_by(PortfolioSnapshot.snapshot_date)
.all()
)
if not snapshots:
return ReturnsResponse(
portfolio_id=portfolio_id,
start_date=None,
end_date=None,
total_return=None,
cagr=None,
data=[],
)
# Calculate returns
data_points = []
first_value = Decimal(str(snapshots[0].total_value))
prev_value = first_value
for snapshot in snapshots:
current_value = Decimal(str(snapshot.total_value))
# Daily return (vs previous snapshot)
if prev_value > 0:
daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01"))
else:
daily_return = Decimal("0")
# Cumulative return (vs first snapshot)
if first_value > 0:
cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
else:
cumulative_return = Decimal("0")
data_points.append(ReturnDataPoint(
date=snapshot.snapshot_date,
total_value=current_value,
daily_return=daily_return,
cumulative_return=cumulative_return,
))
prev_value = current_value
# Calculate total return and CAGR
start_date = snapshots[0].snapshot_date
end_date = snapshots[-1].snapshot_date
last_value = Decimal(str(snapshots[-1].total_value))
total_return = None
cagr = None
if first_value > 0:
total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01"))
# CAGR calculation
days = (end_date - start_date).days
if days > 0:
years = Decimal(str(days)) / Decimal("365")
if years > 0:
ratio = last_value / first_value
# CAGR = (ending/beginning)^(1/years) - 1
cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100
cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01"))
return ReturnsResponse(
portfolio_id=portfolio_id,
start_date=start_date,
end_date=end_date,
total_return=total_return,
cagr=cagr,
data=data_points,
)

View File

@ -4,7 +4,10 @@ Galaxy-PO Backend API
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router
from app.api import (
auth_router, admin_router, portfolio_router, strategy_router,
market_router, backtest_router, snapshot_router,
)
app = FastAPI(
title="Galaxy-PO API",
@ -27,6 +30,7 @@ app.include_router(portfolio_router)
app.include_router(strategy_router)
app.include_router(market_router)
app.include_router(backtest_router)
app.include_router(snapshot_router)
@app.get("/health")

View File

@ -115,7 +115,19 @@ class SnapshotHoldingResponse(BaseModel):
from_attributes = True
class SnapshotListItem(BaseModel):
"""Snapshot list item (without holdings)."""
id: int
portfolio_id: int
total_value: Decimal
snapshot_date: date
class Config:
from_attributes = True
class SnapshotResponse(BaseModel):
"""Snapshot detail with holdings."""
id: int
portfolio_id: int
total_value: Decimal
@ -126,6 +138,24 @@ class SnapshotResponse(BaseModel):
from_attributes = True
class ReturnDataPoint(BaseModel):
"""Single data point for returns chart."""
date: date
total_value: Decimal
daily_return: Decimal | None = None
cumulative_return: Decimal | None = None
class ReturnsResponse(BaseModel):
"""Portfolio returns over time."""
portfolio_id: int
start_date: date | None = None
end_date: date | None = None
total_return: Decimal | None = None
cagr: Decimal | None = None
data: List[ReturnDataPoint] = []
# Rebalancing schemas
class RebalanceItem(BaseModel):
ticker: str

View File

@ -0,0 +1,183 @@
"""
Price service for fetching stock prices.
This is a mock implementation that uses DB data.
Can be replaced with real OpenAPI implementation later.
"""
from datetime import date, timedelta
from decimal import Decimal
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models.stock import Price
class PriceData:
"""Price data point."""
def __init__(
self,
ticker: str,
date: date,
open: Decimal,
high: Decimal,
low: Decimal,
close: Decimal,
volume: int,
):
self.ticker = ticker
self.date = date
self.open = open
self.high = high
self.low = low
self.close = close
self.volume = volume
class PriceService:
"""
Service for fetching stock prices.
Current implementation uses DB data (prices table).
Can be extended to use real-time OpenAPI in the future.
"""
def __init__(self, db: Session):
self.db = db
def get_current_price(self, ticker: str) -> Optional[Decimal]:
"""
Get current price for a single ticker.
Returns the most recent closing price from DB.
"""
result = (
self.db.query(Price.close)
.filter(Price.ticker == ticker)
.order_by(Price.date.desc())
.first()
)
return Decimal(str(result[0])) if result else None
def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]:
"""
Get current prices for multiple tickers.
Returns a dict mapping ticker to most recent closing price.
"""
if not tickers:
return {}
# Subquery to get max date for each ticker
subquery = (
self.db.query(
Price.ticker,
func.max(Price.date).label('max_date')
)
.filter(Price.ticker.in_(tickers))
.group_by(Price.ticker)
.subquery()
)
# Get prices at max date
results = (
self.db.query(Price.ticker, Price.close)
.join(
subquery,
(Price.ticker == subquery.c.ticker) &
(Price.date == subquery.c.max_date)
)
.all()
)
return {ticker: Decimal(str(close)) for ticker, close in results}
def get_price_history(
self,
ticker: str,
start_date: date,
end_date: date,
) -> List[PriceData]:
"""
Get price history for a ticker within date range.
"""
results = (
self.db.query(Price)
.filter(
Price.ticker == ticker,
Price.date >= start_date,
Price.date <= end_date,
)
.order_by(Price.date)
.all()
)
return [
PriceData(
ticker=p.ticker,
date=p.date,
open=Decimal(str(p.open)),
high=Decimal(str(p.high)),
low=Decimal(str(p.low)),
close=Decimal(str(p.close)),
volume=p.volume,
)
for p in results
]
def get_price_at_date(self, ticker: str, target_date: date) -> Optional[Decimal]:
"""
Get closing price at specific date.
If no price on exact date, returns most recent price before that date.
"""
result = (
self.db.query(Price.close)
.filter(
Price.ticker == ticker,
Price.date <= target_date,
)
.order_by(Price.date.desc())
.first()
)
return Decimal(str(result[0])) if result else None
def get_prices_at_date(
self,
tickers: List[str],
target_date: date,
) -> Dict[str, Decimal]:
"""
Get closing prices for multiple tickers at specific date.
"""
if not tickers:
return {}
# Subquery to get max date <= target_date for each ticker
subquery = (
self.db.query(
Price.ticker,
func.max(Price.date).label('max_date')
)
.filter(
Price.ticker.in_(tickers),
Price.date <= target_date,
)
.group_by(Price.ticker)
.subquery()
)
# Get prices at those dates
results = (
self.db.query(Price.ticker, Price.close)
.join(
subquery,
(Price.ticker == subquery.c.ticker) &
(Price.date == subquery.c.max_date)
)
.all()
)
return {ticker: Decimal(str(close)) for ticker, close in results}