galaxis-po/backend/tests/unit/test_pension.py

292 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Tests for pension account models, service, and API endpoints.
"""
import pytest
from decimal import Decimal
from unittest.mock import patch
from app.services.pension_allocation import (
calculate_current_age,
calculate_years_to_retirement,
calculate_glide_path,
calculate_allocation,
get_recommendation,
RISKY_ASSET_LIMIT_PCT,
SAFE_ASSET_MIN_PCT,
)
# --- Service unit tests ---
class TestPensionAllocationService:
def test_calculate_current_age(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
assert calculate_current_age(1990) == 36
assert calculate_current_age(1966) == 60
def test_calculate_years_to_retirement(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
assert calculate_years_to_retirement(1990, 60) == 24
assert calculate_years_to_retirement(1966, 60) == 0
assert calculate_years_to_retirement(1960, 60) == 0 # already past
def test_glide_path_young_person(self):
"""Young person (30) should have high equity allocation."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, bond_pct = calculate_glide_path(1996, 60)
assert equity_pct + bond_pct == Decimal("100.00")
assert equity_pct >= Decimal("60") # young = high equity
def test_glide_path_near_retirement(self):
"""Person near retirement (55) should have low equity allocation."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, bond_pct = calculate_glide_path(1971, 60)
assert equity_pct + bond_pct == Decimal("100.00")
assert equity_pct <= Decimal("40") # near retirement = low equity
def test_glide_path_respects_regulatory_limit(self):
"""Equity allocation should never exceed 70% (regulatory limit)."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, _ = calculate_glide_path(2000, 60)
assert equity_pct <= RISKY_ASSET_LIMIT_PCT
def test_glide_path_minimum_equity(self):
"""Equity allocation should not go below 20% (minimum)."""
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
equity_pct, _ = calculate_glide_path(1960, 60)
assert equity_pct >= Decimal("20")
def test_calculate_allocation(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = calculate_allocation(
account_id=1,
account_type="dc",
total_amount=Decimal("10000000"),
birth_year=1990,
target_retirement_age=60,
)
assert result.account_id == 1
assert result.account_type == "dc"
assert result.total_amount == 10000000
assert result.risky_limit_pct == 70
assert result.safe_min_pct == 30
assert len(result.allocations) == 5 # 2 risky + 3 safe
# Verify risky assets don't exceed limit
risky_ratio = sum(a.ratio for a in result.allocations if a.asset_type == "risky")
assert risky_ratio <= 70
# Verify total amounts sum to total_amount
total_allocated = sum(a.amount for a in result.allocations)
assert abs(total_allocated - 10000000) < 1 # allow small rounding
def test_calculate_allocation_types(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = calculate_allocation(
account_id=1,
account_type="irp",
total_amount=Decimal("5000000"),
birth_year=1985,
target_retirement_age=60,
)
risky = [a for a in result.allocations if a.asset_type == "risky"]
safe = [a for a in result.allocations if a.asset_type == "safe"]
assert len(risky) == 2
assert len(safe) == 3
# Safe includes TDF, bond ETF, deposit
safe_names = [a.asset_name for a in safe]
assert any("TDF" in n for n in safe_names)
assert any("채권" in n for n in safe_names)
assert any("예금" in n for n in safe_names)
def test_get_recommendation(self):
with patch("app.services.pension_allocation.date") as mock_date:
mock_date.today.return_value.year = 2026
result = get_recommendation(
account_id=1,
birth_year=1990,
target_retirement_age=60,
)
assert result.account_id == 1
assert result.birth_year == 1990
assert result.current_age == 36
assert result.years_to_retirement == 24
assert len(result.recommendations) == 5
categories = [r.category for r in result.recommendations]
assert "tdf" in categories
assert "bond_etf" in categories
assert "deposit" in categories
assert "equity_etf" in categories
# All recommendations have reasons
for rec in result.recommendations:
assert rec.reason
# --- API endpoint tests ---
class TestPensionAPI:
def _create_account(self, client, auth_headers, **overrides):
payload = {
"account_type": "dc",
"account_name": "삼성생명 DC",
"total_amount": 10000000,
"birth_year": 1990,
"target_retirement_age": 60,
**overrides,
}
return client.post("/api/pension/accounts", headers=auth_headers, json=payload)
def test_create_account(self, client, auth_headers):
response = self._create_account(client, auth_headers)
assert response.status_code == 201
data = response.json()
assert data["account_type"] == "dc"
assert data["account_name"] == "삼성생명 DC"
assert data["total_amount"] == 10000000
assert data["birth_year"] == 1990
assert data["target_retirement_age"] == 60
assert data["holdings"] == []
def test_create_account_irp(self, client, auth_headers):
response = self._create_account(
client, auth_headers,
account_type="irp",
account_name="NH IRP",
)
assert response.status_code == 201
assert response.json()["account_type"] == "irp"
def test_list_accounts(self, client, auth_headers):
self._create_account(client, auth_headers, account_name="DC 1호")
self._create_account(client, auth_headers, account_name="IRP", account_type="irp")
response = client.get("/api/pension/accounts", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_get_account(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["id"] == account_id
def test_get_account_not_found(self, client, auth_headers):
response = client.get("/api/pension/accounts/9999", headers=auth_headers)
assert response.status_code == 404
def test_update_account(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.put(
f"/api/pension/accounts/{account_id}",
headers=auth_headers,
json={"account_name": "변경된 계좌명", "total_amount": 20000000},
)
assert response.status_code == 200
data = response.json()
assert data["account_name"] == "변경된 계좌명"
assert data["total_amount"] == 20000000
def test_allocate_assets(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.post(
f"/api/pension/accounts/{account_id}/allocate",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["account_id"] == account_id
assert data["risky_limit_pct"] == 70
assert data["safe_min_pct"] == 30
assert len(data["allocations"]) == 5
# Verify risky ratio <= 70%
risky_ratio = sum(a["ratio"] for a in data["allocations"] if a["asset_type"] == "risky")
assert risky_ratio <= 70
# Verify holdings were saved
account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert len(account_resp.json()["holdings"]) == 5
def test_allocate_replaces_previous_holdings(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
# Allocate twice
client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers)
client.post(f"/api/pension/accounts/{account_id}/allocate", headers=auth_headers)
# Should still have 5 holdings (replaced, not duplicated)
account_resp = client.get(f"/api/pension/accounts/{account_id}", headers=auth_headers)
assert len(account_resp.json()["holdings"]) == 5
def test_get_recommendation(self, client, auth_headers):
create_resp = self._create_account(client, auth_headers)
account_id = create_resp.json()["id"]
response = client.get(
f"/api/pension/accounts/{account_id}/recommendation",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["account_id"] == account_id
assert data["birth_year"] == 1990
assert len(data["recommendations"]) == 5
# Verify recommendation categories
categories = [r["category"] for r in data["recommendations"]]
assert "tdf" in categories
assert "bond_etf" in categories
assert "deposit" in categories
assert "equity_etf" in categories
# All recommendations have reasons
for rec in data["recommendations"]:
assert rec["reason"]
assert rec["asset_name"]
def test_unauthenticated_access(self, client):
response = client.get("/api/pension/accounts")
assert response.status_code == 401
def test_user_isolation(self, client, auth_headers, db):
"""User can only see their own pension accounts."""
from app.models.user import User
from app.core.security import get_password_hash, create_access_token
other_user = User(
username="otheruser",
email="other@example.com",
hashed_password=get_password_hash("password"),
)
db.add(other_user)
db.commit()
db.refresh(other_user)
other_token = create_access_token(data={"sub": other_user.username})
other_headers = {"Authorization": f"Bearer {other_token}"}
self._create_account(client, other_headers, account_name="다른 사용자 계좌")
# Current user should see 0 accounts
response = client.get("/api/pension/accounts", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0