device tensor fix
This commit is contained in:
141
test_device_fix.py
Normal file
141
test_device_fix.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device mismatch fixes for GPU training
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import BaseDataInput, OHLCVBar
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_consistency():
|
||||
"""Test that all tensors are on the same device"""
|
||||
|
||||
logger.info("Testing device consistency for EnhancedCNN...")
|
||||
|
||||
# Check if CUDA is available
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
# Initialize the adapter
|
||||
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
# Verify adapter device
|
||||
logger.info(f"Adapter device: {adapter.device}")
|
||||
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
|
||||
|
||||
# Create sample data
|
||||
sample_ohlcv = [
|
||||
OHLCVBar(
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1s",
|
||||
timestamp=1640995200.0, # 2022-01-01
|
||||
open=50000.0,
|
||||
high=51000.0,
|
||||
low=49000.0,
|
||||
close=50500.0,
|
||||
volume=1000.0
|
||||
)
|
||||
] * 300 # 300 frames
|
||||
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=1640995200.0,
|
||||
ohlcv_1s=sample_ohlcv,
|
||||
ohlcv_1m=sample_ohlcv,
|
||||
ohlcv_5m=sample_ohlcv,
|
||||
ohlcv_15m=sample_ohlcv,
|
||||
btc_ohlcv=sample_ohlcv,
|
||||
cob_data={},
|
||||
ma_data={},
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test prediction
|
||||
logger.info("Testing prediction...")
|
||||
prediction = adapter.predict(base_data)
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Test training sample addition
|
||||
logger.info("Testing training sample addition...")
|
||||
adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
# Test training
|
||||
logger.info("Testing training...")
|
||||
training_results = adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
|
||||
logger.info("✅ All device consistency tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device consistency test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_orchestrator_inference_history():
|
||||
"""Test that orchestrator properly initializes inference history"""
|
||||
|
||||
logger.info("Testing orchestrator inference history initialization...")
|
||||
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Initialize orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Check if inference history is initialized
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
|
||||
# Check if models are registered
|
||||
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
|
||||
|
||||
# Verify each registered model has inference history
|
||||
for model_name in orchestrator.model_registry.models.keys():
|
||||
if model_name in orchestrator.inference_history:
|
||||
logger.info(f"✅ {model_name} has inference history initialized")
|
||||
else:
|
||||
logger.warning(f"❌ {model_name} missing inference history")
|
||||
|
||||
logger.info("✅ Orchestrator inference history test completed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting device fix verification tests...")
|
||||
|
||||
# Test 1: Device consistency
|
||||
test1_passed = test_device_consistency()
|
||||
|
||||
# Test 2: Orchestrator inference history
|
||||
test2_passed = test_orchestrator_inference_history()
|
||||
|
||||
# Summary
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device issues should be fixed.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
sys.exit(1)
|
Reference in New Issue
Block a user