203 lines
7.2 KiB
Python
203 lines
7.2 KiB
Python
"""
|
|
Model Interfaces Module
|
|
|
|
Defines abstract base classes and concrete implementations for various model types
|
|
to ensure consistent interaction within the trading system.
|
|
Includes NPU acceleration support for Strix Halo processors.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from typing import Dict, Any, Optional, List, Union
|
|
from abc import ABC, abstractmethod
|
|
import numpy as np
|
|
|
|
# Try to import NPU acceleration utilities
|
|
try:
|
|
from utils.npu_acceleration import NPUAcceleratedModel, is_npu_available
|
|
from utils.npu_detector import get_npu_info
|
|
HAS_NPU_SUPPORT = True
|
|
except ImportError:
|
|
HAS_NPU_SUPPORT = False
|
|
NPUAcceleratedModel = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelInterface(ABC):
|
|
"""Base interface for all models with NPU acceleration support"""
|
|
|
|
def __init__(self, name: str, enable_npu: bool = True):
|
|
self.name = name
|
|
self.enable_npu = enable_npu and HAS_NPU_SUPPORT
|
|
self.npu_model = None
|
|
self.npu_available = False
|
|
|
|
# Initialize NPU acceleration if available
|
|
if self.enable_npu:
|
|
self._setup_npu_acceleration()
|
|
|
|
def _setup_npu_acceleration(self):
|
|
"""Setup NPU acceleration for this model"""
|
|
try:
|
|
if HAS_NPU_SUPPORT and is_npu_available():
|
|
self.npu_available = True
|
|
logger.info(f"NPU acceleration available for model: {self.name}")
|
|
else:
|
|
logger.info(f"NPU acceleration not available for model: {self.name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup NPU acceleration: {e}")
|
|
self.npu_available = False
|
|
|
|
def get_acceleration_info(self) -> Dict[str, Any]:
|
|
"""Get acceleration information"""
|
|
info = {
|
|
'model_name': self.name,
|
|
'npu_support_available': HAS_NPU_SUPPORT,
|
|
'npu_enabled': self.enable_npu,
|
|
'npu_available': self.npu_available
|
|
}
|
|
|
|
if HAS_NPU_SUPPORT:
|
|
info.update(get_npu_info())
|
|
|
|
return info
|
|
|
|
@abstractmethod
|
|
def predict(self, data):
|
|
"""Make a prediction"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_memory_usage(self) -> float:
|
|
"""Get memory usage in MB"""
|
|
pass
|
|
|
|
class CNNModelInterface(ModelInterface):
|
|
"""Interface for CNN models with NPU acceleration support"""
|
|
|
|
def __init__(self, model, name: str, enable_npu: bool = True, input_shape: tuple = None):
|
|
super().__init__(name, enable_npu)
|
|
self.model = model
|
|
self.input_shape = input_shape
|
|
|
|
# Setup NPU acceleration for CNN model
|
|
if self.enable_npu and self.npu_available and input_shape:
|
|
self._setup_cnn_npu_acceleration()
|
|
|
|
def _setup_cnn_npu_acceleration(self):
|
|
"""Setup NPU acceleration for CNN model"""
|
|
try:
|
|
if HAS_NPU_SUPPORT and NPUAcceleratedModel:
|
|
self.npu_model = NPUAcceleratedModel(
|
|
pytorch_model=self.model,
|
|
model_name=f"{self.name}_cnn",
|
|
input_shape=self.input_shape
|
|
)
|
|
logger.info(f"CNN NPU acceleration setup for: {self.name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup CNN NPU acceleration: {e}")
|
|
self.npu_model = None
|
|
|
|
def predict(self, data):
|
|
"""Make CNN prediction with NPU acceleration if available"""
|
|
try:
|
|
# Use NPU acceleration if available
|
|
if self.npu_model and self.npu_available:
|
|
return self.npu_model.predict(data)
|
|
|
|
# Fallback to original model
|
|
if hasattr(self.model, 'predict'):
|
|
return self.model.predict(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN prediction: {e}")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate CNN memory usage"""
|
|
base_memory = 50.0 # MB
|
|
|
|
# Add NPU memory overhead if using NPU acceleration
|
|
if self.npu_model:
|
|
base_memory += 25.0 # Additional NPU memory
|
|
|
|
return base_memory
|
|
|
|
class RLAgentInterface(ModelInterface):
|
|
"""Interface for RL agents with NPU acceleration support"""
|
|
|
|
def __init__(self, model, name: str, enable_npu: bool = True, input_shape: tuple = None):
|
|
super().__init__(name, enable_npu)
|
|
self.model = model
|
|
self.input_shape = input_shape
|
|
|
|
# Setup NPU acceleration for RL model
|
|
if self.enable_npu and self.npu_available and input_shape:
|
|
self._setup_rl_npu_acceleration()
|
|
|
|
def _setup_rl_npu_acceleration(self):
|
|
"""Setup NPU acceleration for RL model"""
|
|
try:
|
|
if HAS_NPU_SUPPORT and NPUAcceleratedModel:
|
|
self.npu_model = NPUAcceleratedModel(
|
|
pytorch_model=self.model,
|
|
model_name=f"{self.name}_rl",
|
|
input_shape=self.input_shape
|
|
)
|
|
logger.info(f"RL NPU acceleration setup for: {self.name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup RL NPU acceleration: {e}")
|
|
self.npu_model = None
|
|
|
|
def predict(self, data):
|
|
"""Make RL prediction with NPU acceleration if available"""
|
|
try:
|
|
# Use NPU acceleration if available
|
|
if self.npu_model and self.npu_available:
|
|
return self.npu_model.predict(data)
|
|
|
|
# Fallback to original model
|
|
if hasattr(self.model, 'act'):
|
|
return self.model.act(data)
|
|
elif hasattr(self.model, 'predict'):
|
|
return self.model.predict(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error in RL prediction: {e}")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate RL memory usage"""
|
|
base_memory = 25.0 # MB
|
|
|
|
# Add NPU memory overhead if using NPU acceleration
|
|
if self.npu_model:
|
|
base_memory += 15.0 # Additional NPU memory
|
|
|
|
return base_memory
|
|
|
|
class ExtremaTrainerInterface(ModelInterface):
|
|
"""Interface for ExtremaTrainer models, providing context features"""
|
|
|
|
def __init__(self, model, name: str):
|
|
super().__init__(name)
|
|
self.model = model
|
|
|
|
def predict(self, data=None):
|
|
"""ExtremaTrainer doesn't predict in the traditional sense, it provides features."""
|
|
logger.warning(f"Predict method called on ExtremaTrainerInterface ({self.name}). Use get_context_features_for_model instead.")
|
|
return None
|
|
|
|
def get_memory_usage(self) -> float:
|
|
"""Estimate ExtremaTrainer memory usage"""
|
|
return 30.0 # MB
|
|
|
|
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
|
"""Get context features from the ExtremaTrainer for model consumption."""
|
|
try:
|
|
if hasattr(self.model, 'get_context_features_for_model'):
|
|
return self.model.get_context_features_for_model(symbol)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting extrema context features: {e}")
|
|
return None |