All checks were successful
Deploy to Production / deploy (push) Successful in 1m8s
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
"""
|
|
Sector data collector from WISEindex.
|
|
"""
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import pandas as pd
|
|
import requests
|
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from app.services.collectors.base import BaseCollector
|
|
from app.models.stock import Sector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SectorCollector(BaseCollector):
|
|
"""Collects WICS sector classification data."""
|
|
|
|
SECTOR_CODES = ["G25", "G35", "G50", "G40", "G10", "G20", "G55", "G30", "G15", "G45"]
|
|
|
|
def __init__(self, db: Session, biz_day: str = None):
|
|
super().__init__(db)
|
|
self.biz_day = biz_day or self._get_latest_biz_day()
|
|
self._validate_biz_day()
|
|
|
|
def _validate_biz_day(self) -> None:
|
|
"""Validate business day format."""
|
|
try:
|
|
datetime.strptime(self.biz_day, "%Y%m%d")
|
|
except ValueError:
|
|
raise ValueError(f"Invalid biz_day format. Expected YYYYMMDD, got: {self.biz_day}")
|
|
|
|
def collect(self) -> int:
|
|
"""Collect sector classification data."""
|
|
all_data = []
|
|
|
|
for sector_code in self.SECTOR_CODES:
|
|
url = f"http://www.wiseindex.com/Index/GetIndexComponets?ceil_yn=0&dt={self.biz_day}&sec_cd={sector_code}"
|
|
try:
|
|
response = requests.get(url, timeout=10)
|
|
data = response.json()
|
|
if "list" in data and data["list"]:
|
|
df = pd.json_normalize(data["list"])
|
|
all_data.append(df)
|
|
except (requests.RequestException, ValueError, KeyError) as e:
|
|
logger.warning(f"Failed to fetch sector {sector_code}: {e}")
|
|
continue
|
|
time.sleep(1) # Rate limiting
|
|
|
|
if not all_data:
|
|
return 0
|
|
|
|
sectors = pd.concat(all_data, axis=0)
|
|
sectors = sectors[["IDX_CD", "CMP_CD", "CMP_KOR", "SEC_NM_KOR"]]
|
|
|
|
records = []
|
|
base_date = datetime.strptime(self.biz_day, "%Y%m%d").date()
|
|
|
|
for _, row in sectors.iterrows():
|
|
records.append({
|
|
"ticker": row["CMP_CD"],
|
|
"sector_code": row["IDX_CD"],
|
|
"company_name": row["CMP_KOR"],
|
|
"sector_name": row["SEC_NM_KOR"],
|
|
"base_date": base_date,
|
|
})
|
|
|
|
if records:
|
|
stmt = insert(Sector).values(records)
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=["ticker"],
|
|
set_={
|
|
"sector_code": stmt.excluded.sector_code,
|
|
"company_name": stmt.excluded.company_name,
|
|
"sector_name": stmt.excluded.sector_name,
|
|
"base_date": stmt.excluded.base_date,
|
|
},
|
|
)
|
|
self.db.execute(stmt)
|
|
self.db.commit()
|
|
|
|
return len(records)
|