misc
This commit is contained in:
@ -82,72 +82,50 @@ def run_data_test():
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_cnn_training():
|
||||
"""Train CNN models only with comprehensive pipeline"""
|
||||
def run_cnn_training(config: Config, symbol: str):
|
||||
"""Run CNN training mode with TensorBoard monitoring"""
|
||||
logger.info("Starting CNN Training Mode...")
|
||||
|
||||
# Initialize data provider and trainer
|
||||
data_provider = DataProvider(config)
|
||||
trainer = CNNTrainer(config)
|
||||
|
||||
# Use configured symbols or provided symbol
|
||||
symbols = config.symbols if symbol == "ETH/USDT" else [symbol] + config.symbols
|
||||
save_path = f"models/cnn/scalping_cnn_trained.pt"
|
||||
|
||||
logger.info(f"Training CNN for symbols: {symbols}")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
logger.info(f"🔗 Monitor training: tensorboard --logdir=runs")
|
||||
|
||||
try:
|
||||
logger.info("Starting CNN Training Mode...")
|
||||
# Train model with TensorBoard logging
|
||||
results = trainer.train(symbols, save_path=save_path)
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '1h', '4h']
|
||||
)
|
||||
|
||||
# Import and create CNN trainer
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
trainer = CNNTrainer(data_provider)
|
||||
|
||||
# Configure training
|
||||
trainer.num_samples = 20000 # Training samples
|
||||
trainer.batch_size = 64
|
||||
trainer.num_epochs = 100
|
||||
trainer.patience = 15
|
||||
|
||||
# Train the model
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
save_path = 'models/cnn/scalping_cnn_trained.pt'
|
||||
|
||||
logger.info(f"Training CNN for symbols: {symbols}")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
|
||||
results = trainer.train(symbols, save_path)
|
||||
|
||||
# Log results
|
||||
logger.info("CNN Training Results:")
|
||||
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total epochs: {results['total_epochs']}")
|
||||
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
||||
logger.info(f" Training time: {results['training_time']:.2f} seconds")
|
||||
logger.info(f" TensorBoard logs: {results['tensorboard_dir']}")
|
||||
|
||||
# Plot training history
|
||||
try:
|
||||
plot_path = 'models/cnn/training_history.png'
|
||||
trainer.plot_training_history(plot_path)
|
||||
logger.info(f"Training plots saved to: {plot_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save training plots: {e}")
|
||||
logger.info(f"📊 View training progress: tensorboard --logdir=runs")
|
||||
logger.info("Evaluating CNN on test data...")
|
||||
|
||||
# Evaluate on test data
|
||||
try:
|
||||
logger.info("Evaluating CNN on test data...")
|
||||
test_symbols = ['ETH/USDT'] # Use subset for testing
|
||||
eval_results = trainer.evaluate_model(test_symbols)
|
||||
|
||||
logger.info("CNN Evaluation Results:")
|
||||
logger.info(f" Test accuracy: {eval_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test loss: {eval_results['test_loss']:.4f}")
|
||||
logger.info(f" Average confidence: {eval_results['avg_confidence']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not run evaluation: {e}")
|
||||
# Quick evaluation on same symbols
|
||||
test_results = trainer.evaluate(symbols[:1]) # Use first symbol for quick test
|
||||
logger.info("CNN Evaluation Results:")
|
||||
logger.info(f" Test accuracy: {test_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test loss: {test_results['test_loss']:.4f}")
|
||||
logger.info(f" Average confidence: {test_results['avg_confidence']:.4f}")
|
||||
|
||||
logger.info("CNN training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"CNN training failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
trainer.close_tensorboard()
|
||||
|
||||
def run_rl_training():
|
||||
"""Train RL agents only with comprehensive pipeline"""
|
||||
@ -404,7 +382,7 @@ async def main():
|
||||
if args.mode == 'test':
|
||||
run_data_test()
|
||||
elif args.mode == 'cnn':
|
||||
run_cnn_training()
|
||||
run_cnn_training(get_config(), args.symbol)
|
||||
elif args.mode == 'rl':
|
||||
run_rl_training()
|
||||
elif args.mode == 'train':
|
||||
|
Reference in New Issue
Block a user