This commit is contained in:
Dobromir Popov
2025-05-24 09:58:36 +03:00
parent ef71160282
commit 0fe8286787
11 changed files with 1396 additions and 483 deletions

View File

@ -82,72 +82,50 @@ def run_data_test():
logger.error(traceback.format_exc())
raise
def run_cnn_training():
"""Train CNN models only with comprehensive pipeline"""
def run_cnn_training(config: Config, symbol: str):
"""Run CNN training mode with TensorBoard monitoring"""
logger.info("Starting CNN Training Mode...")
# Initialize data provider and trainer
data_provider = DataProvider(config)
trainer = CNNTrainer(config)
# Use configured symbols or provided symbol
symbols = config.symbols if symbol == "ETH/USDT" else [symbol] + config.symbols
save_path = f"models/cnn/scalping_cnn_trained.pt"
logger.info(f"Training CNN for symbols: {symbols}")
logger.info(f"Will save to: {save_path}")
logger.info(f"🔗 Monitor training: tensorboard --logdir=runs")
try:
logger.info("Starting CNN Training Mode...")
# Train model with TensorBoard logging
results = trainer.train(symbols, save_path=save_path)
# Initialize components
data_provider = DataProvider(
symbols=['ETH/USDT', 'BTC/USDT'],
timeframes=['1s', '1m', '5m', '1h', '4h']
)
# Import and create CNN trainer
from training.cnn_trainer import CNNTrainer
trainer = CNNTrainer(data_provider)
# Configure training
trainer.num_samples = 20000 # Training samples
trainer.batch_size = 64
trainer.num_epochs = 100
trainer.patience = 15
# Train the model
symbols = ['ETH/USDT', 'BTC/USDT']
save_path = 'models/cnn/scalping_cnn_trained.pt'
logger.info(f"Training CNN for symbols: {symbols}")
logger.info(f"Will save to: {save_path}")
results = trainer.train(symbols, save_path)
# Log results
logger.info("CNN Training Results:")
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
logger.info(f" Total epochs: {results['total_epochs']}")
logger.info(f" Training time: {results['total_time']:.2f} seconds")
logger.info(f" Training time: {results['training_time']:.2f} seconds")
logger.info(f" TensorBoard logs: {results['tensorboard_dir']}")
# Plot training history
try:
plot_path = 'models/cnn/training_history.png'
trainer.plot_training_history(plot_path)
logger.info(f"Training plots saved to: {plot_path}")
except Exception as e:
logger.warning(f"Could not save training plots: {e}")
logger.info(f"📊 View training progress: tensorboard --logdir=runs")
logger.info("Evaluating CNN on test data...")
# Evaluate on test data
try:
logger.info("Evaluating CNN on test data...")
test_symbols = ['ETH/USDT'] # Use subset for testing
eval_results = trainer.evaluate_model(test_symbols)
logger.info("CNN Evaluation Results:")
logger.info(f" Test accuracy: {eval_results['test_accuracy']:.4f}")
logger.info(f" Test loss: {eval_results['test_loss']:.4f}")
logger.info(f" Average confidence: {eval_results['avg_confidence']:.4f}")
except Exception as e:
logger.warning(f"Could not run evaluation: {e}")
# Quick evaluation on same symbols
test_results = trainer.evaluate(symbols[:1]) # Use first symbol for quick test
logger.info("CNN Evaluation Results:")
logger.info(f" Test accuracy: {test_results['test_accuracy']:.4f}")
logger.info(f" Test loss: {test_results['test_loss']:.4f}")
logger.info(f" Average confidence: {test_results['avg_confidence']:.4f}")
logger.info("CNN training completed successfully!")
except Exception as e:
logger.error(f"Error in CNN training: {e}")
import traceback
logger.error(traceback.format_exc())
logger.error(f"CNN training failed: {e}")
raise
finally:
trainer.close_tensorboard()
def run_rl_training():
"""Train RL agents only with comprehensive pipeline"""
@ -404,7 +382,7 @@ async def main():
if args.mode == 'test':
run_data_test()
elif args.mode == 'cnn':
run_cnn_training()
run_cnn_training(get_config(), args.symbol)
elif args.mode == 'rl':
run_rl_training()
elif args.mode == 'train':