diff --git a/backend/app/services/collectors/__init__.py b/backend/app/services/collectors/__init__.py index c332ee9..82606ce 100644 --- a/backend/app/services/collectors/__init__.py +++ b/backend/app/services/collectors/__init__.py @@ -1,5 +1,6 @@ from app.services.collectors.base import BaseCollector from app.services.collectors.stock_collector import StockCollector from app.services.collectors.sector_collector import SectorCollector +from app.services.collectors.price_collector import PriceCollector -__all__ = ["BaseCollector", "StockCollector", "SectorCollector"] +__all__ = ["BaseCollector", "StockCollector", "SectorCollector", "PriceCollector"] diff --git a/backend/app/services/collectors/price_collector.py b/backend/app/services/collectors/price_collector.py new file mode 100644 index 0000000..942878d --- /dev/null +++ b/backend/app/services/collectors/price_collector.py @@ -0,0 +1,98 @@ +""" +Price data collector using pykrx. +""" +import logging +from datetime import datetime, timedelta + +import pandas as pd +from pykrx import stock as pykrx_stock + +from sqlalchemy.orm import Session +from sqlalchemy.dialects.postgresql import insert + +from app.services.collectors.base import BaseCollector +from app.models.stock import Price, Stock + + +logger = logging.getLogger(__name__) + + +class PriceCollector(BaseCollector): + """Collects daily OHLCV price data.""" + + def __init__(self, db: Session, start_date: str = None, end_date: str = None): + super().__init__(db) + self.end_date = end_date or datetime.now().strftime("%Y%m%d") + self.start_date = start_date or ( + datetime.now() - timedelta(days=7) + ).strftime("%Y%m%d") + self._validate_dates() + + def _validate_dates(self) -> None: + """Validate date formats.""" + for date_str, name in [(self.start_date, "start_date"), (self.end_date, "end_date")]: + try: + datetime.strptime(date_str, "%Y%m%d") + except ValueError: + raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}") + + def collect(self) -> int: + """Collect price data for all stocks.""" + # Get list of tickers from database + tickers = self.db.query(Stock.ticker).all() + ticker_list = [t[0] for t in tickers] + + if not ticker_list: + logger.warning("No stocks found in database. Run StockCollector first.") + return 0 + + total_records = 0 + logger.info(f"Collecting prices for {len(ticker_list)} stocks from {self.start_date} to {self.end_date}") + + # Fetch prices in batches + for ticker in ticker_list: + try: + df = pykrx_stock.get_market_ohlcv( + self.start_date, self.end_date, ticker + ) + if df.empty: + continue + + df = df.reset_index() + df.columns = ["date", "open", "high", "low", "close", "volume", + "value", "change"] + + records = [] + for _, row in df.iterrows(): + records.append({ + "ticker": ticker, + "date": row["date"].date() if hasattr(row["date"], "date") else row["date"], + "open": float(row["open"]), + "high": float(row["high"]), + "low": float(row["low"]), + "close": float(row["close"]), + "volume": int(row["volume"]), + }) + + if records: + stmt = insert(Price).values(records) + stmt = stmt.on_conflict_do_update( + index_elements=["ticker", "date"], + set_={ + "open": stmt.excluded.open, + "high": stmt.excluded.high, + "low": stmt.excluded.low, + "close": stmt.excluded.close, + "volume": stmt.excluded.volume, + }, + ) + self.db.execute(stmt) + total_records += len(records) + + except Exception as e: + logger.warning(f"Failed to fetch prices for {ticker}: {e}") + continue + + self.db.commit() + logger.info(f"Collected {total_records} price records") + return total_records