#!/usr/bin/env python3 """ Quick test script for CNN and RL training pipelines """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) from core.config import setup_logging from core.data_provider import DataProvider from training.cnn_trainer import CNNTrainer from training.rl_trainer import RLTrainer def test_cnn_training(): """Test CNN training with small dataset""" print("\n=== Testing CNN Training ===") # Setup data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h']) trainer = CNNTrainer(data_provider) # Configure for quick test trainer.num_samples = 1000 # Small dataset trainer.num_epochs = 5 # Few epochs trainer.batch_size = 32 # Train results = trainer.train(['ETH/USDT'], save_path='test_models/test_cnn.pt') print(f"CNN Training completed!") print(f" Best accuracy: {results['best_val_accuracy']:.4f}") print(f" Training time: {results['total_time']:.2f}s") return True def test_rl_training(): """Test RL training with small dataset""" print("\n=== Testing RL Training ===") # Setup data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h']) trainer = RLTrainer(data_provider) # Configure for quick test trainer.num_episodes = 10 trainer.max_steps_per_episode = 100 trainer.evaluation_frequency = 5 # Train results = trainer.train(save_path='test_models/test_rl.pt') print(f"RL Training completed!") print(f" Best reward: {results['best_reward']:.4f}") print(f" Final balance: ${results['best_balance']:.2f}") return True def main(): setup_logging() try: # Test CNN if test_cnn_training(): print("āœ… CNN training test passed!") # Test RL if test_rl_training(): print("āœ… RL training test passed!") print("\nāœ… All tests passed!") except Exception as e: print(f"\nāŒ Test failed: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": sys.exit(main())