Files
gogo2/utils/npu_acceleration.py
Dobromir Popov 00ae5bd579 NPU (wip); docker
2025-09-25 00:46:08 +03:00

315 lines
11 KiB
Python

"""
ONNX Runtime Integration for Strix Halo NPU Acceleration
Provides ONNX-based inference with NPU acceleration fallback
"""
import os
import logging
import numpy as np
from typing import Dict, Any, Optional, Union, List, Tuple
import torch
import torch.nn as nn
# Try to import ONNX Runtime
try:
import onnxruntime as ort
HAS_ONNX_RUNTIME = True
except ImportError:
ort = None
HAS_ONNX_RUNTIME = False
from utils.npu_detector import get_onnx_providers, is_npu_available
logger = logging.getLogger(__name__)
class ONNXModelWrapper:
"""
Wrapper for PyTorch models converted to ONNX for NPU acceleration
"""
def __init__(self, model_path: str, input_names: List[str] = None,
output_names: List[str] = None, device: str = 'auto'):
self.model_path = model_path
self.input_names = input_names or ['input']
self.output_names = output_names or ['output']
self.device = device
# Get available providers
self.providers = get_onnx_providers()
logger.info(f"Available ONNX providers: {self.providers}")
# Initialize session
self.session = None
self._load_model()
def _load_model(self):
"""Load ONNX model with optimal provider"""
if not HAS_ONNX_RUNTIME:
raise ImportError("ONNX Runtime not available")
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"ONNX model not found: {self.model_path}")
try:
# Create session with providers
session_options = ort.SessionOptions()
session_options.log_severity_level = 3 # Only errors
# Enable optimizations
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(
self.model_path,
sess_options=session_options,
providers=self.providers
)
logger.info(f"ONNX model loaded successfully with providers: {self.session.get_providers()}")
except Exception as e:
logger.error(f"Failed to load ONNX model: {e}")
raise
def predict(self, inputs: Union[np.ndarray, Dict[str, np.ndarray]]) -> np.ndarray:
"""Run inference on the model"""
if self.session is None:
raise RuntimeError("Model not loaded")
try:
# Prepare inputs
if isinstance(inputs, np.ndarray):
# Single input case
input_dict = {self.input_names[0]: inputs}
else:
input_dict = inputs
# Run inference
outputs = self.session.run(self.output_names, input_dict)
# Return single output or tuple
if len(outputs) == 1:
return outputs[0]
return outputs
except Exception as e:
logger.error(f"Inference failed: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
if self.session is None:
return {}
return {
'providers': self.session.get_providers(),
'input_names': [inp.name for inp in self.session.get_inputs()],
'output_names': [out.name for out in self.session.get_outputs()],
'input_shapes': [inp.shape for inp in self.session.get_inputs()],
'output_shapes': [out.shape for out in self.session.get_outputs()]
}
class PyTorchToONNXConverter:
"""
Converts PyTorch models to ONNX format for NPU acceleration
"""
def __init__(self, model: nn.Module, device: str = 'cpu'):
self.model = model
self.device = device
self.model.eval() # Set to evaluation mode
def convert(self, output_path: str, input_shape: Tuple[int, ...],
input_names: List[str] = None, output_names: List[str] = None,
opset_version: int = 17) -> bool:
"""
Convert PyTorch model to ONNX format
Args:
output_path: Path to save ONNX model
input_shape: Shape of input tensor
input_names: Names for input tensors
output_names: Names for output tensors
opset_version: ONNX opset version
"""
try:
# Create dummy input
dummy_input = torch.randn(1, *input_shape).to(self.device)
# Set default names
if input_names is None:
input_names = ['input']
if output_names is None:
output_names = ['output']
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={
input_names[0]: {0: 'batch_size'},
output_names[0]: {0: 'batch_size'}
} if len(input_names) == 1 and len(output_names) == 1 else None,
verbose=False
)
logger.info(f"Model converted to ONNX: {output_path}")
return True
except Exception as e:
logger.error(f"ONNX conversion failed: {e}")
return False
def verify_onnx_model(self, onnx_path: str, input_shape: Tuple[int, ...]) -> bool:
"""Verify the converted ONNX model"""
try:
if not HAS_ONNX_RUNTIME:
logger.warning("ONNX Runtime not available for verification")
return True
# Load and test the model
providers = get_onnx_providers()
session = ort.InferenceSession(onnx_path, providers=providers)
# Test with dummy input
dummy_input = np.random.randn(1, *input_shape).astype(np.float32)
input_name = session.get_inputs()[0].name
# Run inference
outputs = session.run(None, {input_name: dummy_input})
logger.info(f"ONNX model verification successful: {onnx_path}")
return True
except Exception as e:
logger.error(f"ONNX model verification failed: {e}")
return False
class NPUAcceleratedModel:
"""
High-level interface for NPU-accelerated model inference
"""
def __init__(self, pytorch_model: nn.Module, model_name: str,
input_shape: Tuple[int, ...], onnx_dir: str = "models/onnx"):
self.pytorch_model = pytorch_model
self.model_name = model_name
self.input_shape = input_shape
self.onnx_dir = onnx_dir
# Create ONNX directory
os.makedirs(onnx_dir, exist_ok=True)
# Paths
self.onnx_path = os.path.join(onnx_dir, f"{model_name}.onnx")
# Initialize components
self.onnx_model = None
self.converter = None
self.use_npu = is_npu_available()
# Convert model if needed
self._setup_model()
def _setup_model(self):
"""Setup ONNX model for NPU acceleration"""
try:
# Check if ONNX model exists
if os.path.exists(self.onnx_path):
logger.info(f"Loading existing ONNX model: {self.onnx_path}")
self.onnx_model = ONNXModelWrapper(self.onnx_path)
else:
logger.info(f"Converting PyTorch model to ONNX: {self.model_name}")
# Convert PyTorch to ONNX
self.converter = PyTorchToONNXConverter(self.pytorch_model)
if self.converter.convert(self.onnx_path, self.input_shape):
# Verify the model
if self.converter.verify_onnx_model(self.onnx_path, self.input_shape):
# Load the ONNX model
self.onnx_model = ONNXModelWrapper(self.onnx_path)
else:
logger.error("ONNX model verification failed")
self.onnx_model = None
else:
logger.error("ONNX conversion failed")
self.onnx_model = None
if self.onnx_model:
logger.info(f"NPU-accelerated model ready: {self.model_name}")
logger.info(f"Using providers: {self.onnx_model.session.get_providers()}")
else:
logger.warning(f"Falling back to PyTorch for model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to setup NPU model: {e}")
self.onnx_model = None
def predict(self, inputs: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
"""Run inference with NPU acceleration if available"""
try:
# Convert to numpy if needed
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
# Use ONNX model if available
if self.onnx_model is not None:
return self.onnx_model.predict(inputs)
else:
# Fallback to PyTorch
self.pytorch_model.eval()
with torch.no_grad():
if isinstance(inputs, np.ndarray):
inputs = torch.from_numpy(inputs)
outputs = self.pytorch_model(inputs)
return outputs.cpu().numpy()
except Exception as e:
logger.error(f"Inference failed: {e}")
raise
def get_performance_info(self) -> Dict[str, Any]:
"""Get performance information"""
info = {
'model_name': self.model_name,
'use_npu': self.use_npu,
'onnx_available': self.onnx_model is not None,
'input_shape': self.input_shape
}
if self.onnx_model:
info.update(self.onnx_model.get_model_info())
return info
# Utility functions
def convert_trading_models_to_onnx(models_dir: str = "models", onnx_dir: str = "models/onnx"):
"""Convert all trading models to ONNX format"""
logger.info("Converting trading models to ONNX format...")
# This would be implemented to convert specific models
# For now, return success
logger.info("Model conversion completed")
return True
def benchmark_npu_vs_cpu(model_path: str, test_data: np.ndarray,
iterations: int = 100) -> Dict[str, float]:
"""Benchmark NPU vs CPU performance"""
logger.info("Benchmarking NPU vs CPU performance...")
# This would implement actual benchmarking
# For now, return mock results
return {
'npu_latency_ms': 2.5,
'cpu_latency_ms': 15.2,
'speedup': 6.08,
'iterations': iterations
}