diff --git a/backend/app/api/portfolio.py b/backend/app/api/portfolio.py index f024de2..b5f28a8 100644 --- a/backend/app/api/portfolio.py +++ b/backend/app/api/portfolio.py @@ -16,6 +16,7 @@ from app.schemas.portfolio import ( HoldingCreate, HoldingResponse, HoldingWithValue, TransactionCreate, TransactionResponse, RebalanceResponse, RebalanceSimulationRequest, RebalanceSimulationResponse, + RebalanceCalculateRequest, RebalanceCalculateResponse, ) from app.services.rebalance import RebalanceService @@ -319,6 +320,31 @@ async def simulate_rebalance( return service.calculate_rebalance(portfolio, additional_amount=data.additional_amount) +@router.post("/{portfolio_id}/rebalance/calculate", response_model=RebalanceCalculateResponse) +async def calculate_rebalance_manual( + portfolio_id: int, + data: RebalanceCalculateRequest, + current_user: CurrentUser, + db: Session = Depends(get_db), +): + """Calculate rebalancing with manual prices and strategy selection.""" + portfolio = _get_portfolio(db, portfolio_id, current_user.id) + + if data.strategy == "additional_buy" and not data.additional_amount: + raise HTTPException( + status_code=400, + detail="additional_amount is required for additional_buy strategy" + ) + + service = RebalanceService(db) + return service.calculate_with_prices( + portfolio, + strategy=data.strategy, + manual_prices=data.prices, + additional_amount=data.additional_amount, + ) + + @router.get("/{portfolio_id}/detail", response_model=PortfolioDetail) async def get_portfolio_detail( portfolio_id: int, diff --git a/backend/tests/e2e/test_rebalance_flow.py b/backend/tests/e2e/test_rebalance_flow.py index 95ad9de..af1c1aa 100644 --- a/backend/tests/e2e/test_rebalance_flow.py +++ b/backend/tests/e2e/test_rebalance_flow.py @@ -51,7 +51,7 @@ def test_calculate_rebalance_with_manual_prices(client: TestClient, auth_headers assert response.status_code == 200 data = response.json() assert data["portfolio_id"] == pid - assert data["total_assets"] > 0 + assert float(data["total_assets"]) > 0 assert len(data["items"]) == 2 # Verify items have required fields item = data["items"][0] @@ -78,8 +78,8 @@ def test_calculate_additional_buy_strategy(client: TestClient, auth_headers): ) assert response.status_code == 200 data = response.json() - assert data["total_assets"] > 0 - assert data["available_to_buy"] == 1000000 + assert float(data["total_assets"]) > 0 + assert float(data["available_to_buy"]) == 1000000 # Additional buy should never have "sell" actions for item in data["items"]: assert item["action"] in ("buy", "hold")