gogo2/test_training.py
2025-05-24 02:42:11 +03:00

82 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""
Quick test script for CNN and RL training pipelines
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from core.config import setup_logging
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
from training.rl_trainer import RLTrainer
def test_cnn_training():
"""Test CNN training with small dataset"""
print("\n=== Testing CNN Training ===")
# Setup
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
trainer = CNNTrainer(data_provider)
# Configure for quick test
trainer.num_samples = 1000 # Small dataset
trainer.num_epochs = 5 # Few epochs
trainer.batch_size = 32
# Train
results = trainer.train(['ETH/USDT'], save_path='test_models/test_cnn.pt')
print(f"CNN Training completed!")
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
print(f" Training time: {results['total_time']:.2f}s")
return True
def test_rl_training():
"""Test RL training with small dataset"""
print("\n=== Testing RL Training ===")
# Setup
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
trainer = RLTrainer(data_provider)
# Configure for quick test
trainer.num_episodes = 10
trainer.max_steps_per_episode = 100
trainer.evaluation_frequency = 5
# Train
results = trainer.train(save_path='test_models/test_rl.pt')
print(f"RL Training completed!")
print(f" Best reward: {results['best_reward']:.4f}")
print(f" Final balance: ${results['best_balance']:.2f}")
return True
def main():
setup_logging()
try:
# Test CNN
if test_cnn_training():
print("✅ CNN training test passed!")
# Test RL
if test_rl_training():
print("✅ RL training test passed!")
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())