#!/usr/bin/env python3 """ Test script to verify device handling and training sample population fixes """ import logging import asyncio import torch from datetime import datetime # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def test_device_handling(): """Test that device handling is working correctly""" try: logger.info("Testing device handling...") # Test 1: Check CUDA availability cuda_available = torch.cuda.is_available() device = torch.device("cuda" if cuda_available else "cpu") logger.info(f"CUDA available: {cuda_available}") logger.info(f"Using device: {device}") # Test 2: Initialize CNN adapter from core.enhanced_cnn_adapter import EnhancedCNNAdapter logger.info("Initializing CNN adapter...") cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") logger.info(f"CNN adapter device: {cnn_adapter.device}") logger.info(f"CNN model device: {cnn_adapter.model.device}") # Test 3: Create test data from core.data_models import BaseDataInput logger.info("Creating test BaseDataInput...") base_data = BaseDataInput( symbol="ETH/USDT", timestamp=datetime.now(), ohlcv_1s=[], ohlcv_1m=[], ohlcv_1h=[], ohlcv_1d=[], btc_ohlcv_1s=[], cob_data=None, technical_indicators={}, last_predictions={} ) # Test 4: Make prediction (this should not cause device mismatch) logger.info("Making prediction...") prediction = cnn_adapter.predict(base_data) logger.info(f"Prediction successful: {prediction.predictions['action']}") logger.info(f"Confidence: {prediction.confidence:.4f}") # Test 5: Add training samples logger.info("Adding training samples...") cnn_adapter.add_training_sample(base_data, "BUY", 0.1) cnn_adapter.add_training_sample(base_data, "SELL", -0.05) cnn_adapter.add_training_sample(base_data, "HOLD", 0.02) logger.info(f"Training samples added: {len(cnn_adapter.training_data)}") # Test 6: Try training if we have enough samples if len(cnn_adapter.training_data) >= 2: logger.info("Attempting training...") training_results = cnn_adapter.train(epochs=1) logger.info(f"Training results: {training_results}") else: logger.info("Not enough samples for training") logger.info("✅ Device handling test passed!") return True except Exception as e: logger.error(f"❌ Device handling test failed: {e}") import traceback traceback.print_exc() return False async def test_orchestrator_training(): """Test that orchestrator properly adds training samples""" try: logger.info("Testing orchestrator training integration...") # Test 1: Initialize orchestrator from core.orchestrator import TradingOrchestrator from core.standardized_data_provider import StandardizedDataProvider logger.info("Initializing data provider...") data_provider = StandardizedDataProvider() logger.info("Initializing orchestrator...") orchestrator = TradingOrchestrator(data_provider=data_provider) # Test 2: Check if CNN adapter is available if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter: logger.info(f"✅ CNN adapter available in orchestrator") logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}") else: logger.warning("⚠️ CNN adapter not available in orchestrator") return False # Test 3: Make a trading decision (this should add training samples) logger.info("Making trading decision...") decision = await orchestrator.make_trading_decision("ETH/USDT") if decision: logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})") logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}") else: logger.warning("No decision made") # Test 4: Check inference history logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") for model_name, history in orchestrator.inference_history.items(): logger.info(f" {model_name}: {len(history)} records") logger.info("✅ Orchestrator training test passed!") return True except Exception as e: logger.error(f"❌ Orchestrator training test failed: {e}") import traceback traceback.print_exc() return False async def main(): """Run all tests""" logger.info("Starting device and training fix tests...") # Test 1: Device handling test1_passed = test_device_handling() # Test 2: Orchestrator training test2_passed = await test_orchestrator_training() # Summary logger.info("\n" + "="*50) logger.info("TEST SUMMARY:") logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}") logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}") if test1_passed and test2_passed: logger.info("🎉 All tests passed! Device and training issues should be fixed.") else: logger.error("❌ Some tests failed. Please check the logs above.") if __name__ == "__main__": asyncio.run(main())