#!/usr/bin/env python3 """ Enhanced RL Training Launcher with Real Data Integration This script launches the comprehensive RL training system that uses: - Real-time tick data (300s window for momentum detection) - Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) - BTC reference data for correlation - CNN hidden features and predictions - Williams Market Structure pivot points - Market microstructure analysis The RL model will receive ~13,400 features instead of the previous ~100 basic features. """ import asyncio import logging import time import signal import sys from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('enhanced_rl_training.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) # Import our enhanced components from core.config import get_config from core.data_provider import DataProvider from core.enhanced_orchestrator import EnhancedTradingOrchestrator from training.enhanced_rl_trainer import EnhancedRLTrainer from training.enhanced_rl_state_builder import EnhancedRLStateBuilder from training.williams_market_structure import WilliamsMarketStructure from training.cnn_rl_bridge import CNNRLBridge class EnhancedRLTrainingSystem: """Comprehensive RL training system with real data integration""" def __init__(self): """Initialize the enhanced RL training system""" self.config = get_config() self.running = False self.data_provider = None self.orchestrator = None self.rl_trainer = None # Performance tracking self.training_stats = { 'training_sessions': 0, 'total_experiences': 0, 'avg_state_size': 0, 'data_quality_score': 0.0, 'last_training_time': None } logger.info("Enhanced RL Training System initialized") logger.info("Features:") logger.info("- Real-time tick data processing (300s window)") logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)") logger.info("- BTC correlation analysis") logger.info("- CNN feature integration") logger.info("- Williams Market Structure pivot points") logger.info("- ~13,400 feature state vector (vs previous ~100)") async def initialize(self): """Initialize all components""" try: logger.info("Initializing enhanced RL training components...") # Initialize data provider with real-time streaming logger.info("Setting up data provider with real-time streaming...") self.data_provider = DataProvider( symbols=self.config.symbols, timeframes=self.config.timeframes ) # Start real-time data streaming await self.data_provider.start_real_time_streaming() logger.info("Real-time data streaming started") # Wait for initial data collection logger.info("Collecting initial market data...") await asyncio.sleep(30) # Allow 30 seconds for data collection # Initialize enhanced orchestrator logger.info("Initializing enhanced orchestrator...") self.orchestrator = EnhancedTradingOrchestrator(self.data_provider) # Initialize enhanced RL trainer with comprehensive state building logger.info("Initializing enhanced RL trainer...") self.rl_trainer = EnhancedRLTrainer( config=self.config, orchestrator=self.orchestrator ) # Verify data availability data_status = await self._verify_data_availability() if not data_status['has_sufficient_data']: logger.warning("Insufficient data detected. Continuing with limited training.") logger.warning(f"Data status: {data_status}") else: logger.info("Sufficient data available for comprehensive RL training") logger.info(f"Tick data: {data_status['tick_count']} ticks") logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars") self.running = True logger.info("Enhanced RL training system initialized successfully") except Exception as e: logger.error(f"Error during initialization: {e}") raise async def _verify_data_availability(self) -> Dict[str, any]: """Verify that we have sufficient data for training""" try: data_status = { 'has_sufficient_data': False, 'tick_count': 0, 'ohlcv_bars': 0, 'symbols_with_data': [], 'missing_data': [] } for symbol in self.config.symbols: # Check tick data recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100) tick_count = len(recent_ticks) # Check OHLCV data ohlcv_bars = 0 for timeframe in ['1s', '1m', '1h', '1d']: try: df = self.data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, limit=50, refresh=True ) if df is not None and not df.empty: ohlcv_bars += len(df) except Exception as e: logger.warning(f"Error checking {timeframe} data for {symbol}: {e}") data_status['tick_count'] += tick_count data_status['ohlcv_bars'] += ohlcv_bars if tick_count >= 50 and ohlcv_bars >= 100: data_status['symbols_with_data'].append(symbol) else: data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars") # Consider data sufficient if we have at least one symbol with good data data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0 return data_status except Exception as e: logger.error(f"Error verifying data availability: {e}") return {'has_sufficient_data': False, 'error': str(e)} async def run_training_loop(self): """Run the main training loop with real data""" logger.info("Starting enhanced RL training loop...") training_cycle = 0 last_state_size_log = time.time() try: while self.running: training_cycle += 1 cycle_start_time = time.time() logger.info(f"Training cycle {training_cycle} started") # Get comprehensive market states with real data market_states = await self._get_comprehensive_market_states() if not market_states: logger.warning("No market states available. Waiting for data...") await asyncio.sleep(60) continue # Train RL agents with comprehensive states training_results = await self._train_rl_agents(market_states) # Update performance tracking self._update_training_stats(training_results, market_states) # Log training progress cycle_duration = time.time() - cycle_start_time logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s") # Log state size periodically if time.time() - last_state_size_log > 300: # Every 5 minutes self._log_state_size_info(market_states) last_state_size_log = time.time() # Save models periodically if training_cycle % 10 == 0: await self._save_training_progress() # Wait before next training cycle await asyncio.sleep(300) # Train every 5 minutes except Exception as e: logger.error(f"Error in training loop: {e}") raise async def _get_comprehensive_market_states(self) -> Dict[str, any]: """Get comprehensive market states with all required data""" try: # Get market states from orchestrator universal_stream = self.orchestrator.universal_adapter.get_universal_stream() market_states = await self.orchestrator._get_all_market_states_universal(universal_stream) # Verify data quality quality_score = self._calculate_data_quality(market_states) self.training_stats['data_quality_score'] = quality_score if quality_score < 0.5: logger.warning(f"Low data quality detected: {quality_score:.2f}") return market_states except Exception as e: logger.error(f"Error getting comprehensive market states: {e}") return {} def _calculate_data_quality(self, market_states: Dict[str, any]) -> float: """Calculate data quality score based on available data""" try: if not market_states: return 0.0 total_score = 0.0 total_symbols = len(market_states) for symbol, state in market_states.items(): symbol_score = 0.0 # Score based on tick data availability if hasattr(state, 'raw_ticks') and state.raw_ticks: tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks symbol_score += tick_score * 0.3 # Score based on OHLCV data availability if hasattr(state, 'ohlcv_data') and state.ohlcv_data: ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes symbol_score += min(ohlcv_score, 1.0) * 0.4 # Score based on CNN features if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: symbol_score += 0.15 # Score based on pivot points if hasattr(state, 'pivot_points') and state.pivot_points: symbol_score += 0.15 total_score += symbol_score return total_score / total_symbols if total_symbols > 0 else 0.0 except Exception as e: logger.warning(f"Error calculating data quality: {e}") return 0.5 # Default to medium quality async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]: """Train RL agents with comprehensive market states""" try: training_results = { 'symbols_trained': [], 'total_experiences': 0, 'avg_state_size': 0, 'training_errors': [] } for symbol, market_state in market_states.items(): try: # Convert market state to comprehensive RL state rl_state = self.rl_trainer._market_state_to_rl_state(market_state) if rl_state is not None and len(rl_state) > 0: # Record state size training_results['avg_state_size'] += len(rl_state) # Simulate trading action for experience generation # In real implementation, this would be actual trading decisions action = self._simulate_trading_action(symbol, rl_state) # Generate reward based on market outcome reward = self._calculate_training_reward(symbol, market_state, action) # Add experience to RL agent agent = self.rl_trainer.agents.get(symbol) if agent: # Create next state (would be actual next market state in real scenario) next_state = rl_state # Simplified for now agent.remember( state=rl_state, action=action, reward=reward, next_state=next_state, done=False ) # Train agent if enough experiences if len(agent.replay_buffer) >= agent.batch_size: loss = agent.replay() if loss is not None: logger.debug(f"Agent {symbol} training loss: {loss:.4f}") training_results['symbols_trained'].append(symbol) training_results['total_experiences'] += 1 except Exception as e: error_msg = f"Error training {symbol}: {e}" logger.warning(error_msg) training_results['training_errors'].append(error_msg) # Calculate average state size if len(training_results['symbols_trained']) > 0: training_results['avg_state_size'] /= len(training_results['symbols_trained']) return training_results except Exception as e: logger.error(f"Error training RL agents: {e}") return {'error': str(e)} def _simulate_trading_action(self, symbol: str, rl_state) -> int: """Simulate trading action for training (would be real decision in production)""" # Simple simulation based on state features if len(rl_state) > 100: # Use momentum features to decide action momentum_features = rl_state[:100] # First 100 features assumed to be momentum avg_momentum = sum(momentum_features) / len(momentum_features) if avg_momentum > 0.6: return 1 # BUY elif avg_momentum < 0.4: return 2 # SELL else: return 0 # HOLD else: return 0 # HOLD as default def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float: """Calculate training reward based on market state and action""" try: # Simple reward calculation based on market conditions base_reward = 0.0 # Reward based on volatility alignment if hasattr(market_state, 'volatility'): if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility base_reward += 0.1 elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility base_reward += 0.1 # Reward based on trend alignment if hasattr(market_state, 'trend_strength'): if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend base_reward += 0.2 elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend base_reward += 0.2 return base_reward except Exception as e: logger.warning(f"Error calculating reward for {symbol}: {e}") return 0.0 def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]): """Update training statistics""" self.training_stats['training_sessions'] += 1 self.training_stats['total_experiences'] += training_results.get('total_experiences', 0) self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0) self.training_stats['last_training_time'] = datetime.now() # Log statistics periodically if self.training_stats['training_sessions'] % 10 == 0: logger.info("Training Statistics:") logger.info(f" Sessions: {self.training_stats['training_sessions']}") logger.info(f" Total Experiences: {self.training_stats['total_experiences']}") logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}") logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}") def _log_state_size_info(self, market_states: Dict[str, any]): """Log information about state sizes for debugging""" for symbol, state in market_states.items(): info = [] if hasattr(state, 'raw_ticks'): info.append(f"ticks: {len(state.raw_ticks)}") if hasattr(state, 'ohlcv_data'): total_bars = sum(len(bars) for bars in state.ohlcv_data.values()) info.append(f"OHLCV bars: {total_bars}") if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features: info.append("CNN features: available") if hasattr(state, 'pivot_points') and state.pivot_points: info.append("pivot points: available") logger.info(f"{symbol} state data: {', '.join(info)}") async def _save_training_progress(self): """Save training progress and models""" try: if self.rl_trainer: self.rl_trainer._save_all_models() logger.info("Training progress saved") except Exception as e: logger.error(f"Error saving training progress: {e}") async def shutdown(self): """Graceful shutdown""" logger.info("Shutting down enhanced RL training system...") self.running = False # Save final state await self._save_training_progress() # Stop data provider if self.data_provider: await self.data_provider.stop_real_time_streaming() logger.info("Enhanced RL training system shutdown complete") async def main(): """Main function to run enhanced RL training""" system = None def signal_handler(signum, frame): logger.info("Received shutdown signal") if system: asyncio.create_task(system.shutdown()) # Set up signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: # Create and initialize the training system system = EnhancedRLTrainingSystem() await system.initialize() logger.info("Enhanced RL Training System is now running...") logger.info("The RL model now receives ~13,400 features instead of ~100!") logger.info("Press Ctrl+C to stop") # Run the training loop await system.run_training_loop() except KeyboardInterrupt: logger.info("Training interrupted by user") except Exception as e: logger.error(f"Error in main training loop: {e}") raise finally: if system: await system.shutdown() if __name__ == "__main__": asyncio.run(main())