# #!/usr/bin/env python3 # """ # Enhanced RL Training Launcher with Real Data Integration # This script launches the comprehensive RL training system that uses: # - Real-time tick data (300s window for momentum detection) # - Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) # - BTC reference data for correlation # - CNN hidden features and predictions # - Williams Market Structure pivot points # - Market microstructure analysis # The RL model will receive ~13,400 features instead of the previous ~100 basic features. # """ # import asyncio # import logging # import time # import signal # import sys # from datetime import datetime, timedelta # from pathlib import Path # from typing import Dict, List, Optional # # Configure logging # logging.basicConfig( # level=logging.INFO, # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # handlers=[ # logging.FileHandler('enhanced_rl_training.log'), # logging.StreamHandler(sys.stdout) # ] # ) # logger = logging.getLogger(__name__) # # Import our enhanced components # from core.config import get_config # from core.data_provider import DataProvider # from core.enhanced_orchestrator import EnhancedTradingOrchestrator # from training.enhanced_rl_trainer import EnhancedRLTrainer # from training.enhanced_rl_state_builder import EnhancedRLStateBuilder # from training.williams_market_structure import WilliamsMarketStructure # from training.cnn_rl_bridge import CNNRLBridge # class EnhancedRLTrainingSystem: # """Comprehensive RL training system with real data integration""" # def __init__(self): # """Initialize the enhanced RL training system""" # self.config = get_config() # self.running = False # self.data_provider = None # self.orchestrator = None # self.rl_trainer = None # # Performance tracking # self.training_stats = { # 'training_sessions': 0, # 'total_experiences': 0, # 'avg_state_size': 0, # 'data_quality_score': 0.0, # 'last_training_time': None # } # logger.info("Enhanced RL Training System initialized") # logger.info("Features:") # logger.info("- Real-time tick data processing (300s window)") # logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)") # logger.info("- BTC correlation analysis") # logger.info("- CNN feature integration") # logger.info("- Williams Market Structure pivot points") # logger.info("- ~13,400 feature state vector (vs previous ~100)") # async def initialize(self): # """Initialize all components""" # try: # logger.info("Initializing enhanced RL training components...") # # Initialize data provider with real-time streaming # logger.info("Setting up data provider with real-time streaming...") # self.data_provider = DataProvider( # symbols=self.config.symbols, # timeframes=self.config.timeframes # ) # # Start real-time data streaming # await self.data_provider.start_real_time_streaming() # logger.info("Real-time data streaming started") # # Wait for initial data collection # logger.info("Collecting initial market data...") # await asyncio.sleep(30) # Allow 30 seconds for data collection # # Initialize enhanced orchestrator # logger.info("Initializing enhanced orchestrator...") # self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) # # Initialize enhanced RL trainer with comprehensive state building # logger.info("Initializing enhanced RL trainer...") # self.rl_trainer = EnhancedRLTrainer( # config=self.config, # orchestrator=self.orchestrator # ) # # Verify data availability # data_status = await self._verify_data_availability() # if not data_status['has_sufficient_data']: # logger.warning("Insufficient data detected. Continuing with limited training.") # logger.warning(f"Data status: {data_status}") # else: # logger.info("Sufficient data available for comprehensive RL training") # logger.info(f"Tick data: {data_status['tick_count']} ticks") # logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars") # self.running = True # logger.info("Enhanced RL training system initialized successfully") # except Exception as e: # logger.error(f"Error during initialization: {e}") # raise # async def _verify_data_availability(self) -> Dict[str, any]: # """Verify that we have sufficient data for training""" # try: # data_status = { # 'has_sufficient_data': False, # 'tick_count': 0, # 'ohlcv_bars': 0, # 'symbols_with_data': [], # 'missing_data': [] # } # for symbol in self.config.symbols: # # Check tick data # recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100) # tick_count = len(recent_ticks) # # Check OHLCV data # ohlcv_bars = 0 # for timeframe in ['1s', '1m', '1h', '1d']: # try: # df = self.data_provider.get_historical_data( # symbol=symbol, # timeframe=timeframe, # limit=50, # refresh=True # ) # if df is not None and not df.empty: # ohlcv_bars += len(df) # except Exception as e: # logger.warning(f"Error checking {timeframe} data for {symbol}: {e}") # data_status['tick_count'] += tick_count # data_status['ohlcv_bars'] += ohlcv_bars # if tick_count >= 50 and ohlcv_bars >= 100: # data_status['symbols_with_data'].append(symbol) # else: # data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars") # # Consider data sufficient if we have at least one symbol with good data # data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0 # return data_status # except Exception as e: # logger.error(f"Error verifying data availability: {e}") # return {'has_sufficient_data': False, 'error': str(e)} # async def run_training_loop(self): # """Run the main training loop with real data""" # logger.info("Starting enhanced RL training loop...") # training_cycle = 0 # last_state_size_log = time.time() # try: # while self.running: # training_cycle += 1 # cycle_start_time = time.time() # logger.info(f"Training cycle {training_cycle} started") # # Get comprehensive market states with real data # market_states = await self._get_comprehensive_market_states() # if not market_states: # logger.warning("No market states available. Waiting for data...") # await asyncio.sleep(60) # continue # # Train RL agents with comprehensive states # training_results = await self._train_rl_agents(market_states) # # Update performance tracking # self._update_training_stats(training_results, market_states) # # Log training progress # cycle_duration = time.time() - cycle_start_time # logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") # # Log state size periodically # if time.time() - last_state_size_log > 300: # Every 5 minutes # self._log_state_size_info(market_states) # last_state_size_log = time.time() # # Save models periodically # if training_cycle % 10 == 0: # await self._save_training_progress() # # Wait before next training cycle # await asyncio.sleep(300) # Train every 5 minutes # except Exception as e: # logger.error(f"Error in training loop: {e}") # raise # async def _get_comprehensive_market_states(self) -> Dict[str, any]: # """Get comprehensive market states with all required data""" # try: # # Get market states from orchestrator # universal_stream = self.orchestrator.universal_adapter.get_universal_stream() # market_states = await self.orchestrator._get_all_market_states_universal(universal_stream) # # Verify data quality # quality_score = self._calculate_data_quality(market_states) # self.training_stats['data_quality_score'] = quality_score # if quality_score < 0.5: # logger.warning(f"Low data quality detected: {quality_score:.2f}") # return market_states # except Exception as e: # logger.error(f"Error getting comprehensive market states: {e}") # return {} # def _calculate_data_quality(self, market_states: Dict[str, any]) -> float: # """Calculate data quality score based on available data""" # try: # if not market_states: # return 0.0 # total_score = 0.0 # total_symbols = len(market_states) # for symbol, state in market_states.items(): # symbol_score = 0.0 # # Score based on tick data availability # if hasattr(state, 'raw_ticks') and state.raw_ticks: # tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks # symbol_score += tick_score * 0.3 # # Score based on OHLCV data availability # if hasattr(state, 'ohlcv_data') and state.ohlcv_data: # ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes # symbol_score += min(ohlcv_score, 1.0) * 0.4 # # Score based on CNN features # if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: # symbol_score += 0.15 # # Score based on pivot points # if hasattr(state, 'pivot_points') and state.pivot_points: # symbol_score += 0.15 # total_score += symbol_score # return total_score / total_symbols if total_symbols > 0 else 0.0 # except Exception as e: # logger.warning(f"Error calculating data quality: {e}") # return 0.5 # Default to medium quality # async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]: # """Train RL agents with comprehensive market states""" # try: # training_results = { # 'symbols_trained': [], # 'total_experiences': 0, # 'avg_state_size': 0, # 'training_errors': [] # } # for symbol, market_state in market_states.items(): # try: # # Convert market state to comprehensive RL state # rl_state = self.rl_trainer._market_state_to_rl_state(market_state) # if rl_state is not None and len(rl_state) > 0: # # Record state size # training_results['avg_state_size'] += len(rl_state) # # Simulate trading action for experience generation # # In real implementation, this would be actual trading decisions # action = self._simulate_trading_action(symbol, rl_state) # # Generate reward based on market outcome # reward = self._calculate_training_reward(symbol, market_state, action) # # Add experience to RL agent # agent = self.rl_trainer.agents.get(symbol) # if agent: # # Create next state (would be actual next market state in real scenario) # next_state = rl_state # Simplified for now # agent.remember( # state=rl_state, # action=action, # reward=reward, # next_state=next_state, # done=False # ) # # Train agent if enough experiences # if len(agent.replay_buffer) >= agent.batch_size: # loss = agent.replay() # if loss is not None: # logger.debug(f"Agent {symbol} training loss: {loss:.4f}") # training_results['symbols_trained'].append(symbol) # training_results['total_experiences'] += 1 # except Exception as e: # error_msg = f"Error training {symbol}: {e}" # logger.warning(error_msg) # training_results['training_errors'].append(error_msg) # # Calculate average state size # if len(training_results['symbols_trained']) > 0: # training_results['avg_state_size'] /= len(training_results['symbols_trained']) # return training_results # except Exception as e: # logger.error(f"Error training RL agents: {e}") # return {'error': str(e)} # def _simulate_trading_action(self, symbol: str, rl_state) -> int: # """Simulate trading action for training (would be real decision in production)""" # # Simple simulation based on state features # if len(rl_state) > 100: # # Use momentum features to decide action # momentum_features = rl_state[:100] # First 100 features assumed to be momentum # avg_momentum = sum(momentum_features) / len(momentum_features) # if avg_momentum > 0.6: # return 1 # BUY # elif avg_momentum < 0.4: # return 2 # SELL # else: # return 0 # HOLD # else: # return 0 # HOLD as default # def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float: # """Calculate training reward based on market state and action""" # try: # # Simple reward calculation based on market conditions # base_reward = 0.0 # # Reward based on volatility alignment # if hasattr(market_state, 'volatility'): # if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility # base_reward += 0.1 # elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility # base_reward += 0.1 # # Reward based on trend alignment # if hasattr(market_state, 'trend_strength'): # if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend # base_reward += 0.2 # elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend # base_reward += 0.2 # return base_reward # except Exception as e: # logger.warning(f"Error calculating reward for {symbol}: {e}") # return 0.0 # def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]): # """Update training statistics""" # self.training_stats['training_sessions'] += 1 # self.training_stats['total_experiences'] += training_results.get('total_experiences', 0) # self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0) # self.training_stats['last_training_time'] = datetime.now() # # Log statistics periodically # if self.training_stats['training_sessions'] % 10 == 0: # logger.info("Training Statistics:") # logger.info(f" Sessions: {self.training_stats['training_sessions']}") # logger.info(f" Total Experiences: {self.training_stats['total_experiences']}") # logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}") # logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}") # def _log_state_size_info(self, market_states: Dict[str, any]): # """Log information about state sizes for debugging""" # for symbol, state in market_states.items(): # info = [] # if hasattr(state, 'raw_ticks'): # info.append(f"ticks: {len(state.raw_ticks)}") # if hasattr(state, 'ohlcv_data'): # total_bars = sum(len(bars) for bars in state.ohlcv_data.values()) # info.append(f"OHLCV bars: {total_bars}") # if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: # info.append("CNN features: available") # if hasattr(state, 'pivot_points') and state.pivot_points: # info.append("pivot points: available") # logger.info(f"{symbol} state data: {', '.join(info)}") # async def _save_training_progress(self): # """Save training progress and models""" # try: # if self.rl_trainer: # self.rl_trainer._save_all_models() # logger.info("Training progress saved") # except Exception as e: # logger.error(f"Error saving training progress: {e}") # async def shutdown(self): # """Graceful shutdown""" # logger.info("Shutting down enhanced RL training system...") # self.running = False # # Save final state # await self._save_training_progress() # # Stop data provider # if self.data_provider: # await self.data_provider.stop_real_time_streaming() # logger.info("Enhanced RL training system shutdown complete") # async def main(): # """Main function to run enhanced RL training""" # system = None # def signal_handler(signum, frame): # logger.info("Received shutdown signal") # if system: # asyncio.create_task(system.shutdown()) # # Set up signal handlers # signal.signal(signal.SIGINT, signal_handler) # signal.signal(signal.SIGTERM, signal_handler) # try: # # Create and initialize the training system # system = EnhancedRLTrainingSystem() # await system.initialize() # logger.info("Enhanced RL Training System is now running...") # logger.info("The RL model now receives ~13,400 features instead of ~100!") # logger.info("Press Ctrl+C to stop") # # Run the training loop # await system.run_training_loop() # except KeyboardInterrupt: # logger.info("Training interrupted by user") # except Exception as e: # logger.error(f"Error in main training loop: {e}") # raise # finally: # if system: # await system.shutdown() # if __name__ == "__main__": # asyncio.run(main())