""" Inference Logger Centralized logging system for model inferences with database storage Eliminates scattered logging throughout the codebase """ import time import hashlib import logging import psutil from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Union from dataclasses import dataclass import numpy as np from .database_manager import get_database_manager, InferenceRecord from .text_logger import get_text_logger logger = logging.getLogger(__name__) class InferenceLogger: """Centralized inference logging system""" def __init__(self): self.db_manager = get_database_manager() self.text_logger = get_text_logger() self._process = psutil.Process() def log_inference(self, model_name: str, symbol: str, action: str, confidence: float, probabilities: Dict[str, float], input_features: Union[np.ndarray, Dict, List], processing_time_ms: float, checkpoint_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> bool: """ Log a model inference with all relevant details Args: model_name: Name of the model making the prediction symbol: Trading symbol action: Predicted action (BUY/SELL/HOLD) confidence: Confidence score (0.0 to 1.0) probabilities: Action probabilities dict input_features: Input features used for prediction processing_time_ms: Time taken for inference in milliseconds checkpoint_id: ID of the checkpoint used metadata: Additional metadata Returns: bool: True if logged successfully """ try: # Create feature hash for deduplication feature_hash = self._hash_features(input_features) # Get current memory usage memory_usage_mb = self._get_memory_usage() # Convert input features to numpy array if needed features_array = None if isinstance(input_features, np.ndarray): features_array = input_features.astype(np.float32) elif isinstance(input_features, (list, tuple)): features_array = np.array(input_features, dtype=np.float32) # Create inference record record = InferenceRecord( model_name=model_name, timestamp=datetime.now(), symbol=symbol, action=action, confidence=confidence, probabilities=probabilities, input_features_hash=feature_hash, processing_time_ms=processing_time_ms, memory_usage_mb=memory_usage_mb, input_features=features_array, checkpoint_id=checkpoint_id, metadata=metadata ) # Log to database db_success = self.db_manager.log_inference(record) # Log to text file text_success = self.text_logger.log_inference( model_name=model_name, symbol=symbol, action=action, confidence=confidence, processing_time_ms=processing_time_ms, checkpoint_id=checkpoint_id ) if db_success: # Reduced logging - no more scattered logs at runtime pass # Database logging successful, text file provides human-readable record else: logger.error(f"Failed to log inference for {model_name}") return db_success and text_success except Exception as e: logger.error(f"Error logging inference: {e}") return False def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str: """Create a hash of input features for deduplication""" try: if isinstance(features, np.ndarray): # Hash numpy array return hashlib.md5(features.tobytes()).hexdigest()[:16] elif isinstance(features, (dict, list)): # Hash dict or list by converting to string feature_str = str(sorted(features.items()) if isinstance(features, dict) else features) return hashlib.md5(feature_str.encode()).hexdigest()[:16] else: # Hash string representation return hashlib.md5(str(features).encode()).hexdigest()[:16] except Exception: # Fallback to timestamp-based hash return hashlib.md5(str(time.time()).encode()).hexdigest()[:16] def _get_memory_usage(self) -> float: """Get current memory usage in MB""" try: return self._process.memory_info().rss / (1024 * 1024) except Exception: return 0.0 def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]: """Get inference statistics for a model""" try: # Get recent inferences recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000) if not recent_inferences: return { 'total_inferences': 0, 'avg_confidence': 0.0, 'avg_processing_time_ms': 0.0, 'action_distribution': {}, 'symbol_distribution': {} } # Filter by time window cutoff_time = datetime.now() - timedelta(hours=hours) recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time] if not recent_inferences: return { 'total_inferences': 0, 'avg_confidence': 0.0, 'avg_processing_time_ms': 0.0, 'action_distribution': {}, 'symbol_distribution': {} } # Calculate statistics total_inferences = len(recent_inferences) avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences # Action distribution action_counts = {} for record in recent_inferences: action_counts[record.action] = action_counts.get(record.action, 0) + 1 # Symbol distribution symbol_counts = {} for record in recent_inferences: symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1 return { 'total_inferences': total_inferences, 'avg_confidence': avg_confidence, 'avg_processing_time_ms': avg_processing_time, 'action_distribution': action_counts, 'symbol_distribution': symbol_counts, 'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None } except Exception as e: logger.error(f"Error getting model stats: {e}") return {} def cleanup_old_logs(self, days_to_keep: int = 30) -> bool: """Clean up old inference logs""" return self.db_manager.cleanup_old_records(days_to_keep) # Global inference logger instance _inference_logger_instance = None def get_inference_logger() -> InferenceLogger: """Get the global inference logger instance""" global _inference_logger_instance if _inference_logger_instance is None: _inference_logger_instance = InferenceLogger() return _inference_logger_instance def log_model_inference(model_name: str, symbol: str, action: str, confidence: float, probabilities: Dict[str, float], input_features: Union[np.ndarray, Dict, List], processing_time_ms: float, checkpoint_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> bool: """ Convenience function to log model inference This is the main function that should be called throughout the codebase instead of scattered logger.info() calls """ inference_logger = get_inference_logger() return inference_logger.log_inference( model_name=model_name, symbol=symbol, action=action, confidence=confidence, probabilities=probabilities, input_features=input_features, processing_time_ms=processing_time_ms, checkpoint_id=checkpoint_id, metadata=metadata )