feat: add price data collector using pykrx
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
aed636f2b3
commit
135d55b488
@ -1,5 +1,6 @@
|
||||
from app.services.collectors.base import BaseCollector
|
||||
from app.services.collectors.stock_collector import StockCollector
|
||||
from app.services.collectors.sector_collector import SectorCollector
|
||||
from app.services.collectors.price_collector import PriceCollector
|
||||
|
||||
__all__ = ["BaseCollector", "StockCollector", "SectorCollector"]
|
||||
__all__ = ["BaseCollector", "StockCollector", "SectorCollector", "PriceCollector"]
|
||||
|
||||
98
backend/app/services/collectors/price_collector.py
Normal file
98
backend/app/services/collectors/price_collector.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Price data collector using pykrx.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
from pykrx import stock as pykrx_stock
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from app.services.collectors.base import BaseCollector
|
||||
from app.models.stock import Price, Stock
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PriceCollector(BaseCollector):
|
||||
"""Collects daily OHLCV price data."""
|
||||
|
||||
def __init__(self, db: Session, start_date: str = None, end_date: str = None):
|
||||
super().__init__(db)
|
||||
self.end_date = end_date or datetime.now().strftime("%Y%m%d")
|
||||
self.start_date = start_date or (
|
||||
datetime.now() - timedelta(days=7)
|
||||
).strftime("%Y%m%d")
|
||||
self._validate_dates()
|
||||
|
||||
def _validate_dates(self) -> None:
|
||||
"""Validate date formats."""
|
||||
for date_str, name in [(self.start_date, "start_date"), (self.end_date, "end_date")]:
|
||||
try:
|
||||
datetime.strptime(date_str, "%Y%m%d")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}")
|
||||
|
||||
def collect(self) -> int:
|
||||
"""Collect price data for all stocks."""
|
||||
# Get list of tickers from database
|
||||
tickers = self.db.query(Stock.ticker).all()
|
||||
ticker_list = [t[0] for t in tickers]
|
||||
|
||||
if not ticker_list:
|
||||
logger.warning("No stocks found in database. Run StockCollector first.")
|
||||
return 0
|
||||
|
||||
total_records = 0
|
||||
logger.info(f"Collecting prices for {len(ticker_list)} stocks from {self.start_date} to {self.end_date}")
|
||||
|
||||
# Fetch prices in batches
|
||||
for ticker in ticker_list:
|
||||
try:
|
||||
df = pykrx_stock.get_market_ohlcv(
|
||||
self.start_date, self.end_date, ticker
|
||||
)
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
df = df.reset_index()
|
||||
df.columns = ["date", "open", "high", "low", "close", "volume",
|
||||
"value", "change"]
|
||||
|
||||
records = []
|
||||
for _, row in df.iterrows():
|
||||
records.append({
|
||||
"ticker": ticker,
|
||||
"date": row["date"].date() if hasattr(row["date"], "date") else row["date"],
|
||||
"open": float(row["open"]),
|
||||
"high": float(row["high"]),
|
||||
"low": float(row["low"]),
|
||||
"close": float(row["close"]),
|
||||
"volume": int(row["volume"]),
|
||||
})
|
||||
|
||||
if records:
|
||||
stmt = insert(Price).values(records)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["ticker", "date"],
|
||||
set_={
|
||||
"open": stmt.excluded.open,
|
||||
"high": stmt.excluded.high,
|
||||
"low": stmt.excluded.low,
|
||||
"close": stmt.excluded.close,
|
||||
"volume": stmt.excluded.volume,
|
||||
},
|
||||
)
|
||||
self.db.execute(stmt)
|
||||
total_records += len(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch prices for {ticker}: {e}")
|
||||
continue
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Collected {total_records} price records")
|
||||
return total_records
|
||||
Loading…
x
Reference in New Issue
Block a user