# """ # Legacy CNN Model Compatibility Layer # This module provides compatibility redirects to the unified StandardizedCNN model. # All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired # in favor of the StandardizedCNN architecture. # """ # import logging # import warnings # from typing import Tuple, Dict, Any, Optional # import torch # import numpy as np # # Import the standardized CNN model # from .standardized_cnn import StandardizedCNN # logger = logging.getLogger(__name__) # # Compatibility aliases and wrappers # class EnhancedCNNModel: # """Legacy compatibility wrapper - redirects to StandardizedCNN""" # def __init__(self, *args, **kwargs): # warnings.warn( # "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.", # DeprecationWarning, # stacklevel=2 # ) # # Create StandardizedCNN with default parameters # self.standardized_cnn = StandardizedCNN() # logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN") # def __getattr__(self, name): # """Delegate all method calls to StandardizedCNN""" # return getattr(self.standardized_cnn, name) # class CNNModelTrainer: # """Legacy compatibility wrapper for CNN training""" # def __init__(self, model=None, *args, **kwargs): # warnings.warn( # "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.", # DeprecationWarning, # stacklevel=2 # ) # if isinstance(model, EnhancedCNNModel): # self.model = model.standardized_cnn # else: # self.model = StandardizedCNN() # logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()") # def train_step(self, x, y, *args, **kwargs): # """Legacy train step wrapper""" # try: # # Convert to BaseDataInput format if needed # if hasattr(x, 'get_feature_vector'): # # Already BaseDataInput # base_input = x # else: # # Create mock BaseDataInput for legacy compatibility # from core.data_models import BaseDataInput # base_input = BaseDataInput() # # Set mock feature vector # if isinstance(x, torch.Tensor): # feature_vector = x.flatten().cpu().numpy() # else: # feature_vector = np.array(x).flatten() # # Pad or truncate to expected size # expected_size = self.model.expected_feature_dim # if len(feature_vector) < expected_size: # padding = np.zeros(expected_size - len(feature_vector)) # feature_vector = np.concatenate([feature_vector, padding]) # else: # feature_vector = feature_vector[:expected_size] # base_input._feature_vector = feature_vector # # Convert target to string format # if isinstance(y, torch.Tensor): # y_val = y.item() if y.numel() == 1 else y.argmax().item() # else: # y_val = int(y) if np.isscalar(y) else int(np.argmax(y)) # target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'} # target = target_map.get(y_val, 'HOLD') # # Use StandardizedCNN training # optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # loss = self.model.train_step([base_input], [target], optimizer) # return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5} # except Exception as e: # logger.error(f"Legacy train_step error: {e}") # return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5} # # class CNNModel: # # """Legacy compatibility wrapper for CNN model interface""" # # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None): # # warnings.warn( # # "CNNModel is deprecated. Use StandardizedCNN directly.", # # DeprecationWarning, # # stacklevel=2 # # ) # # self.input_shape = input_shape # # self.output_size = output_size # # self.standardized_cnn = StandardizedCNN() # # self.trainer = CNNModelTrainer(self.standardized_cnn) # # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN") # # def build_model(self, **kwargs): # # """Legacy build method - no-op for StandardizedCNN""" # # return self # # def predict(self, X): # # """Legacy predict method""" # # try: # # # Convert input to BaseDataInput # # from core.data_models import BaseDataInput # # base_input = BaseDataInput() # # if isinstance(X, np.ndarray): # # feature_vector = X.flatten() # # else: # # feature_vector = np.array(X).flatten() # # # Pad or truncate to expected size # # expected_size = self.standardized_cnn.expected_feature_dim # # if len(feature_vector) < expected_size: # # padding = np.zeros(expected_size - len(feature_vector)) # # feature_vector = np.concatenate([feature_vector, padding]) # # else: # # feature_vector = feature_vector[:expected_size] # # base_input._feature_vector = feature_vector # # # Get prediction from StandardizedCNN # # result = self.standardized_cnn.predict_from_base_input(base_input) # # # Convert to legacy format # # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2} # # pred_class = np.array([action_map.get(result.predictions['action'], 2)]) # # pred_proba = np.array([result.predictions['action_probabilities']]) # # return pred_class, pred_proba # # except Exception as e: # # logger.error(f"Legacy predict error: {e}") # # # Return safe defaults # # pred_class = np.array([2]) # HOLD # # pred_proba = np.array([[0.33, 0.33, 0.34]]) # # return pred_class, pred_proba # # def fit(self, X, y, **kwargs): # # """Legacy fit method""" # # try: # # return self.trainer.train_step(X, y) # # except Exception as e: # # logger.error(f"Legacy fit error: {e}") # # return self # # def save(self, filepath: str): # # """Legacy save method""" # # try: # # torch.save(self.standardized_cnn.state_dict(), filepath) # # logger.info(f"StandardizedCNN saved to {filepath}") # # except Exception as e: # # logger.error(f"Error saving model: {e}") # def create_enhanced_cnn_model(input_size: int = 60, # feature_dim: int = 50, # output_size: int = 3, # base_channels: int = 256, # device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]: # """Legacy compatibility function - returns StandardizedCNN""" # warnings.warn( # "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.", # DeprecationWarning, # stacklevel=2 # ) # model = StandardizedCNN() # trainer = CNNModelTrainer(model) # logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly") # return model, trainer # # Export compatibility symbols # __all__ = [ # 'EnhancedCNNModel', # 'CNNModelTrainer', # # 'CNNModel', # 'create_enhanced_cnn_model' # ]