""" Training Integration Module This module integrates the comprehensive training data collection system with the existing data provider and model infrastructure. It provides: 1. Real-time data collection from DataProvider 2. Integration with existing CNN and RL models 3. Automatic training data package creation 4. Rapid price change detection and collection 5. Training pipeline coordination Key Features: - Seamless integration with existing DataProvider - Automatic model input package creation - Real-time training data validation - Coordinated training across all models - Performance monitoring and optimization """ import asyncio import logging import numpy as np import pandas as pd import torch from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple, Any, Callable from dataclasses import dataclass import threading import time from collections import deque from .training_data_collector import ( TrainingDataCollector, ModelInputPackage, get_training_data_collector ) from .cnn_training_pipeline import ( CNNTrainer, CNNPivotPredictor, get_cnn_trainer ) from .data_provider import DataProvider logger = logging.getLogger(__name__) @dataclass class TrainingIntegrationConfig: """Configuration for training integration""" # Data collection settings collection_interval: float = 1.0 # seconds min_data_completeness: float = 0.7 # Rapid change detection enable_rapid_change_detection: bool = True price_change_threshold: float = 0.5 # % per minute # Training settings enable_real_time_training: bool = True training_batch_size: int = 32 min_episodes_for_training: int = 50 # Performance settings max_concurrent_collections: int = 4 data_validation_enabled: bool = True class TrainingIntegration: """Main integration class for training data collection and model training""" def __init__(self, data_provider: DataProvider, config: TrainingIntegrationConfig = None): self.data_provider = data_provider self.config = config or TrainingIntegrationConfig() # Get training components self.data_collector = get_training_data_collector() # Initialize CNN components self.cnn_model = CNNPivotPredictor() self.cnn_trainer = get_cnn_trainer(self.cnn_model) # Integration state self.is_running = False self.collection_thread = None self.training_threads = {} # Data buffers for real-time processing self.data_buffers = {} self.last_collection_time = {} # Performance tracking self.integration_stats = { 'data_packages_created': 0, 'training_sessions_triggered': 0, 'rapid_changes_detected': 0, 'validation_failures': 0, 'average_collection_time': 0.0, 'last_update': datetime.now() } # Initialize data buffers for each symbol for symbol in self.data_provider.symbols: self.data_buffers[symbol] = { 'ohlcv_data': {}, 'tick_data': deque(maxlen=1000), 'cob_data': {}, 'technical_indicators': {}, 'pivot_points': [] } self.last_collection_time[symbol] = datetime.now() logger.info("Training Integration initialized") logger.info(f"Symbols: {self.data_provider.symbols}") logger.info(f"Real-time training: {self.config.enable_real_time_training}") logger.info(f"Rapid change detection: {self.config.enable_rapid_change_detection}") def start_integration(self): """Start the training integration system""" if self.is_running: logger.warning("Training integration already running") return self.is_running = True # Start data collection self.data_collector.start_collection() # Start real-time data collection thread self.collection_thread = threading.Thread( target=self._data_collection_worker, daemon=True ) self.collection_thread.start() # Start training threads for each symbol if self.config.enable_real_time_training: for symbol in self.data_provider.symbols: training_thread = threading.Thread( target=self._training_worker, args=(symbol,), daemon=True ) self.training_threads[symbol] = training_thread training_thread.start() logger.info("Training integration started") def stop_integration(self): """Stop the training integration system""" self.is_running = False # Stop data collection self.data_collector.stop_collection() # Stop CNN training self.cnn_trainer.stop_training() # Wait for threads to finish if self.collection_thread: self.collection_thread.join(timeout=10) for thread in self.training_threads.values(): thread.join(timeout=5) logger.info("Training integration stopped") def _data_collection_worker(self): """Main data collection worker""" logger.info("Data collection worker started") while self.is_running: try: start_time = time.time() # Collect data for each symbol for symbol in self.data_provider.symbols: self._collect_symbol_data(symbol) # Update performance stats collection_time = time.time() - start_time self._update_collection_stats(collection_time) # Wait for next collection cycle time.sleep(self.config.collection_interval) except Exception as e: logger.error(f"Error in data collection worker: {e}") time.sleep(5) # Wait before retrying logger.info("Data collection worker stopped") def _collect_symbol_data(self, symbol: str): """Collect comprehensive training data for a symbol""" try: # Get current market data from data provider ohlcv_data = self._get_ohlcv_data(symbol) tick_data = self._get_tick_data(symbol) cob_data = self._get_cob_data(symbol) technical_indicators = self._get_technical_indicators(symbol) pivot_points = self._get_pivot_points(symbol) # Validate data availability if not self._validate_data_availability(symbol, ohlcv_data, tick_data): return # Create model input features cnn_features = self._create_cnn_features(symbol, ohlcv_data, technical_indicators) rl_state = self._create_rl_state(symbol, ohlcv_data, cob_data, technical_indicators) orchestrator_context = self._create_orchestrator_context(symbol) # Get model predictions if available model_predictions = self._get_current_model_predictions(symbol) # Collect training data package episode_id = self.data_collector.collect_training_data( symbol=symbol, ohlcv_data=ohlcv_data, tick_data=tick_data, cob_data=cob_data, technical_indicators=technical_indicators, pivot_points=pivot_points, cnn_features=cnn_features, rl_state=rl_state, orchestrator_context=orchestrator_context, model_predictions=model_predictions ) if episode_id: self.integration_stats['data_packages_created'] += 1 logger.debug(f"Created training data package for {symbol}: {episode_id}") except Exception as e: logger.error(f"Error collecting data for {symbol}: {e}") self.integration_stats['validation_failures'] += 1 def _get_ohlcv_data(self, symbol: str) -> Dict[str, pd.DataFrame]: """Get OHLCV data for all timeframes""" ohlcv_data = {} try: for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']: df = self.data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, limit=300, # Get 300 bars as specified in requirements refresh=True # Get fresh data ) if df is not None and not df.empty: ohlcv_data[timeframe] = df return ohlcv_data except Exception as e: logger.warning(f"Error getting OHLCV data for {symbol}: {e}") return {} def _get_tick_data(self, symbol: str) -> List[Dict[str, Any]]: """Get recent tick data""" try: # Get tick data from data provider's tick buffers binance_symbol = symbol.replace('/', '').upper() if binance_symbol in self.data_provider.tick_buffers: # Get last 300 seconds of tick data current_time = datetime.now() cutoff_time = current_time - timedelta(seconds=300) tick_buffer = self.data_provider.tick_buffers[binance_symbol] recent_ticks = [] # Convert deque to list and filter by time for tick in list(tick_buffer): if hasattr(tick, 'timestamp') and tick.timestamp >= cutoff_time: recent_ticks.append({ 'timestamp': tick.timestamp, 'price': tick.price, 'volume': tick.volume, 'side': tick.side, 'trade_id': tick.trade_id }) return recent_ticks return [] except Exception as e: logger.warning(f"Error getting tick data for {symbol}: {e}") return [] def _get_cob_data(self, symbol: str) -> Dict[str, Any]: """Get Consolidated Order Book data""" try: # Get COB data from data provider's COB cache binance_symbol = symbol.replace('/', '').upper() if binance_symbol in self.data_provider.cob_data_cache: cob_buffer = self.data_provider.cob_data_cache[binance_symbol] if cob_buffer: # Get the most recent COB data latest_cob = list(cob_buffer)[-1] if cob_buffer else None if latest_cob: return { 'timestamp': latest_cob[0] if isinstance(latest_cob, tuple) else datetime.now(), 'cob_features': latest_cob[1] if isinstance(latest_cob, tuple) else latest_cob, 'feature_count': len(latest_cob[1]) if isinstance(latest_cob, tuple) else 0 } return {} except Exception as e: logger.warning(f"Error getting COB data for {symbol}: {e}") return {} def _get_technical_indicators(self, symbol: str) -> Dict[str, float]: """Get technical indicators from OHLCV data""" try: # Get the most recent 1m data with indicators df = self.data_provider.get_historical_data( symbol=symbol, timeframe='1m', limit=50, refresh=True ) if df is not None and not df.empty: # Extract indicators from the latest row latest_row = df.iloc[-1] indicators = {} # Extract common indicators for col in df.columns: if col not in ['open', 'high', 'low', 'close', 'volume', 'timestamp']: try: value = float(latest_row[col]) if not np.isnan(value): indicators[col] = value except (ValueError, TypeError): continue return indicators return {} except Exception as e: logger.warning(f"Error getting technical indicators for {symbol}: {e}") return {} def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]: """Get recent pivot points""" try: # Get pivot points from Williams Market Structure if symbol in self.data_provider.williams_structure: williams = self.data_provider.williams_structure[symbol] # Get recent pivot points pivot_points = [] # This would integrate with the Williams Market Structure # For now, return empty list as placeholder return pivot_points return [] except Exception as e: logger.warning(f"Error getting pivot points for {symbol}: {e}") return [] def _create_cnn_features(self, symbol: str, ohlcv_data: Dict[str, pd.DataFrame], technical_indicators: Dict[str, float]) -> np.ndarray: """Create CNN input features from market data""" try: # This is a simplified feature creation # In practice, you'd create multi-timeframe features features = [] # Add OHLCV features from multiple timeframes for timeframe in ['1s', '1m', '5m', '15m', '1h']: if timeframe in ohlcv_data: df = ohlcv_data[timeframe] if not df.empty: # Normalize OHLCV data ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values if len(ohlcv_values) > 0: # Take last 60 values and flatten recent_values = ohlcv_values[-60:].flatten() features.extend(recent_values) # Add technical indicators for indicator_name, value in technical_indicators.items(): features.append(value) # Pad or truncate to fixed size target_size = 2000 # Match CNN input size if len(features) < target_size: features.extend([0.0] * (target_size - len(features))) else: features = features[:target_size] return np.array(features, dtype=np.float32) except Exception as e: logger.warning(f"Error creating CNN features for {symbol}: {e}") return np.zeros(2000, dtype=np.float32) def _create_rl_state(self, symbol: str, ohlcv_data: Dict[str, pd.DataFrame], cob_data: Dict[str, Any], technical_indicators: Dict[str, float]) -> np.ndarray: """Create RL state representation""" try: state_features = [] # Add market state features if '1m' in ohlcv_data and not ohlcv_data['1m'].empty: latest_candle = ohlcv_data['1m'].iloc[-1] state_features.extend([ latest_candle['open'], latest_candle['high'], latest_candle['low'], latest_candle['close'], latest_candle['volume'] ]) # Add COB features if 'cob_features' in cob_data: cob_features = cob_data['cob_features'] if isinstance(cob_features, (list, np.ndarray)): state_features.extend(cob_features[:100]) # Limit COB features # Add technical indicators for indicator_name, value in technical_indicators.items(): state_features.append(value) # Pad or truncate to fixed size target_size = 2000 # Match RL input size if len(state_features) < target_size: state_features.extend([0.0] * (target_size - len(state_features))) else: state_features = state_features[:target_size] return np.array(state_features, dtype=np.float32) except Exception as e: logger.warning(f"Error creating RL state for {symbol}: {e}") return np.zeros(2000, dtype=np.float32) def _create_orchestrator_context(self, symbol: str) -> Dict[str, Any]: """Create orchestrator context""" try: return { 'symbol': symbol, 'timestamp': datetime.now(), 'market_session': self._determine_market_session(), 'volatility_regime': self._determine_volatility_regime(symbol), 'trend_direction': self._determine_trend_direction(symbol) } except Exception as e: logger.warning(f"Error creating orchestrator context for {symbol}: {e}") return {'symbol': symbol, 'timestamp': datetime.now()} def _determine_market_session(self) -> str: """Determine current market session""" # Simplified market session detection current_hour = datetime.now().hour if 0 <= current_hour < 8: return 'asian' elif 8 <= current_hour < 16: return 'european' else: return 'american' def _determine_volatility_regime(self, symbol: str) -> str: """Determine volatility regime for symbol""" try: # Get recent volatility data df = self.data_provider.get_historical_data(symbol, '1m', limit=100) if df is not None and not df.empty: returns = df['close'].pct_change().dropna() volatility = returns.std() if volatility > 0.02: return 'high' elif volatility > 0.01: return 'medium' else: return 'low' return 'unknown' except Exception: return 'unknown' def _determine_trend_direction(self, symbol: str) -> str: """Determine trend direction for symbol""" try: # Simple trend detection using moving averages df = self.data_provider.get_historical_data(symbol, '1h', limit=50) if df is not None and not df.empty: if 'sma_20' in df.columns and 'sma_50' in df.columns: latest_sma20 = df['sma_20'].iloc[-1] latest_sma50 = df['sma_50'].iloc[-1] if latest_sma20 > latest_sma50: return 'uptrend' elif latest_sma20 < latest_sma50: return 'downtrend' else: return 'sideways' return 'unknown' except Exception: return 'unknown' def _get_current_model_predictions(self, symbol: str) -> Dict[str, Any]: """Get current predictions from all models""" predictions = {} try: # This would integrate with existing model predictions # For now, return empty dict as placeholder return predictions except Exception as e: logger.warning(f"Error getting model predictions for {symbol}: {e}") return {} def _validate_data_availability(self, symbol: str, ohlcv_data: Dict[str, pd.DataFrame], tick_data: List[Dict[str, Any]]) -> bool: """Validate that sufficient data is available for training""" try: # Check OHLCV data availability required_timeframes = ['1m', '5m', '1h'] available_timeframes = 0 for timeframe in required_timeframes: if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty: available_timeframes += 1 # Check minimum data requirements if available_timeframes < 2: # Need at least 2 timeframes return False # Check tick data availability (optional but preferred) has_tick_data = len(tick_data) > 0 # Calculate completeness score completeness = available_timeframes / len(required_timeframes) if has_tick_data: completeness += 0.1 # Bonus for tick data return completeness >= self.config.min_data_completeness except Exception as e: logger.warning(f"Error validating data availability for {symbol}: {e}") return False def _training_worker(self, symbol: str): """Training worker for a specific symbol""" logger.info(f"Training worker started for {symbol}") while self.is_running: try: # Check if we have enough episodes for training episodes = self.data_collector.get_high_priority_episodes( symbol=symbol, limit=self.config.training_batch_size * 2, min_priority=0.3 ) if len(episodes) >= self.config.min_episodes_for_training: # Trigger CNN training results = self.cnn_trainer.train_on_profitable_episodes( symbol=symbol, min_profitability=0.6, max_episodes=len(episodes) ) if results.get('status') == 'success': self.integration_stats['training_sessions_triggered'] += 1 logger.info(f"Training session completed for {symbol}") # Wait before next training check time.sleep(300) # Check every 5 minutes except Exception as e: logger.error(f"Error in training worker for {symbol}: {e}") time.sleep(60) # Wait before retrying logger.info(f"Training worker stopped for {symbol}") def _update_collection_stats(self, collection_time: float): """Update collection performance statistics""" try: # Update average collection time alpha = 0.1 # Exponential moving average factor if self.integration_stats['average_collection_time'] == 0: self.integration_stats['average_collection_time'] = collection_time else: self.integration_stats['average_collection_time'] = ( alpha * collection_time + (1 - alpha) * self.integration_stats['average_collection_time'] ) self.integration_stats['last_update'] = datetime.now() except Exception as e: logger.warning(f"Error updating collection stats: {e}") def get_integration_statistics(self) -> Dict[str, Any]: """Get comprehensive integration statistics""" stats = self.integration_stats.copy() # Add data collector statistics collector_stats = self.data_collector.get_collection_statistics() stats.update(collector_stats) # Add CNN trainer statistics trainer_stats = self.cnn_trainer.get_training_statistics() stats['cnn_training'] = trainer_stats # Add performance metrics stats['is_running'] = self.is_running stats['active_symbols'] = len(self.data_provider.symbols) stats['collection_frequency'] = self.config.collection_interval return stats def trigger_manual_training(self, symbol: str, training_type: str = 'profitable') -> Dict[str, Any]: """Manually trigger training for a symbol""" try: if training_type == 'profitable': results = self.cnn_trainer.train_on_profitable_episodes( symbol=symbol, min_profitability=0.7, max_episodes=200 ) elif training_type == 'high_value_replay': results = self.cnn_trainer.replay_high_value_sessions( symbol=symbol, min_session_value=0.8, max_sessions=10 ) else: return {'status': 'error', 'error': f'Unknown training type: {training_type}'} if results.get('status') == 'success': self.integration_stats['training_sessions_triggered'] += 1 return results except Exception as e: logger.error(f"Error in manual training trigger: {e}") return {'status': 'error', 'error': str(e)} # Global instance training_integration = None def get_training_integration(data_provider: DataProvider = None) -> TrainingIntegration: """Get global training integration instance""" global training_integration if training_integration is None: if data_provider is None: raise ValueError("DataProvider required for first initialization") training_integration = TrainingIntegration(data_provider) return training_integration