190 lines
4.8 KiB
Python
190 lines
4.8 KiB
Python
"""
|
|
Pytest configuration and fixtures
|
|
"""
|
|
import os
|
|
import pytest
|
|
from datetime import date
|
|
from typing import Generator
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.main import app
|
|
from app.database import Base, get_db
|
|
from app.config import get_settings
|
|
from app.models.asset import Asset
|
|
from app.models.price import PriceData
|
|
from app.models.portfolio import Portfolio, PortfolioAsset
|
|
from app.models.backtest import BacktestRun
|
|
|
|
|
|
# Test database URL
|
|
TEST_DATABASE_URL = os.getenv(
|
|
"TEST_DATABASE_URL",
|
|
"postgresql://postgres:postgres@localhost:5432/pension_quant_test"
|
|
)
|
|
|
|
# Create test engine
|
|
test_engine = create_engine(TEST_DATABASE_URL)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def setup_test_database():
|
|
"""Create test database tables before all tests"""
|
|
Base.metadata.create_all(bind=test_engine)
|
|
yield
|
|
Base.metadata.drop_all(bind=test_engine)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def db_session() -> Generator[Session, None, None]:
|
|
"""Create a new database session for each test"""
|
|
connection = test_engine.connect()
|
|
transaction = connection.begin()
|
|
session = TestingSessionLocal(bind=connection)
|
|
|
|
yield session
|
|
|
|
session.close()
|
|
transaction.rollback()
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(db_session: Session) -> Generator[TestClient, None, None]:
|
|
"""Create a FastAPI test client"""
|
|
def override_get_db():
|
|
try:
|
|
yield db_session
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
with TestClient(app) as test_client:
|
|
yield test_client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_assets(db_session: Session):
|
|
"""Create sample assets for testing"""
|
|
assets = [
|
|
Asset(
|
|
ticker="005930",
|
|
name="삼성전자",
|
|
market="KOSPI",
|
|
market_cap=400000000000000,
|
|
stock_type="보통주",
|
|
sector="전기전자",
|
|
last_price=70000,
|
|
eps=5000,
|
|
bps=45000,
|
|
base_date=date(2023, 12, 31),
|
|
is_active=True
|
|
),
|
|
Asset(
|
|
ticker="000660",
|
|
name="SK하이닉스",
|
|
market="KOSPI",
|
|
market_cap=100000000000000,
|
|
stock_type="보통주",
|
|
sector="전기전자",
|
|
last_price=120000,
|
|
eps=8000,
|
|
bps=60000,
|
|
base_date=date(2023, 12, 31),
|
|
is_active=True
|
|
),
|
|
Asset(
|
|
ticker="035420",
|
|
name="NAVER",
|
|
market="KOSPI",
|
|
market_cap=30000000000000,
|
|
stock_type="보통주",
|
|
sector="서비스업",
|
|
last_price=200000,
|
|
eps=10000,
|
|
bps=80000,
|
|
base_date=date(2023, 12, 31),
|
|
is_active=True
|
|
),
|
|
]
|
|
|
|
for asset in assets:
|
|
db_session.add(asset)
|
|
|
|
db_session.commit()
|
|
|
|
return assets
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_price_data(db_session: Session, sample_assets):
|
|
"""Create sample price data for testing"""
|
|
from datetime import datetime, timedelta
|
|
|
|
prices = []
|
|
base_date = datetime(2023, 1, 1)
|
|
|
|
for i in range(30): # 30 days of data
|
|
current_date = base_date + timedelta(days=i)
|
|
|
|
for asset in sample_assets:
|
|
price = PriceData(
|
|
ticker=asset.ticker,
|
|
timestamp=current_date,
|
|
open=asset.last_price * 0.99,
|
|
high=asset.last_price * 1.02,
|
|
low=asset.last_price * 0.98,
|
|
close=asset.last_price * (1 + (i % 5) * 0.01),
|
|
volume=1000000
|
|
)
|
|
prices.append(price)
|
|
db_session.add(price)
|
|
|
|
db_session.commit()
|
|
|
|
return prices
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_portfolio(db_session: Session, sample_assets):
|
|
"""Create a sample portfolio for testing"""
|
|
portfolio = Portfolio(
|
|
name="테스트 포트폴리오",
|
|
description="통합 테스트용 포트폴리오",
|
|
user_id="test_user"
|
|
)
|
|
db_session.add(portfolio)
|
|
db_session.flush()
|
|
|
|
# Add portfolio assets
|
|
portfolio_assets = [
|
|
PortfolioAsset(
|
|
portfolio_id=portfolio.id,
|
|
ticker="005930",
|
|
target_ratio=40.0
|
|
),
|
|
PortfolioAsset(
|
|
portfolio_id=portfolio.id,
|
|
ticker="000660",
|
|
target_ratio=30.0
|
|
),
|
|
PortfolioAsset(
|
|
portfolio_id=portfolio.id,
|
|
ticker="035420",
|
|
target_ratio=30.0
|
|
),
|
|
]
|
|
|
|
for pa in portfolio_assets:
|
|
db_session.add(pa)
|
|
|
|
db_session.commit()
|
|
db_session.refresh(portfolio)
|
|
|
|
return portfolio
|