""" Enhanced Reward System Integration This module provides a simple integration point for the new MSE-based reward system with the existing trading orchestrator and training infrastructure. Key Features: - Easy integration with existing TradingOrchestrator - Minimal changes required to existing code - Backward compatibility maintained - Enhanced performance monitoring - Real-time training with MSE rewards """ import asyncio import logging from typing import Optional, Dict, Any from datetime import datetime from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter logger = logging.getLogger(__name__) class EnhancedRewardSystemIntegration: """ Complete integration of the enhanced reward system This class provides a single integration point that can be easily added to the existing TradingOrchestrator to enable MSE-based rewards and multi-timeframe training. """ def __init__(self, orchestrator: Any, symbols: list = None): """ Initialize the enhanced reward system integration Args: orchestrator: TradingOrchestrator instance symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT) """ self.orchestrator = orchestrator self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] # Initialize core components self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols) self.inference_coordinator = TimeframeInferenceCoordinator( reward_calculator=self.reward_calculator, data_provider=getattr(orchestrator, 'data_provider', None), symbols=self.symbols ) self.training_adapter = EnhancedRLTrainingAdapter( reward_calculator=self.reward_calculator, inference_coordinator=self.inference_coordinator, orchestrator=orchestrator, training_system=getattr(orchestrator, 'enhanced_training_system', None) ) # Integration state self.is_running = False self.integration_stats = { 'start_time': None, 'total_predictions_tracked': 0, 'total_rewards_calculated': 0, 'total_training_batches': 0 } logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}") async def start_integration(self): """Start the enhanced reward system integration""" if self.is_running: logger.warning("Enhanced reward system already running") return try: logger.info("Starting Enhanced Reward System Integration") # Start core components await self.inference_coordinator.start_coordination() await self.training_adapter.start_training_loop() # Start price monitoring asyncio.create_task(self._price_monitoring_loop()) self.is_running = True self.integration_stats['start_time'] = datetime.now().isoformat() logger.info("Enhanced Reward System Integration started successfully") except Exception as e: logger.error(f"Error starting enhanced reward system integration: {e}") await self.stop_integration() async def stop_integration(self): """Stop the enhanced reward system integration""" if not self.is_running: return try: logger.info("Stopping Enhanced Reward System Integration") # Stop components await self.inference_coordinator.stop_coordination() await self.training_adapter.stop_training_loop() self.is_running = False logger.info("Enhanced Reward System Integration stopped") except Exception as e: logger.error(f"Error stopping enhanced reward system integration: {e}") async def _price_monitoring_loop(self): """Monitor prices and update the reward calculator""" while self.is_running: try: # Update current prices for all symbols for symbol in self.symbols: current_price = await self._get_current_price(symbol) if current_price > 0: self.reward_calculator.update_price(symbol, current_price) # Sleep for 1 second between updates await asyncio.sleep(1.0) except Exception as e: logger.debug(f"Error in price monitoring loop: {e}") await asyncio.sleep(5.0) # Wait longer on error async def _get_current_price(self, symbol: str) -> float: """Get current price for a symbol""" try: if hasattr(self.orchestrator, 'data_provider'): current_prices = self.orchestrator.data_provider.current_prices return current_prices.get(symbol, 0.0) except Exception as e: logger.debug(f"Error getting current price for {symbol}: {e}") return 0.0 def add_prediction_manually(self, symbol: str, timeframe_str: str, predicted_price: float, predicted_direction: int, confidence: float, current_price: float, model_name: str) -> str: """ Manually add a prediction to the reward calculator This method allows existing code to easily integrate with the new reward system without major changes. Args: symbol: Trading symbol (e.g., 'ETH/USDT') timeframe_str: Timeframe string ('1s', '1m', '1h', '1d') predicted_price: Model's predicted price predicted_direction: Predicted direction (-1, 0, 1) confidence: Model's confidence (0.0 to 1.0) current_price: Current market price model_name: Name of the model making prediction Returns: Unique prediction ID """ try: # Convert timeframe string to enum timeframe = TimeFrame(timeframe_str) prediction_id = self.reward_calculator.add_prediction( symbol=symbol, timeframe=timeframe, predicted_price=predicted_price, predicted_direction=predicted_direction, confidence=confidence, current_price=current_price, model_name=model_name ) self.integration_stats['total_predictions_tracked'] += 1 return prediction_id except Exception as e: logger.error(f"Error adding prediction manually: {e}") return "" def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]: """ Get accuracy statistics for models Args: model_name: Specific model name (None for all) symbol: Specific symbol (None for all) Returns: Dictionary with accuracy statistics """ try: accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol) if model_name: # Filter by model name in prediction history # This would require enhancing the reward calculator to track by model pass return accuracy_summary except Exception as e: logger.error(f"Error getting model accuracy: {e}") return {} def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None): """ Force immediate evaluation and training for debugging/testing Args: symbol: Specific symbol (None for all) timeframe_str: Specific timeframe (None for all) """ try: if timeframe_str: timeframe = TimeFrame(timeframe_str) symbols_to_process = [symbol] if symbol else self.symbols for sym in symbols_to_process: # Force evaluation of predictions results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe) logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}") else: # Evaluate all pending predictions for sym in (self.symbols if not symbol else [symbol]): results = self.reward_calculator.evaluate_predictions(sym) if sym in results: logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}") except Exception as e: logger.error(f"Error in force evaluation and training: {e}") def get_integration_statistics(self) -> Dict[str, Any]: """Get comprehensive integration statistics""" try: stats = self.integration_stats.copy() # Add component statistics stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics() stats['training_adapter'] = self.training_adapter.get_training_statistics() stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary() # Add system status stats['is_running'] = self.is_running stats['components_running'] = { 'inference_coordinator': self.inference_coordinator.running, 'training_adapter': self.training_adapter.running } return stats except Exception as e: logger.error(f"Error getting integration statistics: {e}") return {'error': str(e)} def cleanup_old_data(self, days_to_keep: int = 7): """Clean up old prediction data to manage memory""" try: self.reward_calculator.cleanup_old_predictions(days_to_keep) logger.info(f"Cleaned up prediction data older than {days_to_keep} days") except Exception as e: logger.error(f"Error cleaning up old data: {e}") # Utility functions for easy integration def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration: """ Utility function to easily integrate enhanced rewards with an existing orchestrator Args: orchestrator: TradingOrchestrator instance symbols: List of symbols to track Returns: EnhancedRewardSystemIntegration instance """ integration = EnhancedRewardSystemIntegration(orchestrator, symbols) # Add integration as an attribute to the orchestrator for easy access setattr(orchestrator, 'enhanced_reward_system', integration) logger.info("Enhanced reward system integrated with orchestrator") return integration async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None): """ Start enhanced rewards for an existing orchestrator Args: orchestrator: TradingOrchestrator instance symbols: List of symbols to track """ if not hasattr(orchestrator, 'enhanced_reward_system'): integrate_enhanced_rewards(orchestrator, symbols) await orchestrator.enhanced_reward_system.start_integration() def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str, predicted_price: float, direction: int, confidence: float, current_price: float, model_name: str) -> str: """ Helper function to add predictions to enhanced rewards from existing code Args: orchestrator: TradingOrchestrator instance with enhanced_reward_system symbol: Trading symbol timeframe: Timeframe string predicted_price: Predicted price direction: Predicted direction (-1, 0, 1) confidence: Model confidence current_price: Current market price model_name: Model name Returns: Prediction ID """ if hasattr(orchestrator, 'enhanced_reward_system'): return orchestrator.enhanced_reward_system.add_prediction_manually( symbol, timeframe, predicted_price, direction, confidence, current_price, model_name ) logger.warning("Enhanced reward system not integrated with orchestrator") return ""