All checks were successful
Deploy to Production / deploy (push) Successful in 1m10s
pykrx get_market_ohlcv returns 6 data columns (시가/고가/저가/종가/거래량/거래대금), not 7. The 등락률 (change) column does not exist, causing a length mismatch error. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
137 lines
4.9 KiB
Python
137 lines
4.9 KiB
Python
"""
|
|
Price data collector using pykrx.
|
|
"""
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
|
|
import pandas as pd
|
|
from pykrx import stock as pykrx_stock
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from app.services.collectors.base import BaseCollector
|
|
from app.models.stock import Price, Stock
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PriceCollector(BaseCollector):
|
|
"""Collects daily OHLCV price data."""
|
|
|
|
def __init__(self, db: Session, start_date: str = None, end_date: str = None):
|
|
super().__init__(db)
|
|
self.end_date = end_date or datetime.now().strftime("%Y%m%d")
|
|
self.start_date = start_date or (
|
|
datetime.now() - timedelta(days=7)
|
|
).strftime("%Y%m%d")
|
|
self._validate_dates()
|
|
|
|
def _validate_dates(self) -> None:
|
|
"""Validate date formats."""
|
|
for date_str, name in [(self.start_date, "start_date"), (self.end_date, "end_date")]:
|
|
try:
|
|
datetime.strptime(date_str, "%Y%m%d")
|
|
except ValueError:
|
|
raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}")
|
|
|
|
def _safe_float(self, value) -> float | None:
|
|
"""Safely convert value to float."""
|
|
if pd.isna(value):
|
|
return None
|
|
try:
|
|
return float(value)
|
|
except (ValueError, TypeError):
|
|
return None
|
|
|
|
def _safe_int(self, value) -> int | None:
|
|
"""Safely convert value to int."""
|
|
if pd.isna(value):
|
|
return None
|
|
try:
|
|
return int(float(value))
|
|
except (ValueError, TypeError):
|
|
return None
|
|
|
|
def collect(self) -> int:
|
|
"""Collect price data for all stocks."""
|
|
# Get list of tickers from database
|
|
tickers = self.db.query(Stock.ticker).all()
|
|
ticker_list = [t[0] for t in tickers]
|
|
|
|
if not ticker_list:
|
|
logger.warning("No stocks found in database. Run StockCollector first.")
|
|
return 0
|
|
|
|
total_records = 0
|
|
logger.info(f"Collecting prices for {len(ticker_list)} stocks from {self.start_date} to {self.end_date}")
|
|
|
|
# Fetch prices in batches
|
|
for ticker in ticker_list:
|
|
try:
|
|
df = pykrx_stock.get_market_ohlcv(
|
|
self.start_date, self.end_date, ticker
|
|
)
|
|
if df.empty:
|
|
continue
|
|
|
|
df = df.reset_index()
|
|
df.columns = ["date", "open", "high", "low", "close", "volume",
|
|
"value"]
|
|
|
|
# Validate column count
|
|
expected_cols = 7 # date + 6 data columns
|
|
if len(df.columns) < expected_cols:
|
|
logger.warning(f"Unexpected column count for {ticker}: {len(df.columns)}")
|
|
continue
|
|
|
|
records = []
|
|
for _, row in df.iterrows():
|
|
# Safely convert values with type checking
|
|
open_val = self._safe_float(row["open"])
|
|
high_val = self._safe_float(row["high"])
|
|
low_val = self._safe_float(row["low"])
|
|
close_val = self._safe_float(row["close"])
|
|
volume_val = self._safe_int(row["volume"])
|
|
|
|
# Skip if essential values are missing
|
|
if close_val is None:
|
|
logger.debug(f"Skipping record for {ticker}: missing close price")
|
|
continue
|
|
|
|
date_value = row["date"].date() if hasattr(row["date"], "date") else row["date"]
|
|
records.append({
|
|
"ticker": ticker,
|
|
"date": date_value,
|
|
"open": open_val,
|
|
"high": high_val,
|
|
"low": low_val,
|
|
"close": close_val,
|
|
"volume": volume_val,
|
|
})
|
|
|
|
if records:
|
|
stmt = insert(Price).values(records)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=["ticker", "date"],
|
|
set_={
|
|
"open": stmt.excluded.open,
|
|
"high": stmt.excluded.high,
|
|
"low": stmt.excluded.low,
|
|
"close": stmt.excluded.close,
|
|
"volume": stmt.excluded.volume,
|
|
},
|
|
)
|
|
self.db.execute(stmt)
|
|
self.db.commit() # Commit per ticker
|
|
total_records += len(records)
|
|
|
|
except Exception as e:
|
|
self.db.rollback() # Rollback on failure
|
|
logger.warning(f"Failed to fetch prices for {ticker}: {e}")
|
|
continue
|
|
|
|
logger.info(f"Collected {total_records} price records")
|
|
return total_records
|