""" Tests for pension account models, service, and API endpoints. """ import pytest from decimal import Decimal from unittest.mock import patch from app.services.pension_allocation import ( calculate_current_age, calculate_years_to_retirement, calculate_glide_path, calculate_allocation, get_recommendation, RISKY_ASSET_LIMIT_PCT, SAFE_ASSET_MIN_PCT, ) # --- Service unit tests --- class TestPensionAllocationService: def test_calculate_current_age(self): with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 assert calculate_current_age(1990) == 36 assert calculate_current_age(1966) == 60 def test_calculate_years_to_retirement(self): with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 assert calculate_years_to_retirement(1990, 60) == 24 assert calculate_years_to_retirement(1966, 60) == 0 assert calculate_years_to_retirement(1960, 60) == 0 # already past def test_glide_path_young_person(self): """Young person (30) should have high equity allocation.""" with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 equity_pct, bond_pct = calculate_glide_path(1996, 60) assert equity_pct + bond_pct == Decimal("100.00") assert equity_pct >= Decimal("60") # young = high equity def test_glide_path_near_retirement(self): """Person near retirement (55) should have low equity allocation.""" with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 equity_pct, bond_pct = calculate_glide_path(1971, 60) assert equity_pct + bond_pct == Decimal("100.00") assert equity_pct <= Decimal("40") # near retirement = low equity def test_glide_path_respects_regulatory_limit(self): """Equity allocation should never exceed 70% (regulatory limit).""" with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 equity_pct, _ = calculate_glide_path(2000, 60) assert equity_pct <= RISKY_ASSET_LIMIT_PCT def test_glide_path_minimum_equity(self): """Equity allocation should not go below 20% (minimum).""" with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 equity_pct, _ = calculate_glide_path(1960, 60) assert equity_pct >= Decimal("20") def test_calculate_allocation(self): with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 result = calculate_allocation( account_id=1, account_type="dc", total_amount=Decimal("10000000"), birth_year=1990, target_retirement_age=60, ) assert result.account_id == 1 assert result.account_type == "dc" assert result.total_amount == 10000000 assert result.risky_limit_pct == 70 assert result.safe_min_pct == 30 assert len(result.allocations) == 5 # 2 risky + 3 safe # Verify risky assets don't exceed limit risky_ratio = sum(a.ratio for a in result.allocations if a.asset_type == "risky") assert risky_ratio <= 70 # Verify total amounts sum to total_amount total_allocated = sum(a.amount for a in result.allocations) assert abs(total_allocated - 10000000) < 1 # allow small rounding def test_calculate_allocation_types(self): with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 result = calculate_allocation( account_id=1, account_type="irp", total_amount=Decimal("5000000"), birth_year=1985, target_retirement_age=60, ) risky = [a for a in result.allocations if a.asset_type == "risky"] safe = [a for a in result.allocations if a.asset_type == "safe"] assert len(risky) == 2 assert len(safe) == 3 # Safe includes TDF, bond ETF, deposit safe_names = [a.asset_name for a in safe] assert any("TDF" in n for n in safe_names) assert any("채권" in n for n in safe_names) assert any("예금" in n for n in safe_names) def test_get_recommendation(self): with patch("app.services.pension_allocation.date") as mock_date: mock_date.today.return_value.year = 2026 result = get_recommendation( account_id=1, birth_year=1990, target_retirement_age=60, ) assert result.account_id == 1 assert result.birth_year == 1990 assert result.current_age == 36 assert result.years_to_retirement == 24 assert len(result.recommendations) == 5 categories = [r.category for r in result.recommendations] assert "tdf" in categories assert "bond_etf" in categories assert "deposit" in categories assert "equity_etf" in categories # All recommendations have reasons for rec in result.recommendations: assert rec.reason # --- API endpoint tests --- class TestPensionAPI: def _create_account(self, client, auth_headers, **overrides): payload = { "account_type": "dc", "account_name": "삼성생명 DC", "total_amount": 10000000, "birth_year": 1990, "target_retirement_age": 60, **overrides, } return client.post("/api/pension/accounts", headers=auth_headers, json=payload) def test_create_account(self, client, auth_headers): response = self._create_account(client, auth_headers) assert response.status_code == 201 data = response.json() assert data["account_type"] == "dc" assert data["account_name"] == "삼성생명 DC" assert data["total_amount"] == 10000000 assert data["birth_year"] == 1990 assert data["target_retirement_age"] == 60 assert data["holdings"] == [] def test_create_account_irp(self, client, auth_headers): response = self._create_account( client, auth_headers, account_type="irp", account_name="NH IRP", ) assert response.status_code == 201 assert response.json()["account_type"] == "irp" def test_list_accounts(self, client, auth_headers): self._create_account(client, auth_headers, account_name="DC 1호") self._create_account(client, auth_headers, account_name="IRP", account_type="irp") response = client.get("/api/pension/accounts", headers=auth_headers) assert response.status_code == 200 data = response.json() assert len(data) == 2 def test_get_account(self, client, auth_headers): create_resp = self._create_account(client, auth_headers) account_id = create_resp.json()["id"] response = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) assert response.status_code == 200 assert response.json()["id"] == account_id def test_get_account_not_found(self, client, auth_headers): response = client.get("/api/pension/accounts/9999", headers=auth_headers) assert response.status_code == 404 def test_update_account(self, client, auth_headers): create_resp = self._create_account(client, auth_headers) account_id = create_resp.json()["id"] response = client.put( f"/api/pension/accounts/{account_id}", headers=auth_headers, json={"account_name": "변경된 계좌명", "total_amount": 20000000}, ) assert response.status_code == 200 data = response.json() assert data["account_name"] == "변경된 계좌명" assert data["total_amount"] == 20000000 def test_allocate_assets(self, client, auth_headers): create_resp = self._create_account(client, auth_headers) account_id = create_resp.json()["id"] response = client.post( f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers, ) assert response.status_code == 200 data = response.json() assert data["account_id"] == account_id assert data["risky_limit_pct"] == 70 assert data["safe_min_pct"] == 30 assert len(data["allocations"]) == 5 # Verify risky ratio <= 70% risky_ratio = sum(a["ratio"] for a in data["allocations"] if a["asset_type"] == "risky") assert risky_ratio <= 70 # Verify holdings were saved account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) assert len(account_resp.json()["holdings"]) == 5 def test_allocate_replaces_previous_holdings(self, client, auth_headers): create_resp = self._create_account(client, auth_headers) account_id = create_resp.json()["id"] # Allocate twice client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers) client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers) # Should still have 5 holdings (replaced, not duplicated) account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers) assert len(account_resp.json()["holdings"]) == 5 def test_get_recommendation(self, client, auth_headers): create_resp = self._create_account(client, auth_headers) account_id = create_resp.json()["id"] response = client.get( f"/api/pension/accounts/{account_id}/recommendation", headers=auth_headers, ) assert response.status_code == 200 data = response.json() assert data["account_id"] == account_id assert data["birth_year"] == 1990 assert len(data["recommendations"]) == 5 # Verify recommendation categories categories = [r["category"] for r in data["recommendations"]] assert "tdf" in categories assert "bond_etf" in categories assert "deposit" in categories assert "equity_etf" in categories # All recommendations have reasons for rec in data["recommendations"]: assert rec["reason"] assert rec["asset_name"] def test_unauthenticated_access(self, client): response = client.get("/api/pension/accounts") assert response.status_code == 401 def test_user_isolation(self, client, auth_headers, db): """User can only see their own pension accounts.""" from app.models.user import User from app.core.security import get_password_hash, create_access_token other_user = User( username="otheruser", email="other@example.com", hashed_password=get_password_hash("password"), ) db.add(other_user) db.commit() db.refresh(other_user) other_token = create_access_token(data={"sub": other_user.username}) other_headers = {"Authorization": f"Bearer {other_token}"} self._create_account(client, other_headers, account_name="다른 사용자 계좌") # Current user should see 0 accounts response = client.get("/api/pension/accounts", headers=auth_headers) assert response.status_code == 200 assert len(response.json()) == 0