541 lines
21 KiB
Python
541 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Prediction Snapshot Storage
|
|
|
|
This module handles storing and retrieving prediction snapshots for future training.
|
|
It uses efficient storage formats and provides batch access for training.
|
|
"""
|
|
|
|
import logging
|
|
import sqlite3
|
|
import json
|
|
import pickle
|
|
import gzip
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import pandas as pd
|
|
from dataclasses import asdict
|
|
|
|
from .multi_horizon_prediction_manager import PredictionSnapshot
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PredictionSnapshotStorage:
|
|
"""Efficient storage system for prediction snapshots"""
|
|
|
|
def __init__(self, storage_dir: str = "data/prediction_snapshots"):
|
|
"""Initialize the snapshot storage"""
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Database for metadata
|
|
self.db_path = self.storage_dir / "snapshots.db"
|
|
self._initialize_database()
|
|
|
|
# Cache for recent snapshots
|
|
self.cache_size = 1000
|
|
self.snapshot_cache: Dict[str, PredictionSnapshot] = {}
|
|
|
|
# Compression settings
|
|
self.compress_snapshots = True
|
|
|
|
logger.info(f"PredictionSnapshotStorage initialized: {self.storage_dir}")
|
|
|
|
def _initialize_database(self):
|
|
"""Initialize SQLite database for snapshot metadata"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Snapshots table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS snapshots (
|
|
prediction_id TEXT PRIMARY KEY,
|
|
symbol TEXT NOT NULL,
|
|
prediction_time TEXT NOT NULL,
|
|
target_horizon_minutes INTEGER NOT NULL,
|
|
target_time TEXT NOT NULL,
|
|
current_price REAL NOT NULL,
|
|
predicted_min_price REAL NOT NULL,
|
|
predicted_max_price REAL NOT NULL,
|
|
confidence REAL NOT NULL,
|
|
outcome_known INTEGER DEFAULT 0,
|
|
actual_min_price REAL,
|
|
actual_max_price REAL,
|
|
outcome_timestamp TEXT,
|
|
prediction_basis TEXT,
|
|
file_path TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Performance indexes
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbol_time ON snapshots(symbol, prediction_time)")
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_horizon_outcome ON snapshots(target_horizon_minutes, outcome_known)")
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_outcome_time ON snapshots(outcome_known, outcome_timestamp)")
|
|
|
|
# Training batches table for batch processing
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS training_batches (
|
|
batch_id TEXT PRIMARY KEY,
|
|
horizon_minutes INTEGER NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
prediction_ids TEXT NOT NULL, -- JSON array
|
|
batch_size INTEGER NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
processed INTEGER DEFAULT 0,
|
|
training_results TEXT -- JSON
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
|
|
def store_snapshot(self, snapshot: PredictionSnapshot) -> bool:
|
|
"""Store a prediction snapshot"""
|
|
try:
|
|
# Generate file path
|
|
date_str = snapshot.prediction_time.strftime("%Y%m%d")
|
|
symbol_dir = self.storage_dir / snapshot.symbol.replace('/', '_')
|
|
symbol_dir.mkdir(exist_ok=True)
|
|
|
|
file_path = symbol_dir / f"{snapshot.prediction_id}.pkl.gz"
|
|
|
|
# Store snapshot data
|
|
self._store_snapshot_data(snapshot, file_path)
|
|
|
|
# Store metadata in database
|
|
self._store_snapshot_metadata(snapshot, str(file_path))
|
|
|
|
# Update cache
|
|
self.snapshot_cache[snapshot.prediction_id] = snapshot
|
|
if len(self.snapshot_cache) > self.cache_size:
|
|
# Remove oldest entries
|
|
oldest_key = min(self.snapshot_cache.keys(),
|
|
key=lambda k: self.snapshot_cache[k].prediction_time)
|
|
del self.snapshot_cache[oldest_key]
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing snapshot {snapshot.prediction_id}: {e}")
|
|
return False
|
|
|
|
def _store_snapshot_data(self, snapshot: PredictionSnapshot, file_path: Path):
|
|
"""Store snapshot data to compressed file"""
|
|
try:
|
|
# Convert dataclasses to dict for serialization
|
|
snapshot_dict = asdict(snapshot)
|
|
|
|
# Convert numpy arrays to lists for JSON serialization
|
|
if 'model_inputs' in snapshot_dict:
|
|
model_inputs = snapshot_dict['model_inputs']
|
|
for key, value in model_inputs.items():
|
|
if isinstance(value, np.ndarray):
|
|
model_inputs[key] = value.tolist()
|
|
elif isinstance(value, dict):
|
|
# Handle nested numpy arrays
|
|
for nested_key, nested_value in value.items():
|
|
if isinstance(nested_value, np.ndarray):
|
|
value[nested_key] = nested_value.tolist()
|
|
|
|
if self.compress_snapshots:
|
|
with gzip.open(file_path, 'wb') as f:
|
|
pickle.dump(snapshot_dict, f)
|
|
else:
|
|
with open(file_path, 'wb') as f:
|
|
pickle.dump(snapshot_dict, f)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing snapshot data to {file_path}: {e}")
|
|
raise
|
|
|
|
def _store_snapshot_metadata(self, snapshot: PredictionSnapshot, file_path: str):
|
|
"""Store snapshot metadata in database"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT OR REPLACE INTO snapshots (
|
|
prediction_id, symbol, prediction_time, target_horizon_minutes,
|
|
target_time, current_price, predicted_min_price, predicted_max_price,
|
|
confidence, outcome_known, actual_min_price, actual_max_price,
|
|
outcome_timestamp, prediction_basis, file_path
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
snapshot.prediction_id,
|
|
snapshot.symbol,
|
|
snapshot.prediction_time.isoformat(),
|
|
snapshot.target_horizon_minutes,
|
|
snapshot.target_time.isoformat(),
|
|
snapshot.current_price,
|
|
snapshot.predicted_min_price,
|
|
snapshot.predicted_max_price,
|
|
snapshot.confidence,
|
|
1 if snapshot.outcome_known else 0,
|
|
snapshot.actual_min_price,
|
|
snapshot.actual_max_price,
|
|
snapshot.outcome_timestamp.isoformat() if snapshot.outcome_timestamp else None,
|
|
snapshot.prediction_metadata.get('prediction_basis', 'unknown'),
|
|
file_path
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
def update_snapshot_outcome(self, prediction_id: str, actual_min_price: float,
|
|
actual_max_price: float, outcome_timestamp: datetime) -> bool:
|
|
"""Update a snapshot with actual outcome data"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
UPDATE snapshots SET
|
|
outcome_known = 1,
|
|
actual_min_price = ?,
|
|
actual_max_price = ?,
|
|
outcome_timestamp = ?
|
|
WHERE prediction_id = ?
|
|
""", (actual_min_price, actual_max_price, outcome_timestamp.isoformat(), prediction_id))
|
|
|
|
if cursor.rowcount > 0:
|
|
# Update cached snapshot if present
|
|
if prediction_id in self.snapshot_cache:
|
|
snapshot = self.snapshot_cache[prediction_id]
|
|
snapshot.outcome_known = True
|
|
snapshot.actual_min_price = actual_min_price
|
|
snapshot.actual_max_price = actual_max_price
|
|
snapshot.outcome_timestamp = outcome_timestamp
|
|
|
|
return True
|
|
else:
|
|
logger.warning(f"No snapshot found with prediction_id: {prediction_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating snapshot outcome for {prediction_id}: {e}")
|
|
return False
|
|
|
|
def get_snapshot(self, prediction_id: str) -> Optional[PredictionSnapshot]:
|
|
"""Retrieve a single snapshot"""
|
|
try:
|
|
# Check cache first
|
|
if prediction_id in self.snapshot_cache:
|
|
return self.snapshot_cache[prediction_id]
|
|
|
|
# Get metadata from database
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT file_path FROM snapshots WHERE prediction_id = ?", (prediction_id,))
|
|
result = cursor.fetchone()
|
|
|
|
if not result:
|
|
return None
|
|
|
|
file_path = result[0]
|
|
|
|
# Load snapshot data
|
|
return self._load_snapshot_from_file(file_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving snapshot {prediction_id}: {e}")
|
|
return None
|
|
|
|
def _load_snapshot_from_file(self, file_path: str) -> Optional[PredictionSnapshot]:
|
|
"""Load snapshot from compressed file"""
|
|
try:
|
|
path = Path(file_path)
|
|
|
|
if self.compress_snapshots:
|
|
with gzip.open(path, 'rb') as f:
|
|
snapshot_dict = pickle.load(f)
|
|
else:
|
|
with open(path, 'rb') as f:
|
|
snapshot_dict = pickle.load(f)
|
|
|
|
# Convert back to PredictionSnapshot
|
|
return self._dict_to_snapshot(snapshot_dict)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading snapshot from {file_path}: {e}")
|
|
return None
|
|
|
|
def _dict_to_snapshot(self, snapshot_dict: Dict[str, Any]) -> PredictionSnapshot:
|
|
"""Convert dictionary back to PredictionSnapshot"""
|
|
try:
|
|
# Handle datetime conversion
|
|
prediction_time = datetime.fromisoformat(snapshot_dict['prediction_time'])
|
|
target_time = datetime.fromisoformat(snapshot_dict['target_time'])
|
|
outcome_timestamp = None
|
|
if snapshot_dict.get('outcome_timestamp'):
|
|
outcome_timestamp = datetime.fromisoformat(snapshot_dict['outcome_timestamp'])
|
|
|
|
return PredictionSnapshot(
|
|
prediction_id=snapshot_dict['prediction_id'],
|
|
symbol=snapshot_dict['symbol'],
|
|
prediction_time=prediction_time,
|
|
target_horizon_minutes=snapshot_dict['target_horizon_minutes'],
|
|
target_time=target_time,
|
|
current_price=snapshot_dict['current_price'],
|
|
predicted_min_price=snapshot_dict['predicted_min_price'],
|
|
predicted_max_price=snapshot_dict['predicted_max_price'],
|
|
confidence=snapshot_dict['confidence'],
|
|
model_inputs=snapshot_dict['model_inputs'],
|
|
market_state=snapshot_dict['market_state'],
|
|
technical_indicators=snapshot_dict['technical_indicators'],
|
|
pivot_analysis=snapshot_dict['pivot_analysis'],
|
|
prediction_metadata=snapshot_dict['prediction_metadata'],
|
|
actual_min_price=snapshot_dict.get('actual_min_price'),
|
|
actual_max_price=snapshot_dict.get('actual_max_price'),
|
|
outcome_known=snapshot_dict['outcome_known'],
|
|
outcome_timestamp=outcome_timestamp
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting dict to snapshot: {e}")
|
|
return None
|
|
|
|
def get_training_batch(self, horizon_minutes: int, symbol: str,
|
|
batch_size: int = 32, min_confidence: float = 0.0) -> List[PredictionSnapshot]:
|
|
"""Get a batch of snapshots ready for training"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Get snapshots that are ready for training (outcome known)
|
|
cursor.execute("""
|
|
SELECT prediction_id FROM snapshots
|
|
WHERE target_horizon_minutes = ?
|
|
AND symbol = ?
|
|
AND outcome_known = 1
|
|
AND confidence >= ?
|
|
ORDER BY outcome_timestamp DESC
|
|
LIMIT ?
|
|
""", (horizon_minutes, symbol, min_confidence, batch_size))
|
|
|
|
prediction_ids = [row[0] for row in cursor.fetchall()]
|
|
|
|
# Load the actual snapshots
|
|
snapshots = []
|
|
for pred_id in prediction_ids:
|
|
snapshot = self.get_snapshot(pred_id)
|
|
if snapshot:
|
|
snapshots.append(snapshot)
|
|
|
|
logger.info(f"Retrieved training batch: {len(snapshots)} snapshots for {horizon_minutes}m {symbol}")
|
|
return snapshots
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training batch: {e}")
|
|
return []
|
|
|
|
def get_pending_validation_snapshots(self, max_age_hours: int = 24) -> List[PredictionSnapshot]:
|
|
"""Get snapshots that need outcome validation"""
|
|
try:
|
|
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT prediction_id FROM snapshots
|
|
WHERE outcome_known = 0
|
|
AND target_time <= ?
|
|
ORDER BY target_time ASC
|
|
""", (datetime.now().isoformat(),))
|
|
|
|
prediction_ids = [row[0] for row in cursor.fetchall()]
|
|
|
|
# Load snapshots
|
|
snapshots = []
|
|
for pred_id in prediction_ids:
|
|
snapshot = self.get_snapshot(pred_id)
|
|
if snapshot:
|
|
snapshots.append(snapshot)
|
|
|
|
return snapshots
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pending validation snapshots: {e}")
|
|
return []
|
|
|
|
def create_training_batch(self, horizon_minutes: int, symbol: str,
|
|
batch_size: int = 100) -> Optional[str]:
|
|
"""Create a training batch for processing"""
|
|
try:
|
|
batch_id = f"batch_{horizon_minutes}m_{symbol.replace('/', '_')}_{int(datetime.now().timestamp())}"
|
|
|
|
# Get available snapshots for this batch
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT prediction_id FROM snapshots
|
|
WHERE target_horizon_minutes = ?
|
|
AND symbol = ?
|
|
AND outcome_known = 1
|
|
ORDER BY RANDOM()
|
|
LIMIT ?
|
|
""", (horizon_minutes, symbol, batch_size))
|
|
|
|
prediction_ids = [row[0] for row in cursor.fetchall()]
|
|
|
|
if not prediction_ids:
|
|
logger.warning(f"No snapshots available for training batch {batch_id}")
|
|
return None
|
|
|
|
# Store batch metadata
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO training_batches (
|
|
batch_id, horizon_minutes, symbol, prediction_ids, batch_size
|
|
) VALUES (?, ?, ?, ?, ?)
|
|
""", (batch_id, horizon_minutes, symbol, json.dumps(prediction_ids), len(prediction_ids)))
|
|
|
|
conn.commit()
|
|
|
|
logger.info(f"Created training batch {batch_id} with {len(prediction_ids)} snapshots")
|
|
return batch_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating training batch: {e}")
|
|
return None
|
|
|
|
def get_training_batch_snapshots(self, batch_id: str) -> List[PredictionSnapshot]:
|
|
"""Get all snapshots for a training batch"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT prediction_ids FROM training_batches WHERE batch_id = ?", (batch_id,))
|
|
result = cursor.fetchone()
|
|
|
|
if not result:
|
|
return []
|
|
|
|
prediction_ids = json.loads(result[0])
|
|
|
|
# Load snapshots
|
|
snapshots = []
|
|
for pred_id in prediction_ids:
|
|
snapshot = self.get_snapshot(pred_id)
|
|
if snapshot:
|
|
snapshots.append(snapshot)
|
|
|
|
return snapshots
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training batch snapshots: {e}")
|
|
return []
|
|
|
|
def update_training_batch_results(self, batch_id: str, training_results: Dict[str, Any]):
|
|
"""Update training batch with results"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
UPDATE training_batches SET
|
|
processed = 1,
|
|
training_results = ?
|
|
WHERE batch_id = ?
|
|
""", (json.dumps(training_results), batch_id))
|
|
|
|
conn.commit()
|
|
|
|
logger.info(f"Updated training batch {batch_id} with results")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating training batch results: {e}")
|
|
|
|
def get_storage_stats(self) -> Dict[str, Any]:
|
|
"""Get storage statistics"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Total snapshots
|
|
cursor.execute("SELECT COUNT(*) FROM snapshots")
|
|
total_snapshots = cursor.fetchone()[0]
|
|
|
|
# Snapshots by horizon
|
|
cursor.execute("""
|
|
SELECT target_horizon_minutes, COUNT(*)
|
|
FROM snapshots
|
|
GROUP BY target_horizon_minutes
|
|
""")
|
|
horizon_counts = dict(cursor.fetchall())
|
|
|
|
# Outcome statistics
|
|
cursor.execute("""
|
|
SELECT outcome_known, COUNT(*)
|
|
FROM snapshots
|
|
GROUP BY outcome_known
|
|
""")
|
|
outcome_counts = dict(cursor.fetchall())
|
|
|
|
# Storage size
|
|
total_size = 0
|
|
for file_path in Path(self.storage_dir).rglob("*.pkl*"):
|
|
total_size += file_path.stat().st_size
|
|
|
|
return {
|
|
'total_snapshots': total_snapshots,
|
|
'snapshots_by_horizon': horizon_counts,
|
|
'outcome_stats': outcome_counts,
|
|
'total_storage_mb': total_size / (1024 * 1024),
|
|
'cache_size': len(self.snapshot_cache)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting storage stats: {e}")
|
|
return {}
|
|
|
|
def cleanup_old_snapshots(self, max_age_days: int = 30):
|
|
"""Clean up old snapshots to save space"""
|
|
try:
|
|
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Get old snapshots
|
|
cursor.execute("""
|
|
SELECT prediction_id, file_path FROM snapshots
|
|
WHERE prediction_time < ?
|
|
""", (cutoff_date.isoformat(),))
|
|
|
|
old_snapshots = cursor.fetchall()
|
|
|
|
# Delete files and database entries
|
|
deleted_count = 0
|
|
for pred_id, file_path in old_snapshots:
|
|
try:
|
|
Path(file_path).unlink(missing_ok=True)
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
logger.debug(f"Error deleting file {file_path}: {e}")
|
|
|
|
# Remove from database
|
|
cursor.execute("""
|
|
DELETE FROM snapshots WHERE prediction_time < ?
|
|
""", (cutoff_date.isoformat(),))
|
|
|
|
conn.commit()
|
|
|
|
# Clean up cache
|
|
to_remove = []
|
|
for pred_id, snapshot in self.snapshot_cache.items():
|
|
if snapshot.prediction_time < cutoff_date:
|
|
to_remove.append(pred_id)
|
|
|
|
for pred_id in to_remove:
|
|
del self.snapshot_cache[pred_id]
|
|
|
|
logger.info(f"Cleaned up {deleted_count} old snapshots")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up old snapshots: {e}")
|