galaxis-po/backend/app/api/portfolio.py

426 lines
13 KiB
Python

"""
Portfolio management API endpoints.
"""
from decimal import Decimal
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.deps import CurrentUser
from app.models.portfolio import Portfolio, PortfolioType, Target, Holding, Transaction, TransactionType
from app.schemas.portfolio import (
PortfolioCreate, PortfolioUpdate, PortfolioResponse, PortfolioDetail,
TargetCreate, TargetResponse,
HoldingCreate, HoldingResponse, HoldingWithValue,
TransactionCreate, TransactionResponse,
RebalanceResponse, RebalanceSimulationRequest, RebalanceSimulationResponse,
RebalanceCalculateRequest, RebalanceCalculateResponse,
)
from app.services.rebalance import RebalanceService
router = APIRouter(prefix="/api/portfolios", tags=["portfolios"])
@router.get("", response_model=List[PortfolioResponse])
async def list_portfolios(
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get all portfolios for current user."""
portfolios = (
db.query(Portfolio)
.filter(Portfolio.user_id == current_user.id)
.order_by(Portfolio.created_at.desc())
.all()
)
return portfolios
@router.post("", response_model=PortfolioResponse, status_code=status.HTTP_201_CREATED)
async def create_portfolio(
data: PortfolioCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Create a new portfolio."""
portfolio_type = PortfolioType(data.portfolio_type)
portfolio = Portfolio(
user_id=current_user.id,
name=data.name,
portfolio_type=portfolio_type,
)
db.add(portfolio)
db.commit()
db.refresh(portfolio)
return portfolio
@router.get("/{portfolio_id}", response_model=PortfolioResponse)
async def get_portfolio(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get a portfolio by ID."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == current_user.id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
return portfolio
@router.put("/{portfolio_id}", response_model=PortfolioResponse)
async def update_portfolio(
portfolio_id: int,
data: PortfolioUpdate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Update a portfolio."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == current_user.id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
if data.name is not None:
portfolio.name = data.name
if data.portfolio_type is not None:
portfolio.portfolio_type = PortfolioType(data.portfolio_type)
db.commit()
db.refresh(portfolio)
return portfolio
@router.delete("/{portfolio_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_portfolio(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Delete a portfolio."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == current_user.id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
db.delete(portfolio)
db.commit()
return None
def _get_portfolio(db: Session, portfolio_id: int, user_id: int) -> Portfolio:
"""Helper to get portfolio with ownership check."""
portfolio = db.query(Portfolio).filter(
Portfolio.id == portfolio_id,
Portfolio.user_id == user_id,
).first()
if not portfolio:
raise HTTPException(status_code=404, detail="Portfolio not found")
return portfolio
@router.get("/{portfolio_id}/targets", response_model=List[TargetResponse])
async def get_targets(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get target allocations for a portfolio."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
return portfolio.targets
@router.put("/{portfolio_id}/targets", response_model=List[TargetResponse])
async def set_targets(
portfolio_id: int,
targets: List[TargetCreate],
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Set target allocations for a portfolio (replaces all existing)."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
# Validate total ratio
total_ratio = sum(t.target_ratio for t in targets)
if total_ratio != 100:
raise HTTPException(
status_code=400,
detail=f"Target ratios must sum to 100%, got {total_ratio}%"
)
# Delete existing targets
db.query(Target).filter(Target.portfolio_id == portfolio_id).delete()
# Create new targets
new_targets = []
for t in targets:
target = Target(
portfolio_id=portfolio_id,
ticker=t.ticker,
target_ratio=t.target_ratio,
)
db.add(target)
new_targets.append(target)
db.commit()
return new_targets
@router.get("/{portfolio_id}/holdings", response_model=List[HoldingResponse])
async def get_holdings(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get holdings for a portfolio."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
return portfolio.holdings
@router.put("/{portfolio_id}/holdings", response_model=List[HoldingResponse])
async def set_holdings(
portfolio_id: int,
holdings: List[HoldingCreate],
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Set holdings for a portfolio (replaces all existing)."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
# Delete existing holdings
db.query(Holding).filter(Holding.portfolio_id == portfolio_id).delete()
# Create new holdings
new_holdings = []
for h in holdings:
holding = Holding(
portfolio_id=portfolio_id,
ticker=h.ticker,
quantity=h.quantity,
avg_price=h.avg_price,
)
db.add(holding)
new_holdings.append(holding)
db.commit()
return new_holdings
@router.get("/{portfolio_id}/transactions", response_model=List[TransactionResponse])
async def get_transactions(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
limit: int = 50,
):
"""Get transaction history for a portfolio."""
_get_portfolio(db, portfolio_id, current_user.id)
transactions = (
db.query(Transaction)
.filter(Transaction.portfolio_id == portfolio_id)
.order_by(Transaction.executed_at.desc())
.limit(limit)
.all()
)
# Resolve stock names
tickers = list({tx.ticker for tx in transactions})
service = RebalanceService(db)
names = service.get_stock_names(tickers)
return [
TransactionResponse(
id=tx.id,
ticker=tx.ticker,
name=names.get(tx.ticker),
tx_type=tx.tx_type.value,
quantity=tx.quantity,
price=tx.price,
executed_at=tx.executed_at,
memo=tx.memo,
)
for tx in transactions
]
@router.post("/{portfolio_id}/transactions", response_model=TransactionResponse, status_code=status.HTTP_201_CREATED)
async def add_transaction(
portfolio_id: int,
data: TransactionCreate,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Add a transaction and update holdings accordingly."""
_get_portfolio(db, portfolio_id, current_user.id)
tx_type = TransactionType(data.tx_type)
# Create transaction
transaction = Transaction(
portfolio_id=portfolio_id,
ticker=data.ticker,
tx_type=tx_type,
quantity=data.quantity,
price=data.price,
executed_at=data.executed_at,
memo=data.memo,
)
db.add(transaction)
# Update holding
holding = db.query(Holding).filter(
Holding.portfolio_id == portfolio_id,
Holding.ticker == data.ticker,
).first()
if tx_type == TransactionType.BUY:
if holding:
# Update average price
total_value = (holding.quantity * holding.avg_price) + (data.quantity * data.price)
new_quantity = holding.quantity + data.quantity
holding.quantity = new_quantity
holding.avg_price = total_value / new_quantity if new_quantity > 0 else 0
else:
# Create new holding
holding = Holding(
portfolio_id=portfolio_id,
ticker=data.ticker,
quantity=data.quantity,
avg_price=data.price,
)
db.add(holding)
elif tx_type == TransactionType.SELL:
if not holding or holding.quantity < data.quantity:
raise HTTPException(
status_code=400,
detail=f"Insufficient quantity for {data.ticker}"
)
holding.quantity -= data.quantity
if holding.quantity == 0:
db.delete(holding)
db.commit()
db.refresh(transaction)
return transaction
@router.get("/{portfolio_id}/rebalance", response_model=RebalanceResponse)
async def calculate_rebalance(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Calculate rebalancing for a portfolio."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
service = RebalanceService(db)
return service.calculate_rebalance(portfolio)
@router.post("/{portfolio_id}/rebalance/simulate", response_model=RebalanceSimulationResponse)
async def simulate_rebalance(
portfolio_id: int,
data: RebalanceSimulationRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Simulate rebalancing with additional investment amount."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
service = RebalanceService(db)
return service.calculate_rebalance(portfolio, additional_amount=data.additional_amount)
@router.post("/{portfolio_id}/rebalance/calculate", response_model=RebalanceCalculateResponse)
async def calculate_rebalance_manual(
portfolio_id: int,
data: RebalanceCalculateRequest,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Calculate rebalancing with manual prices and strategy selection."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
if data.strategy == "additional_buy" and not data.additional_amount:
raise HTTPException(
status_code=400,
detail="additional_amount is required for additional_buy strategy"
)
service = RebalanceService(db)
return service.calculate_with_prices(
portfolio,
strategy=data.strategy,
manual_prices=data.prices,
additional_amount=data.additional_amount,
)
@router.get("/{portfolio_id}/detail", response_model=PortfolioDetail)
async def get_portfolio_detail(
portfolio_id: int,
current_user: CurrentUser,
db: Session = Depends(get_db),
):
"""Get portfolio with calculated values."""
portfolio = _get_portfolio(db, portfolio_id, current_user.id)
# Get current prices and stock names
tickers = [h.ticker for h in portfolio.holdings]
service = RebalanceService(db)
prices = service.get_current_prices(tickers)
names = service.get_stock_names(tickers)
# Calculate holding values
holdings_with_value = []
total_value = Decimal("0")
total_invested = Decimal("0")
for holding in portfolio.holdings:
current_price = prices.get(holding.ticker, Decimal("0"))
value = current_price * holding.quantity
invested = Decimal(str(holding.avg_price)) * holding.quantity
profit_loss = value - invested
profit_loss_ratio = (profit_loss / invested * 100) if invested > 0 else Decimal("0")
total_value += value
total_invested += invested
holdings_with_value.append(HoldingWithValue(
ticker=holding.ticker,
name=names.get(holding.ticker),
quantity=holding.quantity,
avg_price=Decimal(str(holding.avg_price)),
current_price=current_price,
value=value,
current_ratio=Decimal("0"), # Will be calculated after total
profit_loss=profit_loss,
profit_loss_ratio=profit_loss_ratio.quantize(Decimal("0.01")),
))
# Calculate current ratios
for h in holdings_with_value:
if total_value > 0:
h.current_ratio = (h.value / total_value * 100).quantize(Decimal("0.01"))
return PortfolioDetail(
id=portfolio.id,
user_id=portfolio.user_id,
name=portfolio.name,
portfolio_type=portfolio.portfolio_type.value,
created_at=portfolio.created_at,
updated_at=portfolio.updated_at,
targets=portfolio.targets,
holdings=holdings_with_value,
total_value=total_value,
total_invested=total_invested,
total_profit_loss=total_value - total_invested,
)