""" Enhanced Training Integration Module This module provides comprehensive integration between the training data collection system, CNN training pipeline, RL training pipeline, and your existing infrastructure. Key Features: - Real-time integration with existing DataProvider - Coordinated training across CNN and RL models - Automatic outcome validation and profitability tracking - Integration with existing COB RL model - Performance monitoring and optimization - Seamless connection to existing orchestrator and trading executor """ import asyncio import logging import numpy as np import pandas as pd import torch from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple, Any, Callable from dataclasses import dataclass import threading import time from pathlib import Path # Import existing components from .data_provider import DataProvider from .orchestrator import Orchestrator from .trading_executor import TradingExecutor # Import our training system components from .training_data_collector import ( TrainingDataCollector, get_training_data_collector ) from .cnn_training_pipeline import ( CNNPivotPredictor, CNNTrainer, get_cnn_trainer ) from .rl_training_pipeline import ( RLTradingAgent, RLTrainer, get_rl_trainer ) from .training_integration import TrainingIntegration # Import existing RL model try: from NN.models.cob_rl_model import COBRLModelInterface except ImportError: logger.warning("Could not import COBRLModelInterface - using fallback") COBRLModelInterface = None logger = logging.getLogger(__name__) @dataclass class EnhancedTrainingConfig: """Enhanced configuration for comprehensive training integration""" # Data collection collection_interval: float = 1.0 min_data_completeness: float = 0.8 # Training triggers min_episodes_for_cnn_training: int = 100 min_experiences_for_rl_training: int = 200 training_frequency_minutes: int = 30 # Profitability thresholds min_profitability_for_replay: float = 0.1 high_profitability_threshold: float = 0.5 # Model integration use_existing_cob_rl_model: bool = True enable_cross_model_learning: bool = True # Performance optimization max_concurrent_training_sessions: int = 2 enable_background_validation: bool = True class EnhancedTrainingIntegration: """Enhanced training integration with existing infrastructure""" def __init__(self, data_provider: DataProvider, orchestrator: Orchestrator = None, trading_executor: TradingExecutor = None, config: EnhancedTrainingConfig = None): self.data_provider = data_provider self.orchestrator = orchestrator self.trading_executor = trading_executor self.config = config or EnhancedTrainingConfig() # Initialize training components self.data_collector = get_training_data_collector() # Initialize CNN components self.cnn_model = CNNPivotPredictor() self.cnn_trainer = get_cnn_trainer(self.cnn_model) # Initialize RL components if self.config.use_existing_cob_rl_model and COBRLModelInterface: self.existing_rl_model = COBRLModelInterface() logger.info("Using existing COB RL model") else: self.existing_rl_model = None self.rl_agent = RLTradingAgent() self.rl_trainer = get_rl_trainer(self.rl_agent) # Integration state self.is_running = False self.training_threads = {} self.validation_thread = None # Performance tracking self.integration_stats = { 'total_data_packages': 0, 'cnn_training_sessions': 0, 'rl_training_sessions': 0, 'profitable_predictions': 0, 'total_predictions': 0, 'cross_model_improvements': 0, 'last_update': datetime.now() } # Model prediction tracking self.recent_predictions = {} self.prediction_outcomes = {} # Cross-model learning self.model_performance_history = { 'cnn': [], 'rl': [], 'orchestrator': [] } logger.info("Enhanced Training Integration initialized") logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}") logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}") logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}") def start_enhanced_integration(self): """Start the enhanced training integration system""" if self.is_running: logger.warning("Enhanced training integration already running") return self.is_running = True # Start data collection self.data_collector.start_collection() # Start CNN training if self.config.min_episodes_for_cnn_training > 0: for symbol in self.data_provider.symbols: self.cnn_trainer.start_real_time_training(symbol) # Start coordinated training thread self.training_threads['coordinator'] = threading.Thread( target=self._training_coordinator_worker, daemon=True ) self.training_threads['coordinator'].start() # Start data collection and validation self.training_threads['data_collector'] = threading.Thread( target=self._enhanced_data_collection_worker, daemon=True ) self.training_threads['data_collector'].start() # Start outcome validation if enabled if self.config.enable_background_validation: self.validation_thread = threading.Thread( target=self._outcome_validation_worker, daemon=True ) self.validation_thread.start() logger.info("Enhanced training integration started") def stop_enhanced_integration(self): """Stop the enhanced training integration system""" self.is_running = False # Stop data collection self.data_collector.stop_collection() # Stop CNN training self.cnn_trainer.stop_training() # Wait for threads to finish for thread_name, thread in self.training_threads.items(): thread.join(timeout=10) logger.info(f"Stopped {thread_name} thread") if self.validation_thread: self.validation_thread.join(timeout=5) logger.info("Enhanced training integration stopped") def _enhanced_data_collection_worker(self): """Enhanced data collection with real-time model integration""" logger.info("Enhanced data collection worker started") while self.is_running: try: for symbol in self.data_provider.symbols: self._collect_enhanced_training_data(symbol) time.sleep(self.config.collection_interval) except Exception as e: logger.error(f"Error in enhanced data collection: {e}") time.sleep(5) logger.info("Enhanced data collection worker stopped") def _collect_enhanced_training_data(self, symbol: str): """Collect enhanced training data with model predictions""" try: # Get comprehensive market data market_data = self._get_comprehensive_market_data(symbol) if not market_data or not self._validate_market_data(market_data): return # Get current model predictions model_predictions = self._get_all_model_predictions(symbol, market_data) # Create enhanced features cnn_features = self._create_enhanced_cnn_features(symbol, market_data) rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions) # Collect training data with predictions episode_id = self.data_collector.collect_training_data( symbol=symbol, ohlcv_data=market_data['ohlcv'], tick_data=market_data['ticks'], cob_data=market_data['cob'], technical_indicators=market_data['indicators'], pivot_points=market_data['pivots'], cnn_features=cnn_features, rl_state=rl_state, orchestrator_context=market_data['context'], model_predictions=model_predictions ) if episode_id: # Store predictions for outcome validation self.recent_predictions[episode_id] = { 'timestamp': datetime.now(), 'symbol': symbol, 'predictions': model_predictions, 'market_data': market_data } # Add RL experience if we have action if 'rl_action' in model_predictions: self._add_rl_experience(symbol, market_data, model_predictions, episode_id) self.integration_stats['total_data_packages'] += 1 except Exception as e: logger.error(f"Error collecting enhanced training data for {symbol}: {e}") def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]: """Get comprehensive market data from all sources""" try: market_data = {} # OHLCV data ohlcv_data = {} for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']: df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True) if df is not None and not df.empty: ohlcv_data[timeframe] = df market_data['ohlcv'] = ohlcv_data # Tick data market_data['ticks'] = self._get_recent_tick_data(symbol) # COB data market_data['cob'] = self._get_cob_data(symbol) # Technical indicators market_data['indicators'] = self._get_technical_indicators(symbol) # Pivot points market_data['pivots'] = self._get_pivot_points(symbol) # Market context market_data['context'] = self._get_market_context(symbol) return market_data except Exception as e: logger.error(f"Error getting comprehensive market data: {e}") return {} def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]: """Get predictions from all available models""" predictions = {} try: # CNN predictions if self.cnn_model and market_data.get('ohlcv'): cnn_features = self._create_enhanced_cnn_features(symbol, market_data) if cnn_features is not None: cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0) # Reshape for CNN (add channel dimension) cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels with torch.no_grad(): cnn_outputs = self.cnn_model(cnn_input) predictions['cnn'] = { 'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(), 'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(), 'confidence': cnn_outputs['confidence'].cpu().numpy(), 'timestamp': datetime.now() } # RL predictions if self.rl_agent and market_data.get('cob'): rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions) if rl_state is not None: action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1) predictions['rl'] = { 'action': action, 'confidence': confidence, 'timestamp': datetime.now() } predictions['rl_action'] = action # Existing COB RL model predictions if self.existing_rl_model and market_data.get('cob'): cob_features = market_data['cob'].get('cob_features', []) if cob_features and len(cob_features) >= 2000: cob_array = np.array(cob_features[:2000], dtype=np.float32) cob_prediction = self.existing_rl_model.predict(cob_array) predictions['cob_rl'] = { 'predicted_direction': cob_prediction.get('predicted_direction', 1), 'confidence': cob_prediction.get('confidence', 0.5), 'value': cob_prediction.get('value', 0.0), 'timestamp': datetime.now() } # Orchestrator predictions (if available) if self.orchestrator: try: # This would integrate with your orchestrator's prediction method orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions) if orchestrator_prediction: predictions['orchestrator'] = orchestrator_prediction except Exception as e: logger.debug(f"Could not get orchestrator prediction: {e}") return predictions except Exception as e: logger.error(f"Error getting model predictions: {e}") return {} def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any], predictions: Dict[str, Any], episode_id: str): """Add RL experience to the training buffer""" try: # Create RL state state = self._create_enhanced_rl_state(symbol, market_data, predictions) if state is None: return # Get action from predictions action = predictions.get('rl_action', 1) # Default to HOLD # Calculate immediate reward (placeholder - would be updated with actual outcome) reward = 0.0 # Create next state (same as current for now - would be updated) next_state = state.copy() # Market context market_context = { 'symbol': symbol, 'episode_id': episode_id, 'timestamp': datetime.now(), 'market_session': market_data['context'].get('market_session', 'unknown'), 'volatility_regime': market_data['context'].get('volatility_regime', 'unknown') } # Add experience experience_id = self.rl_trainer.add_experience( state=state, action=action, reward=reward, next_state=next_state, done=False, market_context=market_context, cnn_predictions=predictions.get('cnn'), confidence_score=predictions.get('rl', {}).get('confidence', 0.0) ) if experience_id: logger.debug(f"Added RL experience: {experience_id}") except Exception as e: logger.error(f"Error adding RL experience: {e}") def _training_coordinator_worker(self): """Coordinate training across all models""" logger.info("Training coordinator worker started") while self.is_running: try: # Check if we should trigger training for symbol in self.data_provider.symbols: self._check_and_trigger_training(symbol) # Wait before next check time.sleep(self.config.training_frequency_minutes * 60) except Exception as e: logger.error(f"Error in training coordinator: {e}") time.sleep(60) logger.info("Training coordinator worker stopped") def _check_and_trigger_training(self, symbol: str): """Check conditions and trigger training if needed""" try: # Get training episodes and experiences episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000) # Check CNN training conditions if len(episodes) >= self.config.min_episodes_for_cnn_training: profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable] if len(profitable_episodes) >= 20: # Minimum profitable episodes logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes") results = self.cnn_trainer.train_on_profitable_episodes( symbol=symbol, min_profitability=self.config.min_profitability_for_replay, max_episodes=len(profitable_episodes) ) if results.get('status') == 'success': self.integration_stats['cnn_training_sessions'] += 1 logger.info(f"CNN training completed for {symbol}") # Check RL training conditions buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics() total_experiences = buffer_stats.get('total_experiences', 0) if total_experiences >= self.config.min_experiences_for_rl_training: profitable_experiences = buffer_stats.get('profitable_experiences', 0) if profitable_experiences >= 50: # Minimum profitable experiences logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences") results = self.rl_trainer.train_on_profitable_experiences( min_profitability=self.config.min_profitability_for_replay, max_experiences=min(profitable_experiences, 500), batch_size=32 ) if results.get('status') == 'success': self.integration_stats['rl_training_sessions'] += 1 logger.info("RL training completed") except Exception as e: logger.error(f"Error checking training conditions for {symbol}: {e}") def _outcome_validation_worker(self): """Background worker for validating prediction outcomes""" logger.info("Outcome validation worker started") while self.is_running: try: self._validate_recent_predictions() time.sleep(300) # Check every 5 minutes except Exception as e: logger.error(f"Error in outcome validation: {e}") time.sleep(60) logger.info("Outcome validation worker stopped") def _validate_recent_predictions(self): """Validate recent predictions against actual outcomes""" try: current_time = datetime.now() validation_delay = timedelta(hours=1) # Wait 1 hour to validate validated_predictions = [] for episode_id, prediction_data in self.recent_predictions.items(): prediction_time = prediction_data['timestamp'] if current_time - prediction_time >= validation_delay: # Validate this prediction outcome = self._calculate_prediction_outcome(prediction_data) if outcome: self.prediction_outcomes[episode_id] = outcome # Update RL experience if exists if 'rl_action' in prediction_data['predictions']: self._update_rl_experience_outcome(episode_id, outcome) # Update statistics if outcome['is_profitable']: self.integration_stats['profitable_predictions'] += 1 self.integration_stats['total_predictions'] += 1 validated_predictions.append(episode_id) # Remove validated predictions for episode_id in validated_predictions: del self.recent_predictions[episode_id] if validated_predictions: logger.info(f"Validated {len(validated_predictions)} predictions") except Exception as e: logger.error(f"Error validating predictions: {e}") def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Calculate actual outcome for a prediction""" try: symbol = prediction_data['symbol'] prediction_time = prediction_data['timestamp'] # Get price data after prediction current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True) if current_df is None or current_df.empty: return None # Find price at prediction time and current price prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame()) if prediction_price.empty: return None base_price = float(prediction_price['close'].iloc[-1]) current_price = float(current_df['close'].iloc[-1]) # Calculate outcome price_change = (current_price - base_price) / base_price is_profitable = abs(price_change) > 0.005 # 0.5% threshold return { 'episode_id': prediction_data.get('episode_id'), 'base_price': base_price, 'current_price': current_price, 'price_change': price_change, 'is_profitable': is_profitable, 'profitability_score': abs(price_change) * 10, # Scale to 0-1 range 'validation_time': datetime.now() } except Exception as e: logger.error(f"Error calculating prediction outcome: {e}") return None def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]): """Update RL experience with actual outcome""" try: # Find the experience ID associated with this episode # This is a simplified approach - in practice you'd maintain better mapping actual_profit = outcome['price_change'] # Determine optimal action based on outcome if outcome['price_change'] > 0.01: optimal_action = 2 # BUY elif outcome['price_change'] < -0.01: optimal_action = 0 # SELL else: optimal_action = 1 # HOLD # Update experience (this would need proper experience ID mapping) # For now, we'll update the most recent experience # In practice, you'd maintain a mapping between episodes and experiences except Exception as e: logger.error(f"Error updating RL experience outcome: {e}") def get_integration_statistics(self) -> Dict[str, Any]: """Get comprehensive integration statistics""" stats = self.integration_stats.copy() # Add component statistics stats['data_collector'] = self.data_collector.get_collection_statistics() stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics() stats['rl_trainer'] = self.rl_trainer.get_training_statistics() # Add performance metrics stats['is_running'] = self.is_running stats['active_symbols'] = len(self.data_provider.symbols) stats['recent_predictions_count'] = len(self.recent_predictions) stats['validated_outcomes_count'] = len(self.prediction_outcomes) # Calculate profitability rate if stats['total_predictions'] > 0: stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions'] else: stats['overall_profitability_rate'] = 0.0 return stats def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]: """Manually trigger training""" results = {} try: if training_type in ['all', 'cnn']: symbols = [symbol] if symbol else self.data_provider.symbols for sym in symbols: cnn_results = self.cnn_trainer.train_on_profitable_episodes( symbol=sym, min_profitability=0.1, max_episodes=200 ) results[f'cnn_{sym}'] = cnn_results if training_type in ['all', 'rl']: rl_results = self.rl_trainer.train_on_profitable_experiences( min_profitability=0.1, max_experiences=500, batch_size=32 ) results['rl'] = rl_results return {'status': 'success', 'results': results} except Exception as e: logger.error(f"Error in manual training trigger: {e}") return {'status': 'error', 'error': str(e)} # Helper methods (simplified implementations) def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]: """Get recent tick data""" # Implementation would get tick data from data provider return [] def _get_cob_data(self, symbol: str) -> Dict[str, Any]: """Get COB data""" # Implementation would get COB data from data provider return {} def _get_technical_indicators(self, symbol: str) -> Dict[str, float]: """Get technical indicators""" # Implementation would get indicators from data provider return {} def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]: """Get pivot points""" # Implementation would get pivot points from data provider return [] def _get_market_context(self, symbol: str) -> Dict[str, Any]: """Get market context""" return { 'symbol': symbol, 'timestamp': datetime.now(), 'market_session': 'unknown', 'volatility_regime': 'unknown' } def _validate_market_data(self, market_data: Dict[str, Any]) -> bool: """Validate market data completeness""" required_fields = ['ohlcv', 'indicators'] return all(field in market_data for field in required_fields) def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]: """Create enhanced CNN features""" try: # Simplified feature creation features = [] # Add OHLCV features for timeframe in ['1m', '5m', '15m', '1h']: if timeframe in market_data.get('ohlcv', {}): df = market_data['ohlcv'][timeframe] if not df.empty: ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values if len(ohlcv_values) > 0: recent_values = ohlcv_values[-60:].flatten() features.extend(recent_values) # Pad to target size target_size = 3000 # 10 channels * 300 sequence length if len(features) < target_size: features.extend([0.0] * (target_size - len(features))) else: features = features[:target_size] return np.array(features, dtype=np.float32) except Exception as e: logger.warning(f"Error creating CNN features: {e}") return None def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any], predictions: Dict[str, Any] = None) -> Optional[np.ndarray]: """Create enhanced RL state""" try: state_features = [] # Add market features if '1m' in market_data.get('ohlcv', {}): df = market_data['ohlcv']['1m'] if not df.empty: latest = df.iloc[-1] state_features.extend([ latest['open'], latest['high'], latest['low'], latest['close'], latest['volume'] ]) # Add technical indicators indicators = market_data.get('indicators', {}) for value in indicators.values(): state_features.append(value) # Add model predictions as features if predictions: if 'cnn' in predictions: cnn_pred = predictions['cnn'] state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0])) state_features.append(cnn_pred.get('confidence', [0.0])[0]) if 'cob_rl' in predictions: cob_pred = predictions['cob_rl'] state_features.append(cob_pred.get('predicted_direction', 1)) state_features.append(cob_pred.get('confidence', 0.5)) # Pad to target size target_size = 2000 if len(state_features) < target_size: state_features.extend([0.0] * (target_size - len(state_features))) else: state_features = state_features[:target_size] return np.array(state_features, dtype=np.float32) except Exception as e: logger.warning(f"Error creating RL state: {e}") return None def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any], predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Get orchestrator prediction""" # This would integrate with your orchestrator return None # Global instance enhanced_training_integration = None def get_enhanced_training_integration(data_provider: DataProvider = None, orchestrator: Orchestrator = None, trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration: """Get global enhanced training integration instance""" global enhanced_training_integration if enhanced_training_integration is None: if data_provider is None: raise ValueError("DataProvider required for first initialization") enhanced_training_integration = EnhancedTrainingIntegration( data_provider, orchestrator, trading_executor ) return enhanced_training_integration