202 lines
7.9 KiB
Python
202 lines
7.9 KiB
Python
# """
|
|
# Legacy CNN Model Compatibility Layer
|
|
|
|
# This module provides compatibility redirects to the unified StandardizedCNN model.
|
|
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
|
# in favor of the StandardizedCNN architecture.
|
|
# """
|
|
|
|
# import logging
|
|
# import warnings
|
|
# from typing import Tuple, Dict, Any, Optional
|
|
# import torch
|
|
# import numpy as np
|
|
|
|
# # Import the standardized CNN model
|
|
# from .standardized_cnn import StandardizedCNN
|
|
|
|
# logger = logging.getLogger(__name__)
|
|
|
|
# # Compatibility aliases and wrappers
|
|
# class EnhancedCNNModel:
|
|
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
|
|
|
# def __init__(self, *args, **kwargs):
|
|
# warnings.warn(
|
|
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
|
# DeprecationWarning,
|
|
# stacklevel=2
|
|
# )
|
|
# # Create StandardizedCNN with default parameters
|
|
# self.standardized_cnn = StandardizedCNN()
|
|
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
|
|
|
# def __getattr__(self, name):
|
|
# """Delegate all method calls to StandardizedCNN"""
|
|
# return getattr(self.standardized_cnn, name)
|
|
|
|
|
|
# class CNNModelTrainer:
|
|
# """Legacy compatibility wrapper for CNN training"""
|
|
|
|
# def __init__(self, model=None, *args, **kwargs):
|
|
# warnings.warn(
|
|
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
|
# DeprecationWarning,
|
|
# stacklevel=2
|
|
# )
|
|
# if isinstance(model, EnhancedCNNModel):
|
|
# self.model = model.standardized_cnn
|
|
# else:
|
|
# self.model = StandardizedCNN()
|
|
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
|
|
|
# def train_step(self, x, y, *args, **kwargs):
|
|
# """Legacy train step wrapper"""
|
|
# try:
|
|
# # Convert to BaseDataInput format if needed
|
|
# if hasattr(x, 'get_feature_vector'):
|
|
# # Already BaseDataInput
|
|
# base_input = x
|
|
# else:
|
|
# # Create mock BaseDataInput for legacy compatibility
|
|
# from core.data_models import BaseDataInput
|
|
# base_input = BaseDataInput()
|
|
# # Set mock feature vector
|
|
# if isinstance(x, torch.Tensor):
|
|
# feature_vector = x.flatten().cpu().numpy()
|
|
# else:
|
|
# feature_vector = np.array(x).flatten()
|
|
|
|
# # Pad or truncate to expected size
|
|
# expected_size = self.model.expected_feature_dim
|
|
# if len(feature_vector) < expected_size:
|
|
# padding = np.zeros(expected_size - len(feature_vector))
|
|
# feature_vector = np.concatenate([feature_vector, padding])
|
|
# else:
|
|
# feature_vector = feature_vector[:expected_size]
|
|
|
|
# base_input._feature_vector = feature_vector
|
|
|
|
# # Convert target to string format
|
|
# if isinstance(y, torch.Tensor):
|
|
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
|
# else:
|
|
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
|
|
|
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
|
# target = target_map.get(y_val, 'HOLD')
|
|
|
|
# # Use StandardizedCNN training
|
|
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
|
# loss = self.model.train_step([base_input], [target], optimizer)
|
|
|
|
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Legacy train_step error: {e}")
|
|
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
|
|
|
|
|
# # class CNNModel:
|
|
# # """Legacy compatibility wrapper for CNN model interface"""
|
|
|
|
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
|
# # warnings.warn(
|
|
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
|
|
# # DeprecationWarning,
|
|
# # stacklevel=2
|
|
# # )
|
|
# # self.input_shape = input_shape
|
|
# # self.output_size = output_size
|
|
# # self.standardized_cnn = StandardizedCNN()
|
|
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
|
|
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
|
|
|
# # def build_model(self, **kwargs):
|
|
# # """Legacy build method - no-op for StandardizedCNN"""
|
|
# # return self
|
|
|
|
# # def predict(self, X):
|
|
# # """Legacy predict method"""
|
|
# # try:
|
|
# # # Convert input to BaseDataInput
|
|
# # from core.data_models import BaseDataInput
|
|
# # base_input = BaseDataInput()
|
|
|
|
# # if isinstance(X, np.ndarray):
|
|
# # feature_vector = X.flatten()
|
|
# # else:
|
|
# # feature_vector = np.array(X).flatten()
|
|
|
|
# # # Pad or truncate to expected size
|
|
# # expected_size = self.standardized_cnn.expected_feature_dim
|
|
# # if len(feature_vector) < expected_size:
|
|
# # padding = np.zeros(expected_size - len(feature_vector))
|
|
# # feature_vector = np.concatenate([feature_vector, padding])
|
|
# # else:
|
|
# # feature_vector = feature_vector[:expected_size]
|
|
|
|
# # base_input._feature_vector = feature_vector
|
|
|
|
# # # Get prediction from StandardizedCNN
|
|
# # result = self.standardized_cnn.predict_from_base_input(base_input)
|
|
|
|
# # # Convert to legacy format
|
|
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
|
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
|
# # pred_proba = np.array([result.predictions['action_probabilities']])
|
|
|
|
# # return pred_class, pred_proba
|
|
|
|
# # except Exception as e:
|
|
# # logger.error(f"Legacy predict error: {e}")
|
|
# # # Return safe defaults
|
|
# # pred_class = np.array([2]) # HOLD
|
|
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
|
|
# # return pred_class, pred_proba
|
|
|
|
# # def fit(self, X, y, **kwargs):
|
|
# # """Legacy fit method"""
|
|
# # try:
|
|
# # return self.trainer.train_step(X, y)
|
|
# # except Exception as e:
|
|
# # logger.error(f"Legacy fit error: {e}")
|
|
# # return self
|
|
|
|
# # def save(self, filepath: str):
|
|
# # """Legacy save method"""
|
|
# # try:
|
|
# # torch.save(self.standardized_cnn.state_dict(), filepath)
|
|
# # logger.info(f"StandardizedCNN saved to {filepath}")
|
|
# # except Exception as e:
|
|
# # logger.error(f"Error saving model: {e}")
|
|
|
|
|
|
# def create_enhanced_cnn_model(input_size: int = 60,
|
|
# feature_dim: int = 50,
|
|
# output_size: int = 3,
|
|
# base_channels: int = 256,
|
|
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
|
# """Legacy compatibility function - returns StandardizedCNN"""
|
|
# warnings.warn(
|
|
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
|
# DeprecationWarning,
|
|
# stacklevel=2
|
|
# )
|
|
|
|
# model = StandardizedCNN()
|
|
# trainer = CNNModelTrainer(model)
|
|
|
|
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
|
# return model, trainer
|
|
|
|
|
|
# # Export compatibility symbols
|
|
# __all__ = [
|
|
# 'EnhancedCNNModel',
|
|
# 'CNNModelTrainer',
|
|
# # 'CNNModel',
|
|
# 'create_enhanced_cnn_model'
|
|
# ]
|