""" Dashboard CNN Integration This module integrates the EnhancedCNNAdapter with the dashboard system, providing real-time training, predictions, and performance metrics display. """ import logging import time import threading from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Tuple from collections import deque import numpy as np from .enhanced_cnn_adapter import EnhancedCNNAdapter from .standardized_data_provider import StandardizedDataProvider from .data_models import BaseDataInput, ModelOutput, create_model_output logger = logging.getLogger(__name__) class DashboardCNNIntegration: """ CNN integration for the dashboard system This class: 1. Manages CNN model lifecycle in the dashboard 2. Provides real-time training and inference 3. Tracks performance metrics for dashboard display 4. Handles model predictions for chart overlay """ def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None): """ Initialize the dashboard CNN integration Args: data_provider: Standardized data provider symbols: List of symbols to process """ self.data_provider = data_provider self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] # Initialize CNN adapter self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") # Load best checkpoint if available self.cnn_adapter.load_best_checkpoint() # Performance tracking self.performance_metrics = { 'total_predictions': 0, 'total_training_samples': 0, 'last_training_time': None, 'last_inference_time': None, 'training_loss_history': deque(maxlen=100), 'accuracy_history': deque(maxlen=100), 'inference_times': deque(maxlen=100), 'training_times': deque(maxlen=100), 'predictions_per_second': 0.0, 'training_per_second': 0.0, 'model_status': 'FRESH', 'confidence_history': deque(maxlen=100), 'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0} } # Prediction cache for dashboard display self.prediction_cache = {} self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols} # Training control self.training_enabled = True self.inference_enabled = True self.training_lock = threading.Lock() # Real-time processing self.is_running = False self.processing_thread = None logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}") def start_real_time_processing(self): """Start real-time CNN processing""" if self.is_running: logger.warning("Real-time processing already running") return self.is_running = True self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True) self.processing_thread.start() logger.info("Started real-time CNN processing") def stop_real_time_processing(self): """Stop real-time CNN processing""" self.is_running = False if self.processing_thread: self.processing_thread.join(timeout=5) logger.info("Stopped real-time CNN processing") def _real_time_processing_loop(self): """Main real-time processing loop""" last_prediction_time = {} prediction_interval = 1.0 # Make prediction every 1 second while self.is_running: try: current_time = time.time() for symbol in self.symbols: # Check if it's time to make a prediction for this symbol if (symbol not in last_prediction_time or current_time - last_prediction_time[symbol] >= prediction_interval): # Make prediction if inference is enabled if self.inference_enabled: self._make_prediction(symbol) last_prediction_time[symbol] = current_time # Update performance metrics self._update_performance_metrics() # Sleep briefly to prevent overwhelming the system time.sleep(0.1) except Exception as e: logger.error(f"Error in real-time processing loop: {e}") time.sleep(1) def _make_prediction(self, symbol: str): """Make a prediction for a symbol""" try: start_time = time.time() # Get standardized input data base_data = self.data_provider.get_base_data_input(symbol) if base_data is None: logger.debug(f"No base data available for {symbol}") return # Make prediction model_output = self.cnn_adapter.predict(base_data) # Record inference time inference_time = time.time() - start_time self.performance_metrics['inference_times'].append(inference_time) # Update performance metrics self.performance_metrics['total_predictions'] += 1 self.performance_metrics['last_inference_time'] = datetime.now() self.performance_metrics['confidence_history'].append(model_output.confidence) # Update action distribution action = model_output.predictions['action'] self.performance_metrics['action_distribution'][action] += 1 # Cache prediction for dashboard self.prediction_cache[symbol] = model_output self.prediction_history[symbol].append(model_output) # Store model output in data provider self.data_provider.store_model_output(model_output) logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})") except Exception as e: logger.error(f"Error making prediction for {symbol}: {e}") def add_training_sample(self, symbol: str, actual_action: str, reward: float): """Add a training sample and trigger training if enabled""" try: if not self.training_enabled: return # Get base data for the symbol base_data = self.data_provider.get_base_data_input(symbol) if base_data is None: logger.debug(f"No base data available for training sample: {symbol}") return # Add training sample self.cnn_adapter.add_training_sample(base_data, actual_action, reward) # Update metrics self.performance_metrics['total_training_samples'] += 1 # Train model periodically (every 10 samples) if self.performance_metrics['total_training_samples'] % 10 == 0: self._train_model() except Exception as e: logger.error(f"Error adding training sample: {e}") def _train_model(self): """Train the CNN model""" try: with self.training_lock: start_time = time.time() # Train model metrics = self.cnn_adapter.train(epochs=1) # Record training time training_time = time.time() - start_time self.performance_metrics['training_times'].append(training_time) # Update performance metrics self.performance_metrics['last_training_time'] = datetime.now() if 'loss' in metrics: self.performance_metrics['training_loss_history'].append(metrics['loss']) if 'accuracy' in metrics: self.performance_metrics['accuracy_history'].append(metrics['accuracy']) # Update model status if metrics.get('accuracy', 0) > 0.5: self.performance_metrics['model_status'] = 'TRAINED' else: self.performance_metrics['model_status'] = 'TRAINING' logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}") except Exception as e: logger.error(f"Error training CNN model: {e}") def _update_performance_metrics(self): """Update performance metrics for dashboard display""" try: current_time = time.time() # Calculate predictions per second (last 60 seconds) recent_inferences = [t for t in self.performance_metrics['inference_times'] if current_time - t <= 60] self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0 # Calculate training per second (last 60 seconds) recent_trainings = [t for t in self.performance_metrics['training_times'] if current_time - t <= 60] self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0 except Exception as e: logger.error(f"Error updating performance metrics: {e}") def get_dashboard_metrics(self) -> Dict[str, Any]: """Get metrics for dashboard display""" try: # Calculate current loss current_loss = (self.performance_metrics['training_loss_history'][-1] if self.performance_metrics['training_loss_history'] else 0.0) # Calculate current accuracy current_accuracy = (self.performance_metrics['accuracy_history'][-1] if self.performance_metrics['accuracy_history'] else 0.0) # Calculate average confidence avg_confidence = (np.mean(list(self.performance_metrics['confidence_history'])) if self.performance_metrics['confidence_history'] else 0.0) # Get latest prediction latest_prediction = None latest_symbol = None for symbol, prediction in self.prediction_cache.items(): if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp: latest_prediction = prediction latest_symbol = symbol # Format timing information last_inference_str = "None" last_training_str = "None" if self.performance_metrics['last_inference_time']: last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S") if self.performance_metrics['last_training_time']: last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S") return { 'model_name': 'CNN', 'model_type': 'cnn', 'parameters': '50.0M', 'status': self.performance_metrics['model_status'], 'current_loss': current_loss, 'accuracy': current_accuracy, 'confidence': avg_confidence, 'total_predictions': self.performance_metrics['total_predictions'], 'total_training_samples': self.performance_metrics['total_training_samples'], 'predictions_per_second': self.performance_metrics['predictions_per_second'], 'training_per_second': self.performance_metrics['training_per_second'], 'last_inference': last_inference_str, 'last_training': last_training_str, 'latest_prediction': { 'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD', 'confidence': latest_prediction.confidence if latest_prediction else 0.0, 'symbol': latest_symbol or 'ETH/USDT', 'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None" }, 'action_distribution': self.performance_metrics['action_distribution'].copy(), 'training_enabled': self.training_enabled, 'inference_enabled': self.inference_enabled } except Exception as e: logger.error(f"Error getting dashboard metrics: {e}") return { 'model_name': 'CNN', 'model_type': 'cnn', 'parameters': '50.0M', 'status': 'ERROR', 'current_loss': 0.0, 'accuracy': 0.0, 'confidence': 0.0, 'error': str(e) } def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]: """Get predictions for chart overlay""" try: if symbol not in self.prediction_history: return [] predictions = list(self.prediction_history[symbol])[-limit:] chart_data = [] for prediction in predictions: chart_data.append({ 'timestamp': prediction.timestamp, 'action': prediction.predictions['action'], 'confidence': prediction.confidence, 'buy_probability': prediction.predictions.get('buy_probability', 0.0), 'sell_probability': prediction.predictions.get('sell_probability', 0.0), 'hold_probability': prediction.predictions.get('hold_probability', 0.0) }) return chart_data except Exception as e: logger.error(f"Error getting predictions for chart: {e}") return [] def set_training_enabled(self, enabled: bool): """Enable or disable training""" self.training_enabled = enabled logger.info(f"CNN training {'enabled' if enabled else 'disabled'}") def set_inference_enabled(self, enabled: bool): """Enable or disable inference""" self.inference_enabled = enabled logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}") def get_model_info(self) -> Dict[str, Any]: """Get model information for dashboard""" return { 'name': 'Enhanced CNN', 'version': '1.0', 'parameters': '50.0M', 'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown', 'device': str(self.cnn_adapter.device), 'checkpoint_dir': self.cnn_adapter.checkpoint_dir, 'training_samples': len(self.cnn_adapter.training_data), 'max_training_samples': self.cnn_adapter.max_training_samples }