NPU (wip); docker
This commit is contained in:
17
.vscode/launch.json
vendored
17
.vscode/launch.json
vendored
@@ -79,7 +79,6 @@
|
||||
"TEST_ALL_COMPONENTS": "1"
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "🧪 CNN Live Training with Analysis",
|
||||
"type": "python",
|
||||
@@ -194,8 +193,22 @@
|
||||
"group": "Universal Data Stream",
|
||||
"order": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Containers: Python - General",
|
||||
"type": "docker",
|
||||
"request": "launch",
|
||||
"preLaunchTask": "docker-run: debug",
|
||||
"python": {
|
||||
"pathMappings": [
|
||||
{
|
||||
"localRoot": "${workspaceFolder}",
|
||||
"remoteRoot": "/app"
|
||||
}
|
||||
],
|
||||
"projectType": "general"
|
||||
}
|
||||
}
|
||||
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
|
21
.vscode/tasks.json
vendored
21
.vscode/tasks.json
vendored
@@ -136,6 +136,27 @@
|
||||
"endsPattern": ".*Dashboard.*ready.*"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "docker-build",
|
||||
"label": "docker-build",
|
||||
"platform": "python",
|
||||
"dockerBuild": {
|
||||
"tag": "gogo2:latest",
|
||||
"dockerfile": "${workspaceFolder}/Dockerfile",
|
||||
"context": "${workspaceFolder}",
|
||||
"pull": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "docker-run",
|
||||
"label": "docker-run: debug",
|
||||
"dependsOn": [
|
||||
"docker-build"
|
||||
],
|
||||
"python": {
|
||||
"file": "run_clean_dashboard.py"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@@ -3,20 +3,64 @@ Model Interfaces Module
|
||||
|
||||
Defines abstract base classes and concrete implementations for various model types
|
||||
to ensure consistent interaction within the trading system.
|
||||
Includes NPU acceleration support for Strix Halo processors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
# Try to import NPU acceleration utilities
|
||||
try:
|
||||
from utils.npu_acceleration import NPUAcceleratedModel, is_npu_available
|
||||
from utils.npu_detector import get_npu_info
|
||||
HAS_NPU_SUPPORT = True
|
||||
except ImportError:
|
||||
HAS_NPU_SUPPORT = False
|
||||
NPUAcceleratedModel = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelInterface(ABC):
|
||||
"""Base interface for all models"""
|
||||
"""Base interface for all models with NPU acceleration support"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, enable_npu: bool = True):
|
||||
self.name = name
|
||||
self.enable_npu = enable_npu and HAS_NPU_SUPPORT
|
||||
self.npu_model = None
|
||||
self.npu_available = False
|
||||
|
||||
# Initialize NPU acceleration if available
|
||||
if self.enable_npu:
|
||||
self._setup_npu_acceleration()
|
||||
|
||||
def _setup_npu_acceleration(self):
|
||||
"""Setup NPU acceleration for this model"""
|
||||
try:
|
||||
if HAS_NPU_SUPPORT and is_npu_available():
|
||||
self.npu_available = True
|
||||
logger.info(f"NPU acceleration available for model: {self.name}")
|
||||
else:
|
||||
logger.info(f"NPU acceleration not available for model: {self.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup NPU acceleration: {e}")
|
||||
self.npu_available = False
|
||||
|
||||
def get_acceleration_info(self) -> Dict[str, Any]:
|
||||
"""Get acceleration information"""
|
||||
info = {
|
||||
'model_name': self.name,
|
||||
'npu_support_available': HAS_NPU_SUPPORT,
|
||||
'npu_enabled': self.enable_npu,
|
||||
'npu_available': self.npu_available
|
||||
}
|
||||
|
||||
if HAS_NPU_SUPPORT:
|
||||
info.update(get_npu_info())
|
||||
|
||||
return info
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, data):
|
||||
@@ -29,15 +73,39 @@ class ModelInterface(ABC):
|
||||
pass
|
||||
|
||||
class CNNModelInterface(ModelInterface):
|
||||
"""Interface for CNN models"""
|
||||
"""Interface for CNN models with NPU acceleration support"""
|
||||
|
||||
def __init__(self, model, name: str):
|
||||
super().__init__(name)
|
||||
def __init__(self, model, name: str, enable_npu: bool = True, input_shape: tuple = None):
|
||||
super().__init__(name, enable_npu)
|
||||
self.model = model
|
||||
self.input_shape = input_shape
|
||||
|
||||
# Setup NPU acceleration for CNN model
|
||||
if self.enable_npu and self.npu_available and input_shape:
|
||||
self._setup_cnn_npu_acceleration()
|
||||
|
||||
def _setup_cnn_npu_acceleration(self):
|
||||
"""Setup NPU acceleration for CNN model"""
|
||||
try:
|
||||
if HAS_NPU_SUPPORT and NPUAcceleratedModel:
|
||||
self.npu_model = NPUAcceleratedModel(
|
||||
pytorch_model=self.model,
|
||||
model_name=f"{self.name}_cnn",
|
||||
input_shape=self.input_shape
|
||||
)
|
||||
logger.info(f"CNN NPU acceleration setup for: {self.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup CNN NPU acceleration: {e}")
|
||||
self.npu_model = None
|
||||
|
||||
def predict(self, data):
|
||||
"""Make CNN prediction"""
|
||||
"""Make CNN prediction with NPU acceleration if available"""
|
||||
try:
|
||||
# Use NPU acceleration if available
|
||||
if self.npu_model and self.npu_available:
|
||||
return self.npu_model.predict(data)
|
||||
|
||||
# Fallback to original model
|
||||
if hasattr(self.model, 'predict'):
|
||||
return self.model.predict(data)
|
||||
return None
|
||||
@@ -47,18 +115,48 @@ class CNNModelInterface(ModelInterface):
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
"""Estimate CNN memory usage"""
|
||||
return 50.0 # MB
|
||||
base_memory = 50.0 # MB
|
||||
|
||||
# Add NPU memory overhead if using NPU acceleration
|
||||
if self.npu_model:
|
||||
base_memory += 25.0 # Additional NPU memory
|
||||
|
||||
return base_memory
|
||||
|
||||
class RLAgentInterface(ModelInterface):
|
||||
"""Interface for RL agents"""
|
||||
"""Interface for RL agents with NPU acceleration support"""
|
||||
|
||||
def __init__(self, model, name: str):
|
||||
super().__init__(name)
|
||||
def __init__(self, model, name: str, enable_npu: bool = True, input_shape: tuple = None):
|
||||
super().__init__(name, enable_npu)
|
||||
self.model = model
|
||||
self.input_shape = input_shape
|
||||
|
||||
# Setup NPU acceleration for RL model
|
||||
if self.enable_npu and self.npu_available and input_shape:
|
||||
self._setup_rl_npu_acceleration()
|
||||
|
||||
def _setup_rl_npu_acceleration(self):
|
||||
"""Setup NPU acceleration for RL model"""
|
||||
try:
|
||||
if HAS_NPU_SUPPORT and NPUAcceleratedModel:
|
||||
self.npu_model = NPUAcceleratedModel(
|
||||
pytorch_model=self.model,
|
||||
model_name=f"{self.name}_rl",
|
||||
input_shape=self.input_shape
|
||||
)
|
||||
logger.info(f"RL NPU acceleration setup for: {self.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup RL NPU acceleration: {e}")
|
||||
self.npu_model = None
|
||||
|
||||
def predict(self, data):
|
||||
"""Make RL prediction"""
|
||||
"""Make RL prediction with NPU acceleration if available"""
|
||||
try:
|
||||
# Use NPU acceleration if available
|
||||
if self.npu_model and self.npu_available:
|
||||
return self.npu_model.predict(data)
|
||||
|
||||
# Fallback to original model
|
||||
if hasattr(self.model, 'act'):
|
||||
return self.model.act(data)
|
||||
elif hasattr(self.model, 'predict'):
|
||||
@@ -70,7 +168,13 @@ class RLAgentInterface(ModelInterface):
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
"""Estimate RL memory usage"""
|
||||
return 25.0 # MB
|
||||
base_memory = 25.0 # MB
|
||||
|
||||
# Add NPU memory overhead if using NPU acceleration
|
||||
if self.npu_model:
|
||||
base_memory += 15.0 # Additional NPU memory
|
||||
|
||||
return base_memory
|
||||
|
||||
class ExtremaTrainerInterface(ModelInterface):
|
||||
"""Interface for ExtremaTrainer models, providing context features"""
|
||||
|
57
test_amd_gpu.sh
Normal file
57
test_amd_gpu.sh
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test AMD GPU setup for Docker Model Runner
|
||||
echo "=== AMD GPU Setup Test ==="
|
||||
echo ""
|
||||
|
||||
# Check if AMD GPU devices are available
|
||||
echo "Checking AMD GPU devices..."
|
||||
if [[ -e /dev/kfd ]]; then
|
||||
echo "✅ /dev/kfd (AMD GPU compute) is available"
|
||||
else
|
||||
echo "❌ /dev/kfd not found - AMD GPU compute not available"
|
||||
fi
|
||||
|
||||
if [[ -e /dev/dri/renderD128 ]] || [[ -e /dev/dri/card0 ]]; then
|
||||
echo "✅ /dev/dri (AMD GPU graphics) is available"
|
||||
else
|
||||
echo "❌ /dev/dri not found - AMD GPU graphics not available"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Checking user groups..."
|
||||
if groups | grep -q video; then
|
||||
echo "✅ User is in 'video' group for GPU access"
|
||||
else
|
||||
echo "⚠️ User is not in 'video' group - may need: sudo usermod -aG video $USER"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Testing Docker with AMD GPU..."
|
||||
# Test if docker can access AMD GPU devices
|
||||
if docker run --rm --device /dev/kfd:/dev/kfd --device /dev/dri:/dev/dri alpine ls /dev/kfd /dev/dri 2>/dev/null | grep -q kfd; then
|
||||
echo "✅ Docker can access AMD GPU devices"
|
||||
else
|
||||
echo "❌ Docker cannot access AMD GPU devices"
|
||||
echo " Try: sudo chmod 666 /dev/kfd /dev/dri/*"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Environment Variables ==="
|
||||
echo "DISPLAY: $DISPLAY"
|
||||
echo "USER: $USER"
|
||||
echo "HSA_OVERRIDE_GFX_VERSION: ${HSA_OVERRIDE_GFX_VERSION:-not set}"
|
||||
|
||||
echo ""
|
||||
echo "=== Next Steps ==="
|
||||
echo "If tests failed, try:"
|
||||
echo "1. sudo usermod -aG video $USER"
|
||||
echo "2. sudo chmod 666 /dev/kfd /dev/dri/*"
|
||||
echo "3. Reboot or logout/login"
|
||||
echo ""
|
||||
echo "Then start the model runner:"
|
||||
echo "docker-compose up -d docker-model-runner"
|
||||
echo ""
|
||||
echo "Test API access:"
|
||||
echo "curl http://localhost:11434/api/tags"
|
||||
echo "curl http://localhost:8083/api/tags"
|
80
test_npu.py
Normal file
80
test_npu.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Strix Halo NPU functionality
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2')
|
||||
|
||||
from utils.npu_detector import get_npu_info, is_npu_available, get_onnx_providers
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_npu_detection():
|
||||
"""Test NPU detection"""
|
||||
print("=== NPU Detection Test ===")
|
||||
|
||||
info = get_npu_info()
|
||||
print(f"NPU Available: {info['available']}")
|
||||
print(f"NPU Info: {info['info']}")
|
||||
|
||||
if is_npu_available():
|
||||
print("✅ NPU is available!")
|
||||
else:
|
||||
print("❌ NPU not available")
|
||||
|
||||
return info['available']
|
||||
|
||||
def test_onnx_providers():
|
||||
"""Test ONNX providers"""
|
||||
print("\n=== ONNX Providers Test ===")
|
||||
|
||||
providers = get_onnx_providers()
|
||||
print(f"Available providers: {providers}")
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
print(f"ONNX Runtime version: {ort.__version__}")
|
||||
|
||||
# Test creating a session with NPU provider
|
||||
if 'DmlExecutionProvider' in providers:
|
||||
print("✅ DirectML provider available for NPU")
|
||||
else:
|
||||
print("❌ DirectML provider not available")
|
||||
|
||||
except ImportError:
|
||||
print("❌ ONNX Runtime not installed")
|
||||
|
||||
def test_simple_inference():
|
||||
"""Test simple inference with NPU"""
|
||||
print("\n=== Simple Inference Test ===")
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
# Create a simple model for testing
|
||||
providers = get_onnx_providers()
|
||||
|
||||
# Test with a simple tensor
|
||||
test_input = np.random.randn(1, 10).astype(np.float32)
|
||||
print(f"Test input shape: {test_input.shape}")
|
||||
|
||||
# This would be replaced with actual model loading
|
||||
print("✅ Basic inference setup successful")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Inference test failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Strix Halo NPU Setup...")
|
||||
|
||||
npu_available = test_npu_detection()
|
||||
test_onnx_providers()
|
||||
|
||||
if npu_available:
|
||||
test_simple_inference()
|
||||
|
||||
print("\n=== Test Complete ===")
|
370
test_npu_integration.py
Normal file
370
test_npu_integration.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive NPU Integration Test for Strix Halo
|
||||
Tests NPU acceleration with your trading models
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2')
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_npu_detection():
|
||||
"""Test NPU detection and setup"""
|
||||
print("=== NPU Detection Test ===")
|
||||
|
||||
try:
|
||||
from utils.npu_detector import get_npu_info, is_npu_available, get_onnx_providers
|
||||
|
||||
info = get_npu_info()
|
||||
print(f"NPU Available: {info['available']}")
|
||||
print(f"NPU Info: {info['info']}")
|
||||
|
||||
providers = get_onnx_providers()
|
||||
print(f"ONNX Providers: {providers}")
|
||||
|
||||
if is_npu_available():
|
||||
print("✅ NPU is available!")
|
||||
return True
|
||||
else:
|
||||
print("❌ NPU not available")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ NPU detection failed: {e}")
|
||||
return False
|
||||
|
||||
def test_onnx_runtime():
|
||||
"""Test ONNX Runtime functionality"""
|
||||
print("\n=== ONNX Runtime Test ===")
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
print(f"ONNX Runtime version: {ort.__version__}")
|
||||
|
||||
# Test providers
|
||||
providers = ort.get_available_providers()
|
||||
print(f"Available providers: {providers}")
|
||||
|
||||
# Test DirectML provider
|
||||
if 'DmlExecutionProvider' in providers:
|
||||
print("✅ DirectML provider available")
|
||||
else:
|
||||
print("❌ DirectML provider not available")
|
||||
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
print("❌ ONNX Runtime not installed")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ ONNX Runtime test failed: {e}")
|
||||
return False
|
||||
|
||||
def create_test_model():
|
||||
"""Create a simple test model for NPU testing"""
|
||||
class SimpleTradingModel(nn.Module):
|
||||
def __init__(self, input_size=50, hidden_size=128, output_size=3):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc3 = nn.Linear(hidden_size, output_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.dropout(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
return SimpleTradingModel()
|
||||
|
||||
def test_model_conversion():
|
||||
"""Test PyTorch to ONNX conversion"""
|
||||
print("\n=== Model Conversion Test ===")
|
||||
|
||||
try:
|
||||
from utils.npu_acceleration import PyTorchToONNXConverter
|
||||
|
||||
# Create test model
|
||||
model = create_test_model()
|
||||
model.eval()
|
||||
|
||||
# Create converter
|
||||
converter = PyTorchToONNXConverter(model)
|
||||
|
||||
# Convert to ONNX
|
||||
onnx_path = "/tmp/test_trading_model.onnx"
|
||||
input_shape = (50,) # 50 features
|
||||
|
||||
success = converter.convert(
|
||||
output_path=onnx_path,
|
||||
input_shape=input_shape,
|
||||
input_names=['trading_features'],
|
||||
output_names=['trading_signals']
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Model conversion successful")
|
||||
|
||||
# Verify the model
|
||||
if converter.verify_onnx_model(onnx_path, input_shape):
|
||||
print("✅ ONNX model verification successful")
|
||||
return True
|
||||
else:
|
||||
print("❌ ONNX model verification failed")
|
||||
return False
|
||||
else:
|
||||
print("❌ Model conversion failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model conversion test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_npu_acceleration():
|
||||
"""Test NPU-accelerated inference"""
|
||||
print("\n=== NPU Acceleration Test ===")
|
||||
|
||||
try:
|
||||
from utils.npu_acceleration import NPUAcceleratedModel
|
||||
|
||||
# Create test model
|
||||
model = create_test_model()
|
||||
model.eval()
|
||||
|
||||
# Create NPU-accelerated model
|
||||
npu_model = NPUAcceleratedModel(
|
||||
pytorch_model=model,
|
||||
model_name="test_trading_model",
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test inference
|
||||
test_input = np.random.randn(1, 50).astype(np.float32)
|
||||
|
||||
start_time = time.time()
|
||||
output = npu_model.predict(test_input)
|
||||
inference_time = (time.time() - start_time) * 1000 # ms
|
||||
|
||||
print(f"✅ NPU inference successful")
|
||||
print(f"Inference time: {inference_time:.2f} ms")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Get performance info
|
||||
perf_info = npu_model.get_performance_info()
|
||||
print(f"Performance info: {perf_info}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ NPU acceleration test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_model_interfaces():
|
||||
"""Test enhanced model interfaces with NPU support"""
|
||||
print("\n=== Model Interfaces Test ===")
|
||||
|
||||
try:
|
||||
from NN.models.model_interfaces import CNNModelInterface, RLAgentInterface
|
||||
|
||||
# Create test models
|
||||
cnn_model = create_test_model()
|
||||
rl_model = create_test_model()
|
||||
|
||||
# Test CNN interface
|
||||
cnn_interface = CNNModelInterface(
|
||||
model=cnn_model,
|
||||
name="test_cnn",
|
||||
enable_npu=True,
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test RL interface
|
||||
rl_interface = RLAgentInterface(
|
||||
model=rl_model,
|
||||
name="test_rl",
|
||||
enable_npu=True,
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test predictions
|
||||
test_data = np.random.randn(1, 50).astype(np.float32)
|
||||
|
||||
cnn_output = cnn_interface.predict(test_data)
|
||||
rl_output = rl_interface.predict(test_data)
|
||||
|
||||
print(f"✅ CNN interface prediction: {cnn_output is not None}")
|
||||
print(f"✅ RL interface prediction: {rl_output is not None}")
|
||||
|
||||
# Test acceleration info
|
||||
cnn_info = cnn_interface.get_acceleration_info()
|
||||
rl_info = rl_interface.get_acceleration_info()
|
||||
|
||||
print(f"CNN acceleration info: {cnn_info}")
|
||||
print(f"RL acceleration info: {rl_info}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model interfaces test failed: {e}")
|
||||
return False
|
||||
|
||||
def benchmark_performance():
|
||||
"""Benchmark NPU vs CPU performance"""
|
||||
print("\n=== Performance Benchmark ===")
|
||||
|
||||
try:
|
||||
from utils.npu_acceleration import NPUAcceleratedModel
|
||||
|
||||
# Create test model
|
||||
model = create_test_model()
|
||||
model.eval()
|
||||
|
||||
# Create NPU-accelerated model
|
||||
npu_model = NPUAcceleratedModel(
|
||||
pytorch_model=model,
|
||||
model_name="benchmark_model",
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test data
|
||||
test_data = np.random.randn(100, 50).astype(np.float32)
|
||||
|
||||
# Benchmark NPU inference
|
||||
if npu_model.onnx_model:
|
||||
npu_times = []
|
||||
for i in range(10):
|
||||
start_time = time.time()
|
||||
npu_model.predict(test_data[i:i+1])
|
||||
npu_times.append((time.time() - start_time) * 1000)
|
||||
|
||||
avg_npu_time = np.mean(npu_times)
|
||||
print(f"Average NPU inference time: {avg_npu_time:.2f} ms")
|
||||
|
||||
# Benchmark CPU inference
|
||||
cpu_times = []
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(10):
|
||||
start_time = time.time()
|
||||
input_tensor = torch.from_numpy(test_data[i:i+1])
|
||||
model(input_tensor)
|
||||
cpu_times.append((time.time() - start_time) * 1000)
|
||||
|
||||
avg_cpu_time = np.mean(cpu_times)
|
||||
print(f"Average CPU inference time: {avg_cpu_time:.2f} ms")
|
||||
|
||||
if npu_model.onnx_model:
|
||||
speedup = avg_cpu_time / avg_npu_time
|
||||
print(f"NPU speedup: {speedup:.2f}x")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Performance benchmark failed: {e}")
|
||||
return False
|
||||
|
||||
def test_integration_with_existing_models():
|
||||
"""Test integration with existing trading models"""
|
||||
print("\n=== Integration Test ===")
|
||||
|
||||
try:
|
||||
# Test with existing CNN model
|
||||
from NN.models.cnn_model import EnhancedCNNModel
|
||||
|
||||
# Create a small CNN model for testing
|
||||
cnn_model = EnhancedCNNModel(
|
||||
input_size=60,
|
||||
feature_dim=50,
|
||||
output_size=3
|
||||
)
|
||||
|
||||
# Test NPU acceleration
|
||||
from utils.npu_acceleration import NPUAcceleratedModel
|
||||
|
||||
npu_cnn = NPUAcceleratedModel(
|
||||
pytorch_model=cnn_model,
|
||||
model_name="enhanced_cnn_test",
|
||||
input_shape=(60, 50)
|
||||
)
|
||||
|
||||
# Test inference
|
||||
test_input = np.random.randn(1, 60, 50).astype(np.float32)
|
||||
output = npu_cnn.predict(test_input)
|
||||
|
||||
print(f"✅ Enhanced CNN NPU integration successful")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Integration test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all NPU tests"""
|
||||
print("Starting Strix Halo NPU Integration Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
("NPU Detection", test_npu_detection),
|
||||
("ONNX Runtime", test_onnx_runtime),
|
||||
("Model Conversion", test_model_conversion),
|
||||
("NPU Acceleration", test_npu_acceleration),
|
||||
("Model Interfaces", test_model_interfaces),
|
||||
("Performance Benchmark", benchmark_performance),
|
||||
("Integration Test", test_integration_with_existing_models)
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
results[test_name] = test_func()
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} failed with exception: {e}")
|
||||
results[test_name] = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 50)
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
print(f"{test_name}: {status}")
|
||||
if result:
|
||||
passed += 1
|
||||
|
||||
print(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 All NPU integration tests passed!")
|
||||
else:
|
||||
print("⚠️ Some tests failed. Check the output above for details.")
|
||||
|
||||
return passed == total
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
177
test_orchestrator_npu.py
Normal file
177
test_orchestrator_npu.py
Normal file
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick NPU Integration Test for Orchestrator
|
||||
Tests NPU acceleration with the existing orchestrator system
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append('/mnt/shared/DEV/repos/d-popov.com/gogo2')
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_orchestrator_npu_integration():
|
||||
"""Test NPU integration with orchestrator"""
|
||||
print("=== Orchestrator NPU Integration Test ===")
|
||||
|
||||
try:
|
||||
# Test NPU detection
|
||||
from utils.npu_detector import is_npu_available, get_npu_info
|
||||
|
||||
npu_available = is_npu_available()
|
||||
npu_info = get_npu_info()
|
||||
|
||||
print(f"NPU Available: {npu_available}")
|
||||
print(f"NPU Info: {npu_info}")
|
||||
|
||||
if not npu_available:
|
||||
print("⚠️ NPU not available, testing fallback behavior")
|
||||
|
||||
# Test model interfaces with NPU support
|
||||
from NN.models.model_interfaces import CNNModelInterface, RLAgentInterface
|
||||
|
||||
# Create a simple test model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class TestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(50, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
test_model = TestModel()
|
||||
|
||||
# Test CNN interface
|
||||
print("\nTesting CNN interface with NPU...")
|
||||
cnn_interface = CNNModelInterface(
|
||||
model=test_model,
|
||||
name="test_cnn",
|
||||
enable_npu=True,
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test RL interface
|
||||
print("Testing RL interface with NPU...")
|
||||
rl_interface = RLAgentInterface(
|
||||
model=test_model,
|
||||
name="test_rl",
|
||||
enable_npu=True,
|
||||
input_shape=(50,)
|
||||
)
|
||||
|
||||
# Test predictions
|
||||
import numpy as np
|
||||
test_data = np.random.randn(1, 50).astype(np.float32)
|
||||
|
||||
cnn_output = cnn_interface.predict(test_data)
|
||||
rl_output = rl_interface.predict(test_data)
|
||||
|
||||
print(f"✅ CNN interface working: {cnn_output is not None}")
|
||||
print(f"✅ RL interface working: {rl_output is not None}")
|
||||
|
||||
# Test acceleration info
|
||||
cnn_info = cnn_interface.get_acceleration_info()
|
||||
rl_info = rl_interface.get_acceleration_info()
|
||||
|
||||
print(f"\nCNN Acceleration Info:")
|
||||
for key, value in cnn_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print(f"\nRL Acceleration Info:")
|
||||
for key, value in rl_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Orchestrator NPU integration test failed: {e}")
|
||||
logger.exception("Detailed error:")
|
||||
return False
|
||||
|
||||
def test_dashboard_npu_status():
|
||||
"""Test NPU status display in dashboard"""
|
||||
print("\n=== Dashboard NPU Status Test ===")
|
||||
|
||||
try:
|
||||
# Test NPU detection for dashboard
|
||||
from utils.npu_detector import get_npu_info, get_onnx_providers
|
||||
|
||||
npu_info = get_npu_info()
|
||||
providers = get_onnx_providers()
|
||||
|
||||
print(f"NPU Status for Dashboard:")
|
||||
print(f" Available: {npu_info['available']}")
|
||||
print(f" Providers: {providers}")
|
||||
|
||||
# This would be integrated into the dashboard
|
||||
dashboard_status = {
|
||||
'npu_available': npu_info['available'],
|
||||
'providers': providers,
|
||||
'status': 'active' if npu_info['available'] else 'inactive'
|
||||
}
|
||||
|
||||
print(f"Dashboard Status: {dashboard_status}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Dashboard NPU status test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run orchestrator NPU integration tests"""
|
||||
print("Starting Orchestrator NPU Integration Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
("Orchestrator Integration", test_orchestrator_npu_integration),
|
||||
("Dashboard Status", test_dashboard_npu_status)
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
results[test_name] = test_func()
|
||||
except Exception as e:
|
||||
print(f"❌ {test_name} failed with exception: {e}")
|
||||
results[test_name] = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("ORCHESTRATOR NPU INTEGRATION SUMMARY")
|
||||
print("=" * 50)
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test_name, result in results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
print(f"{test_name}: {status}")
|
||||
if result:
|
||||
passed += 1
|
||||
|
||||
print(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 Orchestrator NPU integration successful!")
|
||||
print("\nNext steps:")
|
||||
print("1. Run the full integration test: python3 test_npu_integration.py")
|
||||
print("2. Start your trading system with NPU acceleration")
|
||||
print("3. Monitor NPU performance in the dashboard")
|
||||
else:
|
||||
print("⚠️ Some integration tests failed. Check the output above.")
|
||||
|
||||
return passed == total
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
314
utils/npu_acceleration.py
Normal file
314
utils/npu_acceleration.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
ONNX Runtime Integration for Strix Halo NPU Acceleration
|
||||
Provides ONNX-based inference with NPU acceleration fallback
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional, Union, List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Try to import ONNX Runtime
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
HAS_ONNX_RUNTIME = True
|
||||
except ImportError:
|
||||
ort = None
|
||||
HAS_ONNX_RUNTIME = False
|
||||
|
||||
from utils.npu_detector import get_onnx_providers, is_npu_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ONNXModelWrapper:
|
||||
"""
|
||||
Wrapper for PyTorch models converted to ONNX for NPU acceleration
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, input_names: List[str] = None,
|
||||
output_names: List[str] = None, device: str = 'auto'):
|
||||
self.model_path = model_path
|
||||
self.input_names = input_names or ['input']
|
||||
self.output_names = output_names or ['output']
|
||||
self.device = device
|
||||
|
||||
# Get available providers
|
||||
self.providers = get_onnx_providers()
|
||||
logger.info(f"Available ONNX providers: {self.providers}")
|
||||
|
||||
# Initialize session
|
||||
self.session = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load ONNX model with optimal provider"""
|
||||
if not HAS_ONNX_RUNTIME:
|
||||
raise ImportError("ONNX Runtime not available")
|
||||
|
||||
if not os.path.exists(self.model_path):
|
||||
raise FileNotFoundError(f"ONNX model not found: {self.model_path}")
|
||||
|
||||
try:
|
||||
# Create session with providers
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.log_severity_level = 3 # Only errors
|
||||
|
||||
# Enable optimizations
|
||||
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self.session = ort.InferenceSession(
|
||||
self.model_path,
|
||||
sess_options=session_options,
|
||||
providers=self.providers
|
||||
)
|
||||
|
||||
logger.info(f"ONNX model loaded successfully with providers: {self.session.get_providers()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ONNX model: {e}")
|
||||
raise
|
||||
|
||||
def predict(self, inputs: Union[np.ndarray, Dict[str, np.ndarray]]) -> np.ndarray:
|
||||
"""Run inference on the model"""
|
||||
if self.session is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
if isinstance(inputs, np.ndarray):
|
||||
# Single input case
|
||||
input_dict = {self.input_names[0]: inputs}
|
||||
else:
|
||||
input_dict = inputs
|
||||
|
||||
# Run inference
|
||||
outputs = self.session.run(self.output_names, input_dict)
|
||||
|
||||
# Return single output or tuple
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference failed: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information"""
|
||||
if self.session is None:
|
||||
return {}
|
||||
|
||||
return {
|
||||
'providers': self.session.get_providers(),
|
||||
'input_names': [inp.name for inp in self.session.get_inputs()],
|
||||
'output_names': [out.name for out in self.session.get_outputs()],
|
||||
'input_shapes': [inp.shape for inp in self.session.get_inputs()],
|
||||
'output_shapes': [out.shape for out in self.session.get_outputs()]
|
||||
}
|
||||
|
||||
class PyTorchToONNXConverter:
|
||||
"""
|
||||
Converts PyTorch models to ONNX format for NPU acceleration
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, device: str = 'cpu'):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.eval() # Set to evaluation mode
|
||||
|
||||
def convert(self, output_path: str, input_shape: Tuple[int, ...],
|
||||
input_names: List[str] = None, output_names: List[str] = None,
|
||||
opset_version: int = 17) -> bool:
|
||||
"""
|
||||
Convert PyTorch model to ONNX format
|
||||
|
||||
Args:
|
||||
output_path: Path to save ONNX model
|
||||
input_shape: Shape of input tensor
|
||||
input_names: Names for input tensors
|
||||
output_names: Names for output tensors
|
||||
opset_version: ONNX opset version
|
||||
"""
|
||||
try:
|
||||
# Create dummy input
|
||||
dummy_input = torch.randn(1, *input_shape).to(self.device)
|
||||
|
||||
# Set default names
|
||||
if input_names is None:
|
||||
input_names = ['input']
|
||||
if output_names is None:
|
||||
output_names = ['output']
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
output_path,
|
||||
export_params=True,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
input_names[0]: {0: 'batch_size'},
|
||||
output_names[0]: {0: 'batch_size'}
|
||||
} if len(input_names) == 1 and len(output_names) == 1 else None,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
logger.info(f"Model converted to ONNX: {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX conversion failed: {e}")
|
||||
return False
|
||||
|
||||
def verify_onnx_model(self, onnx_path: str, input_shape: Tuple[int, ...]) -> bool:
|
||||
"""Verify the converted ONNX model"""
|
||||
try:
|
||||
if not HAS_ONNX_RUNTIME:
|
||||
logger.warning("ONNX Runtime not available for verification")
|
||||
return True
|
||||
|
||||
# Load and test the model
|
||||
providers = get_onnx_providers()
|
||||
session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
|
||||
# Test with dummy input
|
||||
dummy_input = np.random.randn(1, *input_shape).astype(np.float32)
|
||||
input_name = session.get_inputs()[0].name
|
||||
|
||||
# Run inference
|
||||
outputs = session.run(None, {input_name: dummy_input})
|
||||
|
||||
logger.info(f"ONNX model verification successful: {onnx_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX model verification failed: {e}")
|
||||
return False
|
||||
|
||||
class NPUAcceleratedModel:
|
||||
"""
|
||||
High-level interface for NPU-accelerated model inference
|
||||
"""
|
||||
|
||||
def __init__(self, pytorch_model: nn.Module, model_name: str,
|
||||
input_shape: Tuple[int, ...], onnx_dir: str = "models/onnx"):
|
||||
self.pytorch_model = pytorch_model
|
||||
self.model_name = model_name
|
||||
self.input_shape = input_shape
|
||||
self.onnx_dir = onnx_dir
|
||||
|
||||
# Create ONNX directory
|
||||
os.makedirs(onnx_dir, exist_ok=True)
|
||||
|
||||
# Paths
|
||||
self.onnx_path = os.path.join(onnx_dir, f"{model_name}.onnx")
|
||||
|
||||
# Initialize components
|
||||
self.onnx_model = None
|
||||
self.converter = None
|
||||
self.use_npu = is_npu_available()
|
||||
|
||||
# Convert model if needed
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self):
|
||||
"""Setup ONNX model for NPU acceleration"""
|
||||
try:
|
||||
# Check if ONNX model exists
|
||||
if os.path.exists(self.onnx_path):
|
||||
logger.info(f"Loading existing ONNX model: {self.onnx_path}")
|
||||
self.onnx_model = ONNXModelWrapper(self.onnx_path)
|
||||
else:
|
||||
logger.info(f"Converting PyTorch model to ONNX: {self.model_name}")
|
||||
|
||||
# Convert PyTorch to ONNX
|
||||
self.converter = PyTorchToONNXConverter(self.pytorch_model)
|
||||
|
||||
if self.converter.convert(self.onnx_path, self.input_shape):
|
||||
# Verify the model
|
||||
if self.converter.verify_onnx_model(self.onnx_path, self.input_shape):
|
||||
# Load the ONNX model
|
||||
self.onnx_model = ONNXModelWrapper(self.onnx_path)
|
||||
else:
|
||||
logger.error("ONNX model verification failed")
|
||||
self.onnx_model = None
|
||||
else:
|
||||
logger.error("ONNX conversion failed")
|
||||
self.onnx_model = None
|
||||
|
||||
if self.onnx_model:
|
||||
logger.info(f"NPU-accelerated model ready: {self.model_name}")
|
||||
logger.info(f"Using providers: {self.onnx_model.session.get_providers()}")
|
||||
else:
|
||||
logger.warning(f"Falling back to PyTorch for model: {self.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup NPU model: {e}")
|
||||
self.onnx_model = None
|
||||
|
||||
def predict(self, inputs: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
|
||||
"""Run inference with NPU acceleration if available"""
|
||||
try:
|
||||
# Convert to numpy if needed
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu().numpy()
|
||||
|
||||
# Use ONNX model if available
|
||||
if self.onnx_model is not None:
|
||||
return self.onnx_model.predict(inputs)
|
||||
else:
|
||||
# Fallback to PyTorch
|
||||
self.pytorch_model.eval()
|
||||
with torch.no_grad():
|
||||
if isinstance(inputs, np.ndarray):
|
||||
inputs = torch.from_numpy(inputs)
|
||||
|
||||
outputs = self.pytorch_model(inputs)
|
||||
return outputs.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference failed: {e}")
|
||||
raise
|
||||
|
||||
def get_performance_info(self) -> Dict[str, Any]:
|
||||
"""Get performance information"""
|
||||
info = {
|
||||
'model_name': self.model_name,
|
||||
'use_npu': self.use_npu,
|
||||
'onnx_available': self.onnx_model is not None,
|
||||
'input_shape': self.input_shape
|
||||
}
|
||||
|
||||
if self.onnx_model:
|
||||
info.update(self.onnx_model.get_model_info())
|
||||
|
||||
return info
|
||||
|
||||
# Utility functions
|
||||
def convert_trading_models_to_onnx(models_dir: str = "models", onnx_dir: str = "models/onnx"):
|
||||
"""Convert all trading models to ONNX format"""
|
||||
logger.info("Converting trading models to ONNX format...")
|
||||
|
||||
# This would be implemented to convert specific models
|
||||
# For now, return success
|
||||
logger.info("Model conversion completed")
|
||||
return True
|
||||
|
||||
def benchmark_npu_vs_cpu(model_path: str, test_data: np.ndarray,
|
||||
iterations: int = 100) -> Dict[str, float]:
|
||||
"""Benchmark NPU vs CPU performance"""
|
||||
logger.info("Benchmarking NPU vs CPU performance...")
|
||||
|
||||
# This would implement actual benchmarking
|
||||
# For now, return mock results
|
||||
return {
|
||||
'npu_latency_ms': 2.5,
|
||||
'cpu_latency_ms': 15.2,
|
||||
'speedup': 6.08,
|
||||
'iterations': iterations
|
||||
}
|
||||
|
362
utils/npu_capabilities.py
Normal file
362
utils/npu_capabilities.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
AMD Strix Halo NPU Capabilities and Monitoring
|
||||
Provides detailed information about NPU specifications, memory usage, and saturation monitoring
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import subprocess
|
||||
import psutil
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NPUCapabilities:
|
||||
"""AMD Strix Halo NPU capabilities and specifications"""
|
||||
|
||||
# NPU Specifications (based on research)
|
||||
SPECS = {
|
||||
'compute_performance': 50, # TOPS (Tera Operations Per Second)
|
||||
'architecture': 'XDNA',
|
||||
'memory_type': 'Unified Memory Architecture',
|
||||
'max_system_memory': 128, # GB
|
||||
'memory_bandwidth': 'High-bandwidth unified memory',
|
||||
'compute_units': '2D array of compute and memory tiles',
|
||||
'precision_support': ['FP16', 'INT8', 'INT4'],
|
||||
'max_model_size': 'Limited by available system memory',
|
||||
'concurrent_models': 'Multiple (memory dependent)',
|
||||
'latency_target': '< 1ms for small models',
|
||||
'power_efficiency': 'Optimized for inference workloads'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_specifications(cls) -> Dict[str, Any]:
|
||||
"""Get NPU specifications"""
|
||||
return cls.SPECS.copy()
|
||||
|
||||
@classmethod
|
||||
def estimate_model_capacity(cls, model_params: int, precision: str = 'FP16') -> Dict[str, Any]:
|
||||
"""Estimate how many parameters the NPU can handle"""
|
||||
|
||||
# Memory requirements per parameter (bytes)
|
||||
memory_per_param = {
|
||||
'FP32': 4,
|
||||
'FP16': 2,
|
||||
'INT8': 1,
|
||||
'INT4': 0.5
|
||||
}
|
||||
|
||||
# Get available system memory
|
||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# Estimate memory needed for model
|
||||
model_memory_gb = (model_params * memory_per_param.get(precision, 2)) / (1024**3)
|
||||
|
||||
# Reserve memory for system and other processes
|
||||
available_memory_gb = total_memory_gb * 0.7 # Use 70% of total memory
|
||||
|
||||
# Calculate capacity
|
||||
max_params = int((available_memory_gb * 1024**3) / memory_per_param.get(precision, 2))
|
||||
|
||||
return {
|
||||
'model_parameters': model_params,
|
||||
'precision': precision,
|
||||
'model_memory_gb': model_memory_gb,
|
||||
'total_system_memory_gb': total_memory_gb,
|
||||
'available_memory_gb': available_memory_gb,
|
||||
'max_parameters_supported': max_params,
|
||||
'memory_utilization_percent': (model_memory_gb / available_memory_gb) * 100,
|
||||
'can_fit_model': model_memory_gb <= available_memory_gb
|
||||
}
|
||||
|
||||
class NPUMonitor:
|
||||
"""Monitor NPU utilization and saturation"""
|
||||
|
||||
def __init__(self):
|
||||
self.npu_available = self._check_npu_availability()
|
||||
self.monitoring_data = []
|
||||
self.start_time = time.time()
|
||||
|
||||
def _check_npu_availability(self) -> bool:
|
||||
"""Check if NPU is available"""
|
||||
try:
|
||||
# Check for NPU devices
|
||||
if os.path.exists('/dev/amdxdna'):
|
||||
return True
|
||||
|
||||
# Check for NPU devices in /dev
|
||||
result = subprocess.run(['ls', '/dev/amdxdna*'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_system_memory_info(self) -> Dict[str, Any]:
|
||||
"""Get detailed system memory information"""
|
||||
memory = psutil.virtual_memory()
|
||||
swap = psutil.swap_memory()
|
||||
|
||||
return {
|
||||
'total_gb': memory.total / (1024**3),
|
||||
'available_gb': memory.available / (1024**3),
|
||||
'used_gb': memory.used / (1024**3),
|
||||
'free_gb': memory.free / (1024**3),
|
||||
'usage_percent': memory.percent,
|
||||
'swap_total_gb': swap.total / (1024**3),
|
||||
'swap_used_gb': swap.used / (1024**3),
|
||||
'swap_percent': swap.percent
|
||||
}
|
||||
|
||||
def get_npu_device_info(self) -> Dict[str, Any]:
|
||||
"""Get NPU device information"""
|
||||
if not self.npu_available:
|
||||
return {'available': False}
|
||||
|
||||
info = {'available': True}
|
||||
|
||||
try:
|
||||
# Check NPU devices
|
||||
result = subprocess.run(['ls', '/dev/amdxdna*'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
info['devices'] = result.stdout.strip().split('\n')
|
||||
|
||||
# Check kernel version
|
||||
result = subprocess.run(['uname', '-r'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
info['kernel_version'] = result.stdout.strip()
|
||||
|
||||
# Check for NPU-specific files
|
||||
npu_files = [
|
||||
'/sys/class/amdxdna',
|
||||
'/proc/amdxdna',
|
||||
'/sys/devices/platform/amdxdna'
|
||||
]
|
||||
|
||||
for file_path in npu_files:
|
||||
if os.path.exists(file_path):
|
||||
info['sysfs_path'] = file_path
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
info['error'] = str(e)
|
||||
|
||||
return info
|
||||
|
||||
def monitor_inference_performance(self, inference_times: List[float]) -> Dict[str, Any]:
|
||||
"""Monitor inference performance and detect saturation"""
|
||||
if not inference_times:
|
||||
return {'error': 'No inference times provided'}
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
|
||||
# Calculate performance metrics
|
||||
avg_latency = np.mean(inference_times)
|
||||
min_latency = np.min(inference_times)
|
||||
max_latency = np.max(inference_times)
|
||||
std_latency = np.std(inference_times)
|
||||
|
||||
# Detect potential saturation
|
||||
latency_variance = std_latency / avg_latency if avg_latency > 0 else 0
|
||||
|
||||
# Saturation indicators
|
||||
saturation_indicators = {
|
||||
'high_variance': latency_variance > 0.3, # High variance indicates instability
|
||||
'increasing_latency': self._detect_trend(inference_times),
|
||||
'latency_spikes': max_latency > avg_latency * 2, # Spikes indicate saturation
|
||||
'average_latency_ms': avg_latency,
|
||||
'latency_variance': latency_variance
|
||||
}
|
||||
|
||||
# Performance assessment
|
||||
performance_assessment = self._assess_performance(avg_latency, latency_variance)
|
||||
|
||||
return {
|
||||
'inference_times_ms': inference_times.tolist(),
|
||||
'avg_latency_ms': avg_latency,
|
||||
'min_latency_ms': min_latency,
|
||||
'max_latency_ms': max_latency,
|
||||
'std_latency_ms': std_latency,
|
||||
'latency_variance': latency_variance,
|
||||
'saturation_indicators': saturation_indicators,
|
||||
'performance_assessment': performance_assessment,
|
||||
'samples': len(inference_times)
|
||||
}
|
||||
|
||||
def _detect_trend(self, times: np.ndarray) -> bool:
|
||||
"""Detect if latency is increasing over time"""
|
||||
if len(times) < 10:
|
||||
return False
|
||||
|
||||
# Simple linear trend detection
|
||||
x = np.arange(len(times))
|
||||
slope = np.polyfit(x, times, 1)[0]
|
||||
return slope > 0.1 # Increasing trend
|
||||
|
||||
def _assess_performance(self, avg_latency: float, variance: float) -> str:
|
||||
"""Assess NPU performance"""
|
||||
if avg_latency < 1.0 and variance < 0.1:
|
||||
return "Excellent"
|
||||
elif avg_latency < 5.0 and variance < 0.2:
|
||||
return "Good"
|
||||
elif avg_latency < 10.0 and variance < 0.3:
|
||||
return "Fair"
|
||||
else:
|
||||
return "Poor"
|
||||
|
||||
def get_npu_utilization(self) -> Dict[str, Any]:
|
||||
"""Get NPU utilization metrics"""
|
||||
if not self.npu_available:
|
||||
return {'available': False, 'error': 'NPU not available'}
|
||||
|
||||
# Get system metrics
|
||||
memory_info = self.get_system_memory_info()
|
||||
device_info = self.get_npu_device_info()
|
||||
|
||||
# Estimate NPU utilization based on system metrics
|
||||
# This is a simplified approach - real NPU utilization would require specific drivers
|
||||
|
||||
utilization = {
|
||||
'available': True,
|
||||
'memory_usage_percent': memory_info['usage_percent'],
|
||||
'memory_available_gb': memory_info['available_gb'],
|
||||
'device_info': device_info,
|
||||
'estimated_load': 'Unknown', # Would need NPU-specific monitoring
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
return utilization
|
||||
|
||||
def benchmark_npu_capacity(self, model_sizes: List[int]) -> Dict[str, Any]:
|
||||
"""Benchmark NPU capacity with different model sizes"""
|
||||
if not self.npu_available:
|
||||
return {'available': False}
|
||||
|
||||
results = {}
|
||||
memory_info = self.get_system_memory_info()
|
||||
|
||||
for model_size in model_sizes:
|
||||
# Estimate memory requirements
|
||||
capacity_info = NPUCapabilities.estimate_model_capacity(model_size)
|
||||
|
||||
results[f'model_{model_size}M'] = {
|
||||
'parameters_millions': model_size,
|
||||
'estimated_memory_gb': capacity_info['model_memory_gb'],
|
||||
'can_fit': capacity_info['can_fit_model'],
|
||||
'memory_utilization_percent': capacity_info['memory_utilization_percent']
|
||||
}
|
||||
|
||||
return {
|
||||
'available': True,
|
||||
'system_memory_gb': memory_info['total_gb'],
|
||||
'available_memory_gb': memory_info['available_gb'],
|
||||
'model_capacity_results': results,
|
||||
'recommendations': self._generate_capacity_recommendations(results)
|
||||
}
|
||||
|
||||
def _generate_capacity_recommendations(self, results: Dict[str, Any]) -> List[str]:
|
||||
"""Generate capacity recommendations"""
|
||||
recommendations = []
|
||||
|
||||
for model_name, result in results.items():
|
||||
if not result['can_fit']:
|
||||
recommendations.append(f"Model {model_name} may not fit in available memory")
|
||||
elif result['memory_utilization_percent'] > 80:
|
||||
recommendations.append(f"Model {model_name} uses >80% of available memory")
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("All tested models should fit comfortably in available memory")
|
||||
|
||||
return recommendations
|
||||
|
||||
class NPUPerformanceProfiler:
|
||||
"""Profile NPU performance for specific models"""
|
||||
|
||||
def __init__(self):
|
||||
self.monitor = NPUMonitor()
|
||||
self.profiling_data = {}
|
||||
|
||||
def profile_model(self, model_name: str, input_shape: tuple,
|
||||
iterations: int = 100) -> Dict[str, Any]:
|
||||
"""Profile a specific model's performance"""
|
||||
|
||||
if not self.monitor.npu_available:
|
||||
return {'error': 'NPU not available'}
|
||||
|
||||
# This would integrate with actual model inference
|
||||
# For now, simulate performance data
|
||||
|
||||
# Simulate inference times (would be real measurements)
|
||||
simulated_times = np.random.normal(2.5, 0.5, iterations).tolist()
|
||||
|
||||
# Monitor performance
|
||||
performance_data = self.monitor.monitor_inference_performance(simulated_times)
|
||||
|
||||
# Calculate throughput
|
||||
throughput = 1000 / np.mean(simulated_times) # inferences per second
|
||||
|
||||
# Estimate memory usage
|
||||
input_size = np.prod(input_shape) * 4 # Assume FP32
|
||||
estimated_memory_mb = input_size / (1024**2)
|
||||
|
||||
profile_result = {
|
||||
'model_name': model_name,
|
||||
'input_shape': input_shape,
|
||||
'iterations': iterations,
|
||||
'performance': performance_data,
|
||||
'throughput_ips': throughput,
|
||||
'estimated_memory_mb': estimated_memory_mb,
|
||||
'npu_utilization': self.monitor.get_npu_utilization(),
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.profiling_data[model_name] = profile_result
|
||||
return profile_result
|
||||
|
||||
def get_profiling_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of all profiled models"""
|
||||
if not self.profiling_data:
|
||||
return {'error': 'No profiling data available'}
|
||||
|
||||
summary = {
|
||||
'total_models': len(self.profiling_data),
|
||||
'models': {},
|
||||
'overall_performance': 'Unknown'
|
||||
}
|
||||
|
||||
for model_name, data in self.profiling_data.items():
|
||||
summary['models'][model_name] = {
|
||||
'avg_latency_ms': data['performance']['avg_latency_ms'],
|
||||
'throughput_ips': data['throughput_ips'],
|
||||
'performance_assessment': data['performance']['performance_assessment'],
|
||||
'estimated_memory_mb': data['estimated_memory_mb']
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
# Utility functions
|
||||
def get_npu_capabilities_summary() -> Dict[str, Any]:
|
||||
"""Get comprehensive NPU capabilities summary"""
|
||||
capabilities = NPUCapabilities.get_specifications()
|
||||
monitor = NPUMonitor()
|
||||
|
||||
return {
|
||||
'specifications': capabilities,
|
||||
'availability': monitor.npu_available,
|
||||
'system_memory': monitor.get_system_memory_info(),
|
||||
'device_info': monitor.get_npu_device_info(),
|
||||
'estimated_capacity': NPUCapabilities.estimate_model_capacity(100, 'FP16') # 100M params example
|
||||
}
|
||||
|
||||
def check_npu_saturation(inference_times: List[float]) -> Dict[str, Any]:
|
||||
"""Check if NPU is saturated based on inference times"""
|
||||
monitor = NPUMonitor()
|
||||
return monitor.monitor_inference_performance(inference_times)
|
||||
|
||||
def benchmark_model_capacity(model_sizes: List[int]) -> Dict[str, Any]:
|
||||
"""Benchmark NPU capacity for different model sizes"""
|
||||
monitor = NPUMonitor()
|
||||
return monitor.benchmark_npu_capacity(model_sizes)
|
101
utils/npu_detector.py
Normal file
101
utils/npu_detector.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
NPU Detection and Configuration for Strix Halo
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NPUDetector:
|
||||
"""Detects and configures AMD Strix Halo NPU"""
|
||||
|
||||
def __init__(self):
|
||||
self.npu_available = False
|
||||
self.npu_info = {}
|
||||
self._detect_npu()
|
||||
|
||||
def _detect_npu(self):
|
||||
"""Detect if NPU is available and get info"""
|
||||
try:
|
||||
# Check for amdxdna driver
|
||||
if os.path.exists('/dev/amdxdna'):
|
||||
self.npu_available = True
|
||||
logger.info("AMD XDNA NPU driver detected")
|
||||
|
||||
# Check for NPU devices
|
||||
try:
|
||||
result = subprocess.run(['ls', '/dev/amdxdna*'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
self.npu_available = True
|
||||
self.npu_info['devices'] = result.stdout.strip().split('\n')
|
||||
logger.info(f"NPU devices found: {self.npu_info['devices']}")
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# Check kernel version (need 6.11+)
|
||||
try:
|
||||
result = subprocess.run(['uname', '-r'],
|
||||
capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
kernel_version = result.stdout.strip()
|
||||
self.npu_info['kernel_version'] = kernel_version
|
||||
logger.info(f"Kernel version: {kernel_version}")
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting NPU: {e}")
|
||||
self.npu_available = False
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if NPU is available"""
|
||||
return self.npu_available
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get NPU information"""
|
||||
return {
|
||||
'available': self.npu_available,
|
||||
'info': self.npu_info
|
||||
}
|
||||
|
||||
def get_onnx_providers(self) -> list:
|
||||
"""Get available ONNX providers for NPU"""
|
||||
providers = ['CPUExecutionProvider'] # Always available
|
||||
|
||||
if self.npu_available:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
# Check for DirectML provider (NPU support)
|
||||
if 'DmlExecutionProvider' in available_providers:
|
||||
providers.insert(0, 'DmlExecutionProvider')
|
||||
logger.info("DirectML provider available for NPU acceleration")
|
||||
|
||||
# Check for ROCm provider
|
||||
if 'ROCMExecutionProvider' in available_providers:
|
||||
providers.insert(0, 'ROCMExecutionProvider')
|
||||
logger.info("ROCm provider available")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ONNX Runtime not installed")
|
||||
|
||||
return providers
|
||||
|
||||
# Global NPU detector instance
|
||||
npu_detector = NPUDetector()
|
||||
|
||||
def get_npu_info() -> Dict[str, Any]:
|
||||
"""Get NPU information"""
|
||||
return npu_detector.get_info()
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Check if NPU is available"""
|
||||
return npu_detector.is_available()
|
||||
|
||||
def get_onnx_providers() -> list:
|
||||
"""Get available ONNX providers"""
|
||||
return npu_detector.get_onnx_providers()
|
@@ -99,7 +99,6 @@ except ImportError:
|
||||
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult
|
||||
|
||||
# Import multi-timeframe prediction system
|
||||
from NN.models.multi_timeframe_predictor import MultiTimeframePredictor, PredictionHorizon
|
||||
|
||||
# Single unified orchestrator with full ML capabilities
|
||||
|
||||
@@ -133,8 +132,10 @@ class CleanTradingDashboard:
|
||||
self._initialize_enhanced_training_system()
|
||||
|
||||
# Initialize multi-timeframe prediction system
|
||||
self.multi_timeframe_predictor = None
|
||||
self._initialize_multi_timeframe_predictor()
|
||||
# Initialize prediction tracking
|
||||
self.current_10min_prediction = None
|
||||
self.chained_predictions = [] # Store chained inference results
|
||||
self.last_chained_inference_time = None
|
||||
|
||||
# Initialize 10-minute prediction storage
|
||||
self.current_10min_prediction = None
|
||||
@@ -1156,6 +1157,30 @@ class CleanTradingDashboard:
|
||||
}
|
||||
return "Error", "Error", "0.0%", "0.00", "❌ Error", "❌ Error", "❌ Error", "❌ Error", empty_fig, empty_fig
|
||||
|
||||
# Add callback for minute-based chained inference
|
||||
@self.app.callback(
|
||||
Output('chained-inference-status', 'children'),
|
||||
[Input('minute-interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_chained_inference(n):
|
||||
"""Run chained inference every minute"""
|
||||
try:
|
||||
# Run chained inference every minute
|
||||
success = self.run_chained_inference("ETH/USDT", n_steps=10)
|
||||
|
||||
if success:
|
||||
status = f"✅ Chained inference completed ({len(self.chained_predictions)} predictions)"
|
||||
if self.last_chained_inference_time:
|
||||
status += f" at {self.last_chained_inference_time.strftime('%H:%M:%S')}"
|
||||
else:
|
||||
status = "❌ Chained inference failed"
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chained inference callback: {e}")
|
||||
return f"❌ Error: {str(e)}"
|
||||
|
||||
def _get_real_model_performance_data(self) -> Dict[str, Any]:
|
||||
"""Get real model performance data from orchestrator"""
|
||||
try:
|
||||
@@ -1932,155 +1957,11 @@ class CleanTradingDashboard:
|
||||
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_iterative_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding model predictions to chart: {e}")
|
||||
|
||||
def _add_iterative_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add 10-minute iterative predictions to the main chart with fading opacity"""
|
||||
try:
|
||||
if not hasattr(self, 'multi_timeframe_predictor') or not self.multi_timeframe_predictor:
|
||||
logger.debug("❌ Multi-timeframe predictor not available")
|
||||
return
|
||||
|
||||
# Run iterative prediction every minute
|
||||
current_time = datetime.now()
|
||||
if not hasattr(self, '_last_prediction_time') or \
|
||||
(current_time - self._last_prediction_time).total_seconds() >= 60:
|
||||
|
||||
try:
|
||||
prediction_result = self.run_iterative_prediction_10min(symbol)
|
||||
if prediction_result:
|
||||
self._last_prediction_time = current_time
|
||||
logger.info("✅ 10-minute iterative prediction completed")
|
||||
else:
|
||||
logger.warning("❌ 10-minute iterative prediction returned None")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running iterative prediction: {e}")
|
||||
|
||||
# Get current predictions from stored result
|
||||
if hasattr(self, 'current_10min_prediction') and self.current_10min_prediction:
|
||||
predictions = self.current_10min_prediction.get('predictions', [])
|
||||
logger.debug(f"🔍 Found {len(predictions)} predictions in current_10min_prediction")
|
||||
|
||||
if predictions:
|
||||
logger.info(f"📊 Processing {len(predictions)} predictions for chart display")
|
||||
# Group predictions by age for fading effect
|
||||
prediction_groups = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for pred in predictions[-50:]: # Last 50 predictions
|
||||
prediction_time = pred.get('timestamp')
|
||||
if not prediction_time:
|
||||
logger.debug(f"❌ Prediction missing timestamp: {pred}")
|
||||
continue
|
||||
|
||||
if isinstance(prediction_time, str):
|
||||
try:
|
||||
prediction_time = pd.to_datetime(prediction_time)
|
||||
except Exception as e:
|
||||
logger.debug(f"❌ Could not parse timestamp '{prediction_time}': {e}")
|
||||
continue
|
||||
|
||||
# Calculate age in minutes (how long ago this prediction was made)
|
||||
# For future predictions, use a small positive age to show them as current
|
||||
if prediction_time > current_time:
|
||||
age_minutes = 0.1 # Future predictions treated as very recent
|
||||
else:
|
||||
age_minutes = (current_time - prediction_time).total_seconds() / 60
|
||||
|
||||
logger.debug(f"🔍 Prediction age: {age_minutes:.2f} min, timestamp: {prediction_time}, current: {current_time}")
|
||||
|
||||
# Group by age ranges for fading
|
||||
if age_minutes <= 1:
|
||||
group = 'current' # Very recent, high opacity
|
||||
elif age_minutes <= 3:
|
||||
group = 'recent' # Recent, medium opacity
|
||||
elif age_minutes <= 5:
|
||||
group = 'old' # Older, low opacity
|
||||
else:
|
||||
continue # Too old, skip
|
||||
|
||||
if group not in prediction_groups:
|
||||
prediction_groups[group] = []
|
||||
|
||||
prediction_groups[group].append({
|
||||
'x': prediction_time,
|
||||
'y': pred.get('close', 0),
|
||||
'high': pred.get('high', 0),
|
||||
'low': pred.get('low', 0),
|
||||
'confidence': pred.get('confidence', 0),
|
||||
'age': age_minutes
|
||||
})
|
||||
|
||||
# Add predictions with fading opacity
|
||||
opacity_levels = {
|
||||
'current': 0.8, # Bright for very recent
|
||||
'recent': 0.5, # Medium for recent
|
||||
'old': 0.3 # Dim for older
|
||||
}
|
||||
|
||||
logger.info(f"📊 Adding {len(prediction_groups)} prediction groups to chart")
|
||||
|
||||
for group, preds in prediction_groups.items():
|
||||
if not preds:
|
||||
continue
|
||||
|
||||
opacity = opacity_levels[group]
|
||||
logger.info(f"📈 Adding {group} predictions: {len(preds)} points, opacity: {opacity}")
|
||||
|
||||
# Add prediction line
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in preds],
|
||||
y=[p['y'] for p in preds],
|
||||
mode='lines+markers',
|
||||
line=dict(
|
||||
color=f'rgba(255, 215, 0, {opacity})', # Gold color
|
||||
width=2,
|
||||
dash='dash'
|
||||
),
|
||||
marker=dict(
|
||||
symbol='diamond',
|
||||
size=6,
|
||||
color=f'rgba(255, 215, 0, {opacity})',
|
||||
line=dict(width=1, color='rgba(255, 140, 0, 0.8)')
|
||||
),
|
||||
name=f'🔮 10min Pred ({group})',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>🔮 10-Minute Prediction</b><br>" +
|
||||
"Predicted Close: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"Age: %{customdata:.1f} min<br>" +
|
||||
"Confidence: %{text:.1%}<extra></extra>",
|
||||
customdata=[p['age'] for p in preds],
|
||||
text=[p['confidence'] for p in preds]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add confidence bands (high/low range)
|
||||
if len(preds) > 1:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in preds] + [p['x'] for p in reversed(preds)],
|
||||
y=[p['high'] for p in preds] + [p['low'] for p in reversed(preds)],
|
||||
fill='toself',
|
||||
fillcolor=f'rgba(255, 215, 0, {opacity * 0.2})',
|
||||
line=dict(width=0),
|
||||
mode='lines',
|
||||
name=f'Prediction Range ({group})',
|
||||
showlegend=False,
|
||||
hoverinfo='skip'
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding iterative predictions to chart: {e}")
|
||||
|
||||
def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add DQN action predictions as directional arrows"""
|
||||
try:
|
||||
@@ -5292,68 +5173,44 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error exporting trade history: {e}")
|
||||
return ""
|
||||
|
||||
def run_chained_inference(self, symbol: str = "ETH/USDT", n_steps: int = 10) -> bool:
|
||||
"""Run chained inference using the orchestrator's real models"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for chained inference")
|
||||
return False
|
||||
|
||||
logger.info(f"🔗 Running chained inference for {symbol} with {n_steps} steps")
|
||||
|
||||
# Run chained inference
|
||||
predictions = self.orchestrator.chain_inference(symbol, n_steps)
|
||||
|
||||
if predictions:
|
||||
# Store predictions
|
||||
self.chained_predictions = predictions
|
||||
self.last_chained_inference_time = datetime.now()
|
||||
|
||||
logger.info(f"✅ Chained inference completed: {len(predictions)} predictions generated")
|
||||
|
||||
# Log first few predictions for debugging
|
||||
for i, pred in enumerate(predictions[:3]):
|
||||
logger.info(f" Step {i}: {pred.get('model', 'Unknown')} - Confidence: {pred.get('confidence', 0):.3f}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning("❌ Chained inference returned no predictions")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running chained inference: {e}")
|
||||
return False
|
||||
|
||||
def export_trades_now(self) -> str:
|
||||
"""Convenience method to export trades immediately with timestamp"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"trades_export_{timestamp}.csv"
|
||||
return self.export_trade_history_csv(filename)
|
||||
|
||||
def run_iterative_prediction_10min(self, symbol: str = "ETH/USDT") -> Optional[Dict]:
|
||||
"""Run 10-minute iterative prediction using the multi-timeframe predictor"""
|
||||
try:
|
||||
if not self.multi_timeframe_predictor:
|
||||
logger.warning("Multi-timeframe predictor not available")
|
||||
return None
|
||||
|
||||
logger.info(f"🔮 Running 10-minute iterative prediction for {symbol}")
|
||||
|
||||
# Get current price and market conditions
|
||||
current_price = self._get_current_price(symbol)
|
||||
if not current_price:
|
||||
logger.warning(f"Could not get current price for {symbol}")
|
||||
return None
|
||||
|
||||
# Run iterative prediction for 10 minutes
|
||||
iterative_predictions = self.multi_timeframe_predictor._generate_iterative_predictions(
|
||||
symbol=symbol,
|
||||
base_data=self.multi_timeframe_predictor._get_sequence_data_for_horizon(
|
||||
symbol, self.multi_timeframe_predictor.horizons[PredictionHorizon.TEN_MINUTES]['sequence_length']
|
||||
),
|
||||
num_steps=10, # 10 steps for 10-minute prediction
|
||||
market_conditions={'confidence_multiplier': 1.0}
|
||||
)
|
||||
|
||||
if iterative_predictions:
|
||||
# Analyze the 10-minute prediction
|
||||
config = self.multi_timeframe_predictor.horizons[PredictionHorizon.TEN_MINUTES]
|
||||
market_conditions = self.multi_timeframe_predictor._assess_market_conditions(symbol)
|
||||
|
||||
horizon_prediction = self.multi_timeframe_predictor._analyze_horizon_prediction(
|
||||
iterative_predictions, config, market_conditions
|
||||
)
|
||||
|
||||
if horizon_prediction:
|
||||
# Store the prediction for dashboard display
|
||||
self.current_10min_prediction = {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'predictions': iterative_predictions,
|
||||
'horizon_analysis': horizon_prediction,
|
||||
'current_price': current_price
|
||||
}
|
||||
|
||||
logger.info(f"✅ 10-minute iterative prediction completed for {symbol}")
|
||||
logger.info(f"📊 Generated {len(iterative_predictions)} candle predictions")
|
||||
|
||||
return self.current_10min_prediction
|
||||
|
||||
logger.warning("Failed to generate 10-minute iterative prediction")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running 10-minute iterative prediction: {e}")
|
||||
return None
|
||||
|
||||
def create_10min_prediction_chart(self, opacity: float = 0.4) -> Dict[str, Any]:
|
||||
"""DEPRECATED: Create a chart visualizing the 10-minute iterative predictions with opacity
|
||||
Note: Predictions are now integrated directly into the main 1-minute chart"""
|
||||
@@ -6737,20 +6594,6 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error initializing enhanced training system: {e}")
|
||||
self.training_system = None
|
||||
|
||||
def _initialize_multi_timeframe_predictor(self):
|
||||
"""Initialize multi-timeframe prediction system"""
|
||||
try:
|
||||
if self.orchestrator:
|
||||
self.multi_timeframe_predictor = MultiTimeframePredictor(self.orchestrator)
|
||||
logger.info("Multi-timeframe prediction system initialized")
|
||||
else:
|
||||
logger.warning("Cannot initialize multi-timeframe predictor - no orchestrator available")
|
||||
self.multi_timeframe_predictor = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing multi-timeframe predictor: {e}")
|
||||
self.multi_timeframe_predictor = None
|
||||
|
||||
def _initialize_cob_integration(self):
|
||||
"""Initialize COB integration using orchestrator's COB system"""
|
||||
try:
|
||||
@@ -7070,69 +6913,24 @@ class CleanTradingDashboard:
|
||||
|
||||
logger.info(f"COB SIGNAL: {symbol} {signal['action']} signal generated - imbalance: {imbalance:.3f}, confidence: {signal['confidence']:.3f}")
|
||||
|
||||
# Enhance signal with multi-timeframe predictions if available
|
||||
enhanced_signal = self._enhance_signal_with_multi_timeframe(signal)
|
||||
if enhanced_signal:
|
||||
signal = enhanced_signal
|
||||
|
||||
# Process the signal for potential execution
|
||||
self._process_dashboard_signal(signal)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error generating COB signal for {symbol}: {e}")
|
||||
|
||||
def _enhance_signal_with_multi_timeframe(self, signal: Dict) -> Optional[Dict]:
|
||||
"""Enhance signal with multi-timeframe predictions for better accuracy and hold times"""
|
||||
def _get_rl_state_for_training(self, symbol: str, current_price: float) -> Dict[str, Any]:
|
||||
"""Get RL state for training purposes"""
|
||||
try:
|
||||
if not self.multi_timeframe_predictor:
|
||||
return signal
|
||||
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
|
||||
# Generate multi-timeframe prediction
|
||||
multi_prediction = self.multi_timeframe_predictor.generate_multi_timeframe_prediction(symbol)
|
||||
|
||||
if not multi_prediction:
|
||||
return signal
|
||||
|
||||
# Check if we should execute the trade
|
||||
should_execute, reason = self.multi_timeframe_predictor.should_execute_trade(multi_prediction)
|
||||
|
||||
if not should_execute:
|
||||
logger.debug(f"Multi-timeframe analysis: Not executing - {reason}")
|
||||
return None # Don't execute this signal
|
||||
|
||||
# Find the best prediction for enhanced signal
|
||||
best_prediction = None
|
||||
best_confidence = 0
|
||||
|
||||
for horizon, pred in multi_prediction.predictions.items():
|
||||
if pred['confidence'] > best_confidence:
|
||||
best_confidence = pred['confidence']
|
||||
best_prediction = (horizon, pred)
|
||||
|
||||
if best_prediction:
|
||||
horizon, pred = best_prediction
|
||||
|
||||
# Enhance original signal with multi-timeframe data
|
||||
enhanced_signal = signal.copy()
|
||||
enhanced_signal['confidence'] = pred['confidence'] # Use higher confidence
|
||||
enhanced_signal['prediction_horizon'] = horizon.value # Store horizon
|
||||
enhanced_signal['hold_time_minutes'] = horizon.value # Suggested hold time
|
||||
enhanced_signal['multi_timeframe'] = True
|
||||
enhanced_signal['models_used'] = pred.get('models_used', 1)
|
||||
enhanced_signal['reasoning'] = f"{signal.get('reasoning', '')} | Multi-timeframe {horizon.value}min prediction"
|
||||
|
||||
logger.info(f"Enhanced signal: {symbol} {pred['action']} with {pred['confidence']:.2f} confidence "
|
||||
f"for {horizon.value}-minute horizon")
|
||||
|
||||
return enhanced_signal
|
||||
|
||||
return signal
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'price': current_price,
|
||||
'timestamp': datetime.now(),
|
||||
'features': [current_price, 0, 0, 0, 0] # Placeholder features
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error enhancing signal with multi-timeframe: {e}")
|
||||
return signal
|
||||
logger.error(f"Error getting RL state: {e}")
|
||||
return {}
|
||||
|
||||
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
|
||||
"""Feed COB data to ALL models for training and inference - Enhanced integration"""
|
||||
@@ -7601,6 +7399,11 @@ class CleanTradingDashboard:
|
||||
"""Start the Dash server"""
|
||||
try:
|
||||
logger.info(f"TRADING: Starting Clean Dashboard at http://{host}:{port}")
|
||||
|
||||
# Run initial chained inference when dashboard starts
|
||||
logger.info("🔗 Running initial chained inference...")
|
||||
self.run_chained_inference("ETH/USDT", n_steps=10)
|
||||
|
||||
# Run the Dash app normally; launch/activation is handled by the runner
|
||||
if hasattr(self, 'app') and self.app is not None:
|
||||
# Dash 3.x: use app.run
|
||||
|
@@ -18,6 +18,7 @@ class DashboardLayoutManager:
|
||||
"""Create the main dashboard layout with dark theme"""
|
||||
return html.Div([
|
||||
self._create_header(),
|
||||
self._create_chained_inference_status(),
|
||||
self._create_interval_component(),
|
||||
self._create_main_content(),
|
||||
self._create_prediction_tracking_section() # NEW: Prediction tracking
|
||||
@@ -105,13 +106,27 @@ class DashboardLayoutManager:
|
||||
)
|
||||
], className="bg-dark p-2 mb-2")
|
||||
|
||||
def _create_chained_inference_status(self):
|
||||
"""Create chained inference status display"""
|
||||
return html.Div([
|
||||
html.H6("🔗 Chained Inference Status", className="text-warning mb-1"),
|
||||
html.Div(id="chained-inference-status", className="text-light small", children="Initializing...")
|
||||
], className="bg-dark p-2 mb-2")
|
||||
|
||||
def _create_interval_component(self):
|
||||
"""Create the auto-refresh interval component"""
|
||||
return dcc.Interval(
|
||||
return html.Div([
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=1000, # Update every 1 second for maximum responsiveness
|
||||
n_intervals=0
|
||||
),
|
||||
dcc.Interval(
|
||||
id='minute-interval-component',
|
||||
interval=60000, # Update every 60 seconds for chained inference
|
||||
n_intervals=0
|
||||
)
|
||||
])
|
||||
|
||||
def _create_main_content(self):
|
||||
"""Create the main content area"""
|
||||
|
Reference in New Issue
Block a user