""" Generate portfolio snapshots from trade history using actual market prices. Uses KRX Open API when KRX_OPENAPI_KEY is set, falls back to pykrx scraping. Snapshot dates: end of each month where trades occurred, plus latest available. Usage: cd backend && python -m scripts.generate_snapshots Requires: - DATABASE_URL environment variable - KRX_OPENAPI_KEY environment variable (preferred) - KRX_ID / KRX_PW environment variables (pykrx fallback) """ import sys import os import time import logging from datetime import date, datetime, timedelta from decimal import Decimal, ROUND_HALF_UP from collections import defaultdict from json import JSONDecodeError sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.services.krx_client import get_krx_client from app.models.portfolio import ( Portfolio, PortfolioSnapshot, SnapshotHolding, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # ETF name -> ticker mapping (must match seed_data.py) ETF_MAP = { "TIGER 200": "102110", "KIWOOM 국고채10년": "148070", "KODEX 200미국채혼합": "284430", "TIGER 미국S&P500": "360750", "ACE KRX금현물": "411060", } TICKER_NAMES = {v: k for k, v in ETF_MAP.items()} # Trade history (same as seed_data.py) TRADES = [ (date(2025, 4, 29), "ACE KRX금현물", 1, Decimal("21620")), (date(2025, 4, 29), "TIGER 미국S&P500", 329, Decimal("19770")), (date(2025, 4, 29), "KIWOOM 국고채10년", 1, Decimal("118000")), (date(2025, 4, 29), "TIGER 200", 16, Decimal("33815")), (date(2025, 4, 30), "KODEX 200미국채혼합", 355, Decimal("13235")), (date(2025, 5, 13), "ACE KRX금현물", 260, Decimal("20820")), (date(2025, 5, 13), "KODEX 200미국채혼합", 14, Decimal("13165")), (date(2025, 5, 14), "ACE KRX금현물", 45, Decimal("20760")), (date(2025, 5, 14), "TIGER 미국S&P500", 45, Decimal("20690")), (date(2025, 5, 14), "KODEX 200미국채혼합", 733, Decimal("13220")), (date(2025, 5, 14), "KIWOOM 국고채10년", 90, Decimal("116939")), (date(2025, 5, 16), "KODEX 200미국채혼합", 169, Decimal("13125")), (date(2025, 6, 12), "ACE KRX금현물", 14, Decimal("20855")), (date(2025, 6, 12), "TIGER 미국S&P500", 3, Decimal("20355")), (date(2025, 6, 12), "KODEX 200미국채혼합", 88, Decimal("13570")), (date(2025, 6, 12), "KIWOOM 국고채10년", 5, Decimal("115945")), (date(2025, 7, 30), "KIWOOM 국고채10년", 6, Decimal("116760")), (date(2025, 8, 14), "ACE KRX금현물", 6, Decimal("21095")), (date(2025, 8, 14), "TIGER 미국S&P500", 3, Decimal("22200")), (date(2025, 8, 14), "KODEX 200미국채혼합", 27, Decimal("14465")), (date(2025, 8, 14), "KIWOOM 국고채10년", 1, Decimal("117075")), (date(2025, 8, 19), "ACE KRX금현물", 2, Decimal("21030")), (date(2025, 10, 13), "TIGER 미국S&P500", 3, Decimal("23480")), (date(2025, 10, 13), "KIWOOM 국고채10년", 12, Decimal("116465")), (date(2025, 12, 5), "KIWOOM 국고채10년", 7, Decimal("112830")), (date(2026, 1, 7), "TIGER 미국S&P500", 2, Decimal("25015")), (date(2026, 1, 8), "KIWOOM 국고채10년", 11, Decimal("109527")), (date(2026, 2, 20), "TIGER 미국S&P500", 20, Decimal("24685")), (date(2026, 2, 20), "KIWOOM 국고채10년", 9, Decimal("108500")), (date(2026, 3, 23), "ACE KRX금현물", 41, Decimal("30095")), (date(2026, 3, 23), "TIGER 미국S&P500", 128, Decimal("24290")), (date(2026, 3, 23), "KODEX 200미국채혼합", 188, Decimal("19579")), (date(2026, 3, 23), "KIWOOM 국고채10년", 10, Decimal("106780")), ] def _get_holdings_at_date(target_date: date) -> dict[str, int]: """Compute cumulative holdings at a given date.""" holdings: dict[str, int] = defaultdict(int) for trade_date, name, qty, _ in TRADES: if trade_date <= target_date: ticker = ETF_MAP[name] holdings[ticker] += qty return dict(holdings) def _generate_snapshot_dates() -> list[date]: """Generate month-end snapshot dates from first trade to today.""" if not TRADES: return [] first_date = min(t[0] for t in TRADES) today = date.today() dates = [] current = date(first_date.year, first_date.month, 1) while current <= today: if current.month == 12: next_month = date(current.year + 1, 1, 1) else: next_month = date(current.year, current.month + 1, 1) last_day = next_month - timedelta(days=1) if last_day >= first_date and last_day <= today: dates.append(last_day) current = next_month return dates def _fetch_price_openapi(ticker: str, date_str: str) -> Decimal | None: """Fetch closing price via KRX Open API.""" client = get_krx_client() if not client: return None target = datetime.strptime(date_str, "%Y%m%d").date() for day_offset in range(5): try_date = target - timedelta(days=day_offset) try_date_str = try_date.strftime("%Y%m%d") try: df = client.get_etf_daily(try_date_str) if df is not None and not df.empty: match = df[df["ISU_SRT_CD"] == ticker] if not match.empty: close = match.iloc[0].get("TDD_CLSPRC") if close and float(close) > 0: return Decimal(str(int(float(close)))) except Exception as e: logger.warning(f"Open API fetch for {ticker} on {try_date_str}: {e}") continue return None def _fetch_price_pykrx(ticker: str, date_str: str, max_retries: int = 3) -> Decimal | None: """Fetch closing price via pykrx scraping.""" from pykrx import stock as pykrx_stock target = datetime.strptime(date_str, "%Y%m%d").date() for day_offset in range(5): try_date = target - timedelta(days=day_offset) try_date_str = try_date.strftime("%Y%m%d") for attempt in range(max_retries): try: df = pykrx_stock.get_etf_ohlcv_by_date(try_date_str, try_date_str, ticker) if df is not None and not df.empty: close = df.iloc[0]["종가"] if close and float(close) > 0: return Decimal(str(int(close))) except (JSONDecodeError, ConnectionError, KeyError, ValueError) as e: if attempt < max_retries - 1: logger.warning(f"Retry {attempt+1}/{max_retries} for {ticker} on {try_date_str}: {e}") time.sleep(2) continue # Fallback: try stock API (for non-ETF tickers) for day_offset in range(5): try_date = target - timedelta(days=day_offset) try_date_str = try_date.strftime("%Y%m%d") try: df = pykrx_stock.get_market_ohlcv(try_date_str, try_date_str, ticker) if df is not None and not df.empty: close = df.iloc[0]["종가"] if close and float(close) > 0: return Decimal(str(int(close))) except Exception: continue return None def _fetch_price_with_retry(ticker: str, date_str: str, max_retries: int = 3) -> Decimal | None: """Fetch closing price, preferring Open API with pykrx fallback.""" client = get_krx_client() if client: price = _fetch_price_openapi(ticker, date_str) if price: return price logger.warning(f"Open API failed for {ticker} on {date_str}, trying pykrx") return _fetch_price_pykrx(ticker, date_str, max_retries) def generate_snapshots(db: Session): """Generate portfolio snapshots from trade history with actual market prices.""" portfolio = db.query(Portfolio).filter(Portfolio.name == "연금 포트폴리오").first() if not portfolio: logger.error("Portfolio '연금 포트폴리오' not found. Run seed_data.py first.") return existing = db.query(PortfolioSnapshot).filter( PortfolioSnapshot.portfolio_id == portfolio.id ).all() if existing: for snap in existing: db.delete(snap) db.flush() logger.info(f"Deleted {len(existing)} existing snapshots") snapshot_dates = _generate_snapshot_dates() logger.info(f"Generating {len(snapshot_dates)} snapshots from {snapshot_dates[0]} to {snapshot_dates[-1]}") all_tickers = list(ETF_MAP.values()) created = 0 for snap_date in snapshot_dates: holdings = _get_holdings_at_date(snap_date) if not holdings: continue date_str = snap_date.strftime("%Y%m%d") logger.info(f"Processing {snap_date} ({len(holdings)} tickers)...") prices: dict[str, Decimal] = {} for ticker in holdings: price = _fetch_price_with_retry(ticker, date_str) if price: prices[ticker] = price else: logger.warning(f" Could not fetch price for {TICKER_NAMES.get(ticker, ticker)} on {snap_date}") if not prices: logger.warning(f" Skipping {snap_date}: no prices available") continue total_value = Decimal("0") snapshot_holdings = [] for ticker, qty in holdings.items(): if ticker not in prices: continue value = qty * prices[ticker] total_value += value snapshot_holdings.append({ "ticker": ticker, "quantity": qty, "price": prices[ticker], "value": value, }) if total_value == 0: continue snapshot = PortfolioSnapshot( portfolio_id=portfolio.id, total_value=total_value, snapshot_date=snap_date, ) db.add(snapshot) db.flush() for h in snapshot_holdings: ratio = (h["value"] / total_value * 100).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) db.add(SnapshotHolding( snapshot_id=snapshot.id, ticker=h["ticker"], quantity=h["quantity"], price=h["price"], value=h["value"], current_ratio=ratio, )) created += 1 logger.info(f" Snapshot {snap_date}: total={total_value:,.0f}") time.sleep(1) db.commit() logger.info(f"Done! Created {created} snapshots.") if __name__ == "__main__": db = SessionLocal() try: generate_snapshots(db) finally: db.close()