202 lines
7.5 KiB
Python
202 lines
7.5 KiB
Python
"""
|
|
Legacy CNN Model Compatibility Layer
|
|
|
|
This module provides compatibility redirects to the unified StandardizedCNN model.
|
|
All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
|
in favor of the StandardizedCNN architecture.
|
|
"""
|
|
|
|
import logging
|
|
import warnings
|
|
from typing import Tuple, Dict, Any, Optional
|
|
import torch
|
|
import numpy as np
|
|
|
|
# Import the standardized CNN model
|
|
from .standardized_cnn import StandardizedCNN
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Compatibility aliases and wrappers
|
|
class EnhancedCNNModel:
|
|
"""Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
warnings.warn(
|
|
"EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
# Create StandardizedCNN with default parameters
|
|
self.standardized_cnn = StandardizedCNN()
|
|
logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
|
|
|
def __getattr__(self, name):
|
|
"""Delegate all method calls to StandardizedCNN"""
|
|
return getattr(self.standardized_cnn, name)
|
|
|
|
|
|
class CNNModelTrainer:
|
|
"""Legacy compatibility wrapper for CNN training"""
|
|
|
|
def __init__(self, model=None, *args, **kwargs):
|
|
warnings.warn(
|
|
"CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
if isinstance(model, EnhancedCNNModel):
|
|
self.model = model.standardized_cnn
|
|
else:
|
|
self.model = StandardizedCNN()
|
|
logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
|
|
|
def train_step(self, x, y, *args, **kwargs):
|
|
"""Legacy train step wrapper"""
|
|
try:
|
|
# Convert to BaseDataInput format if needed
|
|
if hasattr(x, 'get_feature_vector'):
|
|
# Already BaseDataInput
|
|
base_input = x
|
|
else:
|
|
# Create mock BaseDataInput for legacy compatibility
|
|
from core.data_models import BaseDataInput
|
|
base_input = BaseDataInput()
|
|
# Set mock feature vector
|
|
if isinstance(x, torch.Tensor):
|
|
feature_vector = x.flatten().cpu().numpy()
|
|
else:
|
|
feature_vector = np.array(x).flatten()
|
|
|
|
# Pad or truncate to expected size
|
|
expected_size = self.model.expected_feature_dim
|
|
if len(feature_vector) < expected_size:
|
|
padding = np.zeros(expected_size - len(feature_vector))
|
|
feature_vector = np.concatenate([feature_vector, padding])
|
|
else:
|
|
feature_vector = feature_vector[:expected_size]
|
|
|
|
base_input._feature_vector = feature_vector
|
|
|
|
# Convert target to string format
|
|
if isinstance(y, torch.Tensor):
|
|
y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
|
else:
|
|
y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
|
|
|
target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
|
target = target_map.get(y_val, 'HOLD')
|
|
|
|
# Use StandardizedCNN training
|
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
|
loss = self.model.train_step([base_input], [target], optimizer)
|
|
|
|
return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Legacy train_step error: {e}")
|
|
return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
|
|
|
|
|
class CNNModel:
|
|
"""Legacy compatibility wrapper for CNN model interface"""
|
|
|
|
def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
|
warnings.warn(
|
|
"CNNModel is deprecated. Use StandardizedCNN directly.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
self.input_shape = input_shape
|
|
self.output_size = output_size
|
|
self.standardized_cnn = StandardizedCNN()
|
|
self.trainer = CNNModelTrainer(self.standardized_cnn)
|
|
logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
|
|
|
def build_model(self, **kwargs):
|
|
"""Legacy build method - no-op for StandardizedCNN"""
|
|
return self
|
|
|
|
def predict(self, X):
|
|
"""Legacy predict method"""
|
|
try:
|
|
# Convert input to BaseDataInput
|
|
from core.data_models import BaseDataInput
|
|
base_input = BaseDataInput()
|
|
|
|
if isinstance(X, np.ndarray):
|
|
feature_vector = X.flatten()
|
|
else:
|
|
feature_vector = np.array(X).flatten()
|
|
|
|
# Pad or truncate to expected size
|
|
expected_size = self.standardized_cnn.expected_feature_dim
|
|
if len(feature_vector) < expected_size:
|
|
padding = np.zeros(expected_size - len(feature_vector))
|
|
feature_vector = np.concatenate([feature_vector, padding])
|
|
else:
|
|
feature_vector = feature_vector[:expected_size]
|
|
|
|
base_input._feature_vector = feature_vector
|
|
|
|
# Get prediction from StandardizedCNN
|
|
result = self.standardized_cnn.predict_from_base_input(base_input)
|
|
|
|
# Convert to legacy format
|
|
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
|
pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
|
pred_proba = np.array([result.predictions['action_probabilities']])
|
|
|
|
return pred_class, pred_proba
|
|
|
|
except Exception as e:
|
|
logger.error(f"Legacy predict error: {e}")
|
|
# Return safe defaults
|
|
pred_class = np.array([2]) # HOLD
|
|
pred_proba = np.array([[0.33, 0.33, 0.34]])
|
|
return pred_class, pred_proba
|
|
|
|
def fit(self, X, y, **kwargs):
|
|
"""Legacy fit method"""
|
|
try:
|
|
return self.trainer.train_step(X, y)
|
|
except Exception as e:
|
|
logger.error(f"Legacy fit error: {e}")
|
|
return self
|
|
|
|
def save(self, filepath: str):
|
|
"""Legacy save method"""
|
|
try:
|
|
torch.save(self.standardized_cnn.state_dict(), filepath)
|
|
logger.info(f"StandardizedCNN saved to {filepath}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
|
|
|
|
def create_enhanced_cnn_model(input_size: int = 60,
|
|
feature_dim: int = 50,
|
|
output_size: int = 3,
|
|
base_channels: int = 256,
|
|
device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
|
"""Legacy compatibility function - returns StandardizedCNN"""
|
|
warnings.warn(
|
|
"create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
|
|
model = StandardizedCNN()
|
|
trainer = CNNModelTrainer(model)
|
|
|
|
logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
|
return model, trainer
|
|
|
|
|
|
# Export compatibility symbols
|
|
__all__ = [
|
|
'EnhancedCNNModel',
|
|
'CNNModelTrainer',
|
|
'CNNModel',
|
|
'create_enhanced_cnn_model'
|
|
]
|