54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick CNN training test
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from core.config import setup_logging
|
|
from core.data_provider import DataProvider
|
|
from training.cnn_trainer import CNNTrainer
|
|
|
|
def main():
|
|
setup_logging()
|
|
|
|
print("Setting up CNN training test...")
|
|
|
|
# Setup
|
|
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
|
|
trainer = CNNTrainer(data_provider)
|
|
|
|
# Configure for quick test
|
|
trainer.num_samples = 500 # Very small dataset
|
|
trainer.num_epochs = 2 # Just 2 epochs
|
|
trainer.batch_size = 16
|
|
trainer.timeframes = ['1m', '5m', '1h'] # Skip 1s for now
|
|
trainer.n_timeframes = 3
|
|
|
|
print(f"Configuration:")
|
|
print(f" Samples: {trainer.num_samples}")
|
|
print(f" Epochs: {trainer.num_epochs}")
|
|
print(f" Batch size: {trainer.batch_size}")
|
|
print(f" Timeframes: {trainer.timeframes}")
|
|
|
|
# Train
|
|
try:
|
|
results = trainer.train(['ETH/USDT'], save_path='test_models/quick_cnn.pt')
|
|
|
|
print(f"\n✅ CNN Training completed!")
|
|
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
|
print(f" Total epochs: {results['total_epochs']}")
|
|
print(f" Training time: {results['total_time']:.2f}s")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Training failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |