65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick CNN Training Test - Real Market Data Only
|
|
|
|
This script tests CNN training with a small dataset for quick validation.
|
|
All training metrics are logged to TensorBoard for real-time monitoring.
|
|
"""
|
|
|
|
import logging
|
|
from core.config import setup_logging, get_config
|
|
from core.data_provider import DataProvider
|
|
from training.cnn_trainer import CNNTrainer
|
|
|
|
def main():
|
|
"""Test CNN training with real market data"""
|
|
setup_logging()
|
|
|
|
print("Setting up CNN training test...")
|
|
print("📊 Monitor training: tensorboard --logdir=runs")
|
|
|
|
# Configure test parameters
|
|
config = get_config()
|
|
|
|
# Test configuration
|
|
symbols = ['ETH/USDT']
|
|
timeframes = ['1m', '5m', '1h']
|
|
num_samples = 500
|
|
epochs = 2
|
|
batch_size = 16
|
|
|
|
# Override config for quick test
|
|
config._config['timeframes'] = timeframes # Direct config access
|
|
|
|
trainer = CNNTrainer(config)
|
|
trainer.batch_size = batch_size
|
|
trainer.epochs = epochs
|
|
|
|
print("Configuration:")
|
|
print(f" Symbols: {symbols}")
|
|
print(f" Timeframes: {timeframes}")
|
|
print(f" Samples: {num_samples}")
|
|
print(f" Epochs: {epochs}")
|
|
print(f" Batch size: {batch_size}")
|
|
print(" Data source: REAL market data from exchange APIs")
|
|
|
|
try:
|
|
# Train model with TensorBoard logging
|
|
results = trainer.train(symbols, save_path='test_models/quick_cnn.pt', num_samples=num_samples)
|
|
|
|
print(f"\n✅ CNN Training completed!")
|
|
print(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
|
print(f" Total epochs: {results['total_epochs']}")
|
|
print(f" Training time: {results['training_time']:.2f}s")
|
|
print(f" TensorBoard logs: {results['tensorboard_dir']}")
|
|
print(f"\n📊 View training progress: tensorboard --logdir=runs")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Training failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
trainer.close_tensorboard()
|
|
|
|
if __name__ == "__main__":
|
|
main() |