penti/scripts/migrate_mysql_to_postgres.py

411 lines
15 KiB
Python

"""MySQL to PostgreSQL data migration script."""
import sys
import os
from datetime import datetime
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
import pandas as pd
import pymysql
from sqlalchemy import create_engine, Column, String, BigInteger, Numeric, Date, Boolean, DateTime, PrimaryKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import UUID, insert
import uuid
from tqdm import tqdm
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Create base
Base = declarative_base()
# Define models directly
class Asset(Base):
"""Asset model."""
__tablename__ = "assets"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
ticker = Column(String(20), unique=True, nullable=False, index=True)
name = Column(String(100), nullable=False)
market = Column(String(20))
market_cap = Column(BigInteger)
stock_type = Column(String(20))
sector = Column(String(100))
last_price = Column(Numeric(15, 2))
eps = Column(Numeric(15, 2))
bps = Column(Numeric(15, 2))
dividend_per_share = Column(Numeric(15, 2))
base_date = Column(Date)
is_active = Column(Boolean, default=True)
class PriceData(Base):
"""Price data model."""
__tablename__ = "price_data"
ticker = Column(String(20), nullable=False, index=True)
timestamp = Column(DateTime, nullable=False, index=True)
open = Column(Numeric(15, 2))
high = Column(Numeric(15, 2))
low = Column(Numeric(15, 2))
close = Column(Numeric(15, 2), nullable=False)
volume = Column(BigInteger)
__table_args__ = (
PrimaryKeyConstraint('ticker', 'timestamp'),
)
class FinancialStatement(Base):
"""Financial statement model."""
__tablename__ = "financial_statements"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
ticker = Column(String(20), nullable=False, index=True)
account = Column(String(100), nullable=False)
base_date = Column(Date, nullable=False, index=True)
value = Column(Numeric(20, 2))
disclosure_type = Column(String(1))
__table_args__ = (
# Unique constraint for upsert
{'extend_existing': True}
)
# Get PostgreSQL connection from environment
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://pension_user:pension_password@localhost:5432/pension_quant")
# Create PostgreSQL engine and session
pg_engine = create_engine(DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine)
class MySQLToPostgreSQLMigrator:
"""MySQL to PostgreSQL 마이그레이터."""
def __init__(
self,
mysql_host: str,
mysql_user: str,
mysql_password: str,
mysql_database: str,
mysql_port: int = 3306
):
"""
초기화.
Args:
mysql_host: MySQL 호스트
mysql_user: MySQL 사용자
mysql_password: MySQL 비밀번호
mysql_database: MySQL 데이터베이스
mysql_port: MySQL 포트 (기본값: 3306)
"""
self.mysql_conn = pymysql.connect(
host=mysql_host,
port=mysql_port,
user=mysql_user,
password=mysql_password,
database=mysql_database
)
self.pg_session = SessionLocal()
# PostgreSQL 테이블 생성 (없는 경우)
print("PostgreSQL 테이블 확인 및 생성 중...")
Base.metadata.create_all(bind=pg_engine)
print("테이블 준비 완료")
def migrate_ticker_data(self):
"""
kor_ticker → assets 마이그레이션.
"""
print("\n=== 종목 데이터 마이그레이션 시작 ===")
# MySQL에서 데이터 읽기 (종목별 가장 최신 기준일 데이터만)
query = """
SELECT t1.*
FROM kor_ticker t1
INNER JOIN (
SELECT 종목코드, MAX(기준일) as max_date
FROM kor_ticker
GROUP BY 종목코드
) t2 ON t1.종목코드 = t2.종목코드 AND t1.기준일 = t2.max_date
"""
df = pd.read_sql(query, self.mysql_conn)
# DataFrame에서도 중복 제거 (혹시 모를 중복 방지)
df = df.drop_duplicates(subset=['종목코드'], keep='last')
print(f"MySQL에서 {len(df)}개 종목 데이터 읽기 완료 (중복 제거됨)")
# PostgreSQL에 저장 (UPSERT 사용)
success_count = 0
for _, row in tqdm(df.iterrows(), total=len(df), desc="종목 데이터 저장"):
try:
# UPSERT statement 생성
stmt = insert(Asset).values(
id=uuid.uuid4(),
ticker=row['종목코드'],
name=row['종목명'],
market=row['시장구분'],
last_price=row['종가'] if pd.notna(row['종가']) else None,
market_cap=row['시가총액'] if pd.notna(row['시가총액']) else None,
eps=row['EPS'] if pd.notna(row['EPS']) else None,
bps=row['BPS'] if pd.notna(row['BPS']) else None,
dividend_per_share=row['주당배당금'] if pd.notna(row['주당배당금']) else None,
stock_type=row['종목구분'] if pd.notna(row['종목구분']) else None,
base_date=row['기준일'] if pd.notna(row['기준일']) else None,
is_active=True
)
# ON CONFLICT DO UPDATE
stmt = stmt.on_conflict_do_update(
index_elements=['ticker'],
set_={
'name': row['종목명'],
'market': row['시장구분'],
'last_price': row['종가'] if pd.notna(row['종가']) else None,
'market_cap': row['시가총액'] if pd.notna(row['시가총액']) else None,
'eps': row['EPS'] if pd.notna(row['EPS']) else None,
'bps': row['BPS'] if pd.notna(row['BPS']) else None,
'dividend_per_share': row['주당배당금'] if pd.notna(row['주당배당금']) else None,
'stock_type': row['종목구분'] if pd.notna(row['종목구분']) else None,
'base_date': row['기준일'] if pd.notna(row['기준일']) else None,
'is_active': True
}
)
self.pg_session.execute(stmt)
success_count += 1
# 100개마다 커밋
if success_count % 100 == 0:
self.pg_session.commit()
except Exception as e:
print(f"\n종목 {row['종목코드']} 저장 오류: {e}")
self.pg_session.rollback()
continue
# 최종 커밋
self.pg_session.commit()
print(f"\n종목 데이터 마이그레이션 완료: {success_count}")
def migrate_price_data(self, limit: int = None):
"""
kor_price → price_data 마이그레이션.
Args:
limit: 제한 레코드 수 (테스트용, None이면 전체)
"""
print("\n=== 주가 데이터 마이그레이션 시작 ===")
# 전체 레코드 수 조회
count_query = "SELECT COUNT(*) as count FROM kor_price"
total_count = pd.read_sql(count_query, self.mysql_conn)['count'][0]
print(f"전체 주가 레코드 수: {total_count:,}")
if limit:
print(f"제한: {limit:,}개만 마이그레이션")
total_count = min(total_count, limit)
# 배치 처리 (메모리 절약)
batch_size = 10000
success_count = 0
for offset in range(0, total_count, batch_size):
query = f"SELECT * FROM kor_price LIMIT {batch_size} OFFSET {offset}"
df = pd.read_sql(query, self.mysql_conn)
print(f"\n배치 {offset//batch_size + 1}: {len(df)}개 레코드 처리 중...")
for _, row in tqdm(df.iterrows(), total=len(df), desc="주가 데이터 저장"):
try:
# UPSERT statement 생성
stmt = insert(PriceData).values(
ticker=row['종목코드'],
timestamp=row['날짜'],
open=row['시가'] if pd.notna(row['시가']) else None,
high=row['고가'] if pd.notna(row['고가']) else None,
low=row['저가'] if pd.notna(row['저가']) else None,
close=row['종가'],
volume=int(row['거래량']) if pd.notna(row['거래량']) else None
)
# ON CONFLICT DO UPDATE (복합 키: ticker, timestamp)
stmt = stmt.on_conflict_do_update(
index_elements=['ticker', 'timestamp'],
set_={
'open': row['시가'] if pd.notna(row['시가']) else None,
'high': row['고가'] if pd.notna(row['고가']) else None,
'low': row['저가'] if pd.notna(row['저가']) else None,
'close': row['종가'],
'volume': int(row['거래량']) if pd.notna(row['거래량']) else None
}
)
self.pg_session.execute(stmt)
success_count += 1
# 1000개마다 커밋
if success_count % 1000 == 0:
self.pg_session.commit()
except Exception as e:
print(f"\n주가 데이터 저장 오류: {e}")
self.pg_session.rollback()
continue
# 배치 커밋
self.pg_session.commit()
print(f"\n주가 데이터 마이그레이션 완료: {success_count:,}")
def migrate_financial_data(self, limit: int = None):
"""
kor_fs → financial_statements 마이그레이션.
Args:
limit: 제한 레코드 수 (테스트용, None이면 전체)
"""
print("\n=== 재무제표 데이터 마이그레이션 시작 ===")
# 전체 레코드 수 조회
count_query = "SELECT COUNT(*) as count FROM kor_fs"
total_count = pd.read_sql(count_query, self.mysql_conn)['count'][0]
print(f"전체 재무제표 레코드 수: {total_count:,}")
if limit:
print(f"제한: {limit:,}개만 마이그레이션")
total_count = min(total_count, limit)
# 배치 처리
batch_size = 10000
success_count = 0
for offset in range(0, total_count, batch_size):
query = f"SELECT * FROM kor_fs LIMIT {batch_size} OFFSET {offset}"
df = pd.read_sql(query, self.mysql_conn)
print(f"\n배치 {offset//batch_size + 1}: {len(df)}개 레코드 처리 중...")
for _, row in tqdm(df.iterrows(), total=len(df), desc="재무제표 데이터 저장"):
try:
# 기존 레코드 확인
existing = self.pg_session.query(FinancialStatement).filter(
FinancialStatement.ticker == row['종목코드'],
FinancialStatement.account == row['계정'],
FinancialStatement.base_date == row['기준일'],
FinancialStatement.disclosure_type == row['공시구분']
).first()
if existing:
# 업데이트
existing.value = row[''] if pd.notna(row['']) else None
else:
# 신규 삽입
fs = FinancialStatement(
ticker=row['종목코드'],
account=row['계정'],
base_date=row['기준일'],
value=row[''] if pd.notna(row['']) else None,
disclosure_type=row['공시구분']
)
self.pg_session.add(fs)
success_count += 1
# 1000개마다 커밋
if success_count % 1000 == 0:
self.pg_session.commit()
except Exception as e:
print(f"\n재무제표 데이터 저장 오류: {e}")
self.pg_session.rollback()
continue
# 배치 커밋
self.pg_session.commit()
print(f"\n재무제표 데이터 마이그레이션 완료: {success_count:,}")
def migrate_all(self, price_limit: int = None, fs_limit: int = None):
"""
전체 데이터 마이그레이션.
Args:
price_limit: 주가 데이터 제한
fs_limit: 재무제표 데이터 제한
"""
start_time = datetime.now()
print(f"\n{'='*60}")
print(f"MySQL → PostgreSQL 데이터 마이그레이션 시작")
print(f"시작 시간: {start_time}")
print(f"{'='*60}")
try:
# 1. 종목 데이터
self.migrate_ticker_data()
# 2. 주가 데이터
self.migrate_price_data(limit=price_limit)
# 3. 재무제표 데이터
self.migrate_financial_data(limit=fs_limit)
end_time = datetime.now()
duration = end_time - start_time
print(f"\n{'='*60}")
print(f"마이그레이션 완료!")
print(f"종료 시간: {end_time}")
print(f"소요 시간: {duration}")
print(f"{'='*60}")
except Exception as e:
print(f"\n마이그레이션 오류: {e}")
raise
finally:
self.close()
def close(self):
"""연결 종료."""
self.mysql_conn.close()
self.pg_session.close()
def main():
"""메인 함수."""
import argparse
parser = argparse.ArgumentParser(description='MySQL to PostgreSQL 데이터 마이그레이션')
parser.add_argument('--mysql-host', required=True, help='MySQL 호스트')
parser.add_argument('--mysql-port', type=int, default=3306, help='MySQL 포트 (기본값: 3306)')
parser.add_argument('--mysql-user', required=True, help='MySQL 사용자')
parser.add_argument('--mysql-password', required=True, help='MySQL 비밀번호')
parser.add_argument('--mysql-database', required=True, help='MySQL 데이터베이스')
parser.add_argument('--price-limit', type=int, help='주가 데이터 제한 (테스트용)')
parser.add_argument('--fs-limit', type=int, help='재무제표 데이터 제한 (테스트용)')
args = parser.parse_args()
migrator = MySQLToPostgreSQLMigrator(
mysql_host=args.mysql_host,
mysql_port=args.mysql_port,
mysql_user=args.mysql_user,
mysql_password=args.mysql_password,
mysql_database=args.mysql_database
)
migrator.migrate_all(
price_limit=args.price_limit,
fs_limit=args.fs_limit
)
if __name__ == '__main__':
main()