81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for Strix Halo NPU functionality
|
|
"""
|
|
import sys
|
|
import os
|
|
sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2')
|
|
|
|
from utils.npu_detector import get_npu_info, is_npu_available, get_onnx_providers
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_npu_detection():
|
|
"""Test NPU detection"""
|
|
print("=== NPU Detection Test ===")
|
|
|
|
info = get_npu_info()
|
|
print(f"NPU Available: {info['available']}")
|
|
print(f"NPU Info: {info['info']}")
|
|
|
|
if is_npu_available():
|
|
print("✅ NPU is available!")
|
|
else:
|
|
print("❌ NPU not available")
|
|
|
|
return info['available']
|
|
|
|
def test_onnx_providers():
|
|
"""Test ONNX providers"""
|
|
print("\n=== ONNX Providers Test ===")
|
|
|
|
providers = get_onnx_providers()
|
|
print(f"Available providers: {providers}")
|
|
|
|
try:
|
|
import onnxruntime as ort
|
|
print(f"ONNX Runtime version: {ort.__version__}")
|
|
|
|
# Test creating a session with NPU provider
|
|
if 'DmlExecutionProvider' in providers:
|
|
print("✅ DirectML provider available for NPU")
|
|
else:
|
|
print("❌ DirectML provider not available")
|
|
|
|
except ImportError:
|
|
print("❌ ONNX Runtime not installed")
|
|
|
|
def test_simple_inference():
|
|
"""Test simple inference with NPU"""
|
|
print("\n=== Simple Inference Test ===")
|
|
|
|
try:
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
# Create a simple model for testing
|
|
providers = get_onnx_providers()
|
|
|
|
# Test with a simple tensor
|
|
test_input = np.random.randn(1, 10).astype(np.float32)
|
|
print(f"Test input shape: {test_input.shape}")
|
|
|
|
# This would be replaced with actual model loading
|
|
print("✅ Basic inference setup successful")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Inference test failed: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
print("Testing Strix Halo NPU Setup...")
|
|
|
|
npu_available = test_npu_detection()
|
|
test_onnx_providers()
|
|
|
|
if npu_available:
|
|
test_simple_inference()
|
|
|
|
print("\n=== Test Complete ===")
|