gogo2/test_cnn_only.py
2025-05-24 02:42:11 +03:00

54 lines
1.5 KiB
Python

#!/usr/bin/env python3
"""
Quick CNN training test
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from core.config import setup_logging
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
def main():
setup_logging()
print("Setting up CNN training test...")
# Setup
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m', '1h'])
trainer = CNNTrainer(data_provider)
# Configure for quick test
trainer.num_samples = 500 # Very small dataset
trainer.num_epochs = 2 # Just 2 epochs
trainer.batch_size = 16
trainer.timeframes = ['1m', '5m', '1h'] # Skip 1s for now
trainer.n_timeframes = 3
print(f"Configuration:")
print(f" Samples: {trainer.num_samples}")
print(f" Epochs: {trainer.num_epochs}")
print(f" Batch size: {trainer.batch_size}")
print(f" Timeframes: {trainer.timeframes}")
# Train
try:
results = trainer.train(['ETH/USDT'], save_path='test_models/quick_cnn.pt')
print(f"\n✅ CNN Training completed!")
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
print(f" Total epochs: {results['total_epochs']}")
print(f" Training time: {results['total_time']:.2f}s")
except Exception as e:
print(f"\n❌ Training failed: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())