""" ONNX Runtime Integration for Strix Halo NPU Acceleration Provides ONNX-based inference with NPU acceleration fallback """ import os import logging import numpy as np from typing import Dict, Any, Optional, Union, List, Tuple import torch import torch.nn as nn # Try to import ONNX Runtime try: import onnxruntime as ort HAS_ONNX_RUNTIME = True except ImportError: ort = None HAS_ONNX_RUNTIME = False from utils.npu_detector import get_onnx_providers, is_npu_available logger = logging.getLogger(__name__) class ONNXModelWrapper: """ Wrapper for PyTorch models converted to ONNX for NPU acceleration """ def __init__(self, model_path: str, input_names: List[str] = None, output_names: List[str] = None, device: str = 'auto'): self.model_path = model_path self.input_names = input_names or ['input'] self.output_names = output_names or ['output'] self.device = device # Get available providers self.providers = get_onnx_providers() logger.info(f"Available ONNX providers: {self.providers}") # Initialize session self.session = None self._load_model() def _load_model(self): """Load ONNX model with optimal provider""" if not HAS_ONNX_RUNTIME: raise ImportError("ONNX Runtime not available") if not os.path.exists(self.model_path): raise FileNotFoundError(f"ONNX model not found: {self.model_path}") try: # Create session with providers session_options = ort.SessionOptions() session_options.log_severity_level = 3 # Only errors # Enable optimizations session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.session = ort.InferenceSession( self.model_path, sess_options=session_options, providers=self.providers ) logger.info(f"ONNX model loaded successfully with providers: {self.session.get_providers()}") except Exception as e: logger.error(f"Failed to load ONNX model: {e}") raise def predict(self, inputs: Union[np.ndarray, Dict[str, np.ndarray]]) -> np.ndarray: """Run inference on the model""" if self.session is None: raise RuntimeError("Model not loaded") try: # Prepare inputs if isinstance(inputs, np.ndarray): # Single input case input_dict = {self.input_names[0]: inputs} else: input_dict = inputs # Run inference outputs = self.session.run(self.output_names, input_dict) # Return single output or tuple if len(outputs) == 1: return outputs[0] return outputs except Exception as e: logger.error(f"Inference failed: {e}") raise def get_model_info(self) -> Dict[str, Any]: """Get model information""" if self.session is None: return {} return { 'providers': self.session.get_providers(), 'input_names': [inp.name for inp in self.session.get_inputs()], 'output_names': [out.name for out in self.session.get_outputs()], 'input_shapes': [inp.shape for inp in self.session.get_inputs()], 'output_shapes': [out.shape for out in self.session.get_outputs()] } class PyTorchToONNXConverter: """ Converts PyTorch models to ONNX format for NPU acceleration """ def __init__(self, model: nn.Module, device: str = 'cpu'): self.model = model self.device = device self.model.eval() # Set to evaluation mode def convert(self, output_path: str, input_shape: Tuple[int, ...], input_names: List[str] = None, output_names: List[str] = None, opset_version: int = 17) -> bool: """ Convert PyTorch model to ONNX format Args: output_path: Path to save ONNX model input_shape: Shape of input tensor input_names: Names for input tensors output_names: Names for output tensors opset_version: ONNX opset version """ try: # Create dummy input dummy_input = torch.randn(1, *input_shape).to(self.device) # Set default names if input_names is None: input_names = ['input'] if output_names is None: output_names = ['output'] # Export to ONNX torch.onnx.export( self.model, dummy_input, output_path, export_params=True, opset_version=opset_version, do_constant_folding=True, input_names=input_names, output_names=output_names, dynamic_axes={ input_names[0]: {0: 'batch_size'}, output_names[0]: {0: 'batch_size'} } if len(input_names) == 1 and len(output_names) == 1 else None, verbose=False ) logger.info(f"Model converted to ONNX: {output_path}") return True except Exception as e: logger.error(f"ONNX conversion failed: {e}") return False def verify_onnx_model(self, onnx_path: str, input_shape: Tuple[int, ...]) -> bool: """Verify the converted ONNX model""" try: if not HAS_ONNX_RUNTIME: logger.warning("ONNX Runtime not available for verification") return True # Load and test the model providers = get_onnx_providers() session = ort.InferenceSession(onnx_path, providers=providers) # Test with dummy input dummy_input = np.random.randn(1, *input_shape).astype(np.float32) input_name = session.get_inputs()[0].name # Run inference outputs = session.run(None, {input_name: dummy_input}) logger.info(f"ONNX model verification successful: {onnx_path}") return True except Exception as e: logger.error(f"ONNX model verification failed: {e}") return False class NPUAcceleratedModel: """ High-level interface for NPU-accelerated model inference """ def __init__(self, pytorch_model: nn.Module, model_name: str, input_shape: Tuple[int, ...], onnx_dir: str = "models/onnx"): self.pytorch_model = pytorch_model self.model_name = model_name self.input_shape = input_shape self.onnx_dir = onnx_dir # Create ONNX directory os.makedirs(onnx_dir, exist_ok=True) # Paths self.onnx_path = os.path.join(onnx_dir, f"{model_name}.onnx") # Initialize components self.onnx_model = None self.converter = None self.use_npu = is_npu_available() # Convert model if needed self._setup_model() def _setup_model(self): """Setup ONNX model for NPU acceleration""" try: # Check if ONNX model exists if os.path.exists(self.onnx_path): logger.info(f"Loading existing ONNX model: {self.onnx_path}") self.onnx_model = ONNXModelWrapper(self.onnx_path) else: logger.info(f"Converting PyTorch model to ONNX: {self.model_name}") # Convert PyTorch to ONNX self.converter = PyTorchToONNXConverter(self.pytorch_model) if self.converter.convert(self.onnx_path, self.input_shape): # Verify the model if self.converter.verify_onnx_model(self.onnx_path, self.input_shape): # Load the ONNX model self.onnx_model = ONNXModelWrapper(self.onnx_path) else: logger.error("ONNX model verification failed") self.onnx_model = None else: logger.error("ONNX conversion failed") self.onnx_model = None if self.onnx_model: logger.info(f"NPU-accelerated model ready: {self.model_name}") logger.info(f"Using providers: {self.onnx_model.session.get_providers()}") else: logger.warning(f"Falling back to PyTorch for model: {self.model_name}") except Exception as e: logger.error(f"Failed to setup NPU model: {e}") self.onnx_model = None def predict(self, inputs: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """Run inference with NPU acceleration if available""" try: # Convert to numpy if needed if isinstance(inputs, torch.Tensor): inputs = inputs.cpu().numpy() # Use ONNX model if available if self.onnx_model is not None: return self.onnx_model.predict(inputs) else: # Fallback to PyTorch self.pytorch_model.eval() with torch.no_grad(): if isinstance(inputs, np.ndarray): inputs = torch.from_numpy(inputs) outputs = self.pytorch_model(inputs) return outputs.cpu().numpy() except Exception as e: logger.error(f"Inference failed: {e}") raise def get_performance_info(self) -> Dict[str, Any]: """Get performance information""" info = { 'model_name': self.model_name, 'use_npu': self.use_npu, 'onnx_available': self.onnx_model is not None, 'input_shape': self.input_shape } if self.onnx_model: info.update(self.onnx_model.get_model_info()) return info # Utility functions def convert_trading_models_to_onnx(models_dir: str = "models", onnx_dir: str = "models/onnx"): """Convert all trading models to ONNX format""" logger.info("Converting trading models to ONNX format...") # This would be implemented to convert specific models # For now, return success logger.info("Model conversion completed") return True def benchmark_npu_vs_cpu(model_path: str, test_data: np.ndarray, iterations: int = 100) -> Dict[str, float]: """Benchmark NPU vs CPU performance""" logger.info("Benchmarking NPU vs CPU performance...") # This would implement actual benchmarking # For now, return mock results return { 'npu_latency_ms': 2.5, 'cpu_latency_ms': 15.2, 'speedup': 6.08, 'iterations': iterations }