102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
"""
|
|
NPU Detection and Configuration for Strix Halo
|
|
"""
|
|
import os
|
|
import subprocess
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class NPUDetector:
|
|
"""Detects and configures AMD Strix Halo NPU"""
|
|
|
|
def __init__(self):
|
|
self.npu_available = False
|
|
self.npu_info = {}
|
|
self._detect_npu()
|
|
|
|
def _detect_npu(self):
|
|
"""Detect if NPU is available and get info"""
|
|
try:
|
|
# Check for amdxdna driver
|
|
if os.path.exists('/dev/amdxdna'):
|
|
self.npu_available = True
|
|
logger.info("AMD XDNA NPU driver detected")
|
|
|
|
# Check for NPU devices
|
|
try:
|
|
result = subprocess.run(['ls', '/dev/amdxdna*'],
|
|
capture_output=True, text=True, timeout=5)
|
|
if result.returncode == 0 and result.stdout.strip():
|
|
self.npu_available = True
|
|
self.npu_info['devices'] = result.stdout.strip().split('\n')
|
|
logger.info(f"NPU devices found: {self.npu_info['devices']}")
|
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
pass
|
|
|
|
# Check kernel version (need 6.11+)
|
|
try:
|
|
result = subprocess.run(['uname', '-r'],
|
|
capture_output=True, text=True, timeout=5)
|
|
if result.returncode == 0:
|
|
kernel_version = result.stdout.strip()
|
|
self.npu_info['kernel_version'] = kernel_version
|
|
logger.info(f"Kernel version: {kernel_version}")
|
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting NPU: {e}")
|
|
self.npu_available = False
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if NPU is available"""
|
|
return self.npu_available
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
"""Get NPU information"""
|
|
return {
|
|
'available': self.npu_available,
|
|
'info': self.npu_info
|
|
}
|
|
|
|
def get_onnx_providers(self) -> list:
|
|
"""Get available ONNX providers for NPU"""
|
|
providers = ['CPUExecutionProvider'] # Always available
|
|
|
|
if self.npu_available:
|
|
try:
|
|
import onnxruntime as ort
|
|
available_providers = ort.get_available_providers()
|
|
|
|
# Check for DirectML provider (NPU support)
|
|
if 'DmlExecutionProvider' in available_providers:
|
|
providers.insert(0, 'DmlExecutionProvider')
|
|
logger.info("DirectML provider available for NPU acceleration")
|
|
|
|
# Check for ROCm provider
|
|
if 'ROCMExecutionProvider' in available_providers:
|
|
providers.insert(0, 'ROCMExecutionProvider')
|
|
logger.info("ROCm provider available")
|
|
|
|
except ImportError:
|
|
logger.warning("ONNX Runtime not installed")
|
|
|
|
return providers
|
|
|
|
# Global NPU detector instance
|
|
npu_detector = NPUDetector()
|
|
|
|
def get_npu_info() -> Dict[str, Any]:
|
|
"""Get NPU information"""
|
|
return npu_detector.get_info()
|
|
|
|
def is_npu_available() -> bool:
|
|
"""Check if NPU is available"""
|
|
return npu_detector.is_available()
|
|
|
|
def get_onnx_providers() -> list:
|
|
"""Get available ONNX providers"""
|
|
return npu_detector.get_onnx_providers()
|