""" Enhanced RL Training Adapter This module integrates the new MSE-based reward system with existing RL training pipelines. It provides a bridge between the timeframe-aware inference coordinator and the existing model training infrastructure. Key Features: - Integration with EnhancedRewardCalculator - Adaptation of existing RL models to new reward system - Real-time training triggers based on prediction outcomes - Multi-timeframe training coordination - Backward compatibility with existing training infrastructure """ import asyncio import logging import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Union from dataclasses import dataclass import numpy as np import threading from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext logger = logging.getLogger(__name__) @dataclass class TrainingBatch: """Training batch for RL models with enhanced reward data""" model_name: str symbol: str timeframe: TimeFrame states: List[np.ndarray] actions: List[int] rewards: List[float] next_states: List[np.ndarray] dones: List[bool] confidences: List[float] prediction_records: List[PredictionRecord] batch_timestamp: datetime class EnhancedRLTrainingAdapter: """ Adapter that integrates new reward system with existing RL training infrastructure This adapter: 1. Bridges new reward calculator with existing RL models 2. Converts prediction records to RL training format 3. Triggers real-time training based on reward evaluation 4. Maintains compatibility with existing training systems 5. Coordinates multi-timeframe training """ def __init__(self, reward_calculator: EnhancedRewardCalculator, inference_coordinator: TimeframeInferenceCoordinator, orchestrator: Any = None, training_system: Any = None): """ Initialize the enhanced RL training adapter Args: reward_calculator: Enhanced reward calculator instance inference_coordinator: Timeframe inference coordinator orchestrator: Trading orchestrator (optional) training_system: Enhanced realtime training system (optional) """ self.reward_calculator = reward_calculator self.inference_coordinator = inference_coordinator self.orchestrator = orchestrator self.training_system = training_system # Model registry for training functions self.model_trainers: Dict[str, Any] = {} # Training configuration self.min_batch_size = 8 # Minimum samples for training self.max_batch_size = 64 # Maximum samples per training batch self.training_interval_seconds = 5.0 # How often to check for training opportunities # Training statistics self.training_stats = { 'total_training_batches': 0, 'successful_training_calls': 0, 'failed_training_calls': 0, 'last_training_time': None, 'training_times_per_model': {}, 'average_batch_sizes': {} } # State conversion helpers self.state_builders: Dict[str, Any] = {} # Thread safety self.lock = threading.RLock() # Running state self.running = False self.training_task: Optional[asyncio.Task] = None logger.info("EnhancedRLTrainingAdapter initialized") self._register_default_model_handlers() def _register_default_model_handlers(self): """Register default model handlers for existing models""" # Register inference functions with the coordinator if self.inference_coordinator: self.inference_coordinator.register_model_inference_function( 'dqn_agent', self._dqn_inference_wrapper ) self.inference_coordinator.register_model_inference_function( 'cob_rl', self._cob_rl_inference_wrapper ) self.inference_coordinator.register_model_inference_function( 'enhanced_cnn', self._cnn_inference_wrapper ) async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]: """Wrapper for DQN model inference""" try: if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'): # Get base data for the symbol base_data = await self._get_base_data(context.symbol) if base_data is None: return None # Convert to DQN state format state = self._convert_to_dqn_state(base_data, context) # Run DQN prediction if hasattr(self.orchestrator.rl_agent, 'act'): action_idx = self.orchestrator.rl_agent.act(state) # Try to extract confidence from agent if available confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None) if confidence is None: confidence = 0.5 # Convert action to prediction format action_names = ['SELL', 'HOLD', 'BUY'] direction = action_idx - 1 # Convert 0,1,2 to -1,0,1 # Use real current price current_price = self._safe_get_current_price(context.symbol) # Do not fabricate price; set predicted_price only if model provides numeric target later return { 'predicted_price': current_price, # same as current when no numeric target available 'current_price': current_price, 'direction': direction, 'confidence': float(confidence), 'action': action_names[action_idx], 'model_state': (state.tolist() if hasattr(state, 'tolist') else state), 'context': context } except Exception as e: logger.error(f"Error in DQN inference wrapper: {e}") return None async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]: """Wrapper for COB RL model inference""" try: if (self.orchestrator and hasattr(self.orchestrator, 'realtime_rl_trader') and self.orchestrator.realtime_rl_trader): # Get COB features features = await self._get_cob_features(context.symbol) if features is None: return None # Run COB RL prediction prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features) if prediction: current_price = self._safe_get_current_price(context.symbol) # If 'change' is available assume it is a fractional return change = prediction.get('change', None) predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price return { 'predicted_price': predicted_price, 'current_price': current_price, 'direction': prediction.get('direction', 0), 'confidence': prediction.get('confidence', 0.0), 'change': prediction.get('change', 0.0), 'model_features': features, 'context': context } except Exception as e: logger.error(f"Error in COB RL inference wrapper: {e}") return None async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]: """Wrapper for CNN model inference""" try: if self.orchestrator and hasattr(self.orchestrator, 'model_registry'): # Find CNN models in registry for model_name, model in self.orchestrator.model_registry.models.items(): if 'cnn' in model_name.lower(): # Get base data base_data = await self._get_base_data(context.symbol) if base_data is None: continue # Run CNN prediction if hasattr(model, 'predict_from_base_input'): model_output = model.predict_from_base_input(base_data) # Extract current price from data provider current_price = self._safe_get_current_price(context.symbol) # Extract prediction data predictions = model_output.predictions action = predictions.get('action', 'HOLD') confidence = predictions.get('confidence', 0.0) # Convert action to direction only for classification signal direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0) # Use numeric predicted return if provided (no synthetic fabrication) pr_map = { TimeFrame.SECONDS_1: 'predicted_return_1s', TimeFrame.MINUTES_1: 'predicted_return_1m', TimeFrame.HOURS_1: 'predicted_return_1h', TimeFrame.DAYS_1: 'predicted_return_1d', } ret_key = pr_map.get(context.target_timeframe) predicted_return = None if ret_key and ret_key in predictions: predicted_return = float(predictions.get(ret_key)) predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price # Also attach DQN-formatted state if available for training consumption dqn_state = self._convert_to_dqn_state(base_data, context) return { 'predicted_price': predicted_price, 'current_price': current_price, 'direction': direction, 'confidence': confidence, 'predicted_return': predicted_return, 'action': action, 'model_output': model_output, 'model_state': (dqn_state.tolist() if hasattr(dqn_state, 'tolist') else dqn_state), 'context': context } except Exception as e: logger.error(f"Error in CNN inference wrapper: {e}") return None async def _get_base_data(self, symbol: str) -> Optional[Any]: """Get base data for a symbol""" try: if self.orchestrator and hasattr(self.orchestrator, 'data_provider'): # Use orchestrator's data provider return await self.orchestrator._build_base_data(symbol) except Exception as e: logger.debug(f"Error getting base data for {symbol}: {e}") return None async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]: """Get COB features for a symbol""" try: if (self.orchestrator and hasattr(self.orchestrator, 'realtime_rl_trader') and self.orchestrator.realtime_rl_trader): # Get latest features from COB trader feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers if symbol in feature_buffers and feature_buffers[symbol]: latest_data = feature_buffers[symbol][-1] return latest_data.get('features') except Exception as e: logger.debug(f"Error getting COB features for {symbol}: {e}") return None def _safe_get_current_price(self, symbol: str) -> float: """Get current price for a symbol via DataProvider API""" try: if self.orchestrator and hasattr(self.orchestrator, 'data_provider'): price = self.orchestrator.data_provider.get_current_price(symbol) return float(price) if price is not None else 0.0 except Exception as e: logger.debug(f"Error getting current price for {symbol}: {e}") return 0.0 def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray: """Convert base data to DQN state format""" try: # Use existing state building logic if available if (self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system') and hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')): return self.orchestrator.enhanced_training_system._build_dqn_state( base_data, context.symbol ) # Fallback: create simple state representation feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else [] if feature_vector: return np.array(feature_vector, dtype=np.float32) # Last resort: create minimal state return np.zeros(100, dtype=np.float32) except Exception as e: logger.error(f"Error converting to DQN state: {e}") return np.zeros(100, dtype=np.float32) async def start_training_loop(self): """Start the enhanced training loop""" if self.running: logger.warning("Training loop already running") return self.running = True self.training_task = asyncio.create_task(self._training_loop()) logger.info("Enhanced RL training loop started") async def stop_training_loop(self): """Stop the enhanced training loop""" if not self.running: return self.running = False if self.training_task: self.training_task.cancel() try: await self.training_task except asyncio.CancelledError: pass logger.info("Enhanced RL training loop stopped") async def _training_loop(self): """Main training loop that processes evaluated predictions""" logger.info("Starting enhanced RL training loop") while self.running: try: # Process training for each symbol and timeframe for symbol in self.reward_calculator.symbols: for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]: # Get training data for this symbol/timeframe training_data = self.reward_calculator.get_training_data( symbol, timeframe, self.max_batch_size ) if len(training_data) >= self.min_batch_size: await self._process_training_batch(symbol, timeframe, training_data) # Sleep between training checks await asyncio.sleep(self.training_interval_seconds) except Exception as e: logger.error(f"Error in training loop: {e}") await asyncio.sleep(10) # Wait longer on error async def _process_training_batch(self, symbol: str, timeframe: TimeFrame, training_data: List[Tuple[PredictionRecord, float]]): """ Process a training batch for a specific symbol/timeframe Args: symbol: Trading symbol timeframe: Timeframe for training training_data: List of (prediction_record, reward) tuples """ try: # Group training data by model model_batches = {} for prediction_record, reward in training_data: model_name = prediction_record.model_name if model_name not in model_batches: model_batches[model_name] = [] model_batches[model_name].append((prediction_record, reward)) # Process each model's batch for model_name, model_data in model_batches.items(): if len(model_data) >= self.min_batch_size: await self._train_model_batch(model_name, symbol, timeframe, model_data) except Exception as e: logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}") async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame, training_data: List[Tuple[PredictionRecord, float]]): """ Train a specific model with a batch of data Args: model_name: Name of the model to train symbol: Trading symbol timeframe: Timeframe for training training_data: List of (prediction_record, reward) tuples """ try: training_start = time.time() # Convert to training batch format batch = self._create_training_batch(model_name, symbol, timeframe, training_data) if batch is None: return # Call appropriate training function based on model type success = False if 'dqn' in model_name.lower(): success = await self._train_dqn_model(batch) elif 'cob' in model_name.lower(): success = await self._train_cob_rl_model(batch) elif 'cnn' in model_name.lower(): success = await self._train_cnn_model(batch) else: logger.warning(f"Unknown model type for training: {model_name}") # Update statistics training_time = time.time() - training_start self._update_training_stats(model_name, batch, success, training_time) if success: logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} " f"with {len(training_data)} samples in {training_time:.3f}s") except Exception as e: logger.error(f"Error training model {model_name}: {e}") self._update_training_stats(model_name, None, False, 0) def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame, training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]: """Create a training batch from prediction records and rewards""" try: states = [] actions = [] rewards = [] next_states = [] dones = [] confidences = [] prediction_records = [] for prediction_record, reward in training_data: # Extract state information # This would need to be adapted based on how states are stored state = np.zeros(100) next_state = state.copy() # Simplified next state # Convert direction to action direction = prediction_record.predicted_direction action = direction + 1 # Convert -1,0,1 to 0,1,2 states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(True) # Each prediction is treated as terminal confidences.append(prediction_record.confidence) prediction_records.append(prediction_record) return TrainingBatch( model_name=model_name, symbol=symbol, timeframe=timeframe, states=states, actions=actions, rewards=rewards, next_states=next_states, dones=dones, confidences=confidences, prediction_records=prediction_records, batch_timestamp=datetime.now() ) except Exception as e: logger.error(f"Error creating training batch: {e}") return None async def _train_dqn_model(self, batch: TrainingBatch) -> bool: """Train DQN model with batch data""" try: if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'): rl_agent = self.orchestrator.rl_agent # Add experiences to memory for i in range(len(batch.states)): if hasattr(rl_agent, 'remember'): rl_agent.remember( state=batch.states[i], action=batch.actions[i], reward=batch.rewards[i], next_state=batch.next_states[i], done=batch.dones[i] ) # Trigger training if enough experiences if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'): if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32): loss = rl_agent.replay() if loss is not None: logger.debug(f"DQN training loss: {loss:.6f}") return True return False except Exception as e: logger.error(f"Error training DQN model: {e}") return False async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool: """Train COB RL model with batch data""" try: if (self.orchestrator and hasattr(self.orchestrator, 'realtime_rl_trader') and self.orchestrator.realtime_rl_trader): # Use COB RL trainer if available # This is a placeholder - implement based on actual COB RL training interface logger.debug(f"COB RL training batch: {len(batch.states)} samples") return True return False except Exception as e: logger.error(f"Error training COB RL model: {e}") return False async def _train_cnn_model(self, batch: TrainingBatch) -> bool: """Train CNN model with batch data""" try: if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'): # Use enhanced training system for CNN training # This is a placeholder - implement based on actual CNN training interface logger.debug(f"CNN training batch: {len(batch.states)} samples") return True return False except Exception as e: logger.error(f"Error training CNN model: {e}") return False def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch], success: bool, training_time: float): """Update training statistics""" with self.lock: self.training_stats['total_training_batches'] += 1 if success: self.training_stats['successful_training_calls'] += 1 else: self.training_stats['failed_training_calls'] += 1 self.training_stats['last_training_time'] = datetime.now().isoformat() # Model-specific stats if model_name not in self.training_stats['training_times_per_model']: self.training_stats['training_times_per_model'][model_name] = [] self.training_stats['average_batch_sizes'][model_name] = [] self.training_stats['training_times_per_model'][model_name].append(training_time) if batch: self.training_stats['average_batch_sizes'][model_name].append(len(batch.states)) def get_training_statistics(self) -> Dict[str, Any]: """Get training statistics""" with self.lock: stats = self.training_stats.copy() # Calculate averages for model_name in stats['training_times_per_model']: times = stats['training_times_per_model'][model_name] if times: stats[f'{model_name}_avg_training_time'] = sum(times) / len(times) sizes = stats['average_batch_sizes'][model_name] if sizes: stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes) return stats