ehanced training and reward - wip
This commit is contained in:
337
core/enhanced_reward_system_integration.py
Normal file
337
core/enhanced_reward_system_integration.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Enhanced Reward System Integration
|
||||
|
||||
This module provides a simple integration point for the new MSE-based reward system
|
||||
with the existing trading orchestrator and training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Easy integration with existing TradingOrchestrator
|
||||
- Minimal changes required to existing code
|
||||
- Backward compatibility maintained
|
||||
- Enhanced performance monitoring
|
||||
- Real-time training with MSE rewards
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
||||
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Complete integration of the enhanced reward system
|
||||
|
||||
This class provides a single integration point that can be easily added
|
||||
to the existing TradingOrchestrator to enable MSE-based rewards and
|
||||
multi-timeframe training.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Initialize the enhanced reward system integration
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Initialize core components
|
||||
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
|
||||
|
||||
self.inference_coordinator = TimeframeInferenceCoordinator(
|
||||
reward_calculator=self.reward_calculator,
|
||||
data_provider=getattr(orchestrator, 'data_provider', None),
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
self.training_adapter = EnhancedRLTrainingAdapter(
|
||||
reward_calculator=self.reward_calculator,
|
||||
inference_coordinator=self.inference_coordinator,
|
||||
orchestrator=orchestrator,
|
||||
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
||||
)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.integration_stats = {
|
||||
'start_time': None,
|
||||
'total_predictions_tracked': 0,
|
||||
'total_rewards_calculated': 0,
|
||||
'total_training_batches': 0
|
||||
}
|
||||
|
||||
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the enhanced reward system integration"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced reward system already running")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting Enhanced Reward System Integration")
|
||||
|
||||
# Start core components
|
||||
await self.inference_coordinator.start_coordination()
|
||||
await self.training_adapter.start_training_loop()
|
||||
|
||||
# Start price monitoring
|
||||
asyncio.create_task(self._price_monitoring_loop())
|
||||
|
||||
self.is_running = True
|
||||
self.integration_stats['start_time'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Enhanced Reward System Integration started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system integration: {e}")
|
||||
await self.stop_integration()
|
||||
|
||||
async def stop_integration(self):
|
||||
"""Stop the enhanced reward system integration"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Stopping Enhanced Reward System Integration")
|
||||
|
||||
# Stop components
|
||||
await self.inference_coordinator.stop_coordination()
|
||||
await self.training_adapter.stop_training_loop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
logger.info("Enhanced Reward System Integration stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced reward system integration: {e}")
|
||||
|
||||
async def _price_monitoring_loop(self):
|
||||
"""Monitor prices and update the reward calculator"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Update current prices for all symbols
|
||||
for symbol in self.symbols:
|
||||
current_price = await self._get_current_price(symbol)
|
||||
if current_price > 0:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
# Sleep for 1 second between updates
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in price monitoring loop: {e}")
|
||||
await asyncio.sleep(5.0) # Wait longer on error
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def add_prediction_manually(self, symbol: str, timeframe_str: str,
|
||||
predicted_price: float, predicted_direction: int,
|
||||
confidence: float, current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Manually add a prediction to the reward calculator
|
||||
|
||||
This method allows existing code to easily integrate with the new reward system
|
||||
without major changes.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID
|
||||
"""
|
||||
try:
|
||||
# Convert timeframe string to enum
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
self.integration_stats['total_predictions_tracked'] += 1
|
||||
|
||||
return prediction_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prediction manually: {e}")
|
||||
return ""
|
||||
|
||||
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy statistics for models
|
||||
|
||||
Args:
|
||||
model_name: Specific model name (None for all)
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with accuracy statistics
|
||||
"""
|
||||
try:
|
||||
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
|
||||
|
||||
if model_name:
|
||||
# Filter by model name in prediction history
|
||||
# This would require enhancing the reward calculator to track by model
|
||||
pass
|
||||
|
||||
return accuracy_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model accuracy: {e}")
|
||||
return {}
|
||||
|
||||
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
|
||||
"""
|
||||
Force immediate evaluation and training for debugging/testing
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
timeframe_str: Specific timeframe (None for all)
|
||||
"""
|
||||
try:
|
||||
if timeframe_str:
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_process:
|
||||
# Force evaluation of predictions
|
||||
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
|
||||
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
|
||||
else:
|
||||
# Evaluate all pending predictions
|
||||
for sym in (self.symbols if not symbol else [symbol]):
|
||||
results = self.reward_calculator.evaluate_predictions(sym)
|
||||
if sym in results:
|
||||
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in force evaluation and training: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
try:
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
|
||||
stats['training_adapter'] = self.training_adapter.get_training_statistics()
|
||||
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
# Add system status
|
||||
stats['is_running'] = self.is_running
|
||||
stats['components_running'] = {
|
||||
'inference_coordinator': self.inference_coordinator.running,
|
||||
'training_adapter': self.training_adapter.running
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting integration statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 7):
|
||||
"""Clean up old prediction data to manage memory"""
|
||||
try:
|
||||
self.reward_calculator.cleanup_old_predictions(days_to_keep)
|
||||
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old data: {e}")
|
||||
|
||||
|
||||
# Utility functions for easy integration
|
||||
|
||||
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Utility function to easily integrate enhanced rewards with an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
|
||||
Returns:
|
||||
EnhancedRewardSystemIntegration instance
|
||||
"""
|
||||
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
|
||||
|
||||
# Add integration as an attribute to the orchestrator for easy access
|
||||
setattr(orchestrator, 'enhanced_reward_system', integration)
|
||||
|
||||
logger.info("Enhanced reward system integrated with orchestrator")
|
||||
return integration
|
||||
|
||||
|
||||
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Start enhanced rewards for an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
"""
|
||||
if not hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
integrate_enhanced_rewards(orchestrator, symbols)
|
||||
|
||||
await orchestrator.enhanced_reward_system.start_integration()
|
||||
|
||||
|
||||
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
|
||||
predicted_price: float, direction: int, confidence: float,
|
||||
current_price: float, model_name: str) -> str:
|
||||
"""
|
||||
Helper function to add predictions to enhanced rewards from existing code
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance with enhanced_reward_system
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe string
|
||||
predicted_price: Predicted price
|
||||
direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model confidence
|
||||
current_price: Current market price
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Prediction ID
|
||||
"""
|
||||
if hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
return orchestrator.enhanced_reward_system.add_prediction_manually(
|
||||
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
|
||||
)
|
||||
|
||||
logger.warning("Enhanced reward system not integrated with orchestrator")
|
||||
return ""
|
||||
|
Reference in New Issue
Block a user