Files
gogo2/NN/models/cnn_model.py
Dobromir Popov b1ae557843 models overhaul
2025-07-29 19:22:04 +03:00

202 lines
7.9 KiB
Python

# """
# Legacy CNN Model Compatibility Layer
# This module provides compatibility redirects to the unified StandardizedCNN model.
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
# in favor of the StandardizedCNN architecture.
# """
# import logging
# import warnings
# from typing import Tuple, Dict, Any, Optional
# import torch
# import numpy as np
# # Import the standardized CNN model
# from .standardized_cnn import StandardizedCNN
# logger = logging.getLogger(__name__)
# # Compatibility aliases and wrappers
# class EnhancedCNNModel:
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
# def __init__(self, *args, **kwargs):
# warnings.warn(
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Create StandardizedCNN with default parameters
# self.standardized_cnn = StandardizedCNN()
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
# def __getattr__(self, name):
# """Delegate all method calls to StandardizedCNN"""
# return getattr(self.standardized_cnn, name)
# class CNNModelTrainer:
# """Legacy compatibility wrapper for CNN training"""
# def __init__(self, model=None, *args, **kwargs):
# warnings.warn(
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
# DeprecationWarning,
# stacklevel=2
# )
# if isinstance(model, EnhancedCNNModel):
# self.model = model.standardized_cnn
# else:
# self.model = StandardizedCNN()
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
# def train_step(self, x, y, *args, **kwargs):
# """Legacy train step wrapper"""
# try:
# # Convert to BaseDataInput format if needed
# if hasattr(x, 'get_feature_vector'):
# # Already BaseDataInput
# base_input = x
# else:
# # Create mock BaseDataInput for legacy compatibility
# from core.data_models import BaseDataInput
# base_input = BaseDataInput()
# # Set mock feature vector
# if isinstance(x, torch.Tensor):
# feature_vector = x.flatten().cpu().numpy()
# else:
# feature_vector = np.array(x).flatten()
# # Pad or truncate to expected size
# expected_size = self.model.expected_feature_dim
# if len(feature_vector) < expected_size:
# padding = np.zeros(expected_size - len(feature_vector))
# feature_vector = np.concatenate([feature_vector, padding])
# else:
# feature_vector = feature_vector[:expected_size]
# base_input._feature_vector = feature_vector
# # Convert target to string format
# if isinstance(y, torch.Tensor):
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
# else:
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
# target = target_map.get(y_val, 'HOLD')
# # Use StandardizedCNN training
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
# loss = self.model.train_step([base_input], [target], optimizer)
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
# except Exception as e:
# logger.error(f"Legacy train_step error: {e}")
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
# # class CNNModel:
# # """Legacy compatibility wrapper for CNN model interface"""
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
# # warnings.warn(
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
# # DeprecationWarning,
# # stacklevel=2
# # )
# # self.input_shape = input_shape
# # self.output_size = output_size
# # self.standardized_cnn = StandardizedCNN()
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
# # def build_model(self, **kwargs):
# # """Legacy build method - no-op for StandardizedCNN"""
# # return self
# # def predict(self, X):
# # """Legacy predict method"""
# # try:
# # # Convert input to BaseDataInput
# # from core.data_models import BaseDataInput
# # base_input = BaseDataInput()
# # if isinstance(X, np.ndarray):
# # feature_vector = X.flatten()
# # else:
# # feature_vector = np.array(X).flatten()
# # # Pad or truncate to expected size
# # expected_size = self.standardized_cnn.expected_feature_dim
# # if len(feature_vector) < expected_size:
# # padding = np.zeros(expected_size - len(feature_vector))
# # feature_vector = np.concatenate([feature_vector, padding])
# # else:
# # feature_vector = feature_vector[:expected_size]
# # base_input._feature_vector = feature_vector
# # # Get prediction from StandardizedCNN
# # result = self.standardized_cnn.predict_from_base_input(base_input)
# # # Convert to legacy format
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
# # pred_proba = np.array([result.predictions['action_probabilities']])
# # return pred_class, pred_proba
# # except Exception as e:
# # logger.error(f"Legacy predict error: {e}")
# # # Return safe defaults
# # pred_class = np.array([2]) # HOLD
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
# # return pred_class, pred_proba
# # def fit(self, X, y, **kwargs):
# # """Legacy fit method"""
# # try:
# # return self.trainer.train_step(X, y)
# # except Exception as e:
# # logger.error(f"Legacy fit error: {e}")
# # return self
# # def save(self, filepath: str):
# # """Legacy save method"""
# # try:
# # torch.save(self.standardized_cnn.state_dict(), filepath)
# # logger.info(f"StandardizedCNN saved to {filepath}")
# # except Exception as e:
# # logger.error(f"Error saving model: {e}")
# def create_enhanced_cnn_model(input_size: int = 60,
# feature_dim: int = 50,
# output_size: int = 3,
# base_channels: int = 256,
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
# """Legacy compatibility function - returns StandardizedCNN"""
# warnings.warn(
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
# DeprecationWarning,
# stacklevel=2
# )
# model = StandardizedCNN()
# trainer = CNNModelTrainer(model)
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
# return model, trainer
# # Export compatibility symbols
# __all__ = [
# 'EnhancedCNNModel',
# 'CNNModelTrainer',
# # 'CNNModel',
# 'create_enhanced_cnn_model'
# ]