penti/backend/tests/conftest.py

190 lines
4.8 KiB
Python
Raw Normal View History

2026-01-31 23:30:51 +09:00
"""
Pytest configuration and fixtures
"""
import os
import pytest
from datetime import date
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from fastapi.testclient import TestClient
from app.main import app
from app.database import Base, get_db
from app.config import get_settings
from app.models.asset import Asset
from app.models.price import PriceData
from app.models.portfolio import Portfolio, PortfolioAsset
from app.models.backtest import BacktestRun
# Test database URL
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/pension_quant_test"
)
# Create test engine
test_engine = create_engine(TEST_DATABASE_URL)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
@pytest.fixture(scope="session", autouse=True)
def setup_test_database():
"""Create test database tables before all tests"""
Base.metadata.create_all(bind=test_engine)
yield
Base.metadata.drop_all(bind=test_engine)
@pytest.fixture(scope="function")
def db_session() -> Generator[Session, None, None]:
"""Create a new database session for each test"""
connection = test_engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(scope="function")
def client(db_session: Session) -> Generator[TestClient, None, None]:
"""Create a FastAPI test client"""
def override_get_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
@pytest.fixture
def sample_assets(db_session: Session):
"""Create sample assets for testing"""
assets = [
Asset(
ticker="005930",
name="삼성전자",
market="KOSPI",
market_cap=400000000000000,
stock_type="보통주",
sector="전기전자",
last_price=70000,
eps=5000,
bps=45000,
base_date=date(2023, 12, 31),
is_active=True
),
Asset(
ticker="000660",
name="SK하이닉스",
market="KOSPI",
market_cap=100000000000000,
stock_type="보통주",
sector="전기전자",
last_price=120000,
eps=8000,
bps=60000,
base_date=date(2023, 12, 31),
is_active=True
),
Asset(
ticker="035420",
name="NAVER",
market="KOSPI",
market_cap=30000000000000,
stock_type="보통주",
sector="서비스업",
last_price=200000,
eps=10000,
bps=80000,
base_date=date(2023, 12, 31),
is_active=True
),
]
for asset in assets:
db_session.add(asset)
db_session.commit()
return assets
@pytest.fixture
def sample_price_data(db_session: Session, sample_assets):
"""Create sample price data for testing"""
from datetime import datetime, timedelta
prices = []
base_date = datetime(2023, 1, 1)
for i in range(30): # 30 days of data
current_date = base_date + timedelta(days=i)
for asset in sample_assets:
price = PriceData(
ticker=asset.ticker,
timestamp=current_date,
open=asset.last_price * 0.99,
high=asset.last_price * 1.02,
low=asset.last_price * 0.98,
close=asset.last_price * (1 + (i % 5) * 0.01),
volume=1000000
)
prices.append(price)
db_session.add(price)
db_session.commit()
return prices
@pytest.fixture
def sample_portfolio(db_session: Session, sample_assets):
"""Create a sample portfolio for testing"""
portfolio = Portfolio(
name="테스트 포트폴리오",
description="통합 테스트용 포트폴리오",
user_id="test_user"
)
db_session.add(portfolio)
db_session.flush()
# Add portfolio assets
portfolio_assets = [
PortfolioAsset(
portfolio_id=portfolio.id,
ticker="005930",
target_ratio=40.0
),
PortfolioAsset(
portfolio_id=portfolio.id,
ticker="000660",
target_ratio=30.0
),
PortfolioAsset(
portfolio_id=portfolio.id,
ticker="035420",
target_ratio=30.0
),
]
for pa in portfolio_assets:
db_session.add(pa)
db_session.commit()
db_session.refresh(portfolio)
return portfolio