""" Test Enhanced CNN Adapter This script tests the EnhancedCNNAdapter with standardized input format. """ import logging import time from datetime import datetime from core.standardized_data_provider import StandardizedDataProvider from core.enhanced_cnn_adapter import EnhancedCNNAdapter from core.data_models import create_model_output # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def test_cnn_adapter(): """Test the EnhancedCNNAdapter with standardized input format""" try: # Initialize data provider symbols = ['ETH/USDT', 'BTC/USDT'] timeframes = ['1s', '1m', '1h', '1d'] data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) # Initialize CNN adapter cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") # Load best checkpoint if available cnn_adapter.load_best_checkpoint() # Get standardized input data logger.info("Getting standardized input data...") base_data = data_provider.get_base_data_input('ETH/USDT') if base_data is None: logger.error("Failed to get base data input") return # Make prediction logger.info("Making prediction...") model_output = cnn_adapter.predict(base_data) # Log prediction logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") # Store model output data_provider.store_model_output(model_output) # Add training sample (simulated) logger.info("Adding training sample...") cnn_adapter.add_training_sample(base_data, 'BUY', 0.05) # Train model logger.info("Training model...") metrics = cnn_adapter.train(epochs=1) # Log training metrics logger.info(f"Training metrics: {metrics}") # Make another prediction logger.info("Making another prediction...") model_output = cnn_adapter.predict(base_data) # Log prediction logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") # Test model output manager logger.info("Testing model output manager...") output_manager = data_provider.get_model_output_manager() # Get current outputs current_outputs = output_manager.get_all_current_outputs('ETH/USDT') logger.info(f"Current outputs: {len(current_outputs)} models") # Evaluate model performance metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1') logger.info(f"Performance metrics: {metrics}") logger.info("Test completed successfully") except Exception as e: logger.error(f"Error in test: {e}", exc_info=True) if __name__ == "__main__": test_cnn_adapter()