141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify device mismatch fixes for GPU training
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import sys
|
|
import os
|
|
|
|
# Add the project root to the path
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from NN.models.enhanced_cnn import EnhancedCNN
|
|
from core.data_models import BaseDataInput, OHLCVBar
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_device_consistency():
|
|
"""Test that all tensors are on the same device"""
|
|
|
|
logger.info("Testing device consistency for EnhancedCNN...")
|
|
|
|
# Check if CUDA is available
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logger.info(f"Using device: {device}")
|
|
|
|
try:
|
|
# Initialize the adapter
|
|
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
|
|
# Verify adapter device
|
|
logger.info(f"Adapter device: {adapter.device}")
|
|
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
|
|
|
|
# Create sample data
|
|
sample_ohlcv = [
|
|
OHLCVBar(
|
|
symbol="ETH/USDT",
|
|
timeframe="1s",
|
|
timestamp=1640995200.0, # 2022-01-01
|
|
open=50000.0,
|
|
high=51000.0,
|
|
low=49000.0,
|
|
close=50500.0,
|
|
volume=1000.0
|
|
)
|
|
] * 300 # 300 frames
|
|
|
|
base_data = BaseDataInput(
|
|
symbol="ETH/USDT",
|
|
timestamp=1640995200.0,
|
|
ohlcv_1s=sample_ohlcv,
|
|
ohlcv_1m=sample_ohlcv,
|
|
ohlcv_5m=sample_ohlcv,
|
|
ohlcv_15m=sample_ohlcv,
|
|
btc_ohlcv=sample_ohlcv,
|
|
cob_data={},
|
|
ma_data={},
|
|
technical_indicators={},
|
|
last_predictions={}
|
|
)
|
|
|
|
# Test prediction
|
|
logger.info("Testing prediction...")
|
|
prediction = adapter.predict(base_data)
|
|
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
|
|
|
|
# Test training sample addition
|
|
logger.info("Testing training sample addition...")
|
|
adapter.add_training_sample(base_data, "BUY", 0.1)
|
|
adapter.add_training_sample(base_data, "SELL", -0.05)
|
|
adapter.add_training_sample(base_data, "HOLD", 0.02)
|
|
|
|
# Test training
|
|
logger.info("Testing training...")
|
|
training_results = adapter.train(epochs=1)
|
|
logger.info(f"Training results: {training_results}")
|
|
|
|
logger.info("✅ All device consistency tests passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Device consistency test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_orchestrator_inference_history():
|
|
"""Test that orchestrator properly initializes inference history"""
|
|
|
|
logger.info("Testing orchestrator inference history initialization...")
|
|
|
|
try:
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
# Initialize orchestrator
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
|
|
|
# Check if inference history is initialized
|
|
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
|
|
|
# Check if models are registered
|
|
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
|
|
|
|
# Verify each registered model has inference history
|
|
for model_name in orchestrator.model_registry.models.keys():
|
|
if model_name in orchestrator.inference_history:
|
|
logger.info(f"✅ {model_name} has inference history initialized")
|
|
else:
|
|
logger.warning(f"❌ {model_name} missing inference history")
|
|
|
|
logger.info("✅ Orchestrator inference history test completed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Orchestrator test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting device fix verification tests...")
|
|
|
|
# Test 1: Device consistency
|
|
test1_passed = test_device_consistency()
|
|
|
|
# Test 2: Orchestrator inference history
|
|
test2_passed = test_orchestrator_inference_history()
|
|
|
|
# Summary
|
|
if test1_passed and test2_passed:
|
|
logger.info("🎉 All tests passed! Device issues should be fixed.")
|
|
sys.exit(0)
|
|
else:
|
|
logger.error("❌ Some tests failed. Please check the logs above.")
|
|
sys.exit(1) |