Files
gogo2/core/enhanced_reward_system_integration.py

338 lines
13 KiB
Python

"""
Enhanced Reward System Integration
This module provides a simple integration point for the new MSE-based reward system
with the existing trading orchestrator and training infrastructure.
Key Features:
- Easy integration with existing TradingOrchestrator
- Minimal changes required to existing code
- Backward compatibility maintained
- Enhanced performance monitoring
- Real-time training with MSE rewards
"""
import asyncio
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
logger = logging.getLogger(__name__)
class EnhancedRewardSystemIntegration:
"""
Complete integration of the enhanced reward system
This class provides a single integration point that can be easily added
to the existing TradingOrchestrator to enable MSE-based rewards and
multi-timeframe training.
"""
def __init__(self, orchestrator: Any, symbols: list = None):
"""
Initialize the enhanced reward system integration
Args:
orchestrator: TradingOrchestrator instance
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
"""
self.orchestrator = orchestrator
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
# Initialize core components
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
self.inference_coordinator = TimeframeInferenceCoordinator(
reward_calculator=self.reward_calculator,
data_provider=getattr(orchestrator, 'data_provider', None),
symbols=self.symbols
)
self.training_adapter = EnhancedRLTrainingAdapter(
reward_calculator=self.reward_calculator,
inference_coordinator=self.inference_coordinator,
orchestrator=orchestrator,
training_system=getattr(orchestrator, 'enhanced_training_system', None)
)
# Integration state
self.is_running = False
self.integration_stats = {
'start_time': None,
'total_predictions_tracked': 0,
'total_rewards_calculated': 0,
'total_training_batches': 0
}
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
async def start_integration(self):
"""Start the enhanced reward system integration"""
if self.is_running:
logger.warning("Enhanced reward system already running")
return
try:
logger.info("Starting Enhanced Reward System Integration")
# Start core components
await self.inference_coordinator.start_coordination()
await self.training_adapter.start_training_loop()
# Start price monitoring
asyncio.create_task(self._price_monitoring_loop())
self.is_running = True
self.integration_stats['start_time'] = datetime.now().isoformat()
logger.info("Enhanced Reward System Integration started successfully")
except Exception as e:
logger.error(f"Error starting enhanced reward system integration: {e}")
await self.stop_integration()
async def stop_integration(self):
"""Stop the enhanced reward system integration"""
if not self.is_running:
return
try:
logger.info("Stopping Enhanced Reward System Integration")
# Stop components
await self.inference_coordinator.stop_coordination()
await self.training_adapter.stop_training_loop()
self.is_running = False
logger.info("Enhanced Reward System Integration stopped")
except Exception as e:
logger.error(f"Error stopping enhanced reward system integration: {e}")
async def _price_monitoring_loop(self):
"""Monitor prices and update the reward calculator"""
while self.is_running:
try:
# Update current prices for all symbols
for symbol in self.symbols:
current_price = await self._get_current_price(symbol)
if current_price > 0:
self.reward_calculator.update_price(symbol, current_price)
# Sleep for 1 second between updates
await asyncio.sleep(1.0)
except Exception as e:
logger.debug(f"Error in price monitoring loop: {e}")
await asyncio.sleep(5.0) # Wait longer on error
async def _get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol"""
try:
if hasattr(self.orchestrator, 'data_provider'):
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def add_prediction_manually(self, symbol: str, timeframe_str: str,
predicted_price: float, predicted_direction: int,
confidence: float, current_price: float,
model_name: str) -> str:
"""
Manually add a prediction to the reward calculator
This method allows existing code to easily integrate with the new reward system
without major changes.
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
predicted_price: Model's predicted price
predicted_direction: Predicted direction (-1, 0, 1)
confidence: Model's confidence (0.0 to 1.0)
current_price: Current market price
model_name: Name of the model making prediction
Returns:
Unique prediction ID
"""
try:
# Convert timeframe string to enum
timeframe = TimeFrame(timeframe_str)
prediction_id = self.reward_calculator.add_prediction(
symbol=symbol,
timeframe=timeframe,
predicted_price=predicted_price,
predicted_direction=predicted_direction,
confidence=confidence,
current_price=current_price,
model_name=model_name
)
self.integration_stats['total_predictions_tracked'] += 1
return prediction_id
except Exception as e:
logger.error(f"Error adding prediction manually: {e}")
return ""
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
"""
Get accuracy statistics for models
Args:
model_name: Specific model name (None for all)
symbol: Specific symbol (None for all)
Returns:
Dictionary with accuracy statistics
"""
try:
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
if model_name:
# Filter by model name in prediction history
# This would require enhancing the reward calculator to track by model
pass
return accuracy_summary
except Exception as e:
logger.error(f"Error getting model accuracy: {e}")
return {}
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
"""
Force immediate evaluation and training for debugging/testing
Args:
symbol: Specific symbol (None for all)
timeframe_str: Specific timeframe (None for all)
"""
try:
if timeframe_str:
timeframe = TimeFrame(timeframe_str)
symbols_to_process = [symbol] if symbol else self.symbols
for sym in symbols_to_process:
# Force evaluation of predictions
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
else:
# Evaluate all pending predictions
for sym in (self.symbols if not symbol else [symbol]):
results = self.reward_calculator.evaluate_predictions(sym)
if sym in results:
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
except Exception as e:
logger.error(f"Error in force evaluation and training: {e}")
def get_integration_statistics(self) -> Dict[str, Any]:
"""Get comprehensive integration statistics"""
try:
stats = self.integration_stats.copy()
# Add component statistics
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
stats['training_adapter'] = self.training_adapter.get_training_statistics()
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
# Add system status
stats['is_running'] = self.is_running
stats['components_running'] = {
'inference_coordinator': self.inference_coordinator.running,
'training_adapter': self.training_adapter.running
}
return stats
except Exception as e:
logger.error(f"Error getting integration statistics: {e}")
return {'error': str(e)}
def cleanup_old_data(self, days_to_keep: int = 7):
"""Clean up old prediction data to manage memory"""
try:
self.reward_calculator.cleanup_old_predictions(days_to_keep)
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
except Exception as e:
logger.error(f"Error cleaning up old data: {e}")
# Utility functions for easy integration
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
"""
Utility function to easily integrate enhanced rewards with an existing orchestrator
Args:
orchestrator: TradingOrchestrator instance
symbols: List of symbols to track
Returns:
EnhancedRewardSystemIntegration instance
"""
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
# Add integration as an attribute to the orchestrator for easy access
setattr(orchestrator, 'enhanced_reward_system', integration)
logger.info("Enhanced reward system integrated with orchestrator")
return integration
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
"""
Start enhanced rewards for an existing orchestrator
Args:
orchestrator: TradingOrchestrator instance
symbols: List of symbols to track
"""
if not hasattr(orchestrator, 'enhanced_reward_system'):
integrate_enhanced_rewards(orchestrator, symbols)
await orchestrator.enhanced_reward_system.start_integration()
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
predicted_price: float, direction: int, confidence: float,
current_price: float, model_name: str) -> str:
"""
Helper function to add predictions to enhanced rewards from existing code
Args:
orchestrator: TradingOrchestrator instance with enhanced_reward_system
symbol: Trading symbol
timeframe: Timeframe string
predicted_price: Predicted price
direction: Predicted direction (-1, 0, 1)
confidence: Model confidence
current_price: Current market price
model_name: Model name
Returns:
Prediction ID
"""
if hasattr(orchestrator, 'enhanced_reward_system'):
return orchestrator.enhanced_reward_system.add_prediction_manually(
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
)
logger.warning("Enhanced reward system not integrated with orchestrator")
return ""