galaxis-po/backend/scripts/generate_snapshots.py
머니페니 9ab232ba12 feat: KRX Open API migration with pykrx fallback
- Add pykrx-openapi dependency
- New krx_client.py wrapper module
- ETFCollector: Open API bulk fetch + pykrx fallback
- ETFPriceCollector: Open API date-based bulk + pykrx fallback
- StockCollector: Open API base_info + daily_trade + pykrx fallback
- PriceCollector: Open API date-based bulk + pykrx fallback
- ValuationCollector: pykrx retained (Open API has no PER/PBR)
- generate_snapshots.py: Open API + pykrx fallback
- Auto-switch based on KRX_OPENAPI_KEY env var
- All 278 tests passing
2026-04-17 23:07:09 +09:00

293 lines
10 KiB
Python

"""
Generate portfolio snapshots from trade history using actual market prices.
Uses KRX Open API when KRX_OPENAPI_KEY is set, falls back to pykrx scraping.
Snapshot dates: end of each month where trades occurred, plus latest available.
Usage:
cd backend && python -m scripts.generate_snapshots
Requires:
- DATABASE_URL environment variable
- KRX_OPENAPI_KEY environment variable (preferred)
- KRX_ID / KRX_PW environment variables (pykrx fallback)
"""
import sys
import os
import time
import logging
from datetime import date, datetime, timedelta
from decimal import Decimal, ROUND_HALF_UP
from collections import defaultdict
from json import JSONDecodeError
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.krx_client import get_krx_client
from app.models.portfolio import (
Portfolio, PortfolioSnapshot, SnapshotHolding,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ETF name -> ticker mapping (must match seed_data.py)
ETF_MAP = {
"TIGER 200": "102110",
"KIWOOM 국고채10년": "148070",
"KODEX 200미국채혼합": "284430",
"TIGER 미국S&P500": "360750",
"ACE KRX금현물": "411060",
}
TICKER_NAMES = {v: k for k, v in ETF_MAP.items()}
# Trade history (same as seed_data.py)
TRADES = [
(date(2025, 4, 29), "ACE KRX금현물", 1, Decimal("21620")),
(date(2025, 4, 29), "TIGER 미국S&P500", 329, Decimal("19770")),
(date(2025, 4, 29), "KIWOOM 국고채10년", 1, Decimal("118000")),
(date(2025, 4, 29), "TIGER 200", 16, Decimal("33815")),
(date(2025, 4, 30), "KODEX 200미국채혼합", 355, Decimal("13235")),
(date(2025, 5, 13), "ACE KRX금현물", 260, Decimal("20820")),
(date(2025, 5, 13), "KODEX 200미국채혼합", 14, Decimal("13165")),
(date(2025, 5, 14), "ACE KRX금현물", 45, Decimal("20760")),
(date(2025, 5, 14), "TIGER 미국S&P500", 45, Decimal("20690")),
(date(2025, 5, 14), "KODEX 200미국채혼합", 733, Decimal("13220")),
(date(2025, 5, 14), "KIWOOM 국고채10년", 90, Decimal("116939")),
(date(2025, 5, 16), "KODEX 200미국채혼합", 169, Decimal("13125")),
(date(2025, 6, 12), "ACE KRX금현물", 14, Decimal("20855")),
(date(2025, 6, 12), "TIGER 미국S&P500", 3, Decimal("20355")),
(date(2025, 6, 12), "KODEX 200미국채혼합", 88, Decimal("13570")),
(date(2025, 6, 12), "KIWOOM 국고채10년", 5, Decimal("115945")),
(date(2025, 7, 30), "KIWOOM 국고채10년", 6, Decimal("116760")),
(date(2025, 8, 14), "ACE KRX금현물", 6, Decimal("21095")),
(date(2025, 8, 14), "TIGER 미국S&P500", 3, Decimal("22200")),
(date(2025, 8, 14), "KODEX 200미국채혼합", 27, Decimal("14465")),
(date(2025, 8, 14), "KIWOOM 국고채10년", 1, Decimal("117075")),
(date(2025, 8, 19), "ACE KRX금현물", 2, Decimal("21030")),
(date(2025, 10, 13), "TIGER 미국S&P500", 3, Decimal("23480")),
(date(2025, 10, 13), "KIWOOM 국고채10년", 12, Decimal("116465")),
(date(2025, 12, 5), "KIWOOM 국고채10년", 7, Decimal("112830")),
(date(2026, 1, 7), "TIGER 미국S&P500", 2, Decimal("25015")),
(date(2026, 1, 8), "KIWOOM 국고채10년", 11, Decimal("109527")),
(date(2026, 2, 20), "TIGER 미국S&P500", 20, Decimal("24685")),
(date(2026, 2, 20), "KIWOOM 국고채10년", 9, Decimal("108500")),
(date(2026, 3, 23), "ACE KRX금현물", 41, Decimal("30095")),
(date(2026, 3, 23), "TIGER 미국S&P500", 128, Decimal("24290")),
(date(2026, 3, 23), "KODEX 200미국채혼합", 188, Decimal("19579")),
(date(2026, 3, 23), "KIWOOM 국고채10년", 10, Decimal("106780")),
]
def _get_holdings_at_date(target_date: date) -> dict[str, int]:
"""Compute cumulative holdings at a given date."""
holdings: dict[str, int] = defaultdict(int)
for trade_date, name, qty, _ in TRADES:
if trade_date <= target_date:
ticker = ETF_MAP[name]
holdings[ticker] += qty
return dict(holdings)
def _generate_snapshot_dates() -> list[date]:
"""Generate month-end snapshot dates from first trade to today."""
if not TRADES:
return []
first_date = min(t[0] for t in TRADES)
today = date.today()
dates = []
current = date(first_date.year, first_date.month, 1)
while current <= today:
if current.month == 12:
next_month = date(current.year + 1, 1, 1)
else:
next_month = date(current.year, current.month + 1, 1)
last_day = next_month - timedelta(days=1)
if last_day >= first_date and last_day <= today:
dates.append(last_day)
current = next_month
return dates
def _fetch_price_openapi(ticker: str, date_str: str) -> Decimal | None:
"""Fetch closing price via KRX Open API."""
client = get_krx_client()
if not client:
return None
target = datetime.strptime(date_str, "%Y%m%d").date()
for day_offset in range(5):
try_date = target - timedelta(days=day_offset)
try_date_str = try_date.strftime("%Y%m%d")
try:
df = client.get_etf_daily(try_date_str)
if df is not None and not df.empty:
match = df[df["ISU_SRT_CD"] == ticker]
if not match.empty:
close = match.iloc[0].get("TDD_CLSPRC")
if close and float(close) > 0:
return Decimal(str(int(float(close))))
except Exception as e:
logger.warning(f"Open API fetch for {ticker} on {try_date_str}: {e}")
continue
return None
def _fetch_price_pykrx(ticker: str, date_str: str, max_retries: int = 3) -> Decimal | None:
"""Fetch closing price via pykrx scraping."""
from pykrx import stock as pykrx_stock
target = datetime.strptime(date_str, "%Y%m%d").date()
for day_offset in range(5):
try_date = target - timedelta(days=day_offset)
try_date_str = try_date.strftime("%Y%m%d")
for attempt in range(max_retries):
try:
df = pykrx_stock.get_etf_ohlcv_by_date(try_date_str, try_date_str, ticker)
if df is not None and not df.empty:
close = df.iloc[0]["종가"]
if close and float(close) > 0:
return Decimal(str(int(close)))
except (JSONDecodeError, ConnectionError, KeyError, ValueError) as e:
if attempt < max_retries - 1:
logger.warning(f"Retry {attempt+1}/{max_retries} for {ticker} on {try_date_str}: {e}")
time.sleep(2)
continue
# Fallback: try stock API (for non-ETF tickers)
for day_offset in range(5):
try_date = target - timedelta(days=day_offset)
try_date_str = try_date.strftime("%Y%m%d")
try:
df = pykrx_stock.get_market_ohlcv(try_date_str, try_date_str, ticker)
if df is not None and not df.empty:
close = df.iloc[0]["종가"]
if close and float(close) > 0:
return Decimal(str(int(close)))
except Exception:
continue
return None
def _fetch_price_with_retry(ticker: str, date_str: str, max_retries: int = 3) -> Decimal | None:
"""Fetch closing price, preferring Open API with pykrx fallback."""
client = get_krx_client()
if client:
price = _fetch_price_openapi(ticker, date_str)
if price:
return price
logger.warning(f"Open API failed for {ticker} on {date_str}, trying pykrx")
return _fetch_price_pykrx(ticker, date_str, max_retries)
def generate_snapshots(db: Session):
"""Generate portfolio snapshots from trade history with actual market prices."""
portfolio = db.query(Portfolio).filter(Portfolio.name == "연금 포트폴리오").first()
if not portfolio:
logger.error("Portfolio '연금 포트폴리오' not found. Run seed_data.py first.")
return
existing = db.query(PortfolioSnapshot).filter(
PortfolioSnapshot.portfolio_id == portfolio.id
).all()
if existing:
for snap in existing:
db.delete(snap)
db.flush()
logger.info(f"Deleted {len(existing)} existing snapshots")
snapshot_dates = _generate_snapshot_dates()
logger.info(f"Generating {len(snapshot_dates)} snapshots from {snapshot_dates[0]} to {snapshot_dates[-1]}")
all_tickers = list(ETF_MAP.values())
created = 0
for snap_date in snapshot_dates:
holdings = _get_holdings_at_date(snap_date)
if not holdings:
continue
date_str = snap_date.strftime("%Y%m%d")
logger.info(f"Processing {snap_date} ({len(holdings)} tickers)...")
prices: dict[str, Decimal] = {}
for ticker in holdings:
price = _fetch_price_with_retry(ticker, date_str)
if price:
prices[ticker] = price
else:
logger.warning(f" Could not fetch price for {TICKER_NAMES.get(ticker, ticker)} on {snap_date}")
if not prices:
logger.warning(f" Skipping {snap_date}: no prices available")
continue
total_value = Decimal("0")
snapshot_holdings = []
for ticker, qty in holdings.items():
if ticker not in prices:
continue
value = qty * prices[ticker]
total_value += value
snapshot_holdings.append({
"ticker": ticker,
"quantity": qty,
"price": prices[ticker],
"value": value,
})
if total_value == 0:
continue
snapshot = PortfolioSnapshot(
portfolio_id=portfolio.id,
total_value=total_value,
snapshot_date=snap_date,
)
db.add(snapshot)
db.flush()
for h in snapshot_holdings:
ratio = (h["value"] / total_value * 100).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
db.add(SnapshotHolding(
snapshot_id=snapshot.id,
ticker=h["ticker"],
quantity=h["quantity"],
price=h["price"],
value=h["value"],
current_ratio=ratio,
))
created += 1
logger.info(f" Snapshot {snap_date}: total={total_value:,.0f}")
time.sleep(1)
db.commit()
logger.info(f"Done! Created {created} snapshots.")
if __name__ == "__main__":
db = SessionLocal()
try:
generate_snapshots(db)
finally:
db.close()