#!/usr/bin/env python3 """ Prediction Snapshot Storage This module handles storing and retrieving prediction snapshots for future training. It uses efficient storage formats and provides batch access for training. """ import logging import sqlite3 import json import pickle import gzip from datetime import datetime, timedelta from typing import Dict, List, Any, Optional, Tuple from pathlib import Path import numpy as np import pandas as pd from dataclasses import asdict from .multi_horizon_prediction_manager import PredictionSnapshot logger = logging.getLogger(__name__) class PredictionSnapshotStorage: """Efficient storage system for prediction snapshots""" def __init__(self, storage_dir: str = "data/prediction_snapshots"): """Initialize the snapshot storage""" self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(parents=True, exist_ok=True) # Database for metadata self.db_path = self.storage_dir / "snapshots.db" self._initialize_database() # Cache for recent snapshots self.cache_size = 1000 self.snapshot_cache: Dict[str, PredictionSnapshot] = {} # Compression settings self.compress_snapshots = True logger.info(f"PredictionSnapshotStorage initialized: {self.storage_dir}") def _initialize_database(self): """Initialize SQLite database for snapshot metadata""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Snapshots table cursor.execute(""" CREATE TABLE IF NOT EXISTS snapshots ( prediction_id TEXT PRIMARY KEY, symbol TEXT NOT NULL, prediction_time TEXT NOT NULL, target_horizon_minutes INTEGER NOT NULL, target_time TEXT NOT NULL, current_price REAL NOT NULL, predicted_min_price REAL NOT NULL, predicted_max_price REAL NOT NULL, confidence REAL NOT NULL, outcome_known INTEGER DEFAULT 0, actual_min_price REAL, actual_max_price REAL, outcome_timestamp TEXT, prediction_basis TEXT, file_path TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Performance indexes cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbol_time ON snapshots(symbol, prediction_time)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_horizon_outcome ON snapshots(target_horizon_minutes, outcome_known)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_outcome_time ON snapshots(outcome_known, outcome_timestamp)") # Training batches table for batch processing cursor.execute(""" CREATE TABLE IF NOT EXISTS training_batches ( batch_id TEXT PRIMARY KEY, horizon_minutes INTEGER NOT NULL, symbol TEXT NOT NULL, prediction_ids TEXT NOT NULL, -- JSON array batch_size INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, processed INTEGER DEFAULT 0, training_results TEXT -- JSON ) """) conn.commit() def store_snapshot(self, snapshot: PredictionSnapshot) -> bool: """Store a prediction snapshot""" try: # Generate file path date_str = snapshot.prediction_time.strftime("%Y%m%d") symbol_dir = self.storage_dir / snapshot.symbol.replace('/', '_') symbol_dir.mkdir(exist_ok=True) file_path = symbol_dir / f"{snapshot.prediction_id}.pkl.gz" # Store snapshot data self._store_snapshot_data(snapshot, file_path) # Store metadata in database self._store_snapshot_metadata(snapshot, str(file_path)) # Update cache self.snapshot_cache[snapshot.prediction_id] = snapshot if len(self.snapshot_cache) > self.cache_size: # Remove oldest entries oldest_key = min(self.snapshot_cache.keys(), key=lambda k: self.snapshot_cache[k].prediction_time) del self.snapshot_cache[oldest_key] return True except Exception as e: logger.error(f"Error storing snapshot {snapshot.prediction_id}: {e}") return False def _store_snapshot_data(self, snapshot: PredictionSnapshot, file_path: Path): """Store snapshot data to compressed file""" try: # Convert dataclasses to dict for serialization snapshot_dict = asdict(snapshot) # Convert numpy arrays to lists for JSON serialization if 'model_inputs' in snapshot_dict: model_inputs = snapshot_dict['model_inputs'] for key, value in model_inputs.items(): if isinstance(value, np.ndarray): model_inputs[key] = value.tolist() elif isinstance(value, dict): # Handle nested numpy arrays for nested_key, nested_value in value.items(): if isinstance(nested_value, np.ndarray): value[nested_key] = nested_value.tolist() if self.compress_snapshots: with gzip.open(file_path, 'wb') as f: pickle.dump(snapshot_dict, f) else: with open(file_path, 'wb') as f: pickle.dump(snapshot_dict, f) except Exception as e: logger.error(f"Error storing snapshot data to {file_path}: {e}") raise def _store_snapshot_metadata(self, snapshot: PredictionSnapshot, file_path: str): """Store snapshot metadata in database""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" INSERT OR REPLACE INTO snapshots ( prediction_id, symbol, prediction_time, target_horizon_minutes, target_time, current_price, predicted_min_price, predicted_max_price, confidence, outcome_known, actual_min_price, actual_max_price, outcome_timestamp, prediction_basis, file_path ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( snapshot.prediction_id, snapshot.symbol, snapshot.prediction_time.isoformat(), snapshot.target_horizon_minutes, snapshot.target_time.isoformat(), snapshot.current_price, snapshot.predicted_min_price, snapshot.predicted_max_price, snapshot.confidence, 1 if snapshot.outcome_known else 0, snapshot.actual_min_price, snapshot.actual_max_price, snapshot.outcome_timestamp.isoformat() if snapshot.outcome_timestamp else None, snapshot.prediction_metadata.get('prediction_basis', 'unknown'), file_path )) conn.commit() def update_snapshot_outcome(self, prediction_id: str, actual_min_price: float, actual_max_price: float, outcome_timestamp: datetime) -> bool: """Update a snapshot with actual outcome data""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" UPDATE snapshots SET outcome_known = 1, actual_min_price = ?, actual_max_price = ?, outcome_timestamp = ? WHERE prediction_id = ? """, (actual_min_price, actual_max_price, outcome_timestamp.isoformat(), prediction_id)) if cursor.rowcount > 0: # Update cached snapshot if present if prediction_id in self.snapshot_cache: snapshot = self.snapshot_cache[prediction_id] snapshot.outcome_known = True snapshot.actual_min_price = actual_min_price snapshot.actual_max_price = actual_max_price snapshot.outcome_timestamp = outcome_timestamp return True else: logger.warning(f"No snapshot found with prediction_id: {prediction_id}") return False except Exception as e: logger.error(f"Error updating snapshot outcome for {prediction_id}: {e}") return False def get_snapshot(self, prediction_id: str) -> Optional[PredictionSnapshot]: """Retrieve a single snapshot""" try: # Check cache first if prediction_id in self.snapshot_cache: return self.snapshot_cache[prediction_id] # Get metadata from database with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT file_path FROM snapshots WHERE prediction_id = ?", (prediction_id,)) result = cursor.fetchone() if not result: return None file_path = result[0] # Load snapshot data return self._load_snapshot_from_file(file_path) except Exception as e: logger.error(f"Error retrieving snapshot {prediction_id}: {e}") return None def _load_snapshot_from_file(self, file_path: str) -> Optional[PredictionSnapshot]: """Load snapshot from compressed file""" try: path = Path(file_path) if self.compress_snapshots: with gzip.open(path, 'rb') as f: snapshot_dict = pickle.load(f) else: with open(path, 'rb') as f: snapshot_dict = pickle.load(f) # Convert back to PredictionSnapshot return self._dict_to_snapshot(snapshot_dict) except Exception as e: logger.error(f"Error loading snapshot from {file_path}: {e}") return None def _dict_to_snapshot(self, snapshot_dict: Dict[str, Any]) -> PredictionSnapshot: """Convert dictionary back to PredictionSnapshot""" try: # Handle datetime conversion prediction_time = datetime.fromisoformat(snapshot_dict['prediction_time']) target_time = datetime.fromisoformat(snapshot_dict['target_time']) outcome_timestamp = None if snapshot_dict.get('outcome_timestamp'): outcome_timestamp = datetime.fromisoformat(snapshot_dict['outcome_timestamp']) return PredictionSnapshot( prediction_id=snapshot_dict['prediction_id'], symbol=snapshot_dict['symbol'], prediction_time=prediction_time, target_horizon_minutes=snapshot_dict['target_horizon_minutes'], target_time=target_time, current_price=snapshot_dict['current_price'], predicted_min_price=snapshot_dict['predicted_min_price'], predicted_max_price=snapshot_dict['predicted_max_price'], confidence=snapshot_dict['confidence'], model_inputs=snapshot_dict['model_inputs'], market_state=snapshot_dict['market_state'], technical_indicators=snapshot_dict['technical_indicators'], pivot_analysis=snapshot_dict['pivot_analysis'], prediction_metadata=snapshot_dict['prediction_metadata'], actual_min_price=snapshot_dict.get('actual_min_price'), actual_max_price=snapshot_dict.get('actual_max_price'), outcome_known=snapshot_dict['outcome_known'], outcome_timestamp=outcome_timestamp ) except Exception as e: logger.error(f"Error converting dict to snapshot: {e}") return None def get_training_batch(self, horizon_minutes: int, symbol: str, batch_size: int = 32, min_confidence: float = 0.0) -> List[PredictionSnapshot]: """Get a batch of snapshots ready for training""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Get snapshots that are ready for training (outcome known) cursor.execute(""" SELECT prediction_id FROM snapshots WHERE target_horizon_minutes = ? AND symbol = ? AND outcome_known = 1 AND confidence >= ? ORDER BY outcome_timestamp DESC LIMIT ? """, (horizon_minutes, symbol, min_confidence, batch_size)) prediction_ids = [row[0] for row in cursor.fetchall()] # Load the actual snapshots snapshots = [] for pred_id in prediction_ids: snapshot = self.get_snapshot(pred_id) if snapshot: snapshots.append(snapshot) logger.info(f"Retrieved training batch: {len(snapshots)} snapshots for {horizon_minutes}m {symbol}") return snapshots except Exception as e: logger.error(f"Error getting training batch: {e}") return [] def get_pending_validation_snapshots(self, max_age_hours: int = 24) -> List[PredictionSnapshot]: """Get snapshots that need outcome validation""" try: cutoff_time = datetime.now() - timedelta(hours=max_age_hours) with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" SELECT prediction_id FROM snapshots WHERE outcome_known = 0 AND target_time <= ? ORDER BY target_time ASC """, (datetime.now().isoformat(),)) prediction_ids = [row[0] for row in cursor.fetchall()] # Load snapshots snapshots = [] for pred_id in prediction_ids: snapshot = self.get_snapshot(pred_id) if snapshot: snapshots.append(snapshot) return snapshots except Exception as e: logger.error(f"Error getting pending validation snapshots: {e}") return [] def create_training_batch(self, horizon_minutes: int, symbol: str, batch_size: int = 100) -> Optional[str]: """Create a training batch for processing""" try: batch_id = f"batch_{horizon_minutes}m_{symbol.replace('/', '_')}_{int(datetime.now().timestamp())}" # Get available snapshots for this batch with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" SELECT prediction_id FROM snapshots WHERE target_horizon_minutes = ? AND symbol = ? AND outcome_known = 1 ORDER BY RANDOM() LIMIT ? """, (horizon_minutes, symbol, batch_size)) prediction_ids = [row[0] for row in cursor.fetchall()] if not prediction_ids: logger.warning(f"No snapshots available for training batch {batch_id}") return None # Store batch metadata with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO training_batches ( batch_id, horizon_minutes, symbol, prediction_ids, batch_size ) VALUES (?, ?, ?, ?, ?) """, (batch_id, horizon_minutes, symbol, json.dumps(prediction_ids), len(prediction_ids))) conn.commit() logger.info(f"Created training batch {batch_id} with {len(prediction_ids)} snapshots") return batch_id except Exception as e: logger.error(f"Error creating training batch: {e}") return None def get_training_batch_snapshots(self, batch_id: str) -> List[PredictionSnapshot]: """Get all snapshots for a training batch""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT prediction_ids FROM training_batches WHERE batch_id = ?", (batch_id,)) result = cursor.fetchone() if not result: return [] prediction_ids = json.loads(result[0]) # Load snapshots snapshots = [] for pred_id in prediction_ids: snapshot = self.get_snapshot(pred_id) if snapshot: snapshots.append(snapshot) return snapshots except Exception as e: logger.error(f"Error getting training batch snapshots: {e}") return [] def update_training_batch_results(self, batch_id: str, training_results: Dict[str, Any]): """Update training batch with results""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" UPDATE training_batches SET processed = 1, training_results = ? WHERE batch_id = ? """, (json.dumps(training_results), batch_id)) conn.commit() logger.info(f"Updated training batch {batch_id} with results") except Exception as e: logger.error(f"Error updating training batch results: {e}") def get_storage_stats(self) -> Dict[str, Any]: """Get storage statistics""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Total snapshots cursor.execute("SELECT COUNT(*) FROM snapshots") total_snapshots = cursor.fetchone()[0] # Snapshots by horizon cursor.execute(""" SELECT target_horizon_minutes, COUNT(*) FROM snapshots GROUP BY target_horizon_minutes """) horizon_counts = dict(cursor.fetchall()) # Outcome statistics cursor.execute(""" SELECT outcome_known, COUNT(*) FROM snapshots GROUP BY outcome_known """) outcome_counts = dict(cursor.fetchall()) # Storage size total_size = 0 for file_path in Path(self.storage_dir).rglob("*.pkl*"): total_size += file_path.stat().st_size return { 'total_snapshots': total_snapshots, 'snapshots_by_horizon': horizon_counts, 'outcome_stats': outcome_counts, 'total_storage_mb': total_size / (1024 * 1024), 'cache_size': len(self.snapshot_cache) } except Exception as e: logger.error(f"Error getting storage stats: {e}") return {} def cleanup_old_snapshots(self, max_age_days: int = 30): """Clean up old snapshots to save space""" try: cutoff_date = datetime.now() - timedelta(days=max_age_days) with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Get old snapshots cursor.execute(""" SELECT prediction_id, file_path FROM snapshots WHERE prediction_time < ? """, (cutoff_date.isoformat(),)) old_snapshots = cursor.fetchall() # Delete files and database entries deleted_count = 0 for pred_id, file_path in old_snapshots: try: Path(file_path).unlink(missing_ok=True) deleted_count += 1 except Exception as e: logger.debug(f"Error deleting file {file_path}: {e}") # Remove from database cursor.execute(""" DELETE FROM snapshots WHERE prediction_time < ? """, (cutoff_date.isoformat(),)) conn.commit() # Clean up cache to_remove = [] for pred_id, snapshot in self.snapshot_cache.items(): if snapshot.prediction_time < cutoff_date: to_remove.append(pred_id) for pred_id in to_remove: del self.snapshot_cache[pred_id] logger.info(f"Cleaned up {deleted_count} old snapshots") except Exception as e: logger.error(f"Error cleaning up old snapshots: {e}")