137 lines
4.9 KiB
Python
Raw Normal View History

"""
Price data collector using pykrx.
"""
import logging
from datetime import datetime, timedelta
import pandas as pd
from pykrx import stock as pykrx_stock
from sqlalchemy.orm import Session
from sqlalchemy.dialects.postgresql import insert
from app.services.collectors.base import BaseCollector
from app.models.stock import Price, Stock
logger = logging.getLogger(__name__)
class PriceCollector(BaseCollector):
"""Collects daily OHLCV price data."""
def __init__(self, db: Session, start_date: str = None, end_date: str = None):
super().__init__(db)
self.end_date = end_date or datetime.now().strftime("%Y%m%d")
self.start_date = start_date or (
datetime.now() - timedelta(days=7)
).strftime("%Y%m%d")
self._validate_dates()
def _validate_dates(self) -> None:
"""Validate date formats."""
for date_str, name in [(self.start_date, "start_date"), (self.end_date, "end_date")]:
try:
datetime.strptime(date_str, "%Y%m%d")
except ValueError:
raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}")
def _safe_float(self, value) -> float | None:
"""Safely convert value to float."""
if pd.isna(value):
return None
try:
return float(value)
except (ValueError, TypeError):
return None
def _safe_int(self, value) -> int | None:
"""Safely convert value to int."""
if pd.isna(value):
return None
try:
return int(float(value))
except (ValueError, TypeError):
return None
def collect(self) -> int:
"""Collect price data for all stocks."""
# Get list of tickers from database
tickers = self.db.query(Stock.ticker).all()
ticker_list = [t[0] for t in tickers]
if not ticker_list:
logger.warning("No stocks found in database. Run StockCollector first.")
return 0
total_records = 0
logger.info(f"Collecting prices for {len(ticker_list)} stocks from {self.start_date} to {self.end_date}")
# Fetch prices in batches
for ticker in ticker_list:
try:
df = pykrx_stock.get_market_ohlcv(
self.start_date, self.end_date, ticker
)
if df.empty:
continue
df = df.reset_index()
df.columns = ["date", "open", "high", "low", "close", "volume",
"value"]
# Validate column count
expected_cols = 7 # date + 6 data columns
if len(df.columns) < expected_cols:
logger.warning(f"Unexpected column count for {ticker}: {len(df.columns)}")
continue
records = []
for _, row in df.iterrows():
# Safely convert values with type checking
open_val = self._safe_float(row["open"])
high_val = self._safe_float(row["high"])
low_val = self._safe_float(row["low"])
close_val = self._safe_float(row["close"])
volume_val = self._safe_int(row["volume"])
# Skip if essential values are missing
if close_val is None:
logger.debug(f"Skipping record for {ticker}: missing close price")
continue
date_value = row["date"].date() if hasattr(row["date"], "date") else row["date"]
records.append({
"ticker": ticker,
"date": date_value,
"open": open_val,
"high": high_val,
"low": low_val,
"close": close_val,
"volume": volume_val,
})
if records:
stmt = insert(Price).values(records)
stmt = stmt.on_conflict_do_update(
index_elements=["ticker", "date"],
set_={
"open": stmt.excluded.open,
"high": stmt.excluded.high,
"low": stmt.excluded.low,
"close": stmt.excluded.close,
"volume": stmt.excluded.volume,
},
)
self.db.execute(stmt)
self.db.commit() # Commit per ticker
total_records += len(records)
except Exception as e:
self.db.rollback() # Rollback on failure
logger.warning(f"Failed to fetch prices for {ticker}: {e}")
continue
logger.info(f"Collected {total_records} price records")
return total_records