#!/usr/bin/env python3 """ Comprehensive NPU Integration Test for Strix Halo Tests NPU acceleration with your trading models """ import sys import os import time import logging import numpy as np import torch import torch.nn as nn # Add project root to path sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2') # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def test_npu_detection(): """Test NPU detection and setup""" print("=== NPU Detection Test ===") try: from utils.npu_detector import get_npu_info, is_npu_available, get_onnx_providers info = get_npu_info() print(f"NPU Available: {info['available']}") print(f"NPU Info: {info['info']}") providers = get_onnx_providers() print(f"ONNX Providers: {providers}") if is_npu_available(): print("✅ NPU is available!") return True else: print("❌ NPU not available") return False except Exception as e: print(f"❌ NPU detection failed: {e}") return False def test_onnx_runtime(): """Test ONNX Runtime functionality""" print("\n=== ONNX Runtime Test ===") try: import onnxruntime as ort print(f"ONNX Runtime version: {ort.__version__}") # Test providers providers = ort.get_available_providers() print(f"Available providers: {providers}") # Test DirectML provider if 'DmlExecutionProvider' in providers: print("✅ DirectML provider available") else: print("❌ DirectML provider not available") return True except ImportError: print("❌ ONNX Runtime not installed") return False except Exception as e: print(f"❌ ONNX Runtime test failed: {e}") return False def create_test_model(): """Create a simple test model for NPU testing""" class SimpleTradingModel(nn.Module): def __init__(self, input_size=50, hidden_size=128, output_size=3): super().__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.1) def forward(self, x): x = self.relu(self.fc1(x)) x = self.dropout(x) x = self.relu(self.fc2(x)) x = self.dropout(x) x = self.fc3(x) return x return SimpleTradingModel() def test_model_conversion(): """Test PyTorch to ONNX conversion""" print("\n=== Model Conversion Test ===") try: from utils.npu_acceleration import PyTorchToONNXConverter # Create test model model = create_test_model() model.eval() # Create converter converter = PyTorchToONNXConverter(model) # Convert to ONNX onnx_path = "/tmp/test_trading_model.onnx" input_shape = (50,) # 50 features success = converter.convert( output_path=onnx_path, input_shape=input_shape, input_names=['trading_features'], output_names=['trading_signals'] ) if success: print("✅ Model conversion successful") # Verify the model if converter.verify_onnx_model(onnx_path, input_shape): print("✅ ONNX model verification successful") return True else: print("❌ ONNX model verification failed") return False else: print("❌ Model conversion failed") return False except Exception as e: print(f"❌ Model conversion test failed: {e}") return False def test_npu_acceleration(): """Test NPU-accelerated inference""" print("\n=== NPU Acceleration Test ===") try: from utils.npu_acceleration import NPUAcceleratedModel # Create test model model = create_test_model() model.eval() # Create NPU-accelerated model npu_model = NPUAcceleratedModel( pytorch_model=model, model_name="test_trading_model", input_shape=(50,) ) # Test inference test_input = np.random.randn(1, 50).astype(np.float32) start_time = time.time() output = npu_model.predict(test_input) inference_time = (time.time() - start_time) * 1000 # ms print(f"✅ NPU inference successful") print(f"Inference time: {inference_time:.2f} ms") print(f"Output shape: {output.shape}") # Get performance info perf_info = npu_model.get_performance_info() print(f"Performance info: {perf_info}") return True except Exception as e: print(f"❌ NPU acceleration test failed: {e}") return False def test_model_interfaces(): """Test enhanced model interfaces with NPU support""" print("\n=== Model Interfaces Test ===") try: from NN.models.model_interfaces import CNNModelInterface, RLAgentInterface # Create test models cnn_model = create_test_model() rl_model = create_test_model() # Test CNN interface cnn_interface = CNNModelInterface( model=cnn_model, name="test_cnn", enable_npu=True, input_shape=(50,) ) # Test RL interface rl_interface = RLAgentInterface( model=rl_model, name="test_rl", enable_npu=True, input_shape=(50,) ) # Test predictions test_data = np.random.randn(1, 50).astype(np.float32) cnn_output = cnn_interface.predict(test_data) rl_output = rl_interface.predict(test_data) print(f"✅ CNN interface prediction: {cnn_output is not None}") print(f"✅ RL interface prediction: {rl_output is not None}") # Test acceleration info cnn_info = cnn_interface.get_acceleration_info() rl_info = rl_interface.get_acceleration_info() print(f"CNN acceleration info: {cnn_info}") print(f"RL acceleration info: {rl_info}") return True except Exception as e: print(f"❌ Model interfaces test failed: {e}") return False def benchmark_performance(): """Benchmark NPU vs CPU performance""" print("\n=== Performance Benchmark ===") try: from utils.npu_acceleration import NPUAcceleratedModel # Create test model model = create_test_model() model.eval() # Create NPU-accelerated model npu_model = NPUAcceleratedModel( pytorch_model=model, model_name="benchmark_model", input_shape=(50,) ) # Test data test_data = np.random.randn(100, 50).astype(np.float32) # Benchmark NPU inference if npu_model.onnx_model: npu_times = [] for i in range(10): start_time = time.time() npu_model.predict(test_data[i:i+1]) npu_times.append((time.time() - start_time) * 1000) avg_npu_time = np.mean(npu_times) print(f"Average NPU inference time: {avg_npu_time:.2f} ms") # Benchmark CPU inference cpu_times = [] model.eval() with torch.no_grad(): for i in range(10): start_time = time.time() input_tensor = torch.from_numpy(test_data[i:i+1]) model(input_tensor) cpu_times.append((time.time() - start_time) * 1000) avg_cpu_time = np.mean(cpu_times) print(f"Average CPU inference time: {avg_cpu_time:.2f} ms") if npu_model.onnx_model: speedup = avg_cpu_time / avg_npu_time print(f"NPU speedup: {speedup:.2f}x") return True except Exception as e: print(f"❌ Performance benchmark failed: {e}") return False def test_integration_with_existing_models(): """Test integration with existing trading models""" print("\n=== Integration Test ===") try: # Test with existing CNN model from NN.models.cnn_model import EnhancedCNNModel # Create a small CNN model for testing cnn_model = EnhancedCNNModel( input_size=60, feature_dim=50, output_size=3 ) # Test NPU acceleration from utils.npu_acceleration import NPUAcceleratedModel npu_cnn = NPUAcceleratedModel( pytorch_model=cnn_model, model_name="enhanced_cnn_test", input_shape=(60, 50) ) # Test inference test_input = np.random.randn(1, 60, 50).astype(np.float32) output = npu_cnn.predict(test_input) print(f"✅ Enhanced CNN NPU integration successful") print(f"Output shape: {output.shape}") return True except Exception as e: print(f"❌ Integration test failed: {e}") return False def main(): """Run all NPU tests""" print("Starting Strix Halo NPU Integration Tests...") print("=" * 50) tests = [ ("NPU Detection", test_npu_detection), ("ONNX Runtime", test_onnx_runtime), ("Model Conversion", test_model_conversion), ("NPU Acceleration", test_npu_acceleration), ("Model Interfaces", test_model_interfaces), ("Performance Benchmark", benchmark_performance), ("Integration Test", test_integration_with_existing_models) ] results = {} for test_name, test_func in tests: try: results[test_name] = test_func() except Exception as e: print(f"❌ {test_name} failed with exception: {e}") results[test_name] = False # Summary print("\n" + "=" * 50) print("TEST SUMMARY") print("=" * 50) passed = 0 total = len(tests) for test_name, result in results.items(): status = "✅ PASS" if result else "❌ FAIL" print(f"{test_name}: {status}") if result: passed += 1 print(f"\nOverall: {passed}/{total} tests passed") if passed == total: print("🎉 All NPU integration tests passed!") else: print("⚠️ Some tests failed. Check the output above for details.") return passed == total if __name__ == "__main__": success = main() sys.exit(0 if success else 1)