From 29f727970df2b96de07e248b0b269029ccc7023f Mon Sep 17 00:00:00 2001 From: zephyrdark Date: Mon, 2 Feb 2026 23:46:16 +0900 Subject: [PATCH] fix: add transaction safety and type validation to PriceCollector - Implement per-ticker commits to ensure atomic operations per data source - Add rollback on exception to prevent partial data corruption - Add _safe_float() and _safe_int() helper methods for defensive type conversion - Validate column count after DataFrame reset to catch schema issues early - Skip records with missing essential values (close price) with debug logging - Remove final db.commit() since commits now happen per ticker in the loop Co-Authored-By: Claude Opus 4.5 --- .../services/collectors/price_collector.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/backend/app/services/collectors/price_collector.py b/backend/app/services/collectors/price_collector.py index 942878d..517cb53 100644 --- a/backend/app/services/collectors/price_collector.py +++ b/backend/app/services/collectors/price_collector.py @@ -36,6 +36,24 @@ class PriceCollector(BaseCollector): except ValueError: raise ValueError(f"Invalid {name} format. Expected YYYYMMDD, got: {date_str}") + def _safe_float(self, value) -> float | None: + """Safely convert value to float.""" + if pd.isna(value): + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + def _safe_int(self, value) -> int | None: + """Safely convert value to int.""" + if pd.isna(value): + return None + try: + return int(float(value)) + except (ValueError, TypeError): + return None + def collect(self) -> int: """Collect price data for all stocks.""" # Get list of tickers from database @@ -62,16 +80,35 @@ class PriceCollector(BaseCollector): df.columns = ["date", "open", "high", "low", "close", "volume", "value", "change"] + # Validate column count + expected_cols = 8 # date + 7 data columns + if len(df.columns) < expected_cols: + logger.warning(f"Unexpected column count for {ticker}: {len(df.columns)}") + continue + records = [] for _, row in df.iterrows(): + # Safely convert values with type checking + open_val = self._safe_float(row["open"]) + high_val = self._safe_float(row["high"]) + low_val = self._safe_float(row["low"]) + close_val = self._safe_float(row["close"]) + volume_val = self._safe_int(row["volume"]) + + # Skip if essential values are missing + if close_val is None: + logger.debug(f"Skipping record for {ticker}: missing close price") + continue + + date_value = row["date"].date() if hasattr(row["date"], "date") else row["date"] records.append({ "ticker": ticker, - "date": row["date"].date() if hasattr(row["date"], "date") else row["date"], - "open": float(row["open"]), - "high": float(row["high"]), - "low": float(row["low"]), - "close": float(row["close"]), - "volume": int(row["volume"]), + "date": date_value, + "open": open_val, + "high": high_val, + "low": low_val, + "close": close_val, + "volume": volume_val, }) if records: @@ -87,12 +124,13 @@ class PriceCollector(BaseCollector): }, ) self.db.execute(stmt) + self.db.commit() # Commit per ticker total_records += len(records) except Exception as e: + self.db.rollback() # Rollback on failure logger.warning(f"Failed to fetch prices for {ticker}: {e}") continue - self.db.commit() logger.info(f"Collected {total_records} price records") return total_records