""" CNN Dashboard Integration This module integrates the EnhancedCNN model with the dashboard, providing real-time training and visualization of model predictions. """ import logging import threading import time from datetime import datetime from typing import Dict, List, Optional, Any, Tuple import os import json from .enhanced_cnn_adapter import EnhancedCNNAdapter from .data_models import BaseDataInput, ModelOutput, create_model_output from utils.training_integration import get_training_integration logger = logging.getLogger(__name__) class CNNDashboardIntegration: """ Integrates the EnhancedCNN model with the dashboard This class: 1. Loads and initializes the CNN model 2. Processes real-time data for model inference 3. Manages continuous training of the model 4. Provides visualization data for the dashboard """ def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"): """ Initialize the CNN dashboard integration Args: data_provider: Data provider instance checkpoint_dir: Directory to save checkpoints to """ self.data_provider = data_provider self.checkpoint_dir = checkpoint_dir self.cnn_adapter = None self.training_thread = None self.training_active = False self.training_interval = 60 # Train every 60 seconds self.training_samples = [] self.max_training_samples = 1000 self.last_training_time = 0 self.last_predictions = {} self.performance_metrics = {} self.model_name = "enhanced_cnn_v1" # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) # Initialize CNN adapter self._initialize_cnn_adapter() logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}") def _initialize_cnn_adapter(self): """Initialize the CNN adapter""" try: # Import here to avoid circular imports from .enhanced_cnn_adapter import EnhancedCNNAdapter # Create CNN adapter self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir) # Load best checkpoint if available self.cnn_adapter.load_best_checkpoint() logger.info("CNN adapter initialized successfully") except Exception as e: logger.error(f"Error initializing CNN adapter: {e}") self.cnn_adapter = None def start_training_thread(self): """Start the training thread""" if self.training_thread is not None and self.training_thread.is_alive(): logger.info("Training thread already running") return self.training_active = True self.training_thread = threading.Thread(target=self._training_loop, daemon=True) self.training_thread.start() logger.info("CNN training thread started") def stop_training_thread(self): """Stop the training thread""" self.training_active = False if self.training_thread is not None: self.training_thread.join(timeout=5) self.training_thread = None logger.info("CNN training thread stopped") def _training_loop(self): """Training loop for continuous model training""" while self.training_active: try: # Check if it's time to train current_time = time.time() if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10: logger.info(f"Training CNN model with {len(self.training_samples)} samples") # Train model if self.cnn_adapter is not None: metrics = self.cnn_adapter.train(epochs=1) # Update performance metrics self.performance_metrics = { 'loss': metrics.get('loss', 0.0), 'accuracy': metrics.get('accuracy', 0.0), 'samples': metrics.get('samples', 0), 'last_training': datetime.now().isoformat() } # Log training metrics logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}") # Update last training time self.last_training_time = current_time # Sleep to avoid high CPU usage time.sleep(1) except Exception as e: logger.error(f"Error in CNN training loop: {e}") time.sleep(5) # Sleep longer on error def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]: """ Process data for model inference and training Args: symbol: Trading symbol base_data: Standardized input data Returns: Optional[ModelOutput]: Model output, or None if processing failed """ try: if self.cnn_adapter is None: logger.warning("CNN adapter not initialized") return None # Make prediction model_output = self.cnn_adapter.predict(base_data) # Store prediction self.last_predictions[symbol] = model_output # Store model output in data provider if self.data_provider is not None: self.data_provider.store_model_output(model_output) return model_output except Exception as e: logger.error(f"Error processing data for CNN model: {e}") return None def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float): """ Add a training sample Args: base_data: Standardized input data actual_action: Actual action taken ('BUY', 'SELL', 'HOLD') reward: Reward received for the action """ try: if self.cnn_adapter is None: logger.warning("CNN adapter not initialized") return # Add training sample to CNN adapter self.cnn_adapter.add_training_sample(base_data, actual_action, reward) # Add to local training samples self.training_samples.append((base_data.symbol, actual_action, reward)) # Limit training samples if len(self.training_samples) > self.max_training_samples: self.training_samples = self.training_samples[-self.max_training_samples:] logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}") except Exception as e: logger.error(f"Error adding training sample: {e}") def get_performance_metrics(self) -> Dict[str, Any]: """ Get performance metrics Returns: Dict[str, Any]: Performance metrics """ metrics = self.performance_metrics.copy() # Add additional metrics metrics['training_samples'] = len(self.training_samples) metrics['model_name'] = self.model_name # Add last prediction metrics if self.last_predictions: for symbol, prediction in self.last_predictions.items(): metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN') metrics[f'{symbol}_last_confidence'] = prediction.confidence return metrics def get_visualization_data(self, symbol: str) -> Dict[str, Any]: """ Get visualization data for the dashboard Args: symbol: Trading symbol Returns: Dict[str, Any]: Visualization data """ data = { 'model_name': self.model_name, 'symbol': symbol, 'timestamp': datetime.now().isoformat(), 'performance_metrics': self.get_performance_metrics() } # Add last prediction if symbol in self.last_predictions: prediction = self.last_predictions[symbol] data['last_prediction'] = { 'action': prediction.predictions.get('action', 'UNKNOWN'), 'confidence': prediction.confidence, 'timestamp': prediction.timestamp.isoformat(), 'buy_probability': prediction.predictions.get('buy_probability', 0.0), 'sell_probability': prediction.predictions.get('sell_probability', 0.0), 'hold_probability': prediction.predictions.get('hold_probability', 0.0) } # Add training samples summary symbol_samples = [s for s in self.training_samples if s[0] == symbol] data['training_samples'] = { 'total': len(symbol_samples), 'buy': len([s for s in symbol_samples if s[1] == 'BUY']), 'sell': len([s for s in symbol_samples if s[1] == 'SELL']), 'hold': len([s for s in symbol_samples if s[1] == 'HOLD']), 'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0 } return data # Global CNN dashboard integration instance _cnn_dashboard_integration = None def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration: """ Get the global CNN dashboard integration instance Args: data_provider: Data provider instance Returns: CNNDashboardIntegration: Global CNN dashboard integration instance """ global _cnn_dashboard_integration if _cnn_dashboard_integration is None: _cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider) return _cnn_dashboard_integration