""" Model Interfaces Module Defines abstract base classes and concrete implementations for various model types to ensure consistent interaction within the trading system. """ import logging from typing import Dict, Any, Optional, List from abc import ABC, abstractmethod import numpy as np logger = logging.getLogger(__name__) class ModelInterface(ABC): """Base interface for all models""" def __init__(self, name: str): self.name = name @abstractmethod def predict(self, data): """Make a prediction""" pass @abstractmethod def get_memory_usage(self) -> float: """Get memory usage in MB""" pass class CNNModelInterface(ModelInterface): """Interface for CNN models""" def __init__(self, model, name: str): super().__init__(name) self.model = model def predict(self, data): """Make CNN prediction""" try: if hasattr(self.model, 'predict'): return self.model.predict(data) return None except Exception as e: logger.error(f"Error in CNN prediction: {e}") return None def get_memory_usage(self) -> float: """Estimate CNN memory usage""" return 50.0 # MB class RLAgentInterface(ModelInterface): """Interface for RL agents""" def __init__(self, model, name: str): super().__init__(name) self.model = model def predict(self, data): """Make RL prediction""" try: if hasattr(self.model, 'act'): return self.model.act(data) elif hasattr(self.model, 'predict'): return self.model.predict(data) return None except Exception as e: logger.error(f"Error in RL prediction: {e}") return None def get_memory_usage(self) -> float: """Estimate RL memory usage""" return 25.0 # MB class ExtremaTrainerInterface(ModelInterface): """Interface for ExtremaTrainer models, providing context features""" def __init__(self, model, name: str): super().__init__(name) self.model = model def predict(self, data=None): """ExtremaTrainer doesn't predict in the traditional sense, it provides features.""" logger.warning(f"Predict method called on ExtremaTrainerInterface ({self.name}). Use get_context_features_for_model instead.") return None def get_memory_usage(self) -> float: """Estimate ExtremaTrainer memory usage""" return 30.0 # MB def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]: """Get context features from the ExtremaTrainer for model consumption.""" try: if hasattr(self.model, 'get_context_features_for_model'): return self.model.get_context_features_for_model(symbol) return None except Exception as e: logger.error(f"Error getting extrema context features: {e}") return None