#!/usr/bin/env python3 """ Test Training Fixes This script tests the fixes for CNN adapter and DQN training issues. """ import asyncio import time import numpy as np from datetime import datetime from core.orchestrator import TradingOrchestrator from core.data_provider import DataProvider async def test_training_fixes(): """Test the training fixes""" print("=== Testing Training Fixes ===") # Initialize orchestrator print("1. Initializing orchestrator...") data_provider = DataProvider() orchestrator = TradingOrchestrator(data_provider=data_provider) # Wait for initialization await asyncio.sleep(3) # Check CNN adapter initialization print("\n2. Checking CNN adapter initialization:") if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter: print(" ✅ CNN adapter is properly initialized") print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}") else: print(" ❌ CNN adapter is None or missing") # Check DQN agent initialization print("\n3. Checking DQN agent initialization:") if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent: print(" ✅ DQN agent is properly initialized") print(f" DQN agent type: {type(orchestrator.rl_agent)}") if hasattr(orchestrator.rl_agent, 'policy_net'): print(" ✅ DQN policy network is available") else: print(" ❌ DQN policy network is missing") else: print(" ❌ DQN agent is None or missing") # Test CNN predictions print("\n4. Testing CNN predictions:") try: predictions = await orchestrator._get_all_predictions('ETH/USDT') cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()] if cnn_predictions: print(f" ✅ Got {len(cnn_predictions)} CNN predictions") for pred in cnn_predictions: print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})") else: print(" ❌ No CNN predictions received") except Exception as e: print(f" ❌ CNN prediction failed: {e}") # Test training with validation print("\n5. Testing training with validation:") for i in range(3): print(f" Training iteration {i+1}/3...") # Create training records for different models training_records = [ { 'model_name': 'enhanced_cnn', 'model_input': np.random.randn(7850), 'prediction': {'action': 'BUY', 'confidence': 0.7}, 'symbol': 'ETH/USDT', 'timestamp': datetime.now() }, { 'model_name': 'dqn_agent', 'model_input': np.random.randn(7850), 'prediction': {'action': 'SELL', 'confidence': 0.8}, 'symbol': 'ETH/USDT', 'timestamp': datetime.now() } ] for record in training_records: try: success = await orchestrator._train_model_on_outcome( record, True, 0.5, 1.0 ) if success: print(f" ✅ Training succeeded for {record['model_name']}") else: print(f" ⚠️ Training failed for {record['model_name']}") except Exception as e: print(f" ❌ Training error for {record['model_name']}: {e}") await asyncio.sleep(1) # Show final statistics print("\n6. Final model statistics:") orchestrator.log_model_statistics(detailed=True) # Check for overfitting warnings print("\n7. Checking for training quality:") summary = orchestrator.get_model_statistics_summary() for model_name, stats in summary.items(): if stats['total_trainings'] > 0: print(f" {model_name}: {stats['total_trainings']} trainings, " f"avg time: {stats['average_training_time_ms']:.1f}ms") if stats['current_loss'] is not None: if stats['current_loss'] < 0.001: print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting") else: print(f" ✅ {model_name} has reasonable loss ({stats['current_loss']:.6f})") print("\n✅ Training fixes test completed!") if __name__ == "__main__": asyncio.run(test_training_fixes())