#!/usr/bin/env python3 """ Enhanced RL Training Launcher with Real Data Integration This script launches the comprehensive RL training system that uses: - Real-time tick data (300s window for momentum detection) - Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) - BTC reference data for correlation - CNN hidden features and predictions - Williams Market Structure pivot points - Market microstructure analysis The RL model will receive ~13,400 features instead of the previous ~100 basic features. Training metrics are automatically logged to TensorBoard for visualization. """ import asyncio import logging import time import signal import sys from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('enhanced_rl_training.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) # Import our enhanced components from core.config import get_config from core.data_provider import DataProvider from core.enhanced_orchestrator import EnhancedTradingOrchestrator from training.enhanced_rl_trainer import EnhancedRLTrainer from training.enhanced_rl_state_builder import EnhancedRLStateBuilder from training.williams_market_structure import WilliamsMarketStructure from training.cnn_rl_bridge import CNNRLBridge from utils.tensorboard_logger import TensorBoardLogger class EnhancedRLTrainingSystem: """Comprehensive RL training system with real data integration""" def __init__(self): """Initialize the enhanced RL training system""" self.config = get_config() self.running = False self.data_provider = None self.orchestrator = None self.rl_trainer = None # Performance tracking self.training_stats = { 'training_sessions': 0, 'total_experiences': 0, 'avg_state_size': 0, 'data_quality_score': 0.0, 'last_training_time': None } # Initialize TensorBoard logger experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}" self.tb_logger = TensorBoardLogger( log_dir="runs", experiment_name=experiment_name, enabled=True ) logger.info("Enhanced RL Training System initialized") logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}") logger.info("Features:") logger.info("- Real-time tick data processing (300s window)") logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)") logger.info("- BTC correlation analysis") logger.info("- CNN feature integration") logger.info("- Williams Market Structure pivot points") logger.info("- ~13,400 feature state vector (vs previous ~100)") # async def initialize(self): # """Initialize all components""" # try: # logger.info("Initializing enhanced RL training components...") # # Initialize data provider with real-time streaming # logger.info("Setting up data provider with real-time streaming...") # self.data_provider = DataProvider( # symbols=self.config.symbols, # timeframes=self.config.timeframes # ) # # Start real-time data streaming # await self.data_provider.start_real_time_streaming() # logger.info("Real-time data streaming started") # # Wait for initial data collection # logger.info("Collecting initial market data...") # await asyncio.sleep(30) # Allow 30 seconds for data collection # # Initialize enhanced orchestrator # logger.info("Initializing enhanced orchestrator...") # self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) # # Initialize enhanced RL trainer with comprehensive state building # logger.info("Initializing enhanced RL trainer...") # self.rl_trainer = EnhancedRLTrainer( # config=self.config, # orchestrator=self.orchestrator # ) # # Verify data availability # data_status = await self._verify_data_availability() # if not data_status['has_sufficient_data']: # logger.warning("Insufficient data detected. Continuing with limited training.") # logger.warning(f"Data status: {data_status}") # else: # logger.info("Sufficient data available for comprehensive RL training") # logger.info(f"Tick data: {data_status['tick_count']} ticks") # logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars") # self.running = True # logger.info("Enhanced RL training system initialized successfully") # except Exception as e: # logger.error(f"Error during initialization: {e}") # raise # async def _verify_data_availability(self) -> Dict[str, any]: # """Verify that we have sufficient data for training""" # try: # data_status = { # 'has_sufficient_data': False, # 'tick_count': 0, # 'ohlcv_bars': 0, # 'symbols_with_data': [], # 'missing_data': [] # } # for symbol in self.config.symbols: # # Check tick data # recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100) # tick_count = len(recent_ticks) # # Check OHLCV data # ohlcv_bars = 0 # for timeframe in ['1s', '1m', '1h', '1d']: # try: # df = self.data_provider.get_historical_data( # symbol=symbol, # timeframe=timeframe, # limit=50, # refresh=True # ) # if df is not None and not df.empty: # ohlcv_bars += len(df) # except Exception as e: # logger.warning(f"Error checking {timeframe} data for {symbol}: {e}") # data_status['tick_count'] += tick_count # data_status['ohlcv_bars'] += ohlcv_bars # if tick_count >= 50 and ohlcv_bars >= 100: # data_status['symbols_with_data'].append(symbol) # else: # data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars") # # Consider data sufficient if we have at least one symbol with good data # data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0 # return data_status # except Exception as e: # logger.error(f"Error verifying data availability: {e}") # return {'has_sufficient_data': False, 'error': str(e)} # async def run_training_loop(self): # """Run the main training loop with real data""" # logger.info("Starting enhanced RL training loop...") # training_cycle = 0 # last_state_size_log = time.time() # try: # while self.running: # training_cycle += 1 # cycle_start_time = time.time() # logger.info(f"Training cycle {training_cycle} started") # # Get comprehensive market states with real data # market_states = await self._get_comprehensive_market_states() # if not market_states: # logger.warning("No market states available. Waiting for data...") # await asyncio.sleep(60) # continue # # Train RL agents with comprehensive states # training_results = await self._train_rl_agents(market_states) # # Update performance tracking # self._update_training_stats(training_results, market_states) # # Log training progress # cycle_duration = time.time() - cycle_start_time # logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") # # Log state size periodically # if time.time() - last_state_size_log > 300: # Every 5 minutes # self._log_state_size_info(market_states) # last_state_size_log = time.time() # # Save models periodically # if training_cycle % 10 == 0: # await self._save_training_progress() # # Wait before next training cycle # await asyncio.sleep(300) # Train every 5 minutes # except Exception as e: # logger.error(f"Error in training loop: {e}") # raise # async def _get_comprehensive_market_states(self) -> Dict[str, any]: # """Get comprehensive market states with all required data""" # try: # # Get market states from orchestrator # universal_stream = self.orchestrator.universal_adapter.get_universal_stream() # market_states = await self.orchestrator._get_all_market_states_universal(universal_stream) # # Verify data quality # quality_score = self._calculate_data_quality(market_states) # self.training_stats['data_quality_score'] = quality_score # if quality_score < 0.5: # logger.warning(f"Low data quality detected: {quality_score:.2f}") # return market_states # except Exception as e: # logger.error(f"Error getting comprehensive market states: {e}") # return {} # def _calculate_data_quality(self, market_states: Dict[str, any]) -> float: # """Calculate data quality score based on available data""" # try: # if not market_states: # return 0.0 # total_score = 0.0 # total_symbols = len(market_states) # for symbol, state in market_states.items(): # symbol_score = 0.0 # # Score based on tick data availability # if hasattr(state, 'raw_ticks') and state.raw_ticks: # tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks # symbol_score += tick_score * 0.3 # # Score based on OHLCV data availability # if hasattr(state, 'ohlcv_data') and state.ohlcv_data: # ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes # symbol_score += min(ohlcv_score, 1.0) * 0.4 # # Score based on CNN features # if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: # symbol_score += 0.15 # # Score based on pivot points # if hasattr(state, 'pivot_points') and state.pivot_points: # symbol_score += 0.15 # total_score += symbol_score # return total_score / total_symbols if total_symbols > 0 else 0.0 # except Exception as e: # logger.warning(f"Error calculating data quality: {e}") # return 0.5 # Default to medium quality async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]: """Train RL agents with comprehensive market states""" try: training_results = { 'symbols_trained': [], 'total_experiences': 0, 'avg_state_size': 0, 'training_errors': [], 'losses': {}, 'rewards': {} } for symbol, market_state in market_states.items(): try: # Convert market state to comprehensive RL state rl_state = self.rl_trainer._market_state_to_rl_state(market_state) if rl_state is not None and len(rl_state) > 0: # Record state size state_size = len(rl_state) training_results['avg_state_size'] += state_size # Log state size to TensorBoard self.tb_logger.log_scalar( f'State/{symbol}/Size', state_size, self.training_stats['training_sessions'] ) # Simulate trading action for experience generation # In real implementation, this would be actual trading decisions action = self._simulate_trading_action(symbol, rl_state) # Generate reward based on market outcome reward = self._calculate_training_reward(symbol, market_state, action) # Store reward for TensorBoard logging training_results['rewards'][symbol] = reward # Log action and reward to TensorBoard self.tb_logger.log_scalars(f'Actions/{symbol}', { 'action': action, 'reward': reward }, self.training_stats['training_sessions']) # Add experience to RL agent agent = self.rl_trainer.agents.get(symbol) if agent: # Create next state (would be actual next market state in real scenario) next_state = rl_state # Simplified for now agent.remember( state=rl_state, action=action, reward=reward, next_state=next_state, done=False ) # Train agent if enough experiences if len(agent.replay_buffer) >= agent.batch_size: loss = agent.replay() if loss is not None: logger.debug(f"Agent {symbol} training loss: {loss:.4f}") # Store loss for TensorBoard logging training_results['losses'][symbol] = loss # Log loss to TensorBoard self.tb_logger.log_scalar( f'Training/{symbol}/Loss', loss, self.training_stats['training_sessions'] ) training_results['symbols_trained'].append(symbol) training_results['total_experiences'] += 1 except Exception as e: error_msg = f"Error training {symbol}: {e}" logger.warning(error_msg) training_results['training_errors'].append(error_msg) # Calculate average state size if len(training_results['symbols_trained']) > 0: training_results['avg_state_size'] /= len(training_results['symbols_trained']) # Log overall training metrics to TensorBoard self.tb_logger.log_scalars('Training/Overall', { 'symbols_trained': len(training_results['symbols_trained']), 'experiences': training_results['total_experiences'], 'avg_state_size': training_results['avg_state_size'], 'errors': len(training_results['training_errors']) }, self.training_stats['training_sessions']) return training_results except Exception as e: logger.error(f"Error training RL agents: {e}") return {'error': str(e)} # def _simulate_trading_action(self, symbol: str, rl_state) -> int: # """Simulate trading action for training (would be real decision in production)""" # # Simple simulation based on state features # if len(rl_state) > 100: # # Use momentum features to decide action # momentum_features = rl_state[:100] # First 100 features assumed to be momentum # avg_momentum = sum(momentum_features) / len(momentum_features) # if avg_momentum > 0.6: # return 1 # BUY # elif avg_momentum < 0.4: # return 2 # SELL # else: # return 0 # HOLD # else: # return 0 # HOLD as default # def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float: # """Calculate training reward based on market state and action""" # try: # # Simple reward calculation based on market conditions # base_reward = 0.0 # # Reward based on volatility alignment # if hasattr(market_state, 'volatility'): # if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility # base_reward += 0.1 # elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility # base_reward += 0.1 # # Reward based on trend alignment # if hasattr(market_state, 'trend_strength'): # if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend # base_reward += 0.2 # elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend # base_reward += 0.2 # return base_reward # except Exception as e: # logger.warning(f"Error calculating reward for {symbol}: {e}") # return 0.0 # def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]): # """Update training statistics""" # self.training_stats['training_sessions'] += 1 # self.training_stats['total_experiences'] += training_results.get('total_experiences', 0) # self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0) # self.training_stats['last_training_time'] = datetime.now() # # Log statistics periodically # if self.training_stats['training_sessions'] % 10 == 0: # logger.info("Training Statistics:") # logger.info(f" Sessions: {self.training_stats['training_sessions']}") # logger.info(f" Total Experiences: {self.training_stats['total_experiences']}") # logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}") # logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}") # def _log_state_size_info(self, market_states: Dict[str, any]): # """Log information about state sizes for debugging""" # for symbol, state in market_states.items(): # info = [] # if hasattr(state, 'raw_ticks'): # info.append(f"ticks: {len(state.raw_ticks)}") # if hasattr(state, 'ohlcv_data'): # total_bars = sum(len(bars) for bars in state.ohlcv_data.values()) # info.append(f"OHLCV bars: {total_bars}") # if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: # info.append("CNN features: available") # if hasattr(state, 'pivot_points') and state.pivot_points: # info.append("pivot points: available") # logger.info(f"{symbol} state data: {', '.join(info)}") # async def _save_training_progress(self): # """Save training progress and models""" # try: # if self.rl_trainer: # self.rl_trainer._save_all_models() # logger.info("Training progress saved") # except Exception as e: # logger.error(f"Error saving training progress: {e}") # async def shutdown(self): # """Graceful shutdown""" # logger.info("Shutting down enhanced RL training system...") # self.running = False # # Save final state # await self._save_training_progress() # # Stop data provider # if self.data_provider: # await self.data_provider.stop_real_time_streaming() # logger.info("Enhanced RL training system shutdown complete") # async def main(): # """Main function to run enhanced RL training""" # system = None # def signal_handler(signum, frame): # logger.info("Received shutdown signal") # if system: # asyncio.create_task(system.shutdown()) # # Set up signal handlers # signal.signal(signal.SIGINT, signal_handler) # signal.signal(signal.SIGTERM, signal_handler) # try: # # Create and initialize the training system # system = EnhancedRLTrainingSystem() # await system.initialize() # logger.info("Enhanced RL Training System is now running...") # logger.info("The RL model now receives ~13,400 features instead of ~100!") # logger.info("Press Ctrl+C to stop") # # Run the training loop # await system.run_training_loop() # except KeyboardInterrupt: # logger.info("Training interrupted by user") # except Exception as e: # logger.error(f"Error in main training loop: {e}") # raise # finally: # if system: # await system.shutdown() # if __name__ == "__main__": # asyncio.run(main())