""" Database Manager for Trading System Manages SQLite database for: 1. Inference records logging 2. Checkpoint metadata storage 3. Model performance tracking """ import sqlite3 import json import logging import os import numpy as np from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Tuple from contextlib import contextmanager from dataclasses import dataclass, asdict logger = logging.getLogger(__name__) @dataclass class InferenceRecord: """Structure for inference logging""" model_name: str timestamp: datetime symbol: str action: str confidence: float probabilities: Dict[str, float] input_features_hash: str # Hash of input features for deduplication processing_time_ms: float memory_usage_mb: float input_features: Optional[np.ndarray] = None # Full input features for training checkpoint_id: Optional[str] = None metadata: Optional[Dict[str, Any]] = None @dataclass class CheckpointMetadata: """Structure for checkpoint metadata""" checkpoint_id: str model_name: str model_type: str timestamp: datetime performance_metrics: Dict[str, float] training_metadata: Dict[str, Any] file_path: str file_size_mb: float is_active: bool = False # Currently loaded checkpoint class DatabaseManager: """Manages SQLite database for trading system logging and metadata""" def __init__(self, db_path: str = "data/trading_system.db"): self.db_path = db_path self._ensure_db_directory() self._initialize_database() def _ensure_db_directory(self): """Ensure database directory exists""" os.makedirs(os.path.dirname(self.db_path), exist_ok=True) def _initialize_database(self): """Initialize database tables""" with self._get_connection() as conn: # Inference records table conn.execute(""" CREATE TABLE IF NOT EXISTS inference_records ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_name TEXT NOT NULL, timestamp TEXT NOT NULL, symbol TEXT NOT NULL, action TEXT NOT NULL, confidence REAL NOT NULL, probabilities TEXT NOT NULL, -- JSON input_features_hash TEXT NOT NULL, input_features_blob BLOB, -- Store full input features for training processing_time_ms REAL NOT NULL, memory_usage_mb REAL NOT NULL, checkpoint_id TEXT, metadata TEXT, -- JSON created_at TEXT DEFAULT CURRENT_TIMESTAMP ) """) # Checkpoint metadata table conn.execute(""" CREATE TABLE IF NOT EXISTS checkpoint_metadata ( id INTEGER PRIMARY KEY AUTOINCREMENT, checkpoint_id TEXT UNIQUE NOT NULL, model_name TEXT NOT NULL, model_type TEXT NOT NULL, timestamp TEXT NOT NULL, performance_metrics TEXT NOT NULL, -- JSON training_metadata TEXT NOT NULL, -- JSON file_path TEXT NOT NULL, file_size_mb REAL NOT NULL, is_active BOOLEAN DEFAULT FALSE, created_at TEXT DEFAULT CURRENT_TIMESTAMP ) """) # Model performance tracking table conn.execute(""" CREATE TABLE IF NOT EXISTS model_performance ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_name TEXT NOT NULL, date TEXT NOT NULL, total_predictions INTEGER DEFAULT 0, correct_predictions INTEGER DEFAULT 0, accuracy REAL DEFAULT 0.0, avg_confidence REAL DEFAULT 0.0, avg_processing_time_ms REAL DEFAULT 0.0, created_at TEXT DEFAULT CURRENT_TIMESTAMP, UNIQUE(model_name, date) ) """) # Create indexes for better performance conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)") conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)") conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)") conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)") logger.info(f"Database initialized at {self.db_path}") # Run migrations to handle schema updates self._run_migrations() def _run_migrations(self): """Run database migrations to handle schema updates""" try: with self._get_connection() as conn: # Check if input_features_blob column exists cursor = conn.execute("PRAGMA table_info(inference_records)") columns = [row[1] for row in cursor.fetchall()] if 'input_features_blob' not in columns: logger.info("Adding input_features_blob column to inference_records table") conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB") conn.commit() logger.info("Successfully added input_features_blob column") else: logger.debug("input_features_blob column already exists") except Exception as e: logger.error(f"Error running database migrations: {e}") # If migration fails, we can still continue without the blob column @contextmanager def _get_connection(self): """Get database connection with proper error handling""" conn = None try: conn = sqlite3.connect(self.db_path, timeout=30.0) conn.row_factory = sqlite3.Row # Enable dict-like access yield conn except Exception as e: if conn: conn.rollback() logger.error(f"Database error: {e}") raise finally: if conn: conn.close() def log_inference(self, record: InferenceRecord) -> bool: """Log an inference record""" try: with self._get_connection() as conn: # Check if input_features_blob column exists cursor = conn.execute("PRAGMA table_info(inference_records)") columns = [row[1] for row in cursor.fetchall()] has_blob_column = 'input_features_blob' in columns # Serialize input features if provided and column exists input_features_blob = None if record.input_features is not None and has_blob_column: input_features_blob = record.input_features.tobytes() if has_blob_column: # Use full query with blob column conn.execute(""" INSERT INTO inference_records ( model_name, timestamp, symbol, action, confidence, probabilities, input_features_hash, input_features_blob, processing_time_ms, memory_usage_mb, checkpoint_id, metadata ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( record.model_name, record.timestamp.isoformat(), record.symbol, record.action, record.confidence, json.dumps(record.probabilities), record.input_features_hash, input_features_blob, record.processing_time_ms, record.memory_usage_mb, record.checkpoint_id, json.dumps(record.metadata) if record.metadata else None )) else: # Fallback query without blob column logger.warning("input_features_blob column missing, storing without full features") conn.execute(""" INSERT INTO inference_records ( model_name, timestamp, symbol, action, confidence, probabilities, input_features_hash, processing_time_ms, memory_usage_mb, checkpoint_id, metadata ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( record.model_name, record.timestamp.isoformat(), record.symbol, record.action, record.confidence, json.dumps(record.probabilities), record.input_features_hash, record.processing_time_ms, record.memory_usage_mb, record.checkpoint_id, json.dumps(record.metadata) if record.metadata else None )) conn.commit() return True except Exception as e: logger.error(f"Failed to log inference record: {e}") return False def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool: """Save checkpoint metadata""" try: with self._get_connection() as conn: # First, set all other checkpoints for this model as inactive conn.execute(""" UPDATE checkpoint_metadata SET is_active = FALSE WHERE model_name = ? """, (metadata.model_name,)) # Insert or replace the new checkpoint metadata conn.execute(""" INSERT OR REPLACE INTO checkpoint_metadata ( checkpoint_id, model_name, model_type, timestamp, performance_metrics, training_metadata, file_path, file_size_mb, is_active ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( metadata.checkpoint_id, metadata.model_name, metadata.model_type, metadata.timestamp.isoformat(), json.dumps(metadata.performance_metrics), json.dumps(metadata.training_metadata), metadata.file_path, metadata.file_size_mb, metadata.is_active )) conn.commit() return True except Exception as e: logger.error(f"Failed to save checkpoint metadata: {e}") return False def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]: """Get checkpoint metadata without loading the actual model""" try: with self._get_connection() as conn: if checkpoint_id: # Get specific checkpoint cursor = conn.execute(""" SELECT * FROM checkpoint_metadata WHERE model_name = ? AND checkpoint_id = ? """, (model_name, checkpoint_id)) else: # Get active checkpoint cursor = conn.execute(""" SELECT * FROM checkpoint_metadata WHERE model_name = ? AND is_active = TRUE ORDER BY timestamp DESC LIMIT 1 """, (model_name,)) row = cursor.fetchone() if row: return CheckpointMetadata( checkpoint_id=row['checkpoint_id'], model_name=row['model_name'], model_type=row['model_type'], timestamp=datetime.fromisoformat(row['timestamp']), performance_metrics=json.loads(row['performance_metrics']), training_metadata=json.loads(row['training_metadata']), file_path=row['file_path'], file_size_mb=row['file_size_mb'], is_active=bool(row['is_active']) ) return None except Exception as e: logger.error(f"Failed to get checkpoint metadata: {e}") return None def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]: """Get best checkpoint metadata based on performance metric""" try: with self._get_connection() as conn: cursor = conn.execute(""" SELECT * FROM checkpoint_metadata WHERE model_name = ? ORDER BY json_extract(performance_metrics, '$.' || ?) DESC LIMIT 1 """, (model_name, metric_name)) row = cursor.fetchone() if row: return CheckpointMetadata( checkpoint_id=row['checkpoint_id'], model_name=row['model_name'], model_type=row['model_type'], timestamp=datetime.fromisoformat(row['timestamp']), performance_metrics=json.loads(row['performance_metrics']), training_metadata=json.loads(row['training_metadata']), file_path=row['file_path'], file_size_mb=row['file_size_mb'], is_active=bool(row['is_active']) ) return None except Exception as e: logger.error(f"Failed to get best checkpoint metadata: {e}") return None def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]: """List all checkpoints for a model""" try: with self._get_connection() as conn: cursor = conn.execute(""" SELECT * FROM checkpoint_metadata WHERE model_name = ? ORDER BY timestamp DESC """, (model_name,)) checkpoints = [] for row in cursor.fetchall(): checkpoints.append(CheckpointMetadata( checkpoint_id=row['checkpoint_id'], model_name=row['model_name'], model_type=row['model_type'], timestamp=datetime.fromisoformat(row['timestamp']), performance_metrics=json.loads(row['performance_metrics']), training_metadata=json.loads(row['training_metadata']), file_path=row['file_path'], file_size_mb=row['file_size_mb'], is_active=bool(row['is_active']) )) return checkpoints except Exception as e: logger.error(f"Failed to list checkpoints: {e}") return [] def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool: """Set a checkpoint as active for a model""" try: with self._get_connection() as conn: # First, set all checkpoints for this model as inactive conn.execute(""" UPDATE checkpoint_metadata SET is_active = FALSE WHERE model_name = ? """, (model_name,)) # Set the specified checkpoint as active cursor = conn.execute(""" UPDATE checkpoint_metadata SET is_active = TRUE WHERE model_name = ? AND checkpoint_id = ? """, (model_name, checkpoint_id)) conn.commit() return cursor.rowcount > 0 except Exception as e: logger.error(f"Failed to set active checkpoint: {e}") return False def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]: """Get recent inference records for a model""" try: with self._get_connection() as conn: cursor = conn.execute(""" SELECT * FROM inference_records WHERE model_name = ? ORDER BY timestamp DESC LIMIT ? """, (model_name, limit)) records = [] for row in cursor.fetchall(): # Deserialize input features if available input_features = None # Check if the column exists in the row (handles missing column gracefully) if 'input_features_blob' in row.keys() and row['input_features_blob']: try: # Reconstruct numpy array from bytes input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32) except Exception as e: logger.warning(f"Failed to deserialize input features: {e}") records.append(InferenceRecord( model_name=row['model_name'], timestamp=datetime.fromisoformat(row['timestamp']), symbol=row['symbol'], action=row['action'], confidence=row['confidence'], probabilities=json.loads(row['probabilities']), input_features_hash=row['input_features_hash'], processing_time_ms=row['processing_time_ms'], memory_usage_mb=row['memory_usage_mb'], input_features=input_features, checkpoint_id=row['checkpoint_id'], metadata=json.loads(row['metadata']) if row['metadata'] else None )) return records except Exception as e: logger.error(f"Failed to get recent inferences: {e}") return [] def update_model_performance(self, model_name: str, date: str, total_predictions: int, correct_predictions: int, avg_confidence: float, avg_processing_time: float) -> bool: """Update daily model performance statistics""" try: accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 with self._get_connection() as conn: conn.execute(""" INSERT OR REPLACE INTO model_performance ( model_name, date, total_predictions, correct_predictions, accuracy, avg_confidence, avg_processing_time_ms ) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( model_name, date, total_predictions, correct_predictions, accuracy, avg_confidence, avg_processing_time )) conn.commit() return True except Exception as e: logger.error(f"Failed to update model performance: {e}") return False def get_inference_records_for_training(self, model_name: str, symbol: str = None, hours_back: int = 24, limit: int = 1000) -> List[InferenceRecord]: """ Get inference records with input features for training feedback Args: model_name: Name of the model symbol: Optional symbol filter hours_back: How many hours back to look limit: Maximum number of records Returns: List of InferenceRecord with input_features populated """ try: cutoff_time = datetime.now() - timedelta(hours=hours_back) with self._get_connection() as conn: # Check if input_features_blob column exists before querying cursor = conn.execute("PRAGMA table_info(inference_records)") columns = [row[1] for row in cursor.fetchall()] has_blob_column = 'input_features_blob' in columns if not has_blob_column: logger.warning("input_features_blob column not found, returning empty list") return [] if symbol: cursor = conn.execute(""" SELECT * FROM inference_records WHERE model_name = ? AND symbol = ? AND timestamp >= ? AND input_features_blob IS NOT NULL ORDER BY timestamp DESC LIMIT ? """, (model_name, symbol, cutoff_time.isoformat(), limit)) else: cursor = conn.execute(""" SELECT * FROM inference_records WHERE model_name = ? AND timestamp >= ? AND input_features_blob IS NOT NULL ORDER BY timestamp DESC LIMIT ? """, (model_name, cutoff_time.isoformat(), limit)) records = [] for row in cursor.fetchall(): # Deserialize input features input_features = None if row['input_features_blob']: try: input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32) except Exception as e: logger.warning(f"Failed to deserialize input features: {e}") continue # Skip records with corrupted features records.append(InferenceRecord( model_name=row['model_name'], timestamp=datetime.fromisoformat(row['timestamp']), symbol=row['symbol'], action=row['action'], confidence=row['confidence'], probabilities=json.loads(row['probabilities']), input_features_hash=row['input_features_hash'], processing_time_ms=row['processing_time_ms'], memory_usage_mb=row['memory_usage_mb'], input_features=input_features, checkpoint_id=row['checkpoint_id'], metadata=json.loads(row['metadata']) if row['metadata'] else None )) return records except Exception as e: logger.error(f"Failed to get inference records for training: {e}") return [] def cleanup_old_records(self, days_to_keep: int = 30) -> bool: """Clean up old inference records""" try: cutoff_date = datetime.now() - timedelta(days=days_to_keep) with self._get_connection() as conn: cursor = conn.execute(""" DELETE FROM inference_records WHERE timestamp < ? """, (cutoff_date.isoformat(),)) deleted_count = cursor.rowcount conn.commit() if deleted_count > 0: logger.info(f"Cleaned up {deleted_count} old inference records") return True except Exception as e: logger.error(f"Failed to cleanup old records: {e}") return False # Global database manager instance _db_manager_instance = None def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager: """Get the global database manager instance""" global _db_manager_instance if _db_manager_instance is None: _db_manager_instance = DatabaseManager(db_path) return _db_manager_instance def reset_database_manager(): """Reset the database manager instance to force re-initialization""" global _db_manager_instance _db_manager_instance = None logger.info("Database manager instance reset - will re-initialize on next access")