347 lines
13 KiB
Python
347 lines
13 KiB
Python
"""
|
|
Enhanced Reward System Integration
|
|
|
|
This module provides a simple integration point for the new MSE-based reward system
|
|
with the existing trading orchestrator and training infrastructure.
|
|
|
|
Key Features:
|
|
- Easy integration with existing TradingOrchestrator
|
|
- Minimal changes required to existing code
|
|
- Backward compatibility maintained
|
|
- Enhanced performance monitoring
|
|
- Real-time training with MSE rewards
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime
|
|
|
|
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
|
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
|
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
|
from core.unified_training_manager import UnifiedTrainingManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EnhancedRewardSystemIntegration:
|
|
"""
|
|
Complete integration of the enhanced reward system
|
|
|
|
This class provides a single integration point that can be easily added
|
|
to the existing TradingOrchestrator to enable MSE-based rewards and
|
|
multi-timeframe training.
|
|
"""
|
|
|
|
def __init__(self, orchestrator: Any, symbols: list = None):
|
|
"""
|
|
Initialize the enhanced reward system integration
|
|
|
|
Args:
|
|
orchestrator: TradingOrchestrator instance
|
|
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
|
|
"""
|
|
self.orchestrator = orchestrator
|
|
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
|
|
|
# Initialize core components
|
|
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
|
|
|
|
self.inference_coordinator = TimeframeInferenceCoordinator(
|
|
reward_calculator=self.reward_calculator,
|
|
data_provider=getattr(orchestrator, 'data_provider', None),
|
|
symbols=self.symbols
|
|
)
|
|
|
|
self.training_adapter = EnhancedRLTrainingAdapter(
|
|
reward_calculator=self.reward_calculator,
|
|
inference_coordinator=self.inference_coordinator,
|
|
orchestrator=orchestrator,
|
|
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
|
)
|
|
|
|
# Unified Training Manager (always available)
|
|
self.unified_training = UnifiedTrainingManager(
|
|
orchestrator=orchestrator,
|
|
reward_system=self,
|
|
)
|
|
|
|
# Integration state
|
|
self.is_running = False
|
|
self.integration_stats = {
|
|
'start_time': None,
|
|
'total_predictions_tracked': 0,
|
|
'total_rewards_calculated': 0,
|
|
'total_training_batches': 0
|
|
}
|
|
|
|
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
|
|
|
|
async def start_integration(self):
|
|
"""Start the enhanced reward system integration"""
|
|
if self.is_running:
|
|
logger.warning("Enhanced reward system already running")
|
|
return
|
|
|
|
try:
|
|
logger.info("Starting Enhanced Reward System Integration")
|
|
|
|
# Start core components
|
|
await self.inference_coordinator.start_coordination()
|
|
await self.training_adapter.start_training_loop()
|
|
await self.unified_training.start()
|
|
|
|
# Start price monitoring
|
|
asyncio.create_task(self._price_monitoring_loop())
|
|
|
|
self.is_running = True
|
|
self.integration_stats['start_time'] = datetime.now().isoformat()
|
|
|
|
logger.info("Enhanced Reward System Integration started successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting enhanced reward system integration: {e}")
|
|
await self.stop_integration()
|
|
|
|
async def stop_integration(self):
|
|
"""Stop the enhanced reward system integration"""
|
|
if not self.is_running:
|
|
return
|
|
|
|
try:
|
|
logger.info("Stopping Enhanced Reward System Integration")
|
|
|
|
# Stop components
|
|
await self.inference_coordinator.stop_coordination()
|
|
await self.training_adapter.stop_training_loop()
|
|
await self.unified_training.stop()
|
|
|
|
self.is_running = False
|
|
|
|
logger.info("Enhanced Reward System Integration stopped")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error stopping enhanced reward system integration: {e}")
|
|
|
|
async def _price_monitoring_loop(self):
|
|
"""Monitor prices and update the reward calculator"""
|
|
while self.is_running:
|
|
try:
|
|
# Update current prices for all symbols
|
|
for symbol in self.symbols:
|
|
current_price = await self._get_current_price(symbol)
|
|
if current_price > 0:
|
|
self.reward_calculator.update_price(symbol, current_price)
|
|
|
|
# Sleep for 1 second between updates
|
|
await asyncio.sleep(1.0)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in price monitoring loop: {e}")
|
|
await asyncio.sleep(5.0) # Wait longer on error
|
|
|
|
async def _get_current_price(self, symbol: str) -> float:
|
|
"""Get current price for a symbol"""
|
|
try:
|
|
if hasattr(self.orchestrator, 'data_provider'):
|
|
price = self.orchestrator.data_provider.get_current_price(symbol)
|
|
return float(price) if price is not None else 0.0
|
|
except Exception as e:
|
|
logger.debug(f"Error getting current price for {symbol}: {e}")
|
|
|
|
return 0.0
|
|
|
|
def add_prediction_manually(self, symbol: str, timeframe_str: str,
|
|
predicted_price: float, predicted_direction: int,
|
|
confidence: float, current_price: float,
|
|
model_name: str) -> str:
|
|
"""
|
|
Manually add a prediction to the reward calculator
|
|
|
|
This method allows existing code to easily integrate with the new reward system
|
|
without major changes.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'ETH/USDT')
|
|
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
|
|
predicted_price: Model's predicted price
|
|
predicted_direction: Predicted direction (-1, 0, 1)
|
|
confidence: Model's confidence (0.0 to 1.0)
|
|
current_price: Current market price
|
|
model_name: Name of the model making prediction
|
|
|
|
Returns:
|
|
Unique prediction ID
|
|
"""
|
|
try:
|
|
# Convert timeframe string to enum
|
|
timeframe = TimeFrame(timeframe_str)
|
|
|
|
prediction_id = self.reward_calculator.add_prediction(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
predicted_price=predicted_price,
|
|
predicted_direction=predicted_direction,
|
|
confidence=confidence,
|
|
current_price=current_price,
|
|
model_name=model_name
|
|
)
|
|
|
|
self.integration_stats['total_predictions_tracked'] += 1
|
|
|
|
return prediction_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding prediction manually: {e}")
|
|
return ""
|
|
|
|
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
|
|
"""
|
|
Get accuracy statistics for models
|
|
|
|
Args:
|
|
model_name: Specific model name (None for all)
|
|
symbol: Specific symbol (None for all)
|
|
|
|
Returns:
|
|
Dictionary with accuracy statistics
|
|
"""
|
|
try:
|
|
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
|
|
|
|
if model_name:
|
|
# Filter by model name in prediction history
|
|
# This would require enhancing the reward calculator to track by model
|
|
pass
|
|
|
|
return accuracy_summary
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting model accuracy: {e}")
|
|
return {}
|
|
|
|
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
|
|
"""
|
|
Force immediate evaluation and training for debugging/testing
|
|
|
|
Args:
|
|
symbol: Specific symbol (None for all)
|
|
timeframe_str: Specific timeframe (None for all)
|
|
"""
|
|
try:
|
|
if timeframe_str:
|
|
timeframe = TimeFrame(timeframe_str)
|
|
symbols_to_process = [symbol] if symbol else self.symbols
|
|
|
|
for sym in symbols_to_process:
|
|
# Force evaluation of predictions
|
|
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
|
|
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
|
|
else:
|
|
# Evaluate all pending predictions
|
|
for sym in (self.symbols if not symbol else [symbol]):
|
|
results = self.reward_calculator.evaluate_predictions(sym)
|
|
if sym in results:
|
|
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in force evaluation and training: {e}")
|
|
|
|
def get_integration_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive integration statistics"""
|
|
try:
|
|
stats = self.integration_stats.copy()
|
|
|
|
# Add component statistics
|
|
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
|
|
stats['training_adapter'] = self.training_adapter.get_training_statistics()
|
|
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
|
|
|
|
# Add system status
|
|
stats['is_running'] = self.is_running
|
|
stats['components_running'] = {
|
|
'inference_coordinator': self.inference_coordinator.running,
|
|
'training_adapter': self.training_adapter.running
|
|
}
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting integration statistics: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def cleanup_old_data(self, days_to_keep: int = 7):
|
|
"""Clean up old prediction data to manage memory"""
|
|
try:
|
|
self.reward_calculator.cleanup_old_predictions(days_to_keep)
|
|
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up old data: {e}")
|
|
|
|
|
|
# Utility functions for easy integration
|
|
|
|
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
|
|
"""
|
|
Utility function to easily integrate enhanced rewards with an existing orchestrator
|
|
|
|
Args:
|
|
orchestrator: TradingOrchestrator instance
|
|
symbols: List of symbols to track
|
|
|
|
Returns:
|
|
EnhancedRewardSystemIntegration instance
|
|
"""
|
|
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
|
|
|
|
# Add integration as an attribute to the orchestrator for easy access
|
|
setattr(orchestrator, 'enhanced_reward_system', integration)
|
|
|
|
logger.info("Enhanced reward system integrated with orchestrator")
|
|
return integration
|
|
|
|
|
|
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
|
|
"""
|
|
Start enhanced rewards for an existing orchestrator
|
|
|
|
Args:
|
|
orchestrator: TradingOrchestrator instance
|
|
symbols: List of symbols to track
|
|
"""
|
|
if not hasattr(orchestrator, 'enhanced_reward_system'):
|
|
integrate_enhanced_rewards(orchestrator, symbols)
|
|
|
|
await orchestrator.enhanced_reward_system.start_integration()
|
|
|
|
|
|
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
|
|
predicted_price: float, direction: int, confidence: float,
|
|
current_price: float, model_name: str) -> str:
|
|
"""
|
|
Helper function to add predictions to enhanced rewards from existing code
|
|
|
|
Args:
|
|
orchestrator: TradingOrchestrator instance with enhanced_reward_system
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe string
|
|
predicted_price: Predicted price
|
|
direction: Predicted direction (-1, 0, 1)
|
|
confidence: Model confidence
|
|
current_price: Current market price
|
|
model_name: Model name
|
|
|
|
Returns:
|
|
Prediction ID
|
|
"""
|
|
if hasattr(orchestrator, 'enhanced_reward_system'):
|
|
return orchestrator.enhanced_reward_system.add_prediction_manually(
|
|
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
|
|
)
|
|
|
|
logger.warning("Enhanced reward system not integrated with orchestrator")
|
|
return ""
|
|
|