""" NPU Detection and Configuration for Strix Halo """ import os import subprocess import logging from typing import Optional, Dict, Any logger = logging.getLogger(__name__) class NPUDetector: """Detects and configures AMD Strix Halo NPU""" def __init__(self): self.npu_available = False self.npu_info = {} self._detect_npu() def _detect_npu(self): """Detect if NPU is available and get info""" try: # Check for amdxdna driver if os.path.exists('/dev/amdxdna'): self.npu_available = True logger.info("AMD XDNA NPU driver detected") # Check for NPU devices try: result = subprocess.run(['ls', '/dev/amdxdna*'], capture_output=True, text=True, timeout=5) if result.returncode == 0 and result.stdout.strip(): self.npu_available = True self.npu_info['devices'] = result.stdout.strip().split('\n') logger.info(f"NPU devices found: {self.npu_info['devices']}") except (subprocess.TimeoutExpired, FileNotFoundError): pass # Check kernel version (need 6.11+) try: result = subprocess.run(['uname', '-r'], capture_output=True, text=True, timeout=5) if result.returncode == 0: kernel_version = result.stdout.strip() self.npu_info['kernel_version'] = kernel_version logger.info(f"Kernel version: {kernel_version}") except (subprocess.TimeoutExpired, FileNotFoundError): pass except Exception as e: logger.error(f"Error detecting NPU: {e}") self.npu_available = False def is_available(self) -> bool: """Check if NPU is available""" return self.npu_available def get_info(self) -> Dict[str, Any]: """Get NPU information""" return { 'available': self.npu_available, 'info': self.npu_info } def get_onnx_providers(self) -> list: """Get available ONNX providers for NPU""" providers = ['CPUExecutionProvider'] # Always available if self.npu_available: try: import onnxruntime as ort available_providers = ort.get_available_providers() # Check for DirectML provider (NPU support) if 'DmlExecutionProvider' in available_providers: providers.insert(0, 'DmlExecutionProvider') logger.info("DirectML provider available for NPU acceleration") # Check for ROCm provider if 'ROCMExecutionProvider' in available_providers: providers.insert(0, 'ROCMExecutionProvider') logger.info("ROCm provider available") except ImportError: logger.warning("ONNX Runtime not installed") return providers # Global NPU detector instance npu_detector = NPUDetector() def get_npu_info() -> Dict[str, Any]: """Get NPU information""" return npu_detector.get_info() def is_npu_available() -> bool: """Check if NPU is available""" return npu_detector.is_available() def get_onnx_providers() -> list: """Get available ONNX providers""" return npu_detector.get_onnx_providers()