#!/usr/bin/env python3 """ Test Model Training Implementation This script tests the improved model training functionality. """ import asyncio import time import numpy as np from datetime import datetime from core.orchestrator import TradingOrchestrator from core.data_provider import DataProvider async def test_model_training(): """Test the improved model training system""" print("=== Testing Model Training System ===") # Initialize orchestrator print("1. Initializing orchestrator...") data_provider = DataProvider() orchestrator = TradingOrchestrator(data_provider=data_provider) # Wait for initialization await asyncio.sleep(3) # Show initial model statistics print("\n2. Initial model statistics:") orchestrator.log_model_statistics() # Run predictions to generate training data print("\n3. Running predictions to generate training data...") predictions_data = [] for i in range(3): print(f" Running prediction batch {i+1}/3...") predictions = await orchestrator._get_all_predictions('ETH/USDT') print(f" Got {len(predictions)} predictions") # Store prediction data for training simulation for pred in predictions: predictions_data.append({ 'model_name': pred.model_name, 'prediction': { 'action': pred.action, 'confidence': pred.confidence }, 'timestamp': pred.timestamp, 'symbol': 'ETH/USDT' }) await asyncio.sleep(1) print(f"\n4. Collected {len(predictions_data)} predictions for training") # Simulate training with different outcomes print("\n5. Testing training with simulated outcomes...") for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions # Simulate market outcome was_correct = i % 2 == 0 # Alternate between correct and incorrect price_change_pct = 0.5 if was_correct else -0.3 sophisticated_reward = 1.0 if was_correct else -0.5 # Create training record training_record = { 'model_name': pred_data['model_name'], 'model_input': np.random.randn(7850), # Simulate model input 'prediction': pred_data['prediction'], 'symbol': pred_data['symbol'], 'timestamp': pred_data['timestamp'] } print(f" Training {pred_data['model_name']}: " f"action={pred_data['prediction']['action']}, " f"correct={was_correct}, reward={sophisticated_reward}") # Test the training method try: await orchestrator._train_model_on_outcome( training_record, was_correct, price_change_pct, sophisticated_reward ) print(f" ✅ Training completed for {pred_data['model_name']}") except Exception as e: print(f" ❌ Training failed for {pred_data['model_name']}: {e}") # Show updated statistics print("\n6. Updated model statistics after training:") orchestrator.log_model_statistics(detailed=True) # Test specific model training methods print("\n7. Testing specific model training methods...") # Test DQN training if 'dqn_agent' in orchestrator.model_statistics: print(" Testing DQN agent training...") dqn_record = { 'model_name': 'dqn_agent', 'model_input': np.random.randn(7850), 'prediction': {'action': 'BUY', 'confidence': 0.8}, 'symbol': 'ETH/USDT', 'timestamp': datetime.now() } try: await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0) print(" ✅ DQN training test passed") except Exception as e: print(f" ❌ DQN training test failed: {e}") # Test CNN training if 'enhanced_cnn' in orchestrator.model_statistics: print(" Testing CNN model training...") cnn_record = { 'model_name': 'enhanced_cnn', 'model_input': np.random.randn(7850), 'prediction': {'action': 'SELL', 'confidence': 0.6}, 'symbol': 'ETH/USDT', 'timestamp': datetime.now() } try: await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5) print(" ✅ CNN training test passed") except Exception as e: print(f" ❌ CNN training test failed: {e}") # Show final statistics print("\n8. Final model statistics:") summary = orchestrator.get_model_statistics_summary() for model_name, stats in summary.items(): print(f" {model_name}:") print(f" Inferences: {stats['total_inferences']}") print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min") print(f" Current loss: {stats['current_loss']}") print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})") print("\n✅ Model training test completed!") if __name__ == "__main__": asyncio.run(test_model_training())