fix: add transaction safety and type validation to PriceCollector
- Implement per-ticker commits to ensure atomic operations per data source - Add rollback on exception to prevent partial data corruption - Add _safe_float() and _safe_int() helper methods for defensive type conversion - Validate column count after DataFrame reset to catch schema issues early - Skip records with missing essential values (close price) with debug logging - Remove final db.commit() since commits now happen per ticker in the loop Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
135d55b488
commit
29f727970d
@ -36,6 +36,24 @@ class PriceCollector(BaseCollector):
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}")
|
||||
|
||||
def _safe_float(self, value) -> float | None:
|
||||
"""Safely convert value to float."""
|
||||
if pd.isna(value):
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _safe_int(self, value) -> int | None:
|
||||
"""Safely convert value to int."""
|
||||
if pd.isna(value):
|
||||
return None
|
||||
try:
|
||||
return int(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def collect(self) -> int:
|
||||
"""Collect price data for all stocks."""
|
||||
# Get list of tickers from database
|
||||
@ -62,16 +80,35 @@ class PriceCollector(BaseCollector):
|
||||
df.columns = ["date", "open", "high", "low", "close", "volume",
|
||||
"value", "change"]
|
||||
|
||||
# Validate column count
|
||||
expected_cols = 8 # date + 7 data columns
|
||||
if len(df.columns) < expected_cols:
|
||||
logger.warning(f"Unexpected column count for {ticker}: {len(df.columns)}")
|
||||
continue
|
||||
|
||||
records = []
|
||||
for _, row in df.iterrows():
|
||||
# Safely convert values with type checking
|
||||
open_val = self._safe_float(row["open"])
|
||||
high_val = self._safe_float(row["high"])
|
||||
low_val = self._safe_float(row["low"])
|
||||
close_val = self._safe_float(row["close"])
|
||||
volume_val = self._safe_int(row["volume"])
|
||||
|
||||
# Skip if essential values are missing
|
||||
if close_val is None:
|
||||
logger.debug(f"Skipping record for {ticker}: missing close price")
|
||||
continue
|
||||
|
||||
date_value = row["date"].date() if hasattr(row["date"], "date") else row["date"]
|
||||
records.append({
|
||||
"ticker": ticker,
|
||||
"date": row["date"].date() if hasattr(row["date"], "date") else row["date"],
|
||||
"open": float(row["open"]),
|
||||
"high": float(row["high"]),
|
||||
"low": float(row["low"]),
|
||||
"close": float(row["close"]),
|
||||
"volume": int(row["volume"]),
|
||||
"date": date_value,
|
||||
"open": open_val,
|
||||
"high": high_val,
|
||||
"low": low_val,
|
||||
"close": close_val,
|
||||
"volume": volume_val,
|
||||
})
|
||||
|
||||
if records:
|
||||
@ -87,12 +124,13 @@ class PriceCollector(BaseCollector):
|
||||
},
|
||||
)
|
||||
self.db.execute(stmt)
|
||||
self.db.commit() # Commit per ticker
|
||||
total_records += len(records)
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback() # Rollback on failure
|
||||
logger.warning(f"Failed to fetch prices for {ticker}: {e}")
|
||||
continue
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Collected {total_records} price records")
|
||||
return total_records
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user