82 lines
2.1 KiB
Python
82 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick test script for CNN and RL training pipelines
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from core.config import setup_logging
|
|
from core.data_provider import DataProvider
|
|
from training.cnn_trainer import CNNTrainer
|
|
from training.rl_trainer import RLTrainer
|
|
|
|
def test_cnn_training():
|
|
"""Test CNN training with small dataset"""
|
|
print("\n=== Testing CNN Training ===")
|
|
|
|
# Setup
|
|
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
|
|
trainer = CNNTrainer(data_provider)
|
|
|
|
# Configure for quick test
|
|
trainer.num_samples = 1000 # Small dataset
|
|
trainer.num_epochs = 5 # Few epochs
|
|
trainer.batch_size = 32
|
|
|
|
# Train
|
|
results = trainer.train(['ETH/USDT'], save_path='test_models/test_cnn.pt')
|
|
|
|
print(f"CNN Training completed!")
|
|
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
|
print(f" Training time: {results['total_time']:.2f}s")
|
|
|
|
return True
|
|
|
|
def test_rl_training():
|
|
"""Test RL training with small dataset"""
|
|
print("\n=== Testing RL Training ===")
|
|
|
|
# Setup
|
|
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
|
|
trainer = RLTrainer(data_provider)
|
|
|
|
# Configure for quick test
|
|
trainer.num_episodes = 10
|
|
trainer.max_steps_per_episode = 100
|
|
trainer.evaluation_frequency = 5
|
|
|
|
# Train
|
|
results = trainer.train(save_path='test_models/test_rl.pt')
|
|
|
|
print(f"RL Training completed!")
|
|
print(f" Best reward: {results['best_reward']:.4f}")
|
|
print(f" Final balance: ${results['best_balance']:.2f}")
|
|
|
|
return True
|
|
|
|
def main():
|
|
setup_logging()
|
|
|
|
try:
|
|
# Test CNN
|
|
if test_cnn_training():
|
|
print("✅ CNN training test passed!")
|
|
|
|
# Test RL
|
|
if test_rl_training():
|
|
print("✅ RL training test passed!")
|
|
|
|
print("\n✅ All tests passed!")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |