galaxis-po/backend/scripts/generate_snapshots.py
머니페니 625ffadcab perf: bulk-fetch prices in generate_snapshots to reduce API calls
OpenAPI: date당 1회 호출 (기존 ticker×date회)
pykrx: ticker당 전체 기간 조회 1회 (기존 date×ticker회)
date별 sleep(1) 제거
2026-05-13 22:15:22 +09:00

309 lines
11 KiB
Python

"""
Generate portfolio snapshots from trade history using actual market prices.
Uses KRX Open API when KRX_OPENAPI_KEY is set, falls back to pykrx scraping.
Snapshot dates: end of each month where trades occurred, plus latest available.
Usage:
cd backend && python -m scripts.generate_snapshots
Requires:
- DATABASE_URL environment variable
- KRX_OPENAPI_KEY environment variable (preferred)
- KRX_ID / KRX_PW environment variables (pykrx fallback)
"""
import sys
import os
import time
import logging
from datetime import date, timedelta
from decimal import Decimal, ROUND_HALF_UP
from collections import defaultdict
from json import JSONDecodeError
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.krx_client import get_krx_client
from app.models.portfolio import (
Portfolio, PortfolioSnapshot, SnapshotHolding,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ETF name -> ticker mapping (must match seed_data.py)
ETF_MAP = {
"TIGER 200": "102110",
"KIWOOM 국고채10년": "148070",
"KODEX 200미국채혼합": "284430",
"TIGER 미국S&P500": "360750",
"ACE KRX금현물": "411060",
}
TICKER_NAMES = {v: k for k, v in ETF_MAP.items()}
# Trade history (same as seed_data.py)
TRADES = [
(date(2025, 4, 29), "ACE KRX금현물", 1, Decimal("21620")),
(date(2025, 4, 29), "TIGER 미국S&P500", 329, Decimal("19770")),
(date(2025, 4, 29), "KIWOOM 국고채10년", 1, Decimal("118000")),
(date(2025, 4, 29), "TIGER 200", 16, Decimal("33815")),
(date(2025, 4, 30), "KODEX 200미국채혼합", 355, Decimal("13235")),
(date(2025, 5, 13), "ACE KRX금현물", 260, Decimal("20820")),
(date(2025, 5, 13), "KODEX 200미국채혼합", 14, Decimal("13165")),
(date(2025, 5, 14), "ACE KRX금현물", 45, Decimal("20760")),
(date(2025, 5, 14), "TIGER 미국S&P500", 45, Decimal("20690")),
(date(2025, 5, 14), "KODEX 200미국채혼합", 733, Decimal("13220")),
(date(2025, 5, 14), "KIWOOM 국고채10년", 90, Decimal("116939")),
(date(2025, 5, 16), "KODEX 200미국채혼합", 169, Decimal("13125")),
(date(2025, 6, 12), "ACE KRX금현물", 14, Decimal("20855")),
(date(2025, 6, 12), "TIGER 미국S&P500", 3, Decimal("20355")),
(date(2025, 6, 12), "KODEX 200미국채혼합", 88, Decimal("13570")),
(date(2025, 6, 12), "KIWOOM 국고채10년", 5, Decimal("115945")),
(date(2025, 7, 30), "KIWOOM 국고채10년", 6, Decimal("116760")),
(date(2025, 8, 14), "ACE KRX금현물", 6, Decimal("21095")),
(date(2025, 8, 14), "TIGER 미국S&P500", 3, Decimal("22200")),
(date(2025, 8, 14), "KODEX 200미국채혼합", 27, Decimal("14465")),
(date(2025, 8, 14), "KIWOOM 국고채10년", 1, Decimal("117075")),
(date(2025, 8, 19), "ACE KRX금현물", 2, Decimal("21030")),
(date(2025, 10, 13), "TIGER 미국S&P500", 3, Decimal("23480")),
(date(2025, 10, 13), "KIWOOM 국고채10년", 12, Decimal("116465")),
(date(2025, 12, 5), "KIWOOM 국고채10년", 7, Decimal("112830")),
(date(2026, 1, 7), "TIGER 미국S&P500", 2, Decimal("25015")),
(date(2026, 1, 8), "KIWOOM 국고채10년", 11, Decimal("109527")),
(date(2026, 2, 20), "TIGER 미국S&P500", 20, Decimal("24685")),
(date(2026, 2, 20), "KIWOOM 국고채10년", 9, Decimal("108500")),
(date(2026, 3, 23), "ACE KRX금현물", 41, Decimal("30095")),
(date(2026, 3, 23), "TIGER 미국S&P500", 128, Decimal("24290")),
(date(2026, 3, 23), "KODEX 200미국채혼합", 188, Decimal("19579")),
(date(2026, 3, 23), "KIWOOM 국고채10년", 10, Decimal("106780")),
]
def _get_holdings_at_date(target_date: date) -> dict[str, int]:
"""Compute cumulative holdings at a given date."""
holdings: dict[str, int] = defaultdict(int)
for trade_date, name, qty, _ in TRADES:
if trade_date <= target_date:
ticker = ETF_MAP[name]
holdings[ticker] += qty
return dict(holdings)
def _generate_snapshot_dates() -> list[date]:
"""Generate month-end snapshot dates from first trade to today."""
if not TRADES:
return []
first_date = min(t[0] for t in TRADES)
today = date.today()
dates = []
current = date(first_date.year, first_date.month, 1)
while current <= today:
if current.month == 12:
next_month = date(current.year + 1, 1, 1)
else:
next_month = date(current.year, current.month + 1, 1)
last_day = next_month - timedelta(days=1)
if last_day >= first_date and last_day <= today:
dates.append(last_day)
current = next_month
return dates
def _bulk_fetch_openapi(dates: list[date]) -> dict[date, dict[str, Decimal]]:
"""Fetch all ticker prices for each snapshot date — one API call per date.
get_etf_daily returns the full ETF universe for a given date, so a single
call covers all tickers at once. We try up to 5 prior calendar days to
handle weekends / public holidays.
"""
client = get_krx_client()
if not client:
return {}
tickers = set(ETF_MAP.values())
result: dict[date, dict[str, Decimal]] = {}
for snap_date in dates:
for day_offset in range(5):
try_date = snap_date - timedelta(days=day_offset)
try:
df = client.get_etf_daily(try_date.strftime("%Y%m%d"))
if df is None or df.empty:
continue
prices: dict[str, Decimal] = {}
for ticker in tickers:
match = df[df["ISU_SRT_CD"] == ticker]
if not match.empty:
close = match.iloc[0].get("TDD_CLSPRC")
if close and float(close) > 0:
prices[ticker] = Decimal(str(int(float(close))))
if prices:
result[snap_date] = prices
break
except Exception as e:
logger.warning(f"Open API {try_date}: {e}")
return result
def _bulk_fetch_pykrx(dates: list[date]) -> dict[date, dict[str, Decimal]]:
"""Fetch each ticker's full price series in one range query — one API call
per ticker — then extract the needed snapshot dates from the cached result.
"""
import pandas as pd
from pykrx import stock as pykrx_stock
if not dates:
return {}
# Extra buffer so the 5-day fallback window is always covered
start_str = (min(dates) - timedelta(days=7)).strftime("%Y%m%d")
end_str = max(dates).strftime("%Y%m%d")
series: dict[str, pd.DataFrame] = {}
for ticker in ETF_MAP.values():
for attempt in range(3):
try:
df = pykrx_stock.get_etf_ohlcv_by_date(start_str, end_str, ticker)
if df is not None and not df.empty:
series[ticker] = df
break
except (JSONDecodeError, ConnectionError, KeyError, ValueError) as e:
if attempt < 2:
logger.warning(f"pykrx retry {attempt+1}/3 for {ticker}: {e}")
time.sleep(2)
else:
logger.warning(f"pykrx failed for {ticker}: {e}")
result: dict[date, dict[str, Decimal]] = {}
for snap_date in dates:
prices: dict[str, Decimal] = {}
for ticker, df in series.items():
for day_offset in range(5):
ts = pd.Timestamp(snap_date - timedelta(days=day_offset))
if ts in df.index:
close = df.loc[ts, "종가"]
if close and float(close) > 0:
prices[ticker] = Decimal(str(int(close)))
break
if prices:
result[snap_date] = prices
return result
def _bulk_fetch_prices(dates: list[date]) -> dict[date, dict[str, Decimal]]:
"""Fetch prices for all snapshot dates, preferring Open API with pykrx fallback."""
client = get_krx_client()
if client:
result = _bulk_fetch_openapi(dates)
missing = [d for d in dates if d not in result]
if missing:
logger.warning(f"Open API missing {len(missing)} dates, falling back to pykrx")
pykrx_result = _bulk_fetch_pykrx(missing)
result.update(pykrx_result)
return result
return _bulk_fetch_pykrx(dates)
def generate_snapshots(db: Session):
"""Generate portfolio snapshots from trade history with actual market prices."""
portfolio = db.query(Portfolio).filter(Portfolio.name == "연금 포트폴리오").first()
if not portfolio:
logger.error("Portfolio '연금 포트폴리오' not found. Run seed_data.py first.")
return
existing = db.query(PortfolioSnapshot).filter(
PortfolioSnapshot.portfolio_id == portfolio.id
).all()
if existing:
for snap in existing:
db.delete(snap)
db.flush()
logger.info(f"Deleted {len(existing)} existing snapshots")
snapshot_dates = _generate_snapshot_dates()
logger.info(f"Generating {len(snapshot_dates)} snapshots from {snapshot_dates[0]} to {snapshot_dates[-1]}")
# Bulk-fetch all prices upfront — minimises total API calls
logger.info("Fetching prices in bulk...")
all_prices = _bulk_fetch_prices(snapshot_dates)
logger.info(f"Prices fetched for {len(all_prices)}/{len(snapshot_dates)} dates")
created = 0
for snap_date in snapshot_dates:
holdings = _get_holdings_at_date(snap_date)
if not holdings:
continue
date_prices = all_prices.get(snap_date, {})
if not date_prices:
logger.warning(f"Skipping {snap_date}: no prices available")
continue
for ticker in holdings:
if ticker not in date_prices:
logger.warning(f" Missing price for {TICKER_NAMES.get(ticker, ticker)} on {snap_date}")
total_value = Decimal("0")
snapshot_holdings = []
for ticker, qty in holdings.items():
price = date_prices.get(ticker)
if not price:
continue
value = qty * price
total_value += value
snapshot_holdings.append({
"ticker": ticker,
"quantity": qty,
"price": price,
"value": value,
})
if total_value == 0:
continue
snapshot = PortfolioSnapshot(
portfolio_id=portfolio.id,
total_value=total_value,
snapshot_date=snap_date,
)
db.add(snapshot)
db.flush()
for h in snapshot_holdings:
ratio = (h["value"] / total_value * 100).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
db.add(SnapshotHolding(
snapshot_id=snapshot.id,
ticker=h["ticker"],
quantity=h["quantity"],
price=h["price"],
value=h["value"],
current_ratio=ratio,
))
created += 1
logger.info(f" Snapshot {snap_date}: total={total_value:,.0f}")
db.commit()
logger.info(f"Done! Created {created} snapshots.")
if __name__ == "__main__":
db = SessionLocal()
try:
generate_snapshots(db)
finally:
db.close()