Files
gogo2/core/prediction_snapshot_storage.py
Dobromir Popov 608da8233f main cleanup
2025-09-30 23:56:36 +03:00

541 lines
21 KiB
Python

#!/usr/bin/env python3
"""
Prediction Snapshot Storage
This module handles storing and retrieving prediction snapshots for future training.
It uses efficient storage formats and provides batch access for training.
"""
import logging
import sqlite3
import json
import pickle
import gzip
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import numpy as np
import pandas as pd
from dataclasses import asdict
from .multi_horizon_prediction_manager import PredictionSnapshot
logger = logging.getLogger(__name__)
class PredictionSnapshotStorage:
"""Efficient storage system for prediction snapshots"""
def __init__(self, storage_dir: str = "data/prediction_snapshots"):
"""Initialize the snapshot storage"""
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
# Database for metadata
self.db_path = self.storage_dir / "snapshots.db"
self._initialize_database()
# Cache for recent snapshots
self.cache_size = 1000
self.snapshot_cache: Dict[str, PredictionSnapshot] = {}
# Compression settings
self.compress_snapshots = True
logger.info(f"PredictionSnapshotStorage initialized: {self.storage_dir}")
def _initialize_database(self):
"""Initialize SQLite database for snapshot metadata"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Snapshots table
cursor.execute("""
CREATE TABLE IF NOT EXISTS snapshots (
prediction_id TEXT PRIMARY KEY,
symbol TEXT NOT NULL,
prediction_time TEXT NOT NULL,
target_horizon_minutes INTEGER NOT NULL,
target_time TEXT NOT NULL,
current_price REAL NOT NULL,
predicted_min_price REAL NOT NULL,
predicted_max_price REAL NOT NULL,
confidence REAL NOT NULL,
outcome_known INTEGER DEFAULT 0,
actual_min_price REAL,
actual_max_price REAL,
outcome_timestamp TEXT,
prediction_basis TEXT,
file_path TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Performance indexes
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbol_time ON snapshots(symbol, prediction_time)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_horizon_outcome ON snapshots(target_horizon_minutes, outcome_known)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_outcome_time ON snapshots(outcome_known, outcome_timestamp)")
# Training batches table for batch processing
cursor.execute("""
CREATE TABLE IF NOT EXISTS training_batches (
batch_id TEXT PRIMARY KEY,
horizon_minutes INTEGER NOT NULL,
symbol TEXT NOT NULL,
prediction_ids TEXT NOT NULL, -- JSON array
batch_size INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processed INTEGER DEFAULT 0,
training_results TEXT -- JSON
)
""")
conn.commit()
def store_snapshot(self, snapshot: PredictionSnapshot) -> bool:
"""Store a prediction snapshot"""
try:
# Generate file path
date_str = snapshot.prediction_time.strftime("%Y%m%d")
symbol_dir = self.storage_dir / snapshot.symbol.replace('/', '_')
symbol_dir.mkdir(exist_ok=True)
file_path = symbol_dir / f"{snapshot.prediction_id}.pkl.gz"
# Store snapshot data
self._store_snapshot_data(snapshot, file_path)
# Store metadata in database
self._store_snapshot_metadata(snapshot, str(file_path))
# Update cache
self.snapshot_cache[snapshot.prediction_id] = snapshot
if len(self.snapshot_cache) > self.cache_size:
# Remove oldest entries
oldest_key = min(self.snapshot_cache.keys(),
key=lambda k: self.snapshot_cache[k].prediction_time)
del self.snapshot_cache[oldest_key]
return True
except Exception as e:
logger.error(f"Error storing snapshot {snapshot.prediction_id}: {e}")
return False
def _store_snapshot_data(self, snapshot: PredictionSnapshot, file_path: Path):
"""Store snapshot data to compressed file"""
try:
# Convert dataclasses to dict for serialization
snapshot_dict = asdict(snapshot)
# Convert numpy arrays to lists for JSON serialization
if 'model_inputs' in snapshot_dict:
model_inputs = snapshot_dict['model_inputs']
for key, value in model_inputs.items():
if isinstance(value, np.ndarray):
model_inputs[key] = value.tolist()
elif isinstance(value, dict):
# Handle nested numpy arrays
for nested_key, nested_value in value.items():
if isinstance(nested_value, np.ndarray):
value[nested_key] = nested_value.tolist()
if self.compress_snapshots:
with gzip.open(file_path, 'wb') as f:
pickle.dump(snapshot_dict, f)
else:
with open(file_path, 'wb') as f:
pickle.dump(snapshot_dict, f)
except Exception as e:
logger.error(f"Error storing snapshot data to {file_path}: {e}")
raise
def _store_snapshot_metadata(self, snapshot: PredictionSnapshot, file_path: str):
"""Store snapshot metadata in database"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO snapshots (
prediction_id, symbol, prediction_time, target_horizon_minutes,
target_time, current_price, predicted_min_price, predicted_max_price,
confidence, outcome_known, actual_min_price, actual_max_price,
outcome_timestamp, prediction_basis, file_path
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
snapshot.prediction_id,
snapshot.symbol,
snapshot.prediction_time.isoformat(),
snapshot.target_horizon_minutes,
snapshot.target_time.isoformat(),
snapshot.current_price,
snapshot.predicted_min_price,
snapshot.predicted_max_price,
snapshot.confidence,
1 if snapshot.outcome_known else 0,
snapshot.actual_min_price,
snapshot.actual_max_price,
snapshot.outcome_timestamp.isoformat() if snapshot.outcome_timestamp else None,
snapshot.prediction_metadata.get('prediction_basis', 'unknown'),
file_path
))
conn.commit()
def update_snapshot_outcome(self, prediction_id: str, actual_min_price: float,
actual_max_price: float, outcome_timestamp: datetime) -> bool:
"""Update a snapshot with actual outcome data"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE snapshots SET
outcome_known = 1,
actual_min_price = ?,
actual_max_price = ?,
outcome_timestamp = ?
WHERE prediction_id = ?
""", (actual_min_price, actual_max_price, outcome_timestamp.isoformat(), prediction_id))
if cursor.rowcount > 0:
# Update cached snapshot if present
if prediction_id in self.snapshot_cache:
snapshot = self.snapshot_cache[prediction_id]
snapshot.outcome_known = True
snapshot.actual_min_price = actual_min_price
snapshot.actual_max_price = actual_max_price
snapshot.outcome_timestamp = outcome_timestamp
return True
else:
logger.warning(f"No snapshot found with prediction_id: {prediction_id}")
return False
except Exception as e:
logger.error(f"Error updating snapshot outcome for {prediction_id}: {e}")
return False
def get_snapshot(self, prediction_id: str) -> Optional[PredictionSnapshot]:
"""Retrieve a single snapshot"""
try:
# Check cache first
if prediction_id in self.snapshot_cache:
return self.snapshot_cache[prediction_id]
# Get metadata from database
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT file_path FROM snapshots WHERE prediction_id = ?", (prediction_id,))
result = cursor.fetchone()
if not result:
return None
file_path = result[0]
# Load snapshot data
return self._load_snapshot_from_file(file_path)
except Exception as e:
logger.error(f"Error retrieving snapshot {prediction_id}: {e}")
return None
def _load_snapshot_from_file(self, file_path: str) -> Optional[PredictionSnapshot]:
"""Load snapshot from compressed file"""
try:
path = Path(file_path)
if self.compress_snapshots:
with gzip.open(path, 'rb') as f:
snapshot_dict = pickle.load(f)
else:
with open(path, 'rb') as f:
snapshot_dict = pickle.load(f)
# Convert back to PredictionSnapshot
return self._dict_to_snapshot(snapshot_dict)
except Exception as e:
logger.error(f"Error loading snapshot from {file_path}: {e}")
return None
def _dict_to_snapshot(self, snapshot_dict: Dict[str, Any]) -> PredictionSnapshot:
"""Convert dictionary back to PredictionSnapshot"""
try:
# Handle datetime conversion
prediction_time = datetime.fromisoformat(snapshot_dict['prediction_time'])
target_time = datetime.fromisoformat(snapshot_dict['target_time'])
outcome_timestamp = None
if snapshot_dict.get('outcome_timestamp'):
outcome_timestamp = datetime.fromisoformat(snapshot_dict['outcome_timestamp'])
return PredictionSnapshot(
prediction_id=snapshot_dict['prediction_id'],
symbol=snapshot_dict['symbol'],
prediction_time=prediction_time,
target_horizon_minutes=snapshot_dict['target_horizon_minutes'],
target_time=target_time,
current_price=snapshot_dict['current_price'],
predicted_min_price=snapshot_dict['predicted_min_price'],
predicted_max_price=snapshot_dict['predicted_max_price'],
confidence=snapshot_dict['confidence'],
model_inputs=snapshot_dict['model_inputs'],
market_state=snapshot_dict['market_state'],
technical_indicators=snapshot_dict['technical_indicators'],
pivot_analysis=snapshot_dict['pivot_analysis'],
prediction_metadata=snapshot_dict['prediction_metadata'],
actual_min_price=snapshot_dict.get('actual_min_price'),
actual_max_price=snapshot_dict.get('actual_max_price'),
outcome_known=snapshot_dict['outcome_known'],
outcome_timestamp=outcome_timestamp
)
except Exception as e:
logger.error(f"Error converting dict to snapshot: {e}")
return None
def get_training_batch(self, horizon_minutes: int, symbol: str,
batch_size: int = 32, min_confidence: float = 0.0) -> List[PredictionSnapshot]:
"""Get a batch of snapshots ready for training"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get snapshots that are ready for training (outcome known)
cursor.execute("""
SELECT prediction_id FROM snapshots
WHERE target_horizon_minutes = ?
AND symbol = ?
AND outcome_known = 1
AND confidence >= ?
ORDER BY outcome_timestamp DESC
LIMIT ?
""", (horizon_minutes, symbol, min_confidence, batch_size))
prediction_ids = [row[0] for row in cursor.fetchall()]
# Load the actual snapshots
snapshots = []
for pred_id in prediction_ids:
snapshot = self.get_snapshot(pred_id)
if snapshot:
snapshots.append(snapshot)
logger.info(f"Retrieved training batch: {len(snapshots)} snapshots for {horizon_minutes}m {symbol}")
return snapshots
except Exception as e:
logger.error(f"Error getting training batch: {e}")
return []
def get_pending_validation_snapshots(self, max_age_hours: int = 24) -> List[PredictionSnapshot]:
"""Get snapshots that need outcome validation"""
try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT prediction_id FROM snapshots
WHERE outcome_known = 0
AND target_time <= ?
ORDER BY target_time ASC
""", (datetime.now().isoformat(),))
prediction_ids = [row[0] for row in cursor.fetchall()]
# Load snapshots
snapshots = []
for pred_id in prediction_ids:
snapshot = self.get_snapshot(pred_id)
if snapshot:
snapshots.append(snapshot)
return snapshots
except Exception as e:
logger.error(f"Error getting pending validation snapshots: {e}")
return []
def create_training_batch(self, horizon_minutes: int, symbol: str,
batch_size: int = 100) -> Optional[str]:
"""Create a training batch for processing"""
try:
batch_id = f"batch_{horizon_minutes}m_{symbol.replace('/', '_')}_{int(datetime.now().timestamp())}"
# Get available snapshots for this batch
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT prediction_id FROM snapshots
WHERE target_horizon_minutes = ?
AND symbol = ?
AND outcome_known = 1
ORDER BY RANDOM()
LIMIT ?
""", (horizon_minutes, symbol, batch_size))
prediction_ids = [row[0] for row in cursor.fetchall()]
if not prediction_ids:
logger.warning(f"No snapshots available for training batch {batch_id}")
return None
# Store batch metadata
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO training_batches (
batch_id, horizon_minutes, symbol, prediction_ids, batch_size
) VALUES (?, ?, ?, ?, ?)
""", (batch_id, horizon_minutes, symbol, json.dumps(prediction_ids), len(prediction_ids)))
conn.commit()
logger.info(f"Created training batch {batch_id} with {len(prediction_ids)} snapshots")
return batch_id
except Exception as e:
logger.error(f"Error creating training batch: {e}")
return None
def get_training_batch_snapshots(self, batch_id: str) -> List[PredictionSnapshot]:
"""Get all snapshots for a training batch"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT prediction_ids FROM training_batches WHERE batch_id = ?", (batch_id,))
result = cursor.fetchone()
if not result:
return []
prediction_ids = json.loads(result[0])
# Load snapshots
snapshots = []
for pred_id in prediction_ids:
snapshot = self.get_snapshot(pred_id)
if snapshot:
snapshots.append(snapshot)
return snapshots
except Exception as e:
logger.error(f"Error getting training batch snapshots: {e}")
return []
def update_training_batch_results(self, batch_id: str, training_results: Dict[str, Any]):
"""Update training batch with results"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE training_batches SET
processed = 1,
training_results = ?
WHERE batch_id = ?
""", (json.dumps(training_results), batch_id))
conn.commit()
logger.info(f"Updated training batch {batch_id} with results")
except Exception as e:
logger.error(f"Error updating training batch results: {e}")
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Total snapshots
cursor.execute("SELECT COUNT(*) FROM snapshots")
total_snapshots = cursor.fetchone()[0]
# Snapshots by horizon
cursor.execute("""
SELECT target_horizon_minutes, COUNT(*)
FROM snapshots
GROUP BY target_horizon_minutes
""")
horizon_counts = dict(cursor.fetchall())
# Outcome statistics
cursor.execute("""
SELECT outcome_known, COUNT(*)
FROM snapshots
GROUP BY outcome_known
""")
outcome_counts = dict(cursor.fetchall())
# Storage size
total_size = 0
for file_path in Path(self.storage_dir).rglob("*.pkl*"):
total_size += file_path.stat().st_size
return {
'total_snapshots': total_snapshots,
'snapshots_by_horizon': horizon_counts,
'outcome_stats': outcome_counts,
'total_storage_mb': total_size / (1024 * 1024),
'cache_size': len(self.snapshot_cache)
}
except Exception as e:
logger.error(f"Error getting storage stats: {e}")
return {}
def cleanup_old_snapshots(self, max_age_days: int = 30):
"""Clean up old snapshots to save space"""
try:
cutoff_date = datetime.now() - timedelta(days=max_age_days)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get old snapshots
cursor.execute("""
SELECT prediction_id, file_path FROM snapshots
WHERE prediction_time < ?
""", (cutoff_date.isoformat(),))
old_snapshots = cursor.fetchall()
# Delete files and database entries
deleted_count = 0
for pred_id, file_path in old_snapshots:
try:
Path(file_path).unlink(missing_ok=True)
deleted_count += 1
except Exception as e:
logger.debug(f"Error deleting file {file_path}: {e}")
# Remove from database
cursor.execute("""
DELETE FROM snapshots WHERE prediction_time < ?
""", (cutoff_date.isoformat(),))
conn.commit()
# Clean up cache
to_remove = []
for pred_id, snapshot in self.snapshot_cache.items():
if snapshot.prediction_time < cutoff_date:
to_remove.append(pred_id)
for pred_id in to_remove:
del self.snapshot_cache[pred_id]
logger.info(f"Cleaned up {deleted_count} old snapshots")
except Exception as e:
logger.error(f"Error cleaning up old snapshots: {e}")