#!/usr/bin/env python3 """ Test script to verify device mismatch fixes for GPU training """ import torch import logging import sys import os # Add the project root to the path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from NN.models.enhanced_cnn import EnhancedCNN from core.data_models import BaseDataInput, OHLCVBar # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def test_device_consistency(): """Test that all tensors are on the same device""" logger.info("Testing device consistency for EnhancedCNN...") # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") try: # Initialize the adapter adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") # Verify adapter device logger.info(f"Adapter device: {adapter.device}") logger.info(f"Model device: {next(adapter.model.parameters()).device}") # Create sample data sample_ohlcv = [ OHLCVBar( symbol="ETH/USDT", timeframe="1s", timestamp=1640995200.0, # 2022-01-01 open=50000.0, high=51000.0, low=49000.0, close=50500.0, volume=1000.0 ) ] * 300 # 300 frames base_data = BaseDataInput( symbol="ETH/USDT", timestamp=1640995200.0, ohlcv_1s=sample_ohlcv, ohlcv_1m=sample_ohlcv, ohlcv_5m=sample_ohlcv, ohlcv_15m=sample_ohlcv, btc_ohlcv=sample_ohlcv, cob_data={}, ma_data={}, technical_indicators={}, last_predictions={} ) # Test prediction logger.info("Testing prediction...") prediction = adapter.predict(base_data) logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})") # Test training sample addition logger.info("Testing training sample addition...") adapter.add_training_sample(base_data, "BUY", 0.1) adapter.add_training_sample(base_data, "SELL", -0.05) adapter.add_training_sample(base_data, "HOLD", 0.02) # Test training logger.info("Testing training...") training_results = adapter.train(epochs=1) logger.info(f"Training results: {training_results}") logger.info("✅ All device consistency tests passed!") return True except Exception as e: logger.error(f"❌ Device consistency test failed: {e}") import traceback traceback.print_exc() return False def test_orchestrator_inference_history(): """Test that orchestrator properly initializes inference history""" logger.info("Testing orchestrator inference history initialization...") try: from core.orchestrator import TradingOrchestrator from core.data_provider import DataProvider # Initialize orchestrator data_provider = DataProvider() orchestrator = TradingOrchestrator(data_provider=data_provider) # Check if inference history is initialized logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") # Check if models are registered logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}") # Verify each registered model has inference history for model_name in orchestrator.model_registry.models.keys(): if model_name in orchestrator.inference_history: logger.info(f"✅ {model_name} has inference history initialized") else: logger.warning(f"❌ {model_name} missing inference history") logger.info("✅ Orchestrator inference history test completed!") return True except Exception as e: logger.error(f"❌ Orchestrator test failed: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": logger.info("Starting device fix verification tests...") # Test 1: Device consistency test1_passed = test_device_consistency() # Test 2: Orchestrator inference history test2_passed = test_orchestrator_inference_history() # Summary if test1_passed and test2_passed: logger.info("🎉 All tests passed! Device issues should be fixed.") sys.exit(0) else: logger.error("❌ Some tests failed. Please check the logs above.") sys.exit(1)