""" Enhanced CNN Adapter for Standardized Input Format This module provides an adapter for the EnhancedCNN model to work with the standardized BaseDataInput format, enabling seamless integration with the multi-modal trading system. """ import torch import numpy as np import logging import os from datetime import datetime from typing import Dict, List, Optional, Tuple, Any, Union from threading import Lock from .data_models import BaseDataInput, ModelOutput, create_model_output from NN.models.enhanced_cnn import EnhancedCNN logger = logging.getLogger(__name__) class EnhancedCNNAdapter: """ Adapter for EnhancedCNN model to work with standardized BaseDataInput format This adapter: 1. Converts BaseDataInput to the format expected by EnhancedCNN 2. Processes model outputs to create standardized ModelOutput 3. Manages model training with collected data 4. Handles checkpoint management """ def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"): """ Initialize the EnhancedCNN adapter Args: model_path: Path to load model from, if None a new model is created checkpoint_dir: Directory to save checkpoints to """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = None self.model_path = model_path self.checkpoint_dir = checkpoint_dir self.training_lock = Lock() self.training_data = [] self.max_training_samples = 10000 self.batch_size = 32 self.learning_rate = 0.0001 self.model_name = "enhanced_cnn" # Enhanced metrics tracking self.last_inference_time = None self.last_inference_duration = 0.0 self.last_prediction_output = None self.last_training_time = None self.last_training_duration = 0.0 self.last_training_loss = 0.0 self.inference_count = 0 self.training_count = 0 # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) # Initialize the model self._initialize_model() # Load checkpoint if available if model_path and os.path.exists(model_path): self._load_checkpoint(model_path) else: self._load_best_checkpoint() logger.info(f"EnhancedCNNAdapter initialized on {self.device}") def _initialize_model(self): """Initialize the EnhancedCNN model""" try: # Calculate input shape based on BaseDataInput structure # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features # BTC OHLCV: 300 frames x 5 features = 1500 features # COB: ±20 buckets x 4 metrics = 160 features # MA: 4 timeframes x 10 buckets = 40 features # Technical indicators: 100 features # Last predictions: 50 features # Total: 7850 features input_shape = 7850 n_actions = 3 # BUY, SELL, HOLD # Create model self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) self.model.to(self.device) logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}") except Exception as e: logger.error(f"Error initializing EnhancedCNN model: {e}") raise def _load_checkpoint(self, checkpoint_path: str) -> bool: """Load model from checkpoint path""" try: if self.model and os.path.exists(checkpoint_path): success = self.model.load(checkpoint_path) if success: logger.info(f"Loaded model from {checkpoint_path}") return True else: logger.warning(f"Failed to load model from {checkpoint_path}") return False else: logger.warning(f"Checkpoint path does not exist: {checkpoint_path}") return False except Exception as e: logger.error(f"Error loading checkpoint: {e}") return False def _load_best_checkpoint(self) -> bool: """Load the best available checkpoint""" try: return self.load_best_checkpoint() except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return False def load_best_checkpoint(self) -> bool: """Load the best checkpoint based on accuracy""" try: # Import checkpoint manager from utils.checkpoint_manager import CheckpointManager # Create checkpoint manager checkpoint_manager = CheckpointManager( checkpoint_dir=self.checkpoint_dir, max_checkpoints=10, metric_name="accuracy" ) # Load best checkpoint best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name) if not best_checkpoint_path: logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode") return False # Load model success = self.model.load(best_checkpoint_path) if success: logger.info(f"Loaded best checkpoint from {best_checkpoint_path}") # Log metrics metrics = best_checkpoint_metadata.get('metrics', {}) logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}") return True else: logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}") return False except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return False def _create_default_output(self, symbol: str) -> ModelOutput: """Create default output when prediction fails""" return create_model_output( model_type='cnn', model_name=self.model_name, symbol=symbol, action='HOLD', confidence=0.0, metadata={'error': 'Prediction failed, using default output'} ) def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]: """Process hidden states for cross-model feeding""" processed_states = {} for key, value in hidden_states.items(): if isinstance(value, torch.Tensor): # Convert tensor to numpy array processed_states[key] = value.cpu().numpy().tolist() else: processed_states[key] = value return processed_states def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor: """ Convert BaseDataInput to feature vector for EnhancedCNN Args: base_data: Standardized input data Returns: torch.Tensor: Feature vector for EnhancedCNN """ try: # Use the get_feature_vector method from BaseDataInput features = base_data.get_feature_vector() # Convert to torch tensor features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device) return features_tensor except Exception as e: logger.error(f"Error converting BaseDataInput to features: {e}") # Return empty tensor with correct shape return torch.zeros(7850, dtype=torch.float32, device=self.device) def predict(self, base_data: BaseDataInput) -> ModelOutput: """ Make a prediction using the EnhancedCNN model Args: base_data: Standardized input data Returns: ModelOutput: Standardized model output """ try: # Track inference timing start_time = datetime.now() inference_start = start_time.timestamp() # Convert BaseDataInput to features features = self._convert_base_data_to_features(base_data) # Ensure features has batch dimension if features.dim() == 1: features = features.unsqueeze(0) # Set model to evaluation mode self.model.eval() # Make prediction with torch.no_grad(): q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features) # Get action and confidence action_probs = torch.softmax(q_values, dim=1) action_idx = torch.argmax(action_probs, dim=1).item() confidence = float(action_probs[0, action_idx].item()) # Map action index to action string actions = ['BUY', 'SELL', 'HOLD'] action = actions[action_idx] # Extract pivot price prediction (simplified - take first value from price_pred) pivot_price = None if price_pred is not None and len(price_pred.squeeze()) > 0: # Get current price from base_data for context current_price = 0.0 if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0: current_price = base_data.ohlcv_1s[-1].close # Calculate pivot price as current price + predicted change price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price # Create predictions dictionary predictions = { 'action': action, 'buy_probability': float(action_probs[0, 0].item()), 'sell_probability': float(action_probs[0, 1].item()), 'hold_probability': float(action_probs[0, 2].item()), 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(), 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(), 'pivot_price': pivot_price } # Create hidden states dictionary hidden_states = { 'features': features_refined.squeeze(0).cpu().numpy().tolist() } # Calculate inference duration end_time = datetime.now() inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds # Update metrics self.last_inference_time = start_time self.last_inference_duration = inference_duration self.inference_count += 1 # Store last prediction output for dashboard self.last_prediction_output = { 'action': action, 'confidence': confidence, 'pivot_price': pivot_price, 'timestamp': start_time, 'symbol': base_data.symbol } # Create metadata dictionary metadata = { 'model_version': '1.0', 'timestamp': start_time.isoformat(), 'input_shape': features.shape, 'inference_duration_ms': inference_duration, 'inference_count': self.inference_count } # Create ModelOutput model_output = ModelOutput( model_type='cnn', model_name=self.model_name, symbol=base_data.symbol, timestamp=start_time, confidence=confidence, predictions=predictions, hidden_states=hidden_states, metadata=metadata ) return model_output except Exception as e: logger.error(f"Error making prediction with EnhancedCNN: {e}") # Return default ModelOutput return create_model_output( model_type='cnn', model_name=self.model_name, symbol=base_data.symbol, action='HOLD', confidence=0.0 ) def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float): """ Add a training sample to the training data Args: symbol_or_base_data: Either a symbol string or BaseDataInput object actual_action: Actual action taken ('BUY', 'SELL', 'HOLD') reward: Reward received for the action """ try: # Handle both symbol string and BaseDataInput object if isinstance(symbol_or_base_data, str): # For cold start mode - create a simple training sample with current features # This is a simplified approach for rapid training symbol = symbol_or_base_data # Create a simple feature vector (this could be enhanced with actual market data) # For now, use a random feature vector as placeholder for cold start features = torch.randn(7850, dtype=torch.float32, device=self.device) logger.debug(f"Added simplified training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}") else: # Full BaseDataInput object base_data = symbol_or_base_data features = self._convert_base_data_to_features(base_data) symbol = base_data.symbol logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}") # Convert action to index actions = ['BUY', 'SELL', 'HOLD'] action_idx = actions.index(actual_action) # Add to training data with self.training_lock: self.training_data.append((features, action_idx, reward)) # Limit training data size if len(self.training_data) > self.max_training_samples: # Sort by reward (highest first) and keep top samples self.training_data.sort(key=lambda x: x[2], reverse=True) self.training_data = self.training_data[:self.max_training_samples] except Exception as e: logger.error(f"Error adding training sample: {e}") def train(self, epochs: int = 1) -> Dict[str, float]: """ Train the model with collected data Args: epochs: Number of epochs to train for Returns: Dict[str, float]: Training metrics """ try: # Track training timing training_start_time = datetime.now() training_start = training_start_time.timestamp() with self.training_lock: # Check if we have enough data if len(self.training_data) < self.batch_size: logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} # Set model to training mode self.model.train() # Create optimizer optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) # Training metrics total_loss = 0.0 correct_predictions = 0 total_predictions = 0 # Train for specified number of epochs for epoch in range(epochs): # Shuffle training data np.random.shuffle(self.training_data) # Process in batches for i in range(0, len(self.training_data), self.batch_size): batch = self.training_data[i:i+self.batch_size] # Skip if batch is too small if len(batch) < 2: continue # Prepare batch features = torch.stack([sample[0] for sample in batch]) actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) # Zero gradients optimizer.zero_grad() # Forward pass q_values, _, _, _, _ = self.model(features) # Calculate loss (CrossEntropyLoss with reward weighting) # First, apply softmax to get probabilities probs = torch.softmax(q_values, dim=1) # Get probability of chosen action chosen_probs = probs[torch.arange(len(actions)), actions] # Calculate negative log likelihood loss nll_loss = -torch.log(chosen_probs + 1e-10) # Weight by reward (higher reward = higher weight) # Normalize rewards to [0, 1] range min_reward = rewards.min() max_reward = rewards.max() if max_reward > min_reward: normalized_rewards = (rewards - min_reward) / (max_reward - min_reward) else: normalized_rewards = torch.ones_like(rewards) # Apply reward weighting (higher reward = higher weight) weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights # Mean loss loss = weighted_loss.mean() # Backward pass loss.backward() # Update weights optimizer.step() # Update metrics total_loss += loss.item() # Calculate accuracy predicted_actions = torch.argmax(q_values, dim=1) correct_predictions += (predicted_actions == actions).sum().item() total_predictions += len(actions) # Calculate final metrics avg_loss = total_loss / (len(self.training_data) / self.batch_size) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 # Calculate training duration training_end_time = datetime.now() training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds # Update training metrics self.last_training_time = training_start_time self.last_training_duration = training_duration self.last_training_loss = avg_loss self.training_count += 1 # Save checkpoint self._save_checkpoint(avg_loss, accuracy) logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms") return { 'loss': avg_loss, 'accuracy': accuracy, 'samples': len(self.training_data), 'duration_ms': training_duration, 'training_count': self.training_count } except Exception as e: logger.error(f"Error training model: {e}") return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)} def _save_checkpoint(self, loss: float, accuracy: float): """ Save model checkpoint Args: loss: Training loss accuracy: Training accuracy """ try: # Import checkpoint manager from utils.checkpoint_manager import CheckpointManager # Create checkpoint manager checkpoint_manager = CheckpointManager( checkpoint_dir=self.checkpoint_dir, max_checkpoints=10, metric_name="accuracy" ) # Create temporary model file temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp") self.model.save(temp_path) # Create metrics metrics = { 'loss': loss, 'accuracy': accuracy, 'samples': len(self.training_data) } # Create metadata metadata = { 'timestamp': datetime.now().isoformat(), 'model_name': self.model_name, 'input_shape': self.model.input_shape, 'n_actions': self.model.n_actions } # Save checkpoint checkpoint_path = checkpoint_manager.save_checkpoint( model_name=self.model_name, model_path=f"{temp_path}.pt", metrics=metrics, metadata=metadata ) # Delete temporary model file if os.path.exists(f"{temp_path}.pt"): os.remove(f"{temp_path}.pt") logger.info(f"Model checkpoint saved to {checkpoint_path}") except Exception as e: logger.error(f"Error saving checkpoint: {e}")