""" CNN-RL Bridge Module This module provides the interface between CNN models and RL training, extracting hidden features and predictions from CNN models for use in RL state building. """ import logging import numpy as np import torch import torch.nn as nn from typing import Dict, List, Optional, Tuple, Any from datetime import datetime, timedelta logger = logging.getLogger(__name__) class CNNRLBridge: """Bridge between CNN models and RL training for feature extraction""" def __init__(self, config: Dict): """Initialize CNN-RL bridge""" self.config = config self.cnn_models = {} self.feature_cache = {} self.cache_timeout = 60 # Cache features for 60 seconds # Initialize CNN model registry if available self._initialize_cnn_models() logger.info("CNN-RL Bridge initialized") def _initialize_cnn_models(self): """Initialize CNN models from config or model registry""" try: # Try to load CNN models from config if hasattr(self.config, 'cnn_models') and self.config.cnn_models: for model_name, model_config in self.config.cnn_models.items(): try: # Load CNN model (implementation would depend on your CNN architecture) model = self._load_cnn_model(model_name, model_config) if model: self.cnn_models[model_name] = model logger.info(f"Loaded CNN model: {model_name}") except Exception as e: logger.warning(f"Failed to load CNN model {model_name}: {e}") if not self.cnn_models: logger.info("No CNN models available - RL will train without CNN features") except Exception as e: logger.warning(f"Error initializing CNN models: {e}") def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]: """Load a CNN model from configuration""" try: # This would implement actual CNN model loading # For now, return None to indicate no models available # In your implementation, this would load your specific CNN architecture logger.info(f"CNN model loading framework ready for {model_name}") return None except Exception as e: logger.error(f"Error loading CNN model {model_name}: {e}") return None def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]: """Get latest CNN features and predictions for a symbol""" try: # Check cache first cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}" if cache_key in self.feature_cache: cached_data = self.feature_cache[cache_key] if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout: return cached_data['features'] # Generate new features if models available if self.cnn_models: features = self._extract_cnn_features_for_symbol(symbol) # Cache the features self.feature_cache[cache_key] = { 'timestamp': datetime.now(), 'features': features } # Clean old cache entries self._cleanup_cache() return features return None except Exception as e: logger.warning(f"Error getting CNN features for {symbol}: {e}") return None def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]: """Extract CNN hidden features and predictions for a symbol""" try: extracted_features = { 'hidden_features': {}, 'predictions': {} } for model_name, model in self.cnn_models.items(): try: # Extract features from each CNN model hidden_features, predictions = self._extract_model_features(model, symbol) if hidden_features is not None: extracted_features['hidden_features'][model_name] = hidden_features if predictions is not None: extracted_features['predictions'][model_name] = predictions except Exception as e: logger.warning(f"Error extracting features from {model_name}: {e}") return extracted_features except Exception as e: logger.error(f"Error extracting CNN features for {symbol}: {e}") return {'hidden_features': {}, 'predictions': {}} def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """Extract hidden features and predictions from a specific CNN model""" try: # This would implement the actual feature extraction from your CNN models # The implementation depends on your specific CNN architecture # For now, return mock data to show the structure # In real implementation, this would: # 1. Get market data for the model # 2. Run forward pass through CNN # 3. Extract hidden layer activations # 4. Get model predictions # Mock hidden features (last hidden layer of CNN) hidden_features = np.random.random(512).astype(np.float32) # Mock predictions for different timeframes # [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe predictions = np.array([ 0.45, # 1s prediction (probability of up move) 0.52, # 1m prediction 0.38, # 1h prediction 0.61 # 1d prediction ]).astype(np.float32) logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions") return hidden_features, predictions except Exception as e: logger.warning(f"Error extracting features from model: {e}") return None, None def _cleanup_cache(self): """Clean up old cache entries""" try: current_time = datetime.now() expired_keys = [] for key, data in self.feature_cache.items(): if (current_time - data['timestamp']).seconds > self.cache_timeout * 2: expired_keys.append(key) for key in expired_keys: del self.feature_cache[key] except Exception as e: logger.warning(f"Error cleaning up feature cache: {e}") def register_cnn_model(self, model_name: str, model: nn.Module): """Register a CNN model for feature extraction""" try: self.cnn_models[model_name] = model logger.info(f"Registered CNN model: {model_name}") except Exception as e: logger.error(f"Error registering CNN model {model_name}: {e}") def unregister_cnn_model(self, model_name: str): """Unregister a CNN model""" try: if model_name in self.cnn_models: del self.cnn_models[model_name] logger.info(f"Unregistered CNN model: {model_name}") except Exception as e: logger.error(f"Error unregistering CNN model {model_name}: {e}") def get_available_models(self) -> List[str]: """Get list of available CNN models""" return list(self.cnn_models.keys()) def is_model_available(self, model_name: str) -> bool: """Check if a specific CNN model is available""" return model_name in self.cnn_models def get_feature_dimensions(self) -> Dict[str, int]: """Get the dimensions of features extracted from CNN models""" return { 'hidden_features_per_model': 512, 'predictions_per_model': 4, # 1s, 1m, 1h, 1d 'total_models': len(self.cnn_models) } def validate_cnn_integration(self) -> Dict[str, Any]: """Validate CNN integration status""" status = { 'models_available': len(self.cnn_models), 'models_list': list(self.cnn_models.keys()), 'cache_entries': len(self.feature_cache), 'integration_ready': len(self.cnn_models) > 0, 'expected_feature_size': len(self.cnn_models) * 512, # hidden features 'expected_prediction_size': len(self.cnn_models) * 4 # predictions } return status