From 884292836319289ca75187b940ce0c7d51a6d736 Mon Sep 17 00:00:00 2001 From: zephyrdark Date: Tue, 3 Feb 2026 12:23:56 +0900 Subject: [PATCH] feat: add PriceService and snapshot API endpoints - PriceService: Mock implementation using DB prices - Snapshot schemas: SnapshotListItem, ReturnsResponse, ReturnDataPoint - Snapshot API: list, create, get, delete snapshots - Returns API: portfolio returns calculation with CAGR Co-Authored-By: Claude Opus 4.5 --- backend/app/api/__init__.py | 11 +- backend/app/api/snapshot.py | 287 ++++++++++++++++++++++++++ backend/app/main.py | 6 +- backend/app/schemas/portfolio.py | 30 +++ backend/app/services/price_service.py | 183 ++++++++++++++++ 5 files changed, 515 insertions(+), 2 deletions(-) create mode 100644 backend/app/api/snapshot.py create mode 100644 backend/app/services/price_service.py diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 4d98935..b2d2274 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -4,5 +4,14 @@ from app.api.portfolio import router as portfolio_router from app.api.strategy import router as strategy_router from app.api.market import router as market_router from app.api.backtest import router as backtest_router +from app.api.snapshot import router as snapshot_router -__all__ = ["auth_router", "admin_router", "portfolio_router", "strategy_router", "market_router", "backtest_router"] +__all__ = [ + "auth_router", + "admin_router", + "portfolio_router", + "strategy_router", + "market_router", + "backtest_router", + "snapshot_router", +] diff --git a/backend/app/api/snapshot.py b/backend/app/api/snapshot.py new file mode 100644 index 0000000..a9d57b3 --- /dev/null +++ b/backend/app/api/snapshot.py @@ -0,0 +1,287 @@ +""" +Snapshot API endpoints for portfolio history tracking. +""" +from datetime import date +from decimal import Decimal +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.api.deps import CurrentUser +from app.models.portfolio import Portfolio, PortfolioSnapshot, SnapshotHolding +from app.schemas.portfolio import ( + SnapshotListItem, SnapshotResponse, SnapshotHoldingResponse, + ReturnsResponse, ReturnDataPoint, +) +from app.services.price_service import PriceService + +router = APIRouter(prefix="/api/portfolios", tags=["snapshots"]) + + +def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio: + """Helper to get portfolio with ownership check.""" + portfolio = db.query(Portfolio).filter( + Portfolio.id == portfolio_id, + Portfolio.user_id == user_id, + ).first() + if not portfolio: + raise HTTPException(status_code=404, detail="Portfolio not found") + return portfolio + + +@router.get("/{portfolio_id}/snapshots", response_model=List[SnapshotListItem]) +async def list_snapshots( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get all snapshots for a portfolio.""" + _get_portfolio(db, portfolio_id, current_user.id) + + snapshots = ( + db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date.desc()) + .all() + ) + + return snapshots + + +@router.post("/{portfolio_id}/snapshots", response_model=SnapshotResponse, status_code=status.HTTP_201_CREATED) +async def create_snapshot( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Create a snapshot of current portfolio state.""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + + if not portfolio.holdings: + raise HTTPException( + status_code=400, + detail="Cannot create snapshot for empty portfolio" + ) + + # Get current prices + price_service = PriceService(db) + tickers = [h.ticker for h in portfolio.holdings] + prices = price_service.get_current_prices(tickers) + + # Calculate total value + total_value = Decimal("0") + holding_values = [] + + for holding in portfolio.holdings: + price = prices.get(holding.ticker, Decimal("0")) + value = price * holding.quantity + total_value += value + holding_values.append({ + "ticker": holding.ticker, + "quantity": holding.quantity, + "price": price, + "value": value, + }) + + # Create snapshot + snapshot = PortfolioSnapshot( + portfolio_id=portfolio_id, + total_value=total_value, + snapshot_date=date.today(), + ) + db.add(snapshot) + db.flush() # Get snapshot ID + + # Create snapshot holdings + for hv in holding_values: + ratio = (hv["value"] / total_value * 100) if total_value > 0 else Decimal("0") + snapshot_holding = SnapshotHolding( + snapshot_id=snapshot.id, + ticker=hv["ticker"], + quantity=hv["quantity"], + price=hv["price"], + value=hv["value"], + current_ratio=ratio.quantize(Decimal("0.01")), + ) + db.add(snapshot_holding) + + db.commit() + db.refresh(snapshot) + + return SnapshotResponse( + id=snapshot.id, + portfolio_id=snapshot.portfolio_id, + total_value=snapshot.total_value, + snapshot_date=snapshot.snapshot_date, + holdings=[ + SnapshotHoldingResponse( + ticker=h.ticker, + quantity=h.quantity, + price=h.price, + value=h.value, + current_ratio=h.current_ratio, + ) + for h in snapshot.holdings + ], + ) + + +@router.get("/{portfolio_id}/snapshots/{snapshot_id}", response_model=SnapshotResponse) +async def get_snapshot( + portfolio_id: int, + snapshot_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get a specific snapshot with holdings.""" + _get_portfolio(db, portfolio_id, current_user.id) + + snapshot = ( + db.query(PortfolioSnapshot) + .filter( + PortfolioSnapshot.id == snapshot_id, + PortfolioSnapshot.portfolio_id == portfolio_id, + ) + .first() + ) + + if not snapshot: + raise HTTPException(status_code=404, detail="Snapshot not found") + + return SnapshotResponse( + id=snapshot.id, + portfolio_id=snapshot.portfolio_id, + total_value=snapshot.total_value, + snapshot_date=snapshot.snapshot_date, + holdings=[ + SnapshotHoldingResponse( + ticker=h.ticker, + quantity=h.quantity, + price=h.price, + value=h.value, + current_ratio=h.current_ratio, + ) + for h in snapshot.holdings + ], + ) + + +@router.delete("/{portfolio_id}/snapshots/{snapshot_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_snapshot( + portfolio_id: int, + snapshot_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Delete a snapshot.""" + _get_portfolio(db, portfolio_id, current_user.id) + + snapshot = ( + db.query(PortfolioSnapshot) + .filter( + PortfolioSnapshot.id == snapshot_id, + PortfolioSnapshot.portfolio_id == portfolio_id, + ) + .first() + ) + + if not snapshot: + raise HTTPException(status_code=404, detail="Snapshot not found") + + # Delete holdings first (cascade should handle this, but being explicit) + db.query(SnapshotHolding).filter( + SnapshotHolding.snapshot_id == snapshot_id + ).delete() + + db.delete(snapshot) + db.commit() + + return None + + +@router.get("/{portfolio_id}/returns", response_model=ReturnsResponse) +async def get_returns( + portfolio_id: int, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Get portfolio returns over time based on snapshots.""" + _get_portfolio(db, portfolio_id, current_user.id) + + snapshots = ( + db.query(PortfolioSnapshot) + .filter(PortfolioSnapshot.portfolio_id == portfolio_id) + .order_by(PortfolioSnapshot.snapshot_date) + .all() + ) + + if not snapshots: + return ReturnsResponse( + portfolio_id=portfolio_id, + start_date=None, + end_date=None, + total_return=None, + cagr=None, + data=[], + ) + + # Calculate returns + data_points = [] + first_value = Decimal(str(snapshots[0].total_value)) + prev_value = first_value + + for snapshot in snapshots: + current_value = Decimal(str(snapshot.total_value)) + + # Daily return (vs previous snapshot) + if prev_value > 0: + daily_return = ((current_value - prev_value) / prev_value * 100).quantize(Decimal("0.01")) + else: + daily_return = Decimal("0") + + # Cumulative return (vs first snapshot) + if first_value > 0: + cumulative_return = ((current_value - first_value) / first_value * 100).quantize(Decimal("0.01")) + else: + cumulative_return = Decimal("0") + + data_points.append(ReturnDataPoint( + date=snapshot.snapshot_date, + total_value=current_value, + daily_return=daily_return, + cumulative_return=cumulative_return, + )) + + prev_value = current_value + + # Calculate total return and CAGR + start_date = snapshots[0].snapshot_date + end_date = snapshots[-1].snapshot_date + last_value = Decimal(str(snapshots[-1].total_value)) + + total_return = None + cagr = None + + if first_value > 0: + total_return = ((last_value - first_value) / first_value * 100).quantize(Decimal("0.01")) + + # CAGR calculation + days = (end_date - start_date).days + if days > 0: + years = Decimal(str(days)) / Decimal("365") + if years > 0: + ratio = last_value / first_value + # CAGR = (ending/beginning)^(1/years) - 1 + cagr_value = (float(ratio) ** (1 / float(years)) - 1) * 100 + cagr = Decimal(str(cagr_value)).quantize(Decimal("0.01")) + + return ReturnsResponse( + portfolio_id=portfolio_id, + start_date=start_date, + end_date=end_date, + total_return=total_return, + cagr=cagr, + data=data_points, + ) diff --git a/backend/app/main.py b/backend/app/main.py index bd1f959..6983bed 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,7 +4,10 @@ Galaxy-PO Backend API from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api import auth_router, admin_router, portfolio_router, strategy_router, market_router, backtest_router +from app.api import ( + auth_router, admin_router, portfolio_router, strategy_router, + market_router, backtest_router, snapshot_router, +) app = FastAPI( title="Galaxy-PO API", @@ -27,6 +30,7 @@ app.include_router(portfolio_router) app.include_router(strategy_router) app.include_router(market_router) app.include_router(backtest_router) +app.include_router(snapshot_router) @app.get("/health") diff --git a/backend/app/schemas/portfolio.py b/backend/app/schemas/portfolio.py index d166fee..db6a4d3 100644 --- a/backend/app/schemas/portfolio.py +++ b/backend/app/schemas/portfolio.py @@ -115,7 +115,19 @@ class SnapshotHoldingResponse(BaseModel): from_attributes = True +class SnapshotListItem(BaseModel): + """Snapshot list item (without holdings).""" + id: int + portfolio_id: int + total_value: Decimal + snapshot_date: date + + class Config: + from_attributes = True + + class SnapshotResponse(BaseModel): + """Snapshot detail with holdings.""" id: int portfolio_id: int total_value: Decimal @@ -126,6 +138,24 @@ class SnapshotResponse(BaseModel): from_attributes = True +class ReturnDataPoint(BaseModel): + """Single data point for returns chart.""" + date: date + total_value: Decimal + daily_return: Decimal | None = None + cumulative_return: Decimal | None = None + + +class ReturnsResponse(BaseModel): + """Portfolio returns over time.""" + portfolio_id: int + start_date: date | None = None + end_date: date | None = None + total_return: Decimal | None = None + cagr: Decimal | None = None + data: List[ReturnDataPoint] = [] + + # Rebalancing schemas class RebalanceItem(BaseModel): ticker: str diff --git a/backend/app/services/price_service.py b/backend/app/services/price_service.py new file mode 100644 index 0000000..6270a64 --- /dev/null +++ b/backend/app/services/price_service.py @@ -0,0 +1,183 @@ +""" +Price service for fetching stock prices. + +This is a mock implementation that uses DB data. +Can be replaced with real OpenAPI implementation later. +""" +from datetime import date, timedelta +from decimal import Decimal +from typing import Dict, List, Optional + +from sqlalchemy.orm import Session +from sqlalchemy import func + +from app.models.stock import Price + + +class PriceData: + """Price data point.""" + + def __init__( + self, + ticker: str, + date: date, + open: Decimal, + high: Decimal, + low: Decimal, + close: Decimal, + volume: int, + ): + self.ticker = ticker + self.date = date + self.open = open + self.high = high + self.low = low + self.close = close + self.volume = volume + + +class PriceService: + """ + Service for fetching stock prices. + + Current implementation uses DB data (prices table). + Can be extended to use real-time OpenAPI in the future. + """ + + def __init__(self, db: Session): + self.db = db + + def get_current_price(self, ticker: str) -> Optional[Decimal]: + """ + Get current price for a single ticker. + + Returns the most recent closing price from DB. + """ + result = ( + self.db.query(Price.close) + .filter(Price.ticker == ticker) + .order_by(Price.date.desc()) + .first() + ) + return Decimal(str(result[0])) if result else None + + def get_current_prices(self, tickers: List[str]) -> Dict[str, Decimal]: + """ + Get current prices for multiple tickers. + + Returns a dict mapping ticker to most recent closing price. + """ + if not tickers: + return {} + + # Subquery to get max date for each ticker + subquery = ( + self.db.query( + Price.ticker, + func.max(Price.date).label('max_date') + ) + .filter(Price.ticker.in_(tickers)) + .group_by(Price.ticker) + .subquery() + ) + + # Get prices at max date + results = ( + self.db.query(Price.ticker, Price.close) + .join( + subquery, + (Price.ticker == subquery.c.ticker) & + (Price.date == subquery.c.max_date) + ) + .all() + ) + + return {ticker: Decimal(str(close)) for ticker, close in results} + + def get_price_history( + self, + ticker: str, + start_date: date, + end_date: date, + ) -> List[PriceData]: + """ + Get price history for a ticker within date range. + """ + results = ( + self.db.query(Price) + .filter( + Price.ticker == ticker, + Price.date >= start_date, + Price.date <= end_date, + ) + .order_by(Price.date) + .all() + ) + + return [ + PriceData( + ticker=p.ticker, + date=p.date, + open=Decimal(str(p.open)), + high=Decimal(str(p.high)), + low=Decimal(str(p.low)), + close=Decimal(str(p.close)), + volume=p.volume, + ) + for p in results + ] + + def get_price_at_date(self, ticker: str, target_date: date) -> Optional[Decimal]: + """ + Get closing price at specific date. + + If no price on exact date, returns most recent price before that date. + """ + result = ( + self.db.query(Price.close) + .filter( + Price.ticker == ticker, + Price.date <= target_date, + ) + .order_by(Price.date.desc()) + .first() + ) + return Decimal(str(result[0])) if result else None + + def get_prices_at_date( + self, + tickers: List[str], + target_date: date, + ) -> Dict[str, Decimal]: + """ + Get closing prices for multiple tickers at specific date. + """ + if not tickers: + return {} + + # Subquery to get max date <= target_date for each ticker + subquery = ( + self.db.query( + Price.ticker, + func.max(Price.date).label('max_date') + ) + .filter( + Price.ticker.in_(tickers), + Price.date <= target_date, + ) + .group_by(Price.ticker) + .subquery() + ) + + # Get prices at those dates + results = ( + self.db.query(Price.ticker, Price.close) + .join( + subquery, + (Price.ticker == subquery.c.ticker) & + (Price.date == subquery.c.max_date) + ) + .all() + ) + + return {ticker: Decimal(str(close)) for ticker, close in results}