""" Test script for StandardizedCNN This script tests the standardized CNN model with BaseDataInput format """ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import logging import torch from datetime import datetime from core.standardized_data_provider import StandardizedDataProvider from NN.models.standardized_cnn import StandardizedCNN # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def test_standardized_cnn(): """Test the StandardizedCNN with BaseDataInput""" print("Testing StandardizedCNN with BaseDataInput...") # Initialize data provider symbols = ['ETH/USDT', 'BTC/USDT'] provider = StandardizedDataProvider(symbols=symbols) # Initialize CNN model cnn_model = StandardizedCNN( model_name="test_standardized_cnn_v1", confidence_threshold=0.6 ) print("✅ StandardizedCNN initialized") print(f" Model info: {cnn_model.get_model_info()}") # Test 1: Get BaseDataInput print("\n1. Testing BaseDataInput creation...") # Set mock current price for COB data provider.current_prices['ETHUSDT'] = 3000.0 provider.current_prices['BTCUSDT'] = 50000.0 base_input = provider.get_base_data_input('ETH/USDT') if base_input is None: print("⚠️ BaseDataInput is None - creating mock data for testing") # Create mock BaseDataInput for testing from core.data_models import BaseDataInput, OHLCVBar, COBData # Create mock OHLCV data mock_ohlcv = [] for i in range(300): bar = OHLCVBar( symbol='ETH/USDT', timestamp=datetime.now(), open=3000.0 + i, high=3010.0 + i, low=2990.0 + i, close=3005.0 + i, volume=1000.0, timeframe='1s' ) mock_ohlcv.append(bar) # Create mock COB data mock_cob = COBData( symbol='ETH/USDT', timestamp=datetime.now(), current_price=3000.0, bucket_size=1.0, price_buckets={3000.0 + i: {'bid_volume': 100, 'ask_volume': 100, 'total_volume': 200, 'imbalance': 0.0} for i in range(-20, 21)}, bid_ask_imbalance={3000.0 + i: 0.0 for i in range(-20, 21)}, volume_weighted_prices={3000.0 + i: 3000.0 + i for i in range(-20, 21)}, order_flow_metrics={} ) base_input = BaseDataInput( symbol='ETH/USDT', timestamp=datetime.now(), ohlcv_1s=mock_ohlcv, ohlcv_1m=mock_ohlcv, ohlcv_1h=mock_ohlcv, ohlcv_1d=mock_ohlcv, btc_ohlcv_1s=mock_ohlcv, cob_data=mock_cob ) print(f"✅ BaseDataInput available: {base_input.symbol}") print(f" Feature vector shape: {base_input.get_feature_vector().shape}") print(f" Validation: {'PASSED' if base_input.validate() else 'FAILED'}") # Test 2: CNN Inference print("\n2. Testing CNN inference with BaseDataInput...") try: model_output = cnn_model.predict_from_base_input(base_input) print("✅ CNN inference successful!") print(f" Model: {model_output.model_name} ({model_output.model_type})") print(f" Action: {model_output.predictions['action']}") print(f" Confidence: {model_output.confidence:.3f}") print(f" Probabilities: BUY={model_output.predictions['buy_probability']:.3f}, " f"SELL={model_output.predictions['sell_probability']:.3f}, " f"HOLD={model_output.predictions['hold_probability']:.3f}") print(f" Hidden states: {len(model_output.hidden_states)} layers") print(f" Metadata: {len(model_output.metadata)} fields") # Test hidden states for cross-model feeding if model_output.hidden_states: print(" Hidden state layers:") for key, value in model_output.hidden_states.items(): if isinstance(value, list): print(f" {key}: {len(value)} features") else: print(f" {key}: {type(value)}") except Exception as e: print(f"❌ CNN inference failed: {e}") import traceback traceback.print_exc() # Test 3: Integration with StandardizedDataProvider print("\n3. Testing integration with StandardizedDataProvider...") try: # Store the model output in the provider provider.store_model_output(model_output) # Retrieve it back stored_outputs = provider.get_model_outputs('ETH/USDT') if cnn_model.model_name in stored_outputs: print("✅ Model output storage and retrieval successful!") stored_output = stored_outputs[cnn_model.model_name] print(f" Stored action: {stored_output.predictions['action']}") print(f" Stored confidence: {stored_output.confidence:.3f}") else: print("❌ Model output storage failed") # Test cross-model feeding updated_base_input = provider.get_base_data_input('ETH/USDT') if updated_base_input and cnn_model.model_name in updated_base_input.last_predictions: print("✅ Cross-model feeding working!") print(f" CNN prediction available in BaseDataInput for other models") else: print("⚠️ Cross-model feeding not working as expected") except Exception as e: print(f"❌ Integration test failed: {e}") # Test 4: Training capabilities print("\n4. Testing training capabilities...") try: # Create mock training data training_inputs = [base_input] * 5 # Small batch training_targets = ['BUY', 'SELL', 'HOLD', 'BUY', 'HOLD'] # Create optimizer optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001) # Perform training step loss = cnn_model.train_step(training_inputs, training_targets, optimizer) print(f"✅ Training step successful!") print(f" Training loss: {loss:.4f}") # Test evaluation eval_metrics = cnn_model.evaluate(training_inputs, training_targets) print(f" Evaluation metrics: {eval_metrics}") except Exception as e: print(f"❌ Training test failed: {e}") import traceback traceback.print_exc() # Test 5: Checkpoint management print("\n5. Testing checkpoint management...") try: # Save checkpoint checkpoint_path = "test_cache/cnn_checkpoint.pth" os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) metadata = { 'training_loss': loss if 'loss' in locals() else 0.5, 'accuracy': eval_metrics.get('accuracy', 0.0) if 'eval_metrics' in locals() else 0.0, 'test_run': True } cnn_model.save_checkpoint(checkpoint_path, metadata) print("✅ Checkpoint saved successfully!") # Create new model and load checkpoint new_cnn = StandardizedCNN(model_name="loaded_cnn_v1") success = new_cnn.load_checkpoint(checkpoint_path) if success: print("✅ Checkpoint loaded successfully!") print(f" Loaded model info: {new_cnn.get_model_info()}") else: print("❌ Checkpoint loading failed") except Exception as e: print(f"❌ Checkpoint test failed: {e}") # Test 6: Performance and compatibility print("\n6. Testing performance and compatibility...") try: # Test inference speed import time start_time = time.time() for _ in range(10): _ = cnn_model.predict_from_base_input(base_input) end_time = time.time() avg_inference_time = (end_time - start_time) / 10 * 1000 # ms print(f"✅ Performance test completed!") print(f" Average inference time: {avg_inference_time:.2f} ms") # Test memory usage if torch.cuda.is_available(): memory_used = torch.cuda.memory_allocated() / 1024 / 1024 # MB print(f" GPU memory used: {memory_used:.2f} MB") # Test model size param_count = sum(p.numel() for p in cnn_model.parameters()) model_size_mb = param_count * 4 / 1024 / 1024 # Assuming float32 print(f" Model parameters: {param_count:,}") print(f" Estimated model size: {model_size_mb:.2f} MB") except Exception as e: print(f"❌ Performance test failed: {e}") print("\n✅ StandardizedCNN test completed!") print("\n🎯 Key achievements:") print("✓ Accepts standardized BaseDataInput format") print("✓ Processes COB+OHLCV data (300 frames multi-timeframe)") print("✓ Outputs BUY/SELL/HOLD with confidence scores") print("✓ Provides hidden states for cross-model feeding") print("✓ Integrates with ModelOutputManager") print("✓ Supports training and evaluation") print("✓ Checkpoint management for persistence") print("✓ Real-time inference capabilities") print("\n🚀 Ready for integration:") print("1. Can be used by orchestrator for decision making") print("2. Hidden states available for RL model cross-feeding") print("3. Outputs stored in standardized ModelOutput format") print("4. Compatible with checkpoint management system") print("5. Optimized for real-time trading inference") return cnn_model if __name__ == "__main__": test_standardized_cnn()