178 lines
5.3 KiB
Python
178 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick NPU Integration Test for Orchestrator
|
|
Tests NPU acceleration with the existing orchestrator system
|
|
"""
|
|
import sys
|
|
import os
|
|
import logging
|
|
|
|
# Add project root to path
|
|
sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2')
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_orchestrator_npu_integration():
|
|
"""Test NPU integration with orchestrator"""
|
|
print("=== Orchestrator NPU Integration Test ===")
|
|
|
|
try:
|
|
# Test NPU detection
|
|
from utils.npu_detector import is_npu_available, get_npu_info
|
|
|
|
npu_available = is_npu_available()
|
|
npu_info = get_npu_info()
|
|
|
|
print(f"NPU Available: {npu_available}")
|
|
print(f"NPU Info: {npu_info}")
|
|
|
|
if not npu_available:
|
|
print("⚠️ NPU not available, testing fallback behavior")
|
|
|
|
# Test model interfaces with NPU support
|
|
from NN.models.model_interfaces import CNNModelInterface, RLAgentInterface
|
|
|
|
# Create a simple test model
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class TestModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc = nn.Linear(50, 3)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
test_model = TestModel()
|
|
|
|
# Test CNN interface
|
|
print("\nTesting CNN interface with NPU...")
|
|
cnn_interface = CNNModelInterface(
|
|
model=test_model,
|
|
name="test_cnn",
|
|
enable_npu=True,
|
|
input_shape=(50,)
|
|
)
|
|
|
|
# Test RL interface
|
|
print("Testing RL interface with NPU...")
|
|
rl_interface = RLAgentInterface(
|
|
model=test_model,
|
|
name="test_rl",
|
|
enable_npu=True,
|
|
input_shape=(50,)
|
|
)
|
|
|
|
# Test predictions
|
|
import numpy as np
|
|
test_data = np.random.randn(1, 50).astype(np.float32)
|
|
|
|
cnn_output = cnn_interface.predict(test_data)
|
|
rl_output = rl_interface.predict(test_data)
|
|
|
|
print(f"✅ CNN interface working: {cnn_output is not None}")
|
|
print(f"✅ RL interface working: {rl_output is not None}")
|
|
|
|
# Test acceleration info
|
|
cnn_info = cnn_interface.get_acceleration_info()
|
|
rl_info = rl_interface.get_acceleration_info()
|
|
|
|
print(f"\nCNN Acceleration Info:")
|
|
for key, value in cnn_info.items():
|
|
print(f" {key}: {value}")
|
|
|
|
print(f"\nRL Acceleration Info:")
|
|
for key, value in rl_info.items():
|
|
print(f" {key}: {value}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Orchestrator NPU integration test failed: {e}")
|
|
logger.exception("Detailed error:")
|
|
return False
|
|
|
|
def test_dashboard_npu_status():
|
|
"""Test NPU status display in dashboard"""
|
|
print("\n=== Dashboard NPU Status Test ===")
|
|
|
|
try:
|
|
# Test NPU detection for dashboard
|
|
from utils.npu_detector import get_npu_info, get_onnx_providers
|
|
|
|
npu_info = get_npu_info()
|
|
providers = get_onnx_providers()
|
|
|
|
print(f"NPU Status for Dashboard:")
|
|
print(f" Available: {npu_info['available']}")
|
|
print(f" Providers: {providers}")
|
|
|
|
# This would be integrated into the dashboard
|
|
dashboard_status = {
|
|
'npu_available': npu_info['available'],
|
|
'providers': providers,
|
|
'status': 'active' if npu_info['available'] else 'inactive'
|
|
}
|
|
|
|
print(f"Dashboard Status: {dashboard_status}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Dashboard NPU status test failed: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Run orchestrator NPU integration tests"""
|
|
print("Starting Orchestrator NPU Integration Tests...")
|
|
print("=" * 50)
|
|
|
|
tests = [
|
|
("Orchestrator Integration", test_orchestrator_npu_integration),
|
|
("Dashboard Status", test_dashboard_npu_status)
|
|
]
|
|
|
|
results = {}
|
|
|
|
for test_name, test_func in tests:
|
|
try:
|
|
results[test_name] = test_func()
|
|
except Exception as e:
|
|
print(f"❌ {test_name} failed with exception: {e}")
|
|
results[test_name] = False
|
|
|
|
# Summary
|
|
print("\n" + "=" * 50)
|
|
print("ORCHESTRATOR NPU INTEGRATION SUMMARY")
|
|
print("=" * 50)
|
|
|
|
passed = 0
|
|
total = len(tests)
|
|
|
|
for test_name, result in results.items():
|
|
status = "✅ PASS" if result else "❌ FAIL"
|
|
print(f"{test_name}: {status}")
|
|
if result:
|
|
passed += 1
|
|
|
|
print(f"\nOverall: {passed}/{total} tests passed")
|
|
|
|
if passed == total:
|
|
print("🎉 Orchestrator NPU integration successful!")
|
|
print("\nNext steps:")
|
|
print("1. Run the full integration test: python3 test_npu_integration.py")
|
|
print("2. Start your trading system with NPU acceleration")
|
|
print("3. Monitor NPU performance in the dashboard")
|
|
else:
|
|
print("⚠️ Some integration tests failed. Check the output above.")
|
|
|
|
return passed == total
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|
|
|