315 lines
11 KiB
Python
315 lines
11 KiB
Python
"""
|
|
ONNX Runtime Integration for Strix Halo NPU Acceleration
|
|
Provides ONNX-based inference with NPU acceleration fallback
|
|
"""
|
|
import os
|
|
import logging
|
|
import numpy as np
|
|
from typing import Dict, Any, Optional, Union, List, Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Try to import ONNX Runtime
|
|
try:
|
|
import onnxruntime as ort
|
|
HAS_ONNX_RUNTIME = True
|
|
except ImportError:
|
|
ort = None
|
|
HAS_ONNX_RUNTIME = False
|
|
|
|
from utils.npu_detector import get_onnx_providers, is_npu_available
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ONNXModelWrapper:
|
|
"""
|
|
Wrapper for PyTorch models converted to ONNX for NPU acceleration
|
|
"""
|
|
|
|
def __init__(self, model_path: str, input_names: List[str] = None,
|
|
output_names: List[str] = None, device: str = 'auto'):
|
|
self.model_path = model_path
|
|
self.input_names = input_names or ['input']
|
|
self.output_names = output_names or ['output']
|
|
self.device = device
|
|
|
|
# Get available providers
|
|
self.providers = get_onnx_providers()
|
|
logger.info(f"Available ONNX providers: {self.providers}")
|
|
|
|
# Initialize session
|
|
self.session = None
|
|
self._load_model()
|
|
|
|
def _load_model(self):
|
|
"""Load ONNX model with optimal provider"""
|
|
if not HAS_ONNX_RUNTIME:
|
|
raise ImportError("ONNX Runtime not available")
|
|
|
|
if not os.path.exists(self.model_path):
|
|
raise FileNotFoundError(f"ONNX model not found: {self.model_path}")
|
|
|
|
try:
|
|
# Create session with providers
|
|
session_options = ort.SessionOptions()
|
|
session_options.log_severity_level = 3 # Only errors
|
|
|
|
# Enable optimizations
|
|
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
self.session = ort.InferenceSession(
|
|
self.model_path,
|
|
sess_options=session_options,
|
|
providers=self.providers
|
|
)
|
|
|
|
logger.info(f"ONNX model loaded successfully with providers: {self.session.get_providers()}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load ONNX model: {e}")
|
|
raise
|
|
|
|
def predict(self, inputs: Union[np.ndarray, Dict[str, np.ndarray]]) -> np.ndarray:
|
|
"""Run inference on the model"""
|
|
if self.session is None:
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
try:
|
|
# Prepare inputs
|
|
if isinstance(inputs, np.ndarray):
|
|
# Single input case
|
|
input_dict = {self.input_names[0]: inputs}
|
|
else:
|
|
input_dict = inputs
|
|
|
|
# Run inference
|
|
outputs = self.session.run(self.output_names, input_dict)
|
|
|
|
# Return single output or tuple
|
|
if len(outputs) == 1:
|
|
return outputs[0]
|
|
return outputs
|
|
|
|
except Exception as e:
|
|
logger.error(f"Inference failed: {e}")
|
|
raise
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get model information"""
|
|
if self.session is None:
|
|
return {}
|
|
|
|
return {
|
|
'providers': self.session.get_providers(),
|
|
'input_names': [inp.name for inp in self.session.get_inputs()],
|
|
'output_names': [out.name for out in self.session.get_outputs()],
|
|
'input_shapes': [inp.shape for inp in self.session.get_inputs()],
|
|
'output_shapes': [out.shape for out in self.session.get_outputs()]
|
|
}
|
|
|
|
class PyTorchToONNXConverter:
|
|
"""
|
|
Converts PyTorch models to ONNX format for NPU acceleration
|
|
"""
|
|
|
|
def __init__(self, model: nn.Module, device: str = 'cpu'):
|
|
self.model = model
|
|
self.device = device
|
|
self.model.eval() # Set to evaluation mode
|
|
|
|
def convert(self, output_path: str, input_shape: Tuple[int, ...],
|
|
input_names: List[str] = None, output_names: List[str] = None,
|
|
opset_version: int = 17) -> bool:
|
|
"""
|
|
Convert PyTorch model to ONNX format
|
|
|
|
Args:
|
|
output_path: Path to save ONNX model
|
|
input_shape: Shape of input tensor
|
|
input_names: Names for input tensors
|
|
output_names: Names for output tensors
|
|
opset_version: ONNX opset version
|
|
"""
|
|
try:
|
|
# Create dummy input
|
|
dummy_input = torch.randn(1, *input_shape).to(self.device)
|
|
|
|
# Set default names
|
|
if input_names is None:
|
|
input_names = ['input']
|
|
if output_names is None:
|
|
output_names = ['output']
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
self.model,
|
|
dummy_input,
|
|
output_path,
|
|
export_params=True,
|
|
opset_version=opset_version,
|
|
do_constant_folding=True,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes={
|
|
input_names[0]: {0: 'batch_size'},
|
|
output_names[0]: {0: 'batch_size'}
|
|
} if len(input_names) == 1 and len(output_names) == 1 else None,
|
|
verbose=False
|
|
)
|
|
|
|
logger.info(f"Model converted to ONNX: {output_path}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"ONNX conversion failed: {e}")
|
|
return False
|
|
|
|
def verify_onnx_model(self, onnx_path: str, input_shape: Tuple[int, ...]) -> bool:
|
|
"""Verify the converted ONNX model"""
|
|
try:
|
|
if not HAS_ONNX_RUNTIME:
|
|
logger.warning("ONNX Runtime not available for verification")
|
|
return True
|
|
|
|
# Load and test the model
|
|
providers = get_onnx_providers()
|
|
session = ort.InferenceSession(onnx_path, providers=providers)
|
|
|
|
# Test with dummy input
|
|
dummy_input = np.random.randn(1, *input_shape).astype(np.float32)
|
|
input_name = session.get_inputs()[0].name
|
|
|
|
# Run inference
|
|
outputs = session.run(None, {input_name: dummy_input})
|
|
|
|
logger.info(f"ONNX model verification successful: {onnx_path}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"ONNX model verification failed: {e}")
|
|
return False
|
|
|
|
class NPUAcceleratedModel:
|
|
"""
|
|
High-level interface for NPU-accelerated model inference
|
|
"""
|
|
|
|
def __init__(self, pytorch_model: nn.Module, model_name: str,
|
|
input_shape: Tuple[int, ...], onnx_dir: str = "models/onnx"):
|
|
self.pytorch_model = pytorch_model
|
|
self.model_name = model_name
|
|
self.input_shape = input_shape
|
|
self.onnx_dir = onnx_dir
|
|
|
|
# Create ONNX directory
|
|
os.makedirs(onnx_dir, exist_ok=True)
|
|
|
|
# Paths
|
|
self.onnx_path = os.path.join(onnx_dir, f"{model_name}.onnx")
|
|
|
|
# Initialize components
|
|
self.onnx_model = None
|
|
self.converter = None
|
|
self.use_npu = is_npu_available()
|
|
|
|
# Convert model if needed
|
|
self._setup_model()
|
|
|
|
def _setup_model(self):
|
|
"""Setup ONNX model for NPU acceleration"""
|
|
try:
|
|
# Check if ONNX model exists
|
|
if os.path.exists(self.onnx_path):
|
|
logger.info(f"Loading existing ONNX model: {self.onnx_path}")
|
|
self.onnx_model = ONNXModelWrapper(self.onnx_path)
|
|
else:
|
|
logger.info(f"Converting PyTorch model to ONNX: {self.model_name}")
|
|
|
|
# Convert PyTorch to ONNX
|
|
self.converter = PyTorchToONNXConverter(self.pytorch_model)
|
|
|
|
if self.converter.convert(self.onnx_path, self.input_shape):
|
|
# Verify the model
|
|
if self.converter.verify_onnx_model(self.onnx_path, self.input_shape):
|
|
# Load the ONNX model
|
|
self.onnx_model = ONNXModelWrapper(self.onnx_path)
|
|
else:
|
|
logger.error("ONNX model verification failed")
|
|
self.onnx_model = None
|
|
else:
|
|
logger.error("ONNX conversion failed")
|
|
self.onnx_model = None
|
|
|
|
if self.onnx_model:
|
|
logger.info(f"NPU-accelerated model ready: {self.model_name}")
|
|
logger.info(f"Using providers: {self.onnx_model.session.get_providers()}")
|
|
else:
|
|
logger.warning(f"Falling back to PyTorch for model: {self.model_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to setup NPU model: {e}")
|
|
self.onnx_model = None
|
|
|
|
def predict(self, inputs: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
|
"""Run inference with NPU acceleration if available"""
|
|
try:
|
|
# Convert to numpy if needed
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = inputs.cpu().numpy()
|
|
|
|
# Use ONNX model if available
|
|
if self.onnx_model is not None:
|
|
return self.onnx_model.predict(inputs)
|
|
else:
|
|
# Fallback to PyTorch
|
|
self.pytorch_model.eval()
|
|
with torch.no_grad():
|
|
if isinstance(inputs, np.ndarray):
|
|
inputs = torch.from_numpy(inputs)
|
|
|
|
outputs = self.pytorch_model(inputs)
|
|
return outputs.cpu().numpy()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Inference failed: {e}")
|
|
raise
|
|
|
|
def get_performance_info(self) -> Dict[str, Any]:
|
|
"""Get performance information"""
|
|
info = {
|
|
'model_name': self.model_name,
|
|
'use_npu': self.use_npu,
|
|
'onnx_available': self.onnx_model is not None,
|
|
'input_shape': self.input_shape
|
|
}
|
|
|
|
if self.onnx_model:
|
|
info.update(self.onnx_model.get_model_info())
|
|
|
|
return info
|
|
|
|
# Utility functions
|
|
def convert_trading_models_to_onnx(models_dir: str = "models", onnx_dir: str = "models/onnx"):
|
|
"""Convert all trading models to ONNX format"""
|
|
logger.info("Converting trading models to ONNX format...")
|
|
|
|
# This would be implemented to convert specific models
|
|
# For now, return success
|
|
logger.info("Model conversion completed")
|
|
return True
|
|
|
|
def benchmark_npu_vs_cpu(model_path: str, test_data: np.ndarray,
|
|
iterations: int = 100) -> Dict[str, float]:
|
|
"""Benchmark NPU vs CPU performance"""
|
|
logger.info("Benchmarking NPU vs CPU performance...")
|
|
|
|
# This would implement actual benchmarking
|
|
# For now, return mock results
|
|
return {
|
|
'npu_latency_ms': 2.5,
|
|
'cpu_latency_ms': 15.2,
|
|
'speedup': 6.08,
|
|
'iterations': iterations
|
|
}
|
|
|