new__training
This commit is contained in:
190
main_clean.py
190
main_clean.py
@ -83,7 +83,7 @@ def run_data_test():
|
||||
raise
|
||||
|
||||
def run_cnn_training():
|
||||
"""Train CNN models only"""
|
||||
"""Train CNN models only with comprehensive pipeline"""
|
||||
try:
|
||||
logger.info("Starting CNN Training Mode...")
|
||||
|
||||
@ -92,85 +92,185 @@ def run_cnn_training():
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '1h', '4h']
|
||||
)
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
|
||||
logger.info("Creating CNN training data...")
|
||||
# Import and create CNN trainer
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
trainer = CNNTrainer(data_provider)
|
||||
|
||||
# Prepare multi-timeframe, multi-symbol feature matrices
|
||||
# Configure training
|
||||
trainer.num_samples = 20000 # Training samples
|
||||
trainer.batch_size = 64
|
||||
trainer.num_epochs = 100
|
||||
trainer.patience = 15
|
||||
|
||||
# Train the model
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
timeframes = ['1m', '5m', '1h', '4h']
|
||||
save_path = 'models/cnn/scalping_cnn_trained.pt'
|
||||
|
||||
for symbol in symbols:
|
||||
logger.info(f"Preparing CNN data for {symbol}...")
|
||||
|
||||
feature_matrix = data_provider.get_feature_matrix(
|
||||
symbol, timeframes, window_size=50
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
logger.info(f"CNN training data ready for {symbol}: {feature_matrix.shape}")
|
||||
# Here you would integrate with your CNN training module
|
||||
# Example: cnn_model.train(feature_matrix, labels)
|
||||
else:
|
||||
logger.warning(f"Could not prepare CNN data for {symbol}")
|
||||
logger.info(f"Training CNN for symbols: {symbols}")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
|
||||
logger.info("CNN training preparation completed!")
|
||||
logger.info("Note: Integrate this with your actual CNN training module")
|
||||
results = trainer.train(symbols, save_path)
|
||||
|
||||
# Log results
|
||||
logger.info("CNN Training Results:")
|
||||
logger.info(f" Best validation accuracy: {results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best validation loss: {results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total epochs: {results['total_epochs']}")
|
||||
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
||||
|
||||
# Plot training history
|
||||
try:
|
||||
plot_path = 'models/cnn/training_history.png'
|
||||
trainer.plot_training_history(plot_path)
|
||||
logger.info(f"Training plots saved to: {plot_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save training plots: {e}")
|
||||
|
||||
# Evaluate on test data
|
||||
try:
|
||||
logger.info("Evaluating CNN on test data...")
|
||||
test_symbols = ['ETH/USDT'] # Use subset for testing
|
||||
eval_results = trainer.evaluate_model(test_symbols)
|
||||
|
||||
logger.info("CNN Evaluation Results:")
|
||||
logger.info(f" Test accuracy: {eval_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test loss: {eval_results['test_loss']:.4f}")
|
||||
logger.info(f" Average confidence: {eval_results['avg_confidence']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not run evaluation: {e}")
|
||||
|
||||
logger.info("CNN training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_rl_training():
|
||||
"""Train RL agents only"""
|
||||
"""Train RL agents only with comprehensive pipeline"""
|
||||
try:
|
||||
logger.info("Starting RL Training Mode...")
|
||||
|
||||
# Initialize components for RL
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT'],
|
||||
timeframes=['1s', '1m', '5m'] # Focus on short timeframes for RL
|
||||
timeframes=['1s', '1m', '5m', '1h'] # Focus on scalping timeframes
|
||||
)
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
|
||||
logger.info("Setting up RL environment...")
|
||||
# Import and create RL trainer
|
||||
from training.rl_trainer import RLTrainer
|
||||
trainer = RLTrainer(data_provider)
|
||||
|
||||
# Get scalping data for RL training
|
||||
scalping_data = data_provider.get_latest_candles('ETH/USDT', '1s', limit=1000)
|
||||
# Configure training
|
||||
trainer.num_episodes = 1000
|
||||
trainer.max_steps_per_episode = 1000
|
||||
trainer.evaluation_frequency = 50
|
||||
trainer.save_frequency = 100
|
||||
|
||||
if not scalping_data.empty:
|
||||
logger.info(f"RL training data ready: {len(scalping_data)} 1s candles")
|
||||
logger.info(f"Price range: ${scalping_data['close'].min():.2f} - ${scalping_data['close'].max():.2f}")
|
||||
# Train the agent
|
||||
save_path = 'models/rl/scalping_agent_trained.pt'
|
||||
|
||||
logger.info(f"Training RL agent for scalping")
|
||||
logger.info(f"Will save to: {save_path}")
|
||||
|
||||
results = trainer.train(save_path)
|
||||
|
||||
# Log results
|
||||
logger.info("RL Training Results:")
|
||||
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
||||
logger.info(f" Best balance: ${results['best_balance']:.2f}")
|
||||
logger.info(f" Total episodes: {results['total_episodes']}")
|
||||
logger.info(f" Training time: {results['total_time']:.2f} seconds")
|
||||
logger.info(f" Final epsilon: {results['agent_config']['epsilon_final']:.4f}")
|
||||
|
||||
# Final evaluation results
|
||||
final_eval = results['final_evaluation']
|
||||
logger.info("Final Evaluation:")
|
||||
logger.info(f" Win rate: {final_eval['win_rate']:.2%}")
|
||||
logger.info(f" Average PnL: {final_eval['avg_pnl_percentage']:.2f}%")
|
||||
logger.info(f" Average trades: {final_eval['avg_trades']:.1f}")
|
||||
|
||||
# Plot training progress
|
||||
try:
|
||||
plot_path = 'models/rl/training_progress.png'
|
||||
trainer.plot_training_progress(plot_path)
|
||||
logger.info(f"Training plots saved to: {plot_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save training plots: {e}")
|
||||
|
||||
# Backtest the trained agent
|
||||
try:
|
||||
logger.info("Backtesting trained agent...")
|
||||
backtest_results = trainer.backtest_agent(save_path, test_episodes=50)
|
||||
|
||||
# Here you would integrate with your RL training module
|
||||
# Example: rl_agent.train(environment_data=scalping_data)
|
||||
else:
|
||||
logger.warning("No scalping data available for RL training")
|
||||
analysis = backtest_results['analysis']
|
||||
logger.info("Backtest Results:")
|
||||
logger.info(f" Win rate: {analysis['win_rate']:.2%}")
|
||||
logger.info(f" Average PnL: {analysis['avg_pnl']:.2f}%")
|
||||
logger.info(f" Sharpe ratio: {analysis['sharpe_ratio']:.4f}")
|
||||
logger.info(f" Max drawdown: {analysis['max_drawdown']:.2f}%")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not run backtest: {e}")
|
||||
|
||||
logger.info("RL training preparation completed!")
|
||||
logger.info("Note: Integrate this with your actual RL training module")
|
||||
logger.info("RL training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_combined_training():
|
||||
"""Train both CNN and RL models"""
|
||||
"""Train both CNN and RL models with hybrid approach"""
|
||||
try:
|
||||
logger.info("Starting Combined Training Mode...")
|
||||
logger.info("Starting Hybrid CNN + RL Training Mode...")
|
||||
|
||||
# Run CNN training first
|
||||
logger.info("Phase 1: CNN Training")
|
||||
run_cnn_training()
|
||||
# Initialize data provider
|
||||
data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '5m', '1h', '4h']
|
||||
)
|
||||
|
||||
# Then RL training
|
||||
logger.info("Phase 2: RL Training")
|
||||
run_rl_training()
|
||||
# Import and create hybrid trainer
|
||||
from training.rl_trainer import HybridTrainer
|
||||
trainer = HybridTrainer(data_provider)
|
||||
|
||||
logger.info("Combined training completed!")
|
||||
# Define save paths
|
||||
cnn_save_path = 'models/cnn/hybrid_cnn_trained.pt'
|
||||
rl_save_path = 'models/rl/hybrid_rl_trained.pt'
|
||||
|
||||
# Train hybrid system
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
logger.info(f"Training hybrid system for symbols: {symbols}")
|
||||
|
||||
results = trainer.train_hybrid(symbols, cnn_save_path, rl_save_path)
|
||||
|
||||
# Log results
|
||||
cnn_results = results['cnn_results']
|
||||
rl_results = results['rl_results']
|
||||
|
||||
logger.info("Hybrid Training Results:")
|
||||
logger.info("CNN Phase:")
|
||||
logger.info(f" Best accuracy: {cnn_results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Training time: {cnn_results['total_time']:.2f}s")
|
||||
|
||||
logger.info("RL Phase:")
|
||||
logger.info(f" Best reward: {rl_results['best_reward']:.4f}")
|
||||
logger.info(f" Final balance: ${rl_results['best_balance']:.2f}")
|
||||
logger.info(f" Training time: {rl_results['total_time']:.2f}s")
|
||||
|
||||
logger.info(f"Total training time: {results['total_time']:.2f}s")
|
||||
|
||||
logger.info("Hybrid training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in combined training: {e}")
|
||||
logger.error(f"Error in hybrid training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def run_live_trading():
|
||||
|
Reference in New Issue
Block a user