99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
"""
|
|
Model Interfaces Module
|
|
|
|
Defines abstract base classes and concrete implementations for various model types
|
|
to ensure consistent interaction within the trading system.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, List
|
|
from abc import ABC, abstractmethod
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelInterface(ABC):
|
|
"""Base interface for all models"""
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
@abstractmethod
|
|
def predict(self, data):
|
|
"""Make a prediction"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_memory_usage(self) -> float:
|
|
"""Get memory usage in MB"""
|
|
pass
|
|
|
|
class CNNModelInterface(ModelInterface):
|
|
"""Interface for CNN models"""
|
|
|
|
def __init__(self, model, name: str):
|
|
super().__init__(name)
|
|
self.model = model
|
|
|
|
def predict(self, data):
|
|
"""Make CNN prediction"""
|
|
try:
|
|
if hasattr(self.model, 'predict'):
|
|
return self.model.predict(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN prediction: {e}")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate CNN memory usage"""
|
|
return 50.0 # MB
|
|
|
|
class RLAgentInterface(ModelInterface):
|
|
"""Interface for RL agents"""
|
|
|
|
def __init__(self, model, name: str):
|
|
super().__init__(name)
|
|
self.model = model
|
|
|
|
def predict(self, data):
|
|
"""Make RL prediction"""
|
|
try:
|
|
if hasattr(self.model, 'act'):
|
|
return self.model.act(data)
|
|
elif hasattr(self.model, 'predict'):
|
|
return self.model.predict(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error in RL prediction: {e}")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate RL memory usage"""
|
|
return 25.0 # MB
|
|
|
|
class ExtremaTrainerInterface(ModelInterface):
|
|
"""Interface for ExtremaTrainer models, providing context features"""
|
|
|
|
def __init__(self, model, name: str):
|
|
super().__init__(name)
|
|
self.model = model
|
|
|
|
def predict(self, data=None):
|
|
"""ExtremaTrainer doesn't predict in the traditional sense, it provides features."""
|
|
logger.warning(f"Predict method called on ExtremaTrainerInterface ({self.name}). Use get_context_features_for_model instead.")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate ExtremaTrainer memory usage"""
|
|
return 30.0 # MB
|
|
|
|
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
|
"""Get context features from the ExtremaTrainer for model consumption."""
|
|
try:
|
|
if hasattr(self.model, 'get_context_features_for_model'):
|
|
return self.model.get_context_features_for_model(symbol)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting extrema context features: {e}")
|
|
return None |