""" Comprehensive Training Data Collection System This module implements a robust training data collection system that: 1. Captures all model inputs with validation and completeness checks 2. Stores training data packages with future outcome validation 3. Detects rapid price changes for high-value training examples 4. Enables replay and retraining on most profitable setups 5. Maintains data integrity and traceability Key Features: - Real-time data package creation with all model inputs - Future outcome validation (profitable vs unprofitable predictions) - Rapid price change detection for premium training examples - Comprehensive data validation and completeness verification - Backpropagation data storage for gradient replay - Training episode profitability tracking and ranking """ import asyncio import json import logging import numpy as np import pandas as pd import pickle import torch from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Tuple, Any, Callable from dataclasses import dataclass, field, asdict from collections import deque import hashlib import threading from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger(__name__) @dataclass class ModelInputPackage: """Complete package of all model inputs at a specific timestamp""" timestamp: datetime symbol: str # Market data inputs ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame} tick_data: List[Dict[str, Any]] # Raw tick data cob_data: Dict[str, Any] # Consolidated Order Book data technical_indicators: Dict[str, float] # All technical indicators pivot_points: List[Dict[str, Any]] # Detected pivot points # Model-specific inputs cnn_features: np.ndarray # CNN input features rl_state: np.ndarray # RL state representation orchestrator_context: Dict[str, Any] # Orchestrator context # Cross-model inputs (outputs from other models) cnn_predictions: Optional[Dict[str, Any]] = None rl_predictions: Optional[Dict[str, Any]] = None orchestrator_decision: Optional[Dict[str, Any]] = None # Data validation data_hash: str = "" completeness_score: float = 0.0 validation_flags: Dict[str, bool] = field(default_factory=dict) def __post_init__(self): """Calculate data hash and completeness after initialization""" self.data_hash = self._calculate_hash() self.completeness_score = self._calculate_completeness() self.validation_flags = self._validate_data() def _calculate_hash(self) -> str: """Calculate hash for data integrity verification""" try: # Create a string representation of all data data_str = f"{self.timestamp}_{self.symbol}" data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}" data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}" data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}" return hashlib.md5(data_str.encode()).hexdigest() except Exception as e: logger.warning(f"Error calculating data hash: {e}") return "invalid_hash" def _calculate_completeness(self) -> float: """Calculate completeness score (0.0 to 1.0)""" try: total_fields = 10 # Total expected data fields complete_fields = 0 # Check each required field if self.ohlcv_data and len(self.ohlcv_data) > 0: complete_fields += 1 if self.tick_data and len(self.tick_data) > 0: complete_fields += 1 if self.cob_data and len(self.cob_data) > 0: complete_fields += 1 if self.technical_indicators and len(self.technical_indicators) > 0: complete_fields += 1 if self.pivot_points and len(self.pivot_points) > 0: complete_fields += 1 if self.cnn_features is not None and self.cnn_features.size > 0: complete_fields += 1 if self.rl_state is not None and self.rl_state.size > 0: complete_fields += 1 if self.orchestrator_context and len(self.orchestrator_context) > 0: complete_fields += 1 if self.cnn_predictions is not None: complete_fields += 1 if self.rl_predictions is not None: complete_fields += 1 return complete_fields / total_fields except Exception as e: logger.warning(f"Error calculating completeness: {e}") return 0.0 def _validate_data(self) -> Dict[str, bool]: """Validate data integrity and consistency""" flags = {} try: # Validate timestamp flags['valid_timestamp'] = isinstance(self.timestamp, datetime) # Validate OHLCV data flags['valid_ohlcv'] = ( self.ohlcv_data is not None and len(self.ohlcv_data) > 0 and all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values()) ) # Validate feature arrays flags['valid_cnn_features'] = ( self.cnn_features is not None and isinstance(self.cnn_features, np.ndarray) and self.cnn_features.size > 0 ) flags['valid_rl_state'] = ( self.rl_state is not None and isinstance(self.rl_state, np.ndarray) and self.rl_state.size > 0 ) # Validate data consistency flags['data_consistent'] = self.completeness_score > 0.7 except Exception as e: logger.warning(f"Error validating data: {e}") flags['validation_error'] = True return flags @dataclass class TrainingOutcome: """Future outcome validation for training data""" input_package_hash: str timestamp: datetime symbol: str # Price movement outcomes price_change_1m: float price_change_5m: float price_change_15m: float price_change_1h: float # Profitability metrics max_profit_potential: float max_loss_potential: float optimal_entry_price: float optimal_exit_price: float optimal_holding_time: timedelta # Classification labels is_profitable: bool profitability_score: float # 0.0 to 1.0 risk_reward_ratio: float # Rapid price change detection is_rapid_change: bool change_velocity: float # Price change per minute volatility_spike: bool # Validation outcome_validated: bool = False validation_timestamp: datetime = field(default_factory=datetime.now) @dataclass class TrainingEpisode: """Complete training episode with inputs, predictions, and outcomes""" episode_id: str input_package: ModelInputPackage model_predictions: Dict[str, Any] # Predictions from all models actual_outcome: TrainingOutcome # Training metadata episode_type: str # 'normal', 'rapid_change', 'high_profit' profitability_rank: float # Ranking among all episodes training_priority: float # Priority for replay training # Backpropagation data storage gradient_data: Optional[Dict[str, torch.Tensor]] = None loss_components: Optional[Dict[str, float]] = None model_states: Optional[Dict[str, Any]] = None # Episode statistics created_timestamp: datetime = field(default_factory=datetime.now) last_trained_timestamp: Optional[datetime] = None training_count: int = 0 def calculate_training_priority(self) -> float: """Calculate training priority based on profitability and characteristics""" try: priority = 0.0 # Base priority from profitability if self.actual_outcome.is_profitable: priority += self.actual_outcome.profitability_score * 0.4 # Bonus for rapid changes (high learning value) if self.actual_outcome.is_rapid_change: priority += 0.3 # Bonus for high risk-reward ratio if self.actual_outcome.risk_reward_ratio > 2.0: priority += 0.2 # Bonus for data completeness priority += self.input_package.completeness_score * 0.1 # Penalty for frequent training (avoid overfitting) if self.training_count > 5: priority *= 0.8 return min(priority, 1.0) except Exception as e: logger.warning(f"Error calculating training priority: {e}") return 0.0 class RapidChangeDetector: """Detects rapid price changes for high-value training examples""" def __init__(self, velocity_threshold: float = 0.5, # % per minute volatility_multiplier: float = 3.0, lookback_minutes: int = 5): self.velocity_threshold = velocity_threshold self.volatility_multiplier = volatility_multiplier self.lookback_minutes = lookback_minutes # Price history for change detection self.price_history: Dict[str, deque] = {} self.volatility_baseline: Dict[str, float] = {} def add_price_point(self, symbol: str, timestamp: datetime, price: float): """Add new price point for change detection""" if symbol not in self.price_history: self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution self.volatility_baseline[symbol] = 0.0 self.price_history[symbol].append((timestamp, price)) self._update_volatility_baseline(symbol) def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]: """ Detect rapid price changes Returns: (is_rapid_change, change_velocity, volatility_spike) """ if symbol not in self.price_history or len(self.price_history[symbol]) < 60: return False, 0.0, False try: prices = list(self.price_history[symbol]) # Calculate recent velocity (last minute) recent_prices = prices[-60:] # Last 60 seconds if len(recent_prices) < 2: return False, 0.0, False start_price = recent_prices[0][1] end_price = recent_prices[-1][1] time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes if time_diff <= 0: return False, 0.0, False # Calculate velocity (% change per minute) velocity = abs((end_price - start_price) / start_price * 100) / time_diff # Check for rapid change is_rapid = velocity > self.velocity_threshold # Check for volatility spike current_volatility = self._calculate_current_volatility(symbol) baseline_volatility = self.volatility_baseline.get(symbol, 0.0) volatility_spike = ( baseline_volatility > 0 and current_volatility > baseline_volatility * self.volatility_multiplier ) return is_rapid, velocity, volatility_spike except Exception as e: logger.warning(f"Error detecting rapid change for {symbol}: {e}") return False, 0.0, False def _update_volatility_baseline(self, symbol: str): """Update volatility baseline for the symbol""" try: if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data return # Calculate rolling volatility over longer period prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes if len(prices) < 2: return # Calculate standard deviation of price changes price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))] volatility = np.std(price_changes) * 100 # Convert to percentage # Update baseline with exponential moving average alpha = 0.1 if self.volatility_baseline[symbol] == 0: self.volatility_baseline[symbol] = volatility else: self.volatility_baseline[symbol] = ( alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol] ) except Exception as e: logger.warning(f"Error updating volatility baseline for {symbol}: {e}") def _calculate_current_volatility(self, symbol: str) -> float: """Calculate current volatility for the symbol""" try: if len(self.price_history[symbol]) < 60: return 0.0 # Use last minute of data recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]] if len(recent_prices) < 2: return 0.0 price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1] for i in range(1, len(recent_prices))] return np.std(price_changes) * 100 except Exception as e: logger.warning(f"Error calculating current volatility for {symbol}: {e}") return 0.0 class TrainingDataCollector: """Main training data collection system""" def __init__(self, storage_dir: str = "training_data", max_episodes_per_symbol: int = 10000, outcome_validation_delay: timedelta = timedelta(hours=1)): self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(parents=True, exist_ok=True) self.max_episodes_per_symbol = max_episodes_per_symbol self.outcome_validation_delay = outcome_validation_delay # Data storage self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes} self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation # Rapid change detection self.rapid_change_detector = RapidChangeDetector() # Data validation and statistics self.collection_stats = { 'total_episodes': 0, 'profitable_episodes': 0, 'rapid_change_episodes': 0, 'validation_errors': 0, 'data_completeness_avg': 0.0 } # Background processing self.is_collecting = False self.collection_thread = None self.outcome_validation_thread = None # Thread safety self.data_lock = threading.Lock() logger.info(f"Training Data Collector initialized") logger.info(f"Storage directory: {self.storage_dir}") logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}") def start_collection(self): """Start the training data collection system""" if self.is_collecting: logger.warning("Training data collection already running") return self.is_collecting = True # Start outcome validation thread self.outcome_validation_thread = threading.Thread( target=self._outcome_validation_worker, daemon=True ) self.outcome_validation_thread.start() logger.info("Training data collection started") def stop_collection(self): """Stop the training data collection system""" self.is_collecting = False if self.outcome_validation_thread: self.outcome_validation_thread.join(timeout=5) logger.info("Training data collection stopped") def collect_training_data(self, symbol: str, ohlcv_data: Dict[str, pd.DataFrame], tick_data: List[Dict[str, Any]], cob_data: Dict[str, Any], technical_indicators: Dict[str, float], pivot_points: List[Dict[str, Any]], cnn_features: np.ndarray, rl_state: np.ndarray, orchestrator_context: Dict[str, Any], model_predictions: Dict[str, Any] = None) -> str: """ Collect comprehensive training data package Returns: episode_id for tracking """ try: # Create input package input_package = ModelInputPackage( timestamp=datetime.now(), symbol=symbol, ohlcv_data=ohlcv_data, tick_data=tick_data, cob_data=cob_data, technical_indicators=technical_indicators, pivot_points=pivot_points, cnn_features=cnn_features, rl_state=rl_state, orchestrator_context=orchestrator_context ) # Validate data completeness if input_package.completeness_score < 0.5: logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}") self.collection_stats['validation_errors'] += 1 return None # Check for rapid price changes current_price = self._extract_current_price(ohlcv_data) if current_price: self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price) # Add to pending outcomes for future validation with self.data_lock: if symbol not in self.pending_outcomes: self.pending_outcomes[symbol] = [] self.pending_outcomes[symbol].append(input_package) # Limit pending outcomes to prevent memory issues if len(self.pending_outcomes[symbol]) > 1000: self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:] # Generate episode ID episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}" # Update statistics self.collection_stats['total_episodes'] += 1 self.collection_stats['data_completeness_avg'] = ( (self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) + input_package.completeness_score) / self.collection_stats['total_episodes'] ) logger.debug(f"Collected training data for {symbol}: {episode_id}") logger.debug(f"Data completeness: {input_package.completeness_score:.2f}") return episode_id except Exception as e: logger.error(f"Error collecting training data for {symbol}: {e}") self.collection_stats['validation_errors'] += 1 return None def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]: """Extract current price from OHLCV data""" try: # Try to get price from shortest timeframe first for timeframe in ['1s', '1m', '5m', '15m', '1h']: if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty: return float(ohlcv_data[timeframe]['close'].iloc[-1]) return None except Exception as e: logger.warning(f"Error extracting current price: {e}") return None def _outcome_validation_worker(self): """Background worker for validating training outcomes""" logger.info("Outcome validation worker started") while self.is_collecting: try: self._validate_pending_outcomes() threading.Event().wait(60) # Check every minute except Exception as e: logger.error(f"Error in outcome validation worker: {e}") threading.Event().wait(30) # Wait before retrying logger.info("Outcome validation worker stopped") def _validate_pending_outcomes(self): """Validate outcomes for pending training data""" current_time = datetime.now() with self.data_lock: for symbol in list(self.pending_outcomes.keys()): if symbol not in self.pending_outcomes: continue validated_packages = [] remaining_packages = [] for package in self.pending_outcomes[symbol]: # Check if enough time has passed for outcome validation if current_time - package.timestamp >= self.outcome_validation_delay: outcome = self._calculate_training_outcome(package) if outcome: self._create_training_episode(package, outcome) validated_packages.append(package) else: remaining_packages.append(package) else: remaining_packages.append(package) # Update pending outcomes self.pending_outcomes[symbol] = remaining_packages if validated_packages: logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}") def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]: """Calculate training outcome based on future price movements""" try: # This would typically fetch recent price data to calculate outcomes # For now, we'll create a placeholder implementation # Extract base price from input package base_price = self._extract_current_price(input_package.ohlcv_data) if not base_price: return None # Simulate outcome calculation (in real implementation, fetch actual future prices) # This is where you would integrate with your data provider to get actual outcomes # Check for rapid change is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change( input_package.symbol ) # Create outcome (placeholder values - replace with actual calculation) outcome = TrainingOutcome( input_package_hash=input_package.data_hash, timestamp=input_package.timestamp, symbol=input_package.symbol, price_change_1m=0.0, # Calculate from actual future data price_change_5m=0.0, price_change_15m=0.0, price_change_1h=0.0, max_profit_potential=0.0, max_loss_potential=0.0, optimal_entry_price=base_price, optimal_exit_price=base_price, optimal_holding_time=timedelta(minutes=5), is_profitable=False, # Determine from actual outcomes profitability_score=0.0, risk_reward_ratio=1.0, is_rapid_change=is_rapid, change_velocity=velocity, volatility_spike=volatility_spike, outcome_validated=True ) return outcome except Exception as e: logger.error(f"Error calculating training outcome: {e}") return None def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome): """Create complete training episode""" try: episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}" # Determine episode type episode_type = 'normal' if outcome.is_rapid_change: episode_type = 'rapid_change' self.collection_stats['rapid_change_episodes'] += 1 elif outcome.profitability_score > 0.8: episode_type = 'high_profit' if outcome.is_profitable: self.collection_stats['profitable_episodes'] += 1 # Create training episode episode = TrainingEpisode( episode_id=episode_id, input_package=input_package, model_predictions={}, # Will be filled when models make predictions actual_outcome=outcome, episode_type=episode_type, profitability_rank=0.0, # Will be calculated later training_priority=0.0 ) # Calculate training priority episode.training_priority = episode.calculate_training_priority() # Store episode symbol = input_package.symbol if symbol not in self.training_episodes: self.training_episodes[symbol] = [] self.training_episodes[symbol].append(episode) # Limit episodes per symbol if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol: # Keep highest priority episodes self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True) self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol] # Save episode to disk self._save_episode_to_disk(episode) logger.debug(f"Created training episode: {episode_id}") logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}") except Exception as e: logger.error(f"Error creating training episode: {e}") def _save_episode_to_disk(self, episode: TrainingEpisode): """Save training episode to disk for persistence""" try: symbol_dir = self.storage_dir / episode.input_package.symbol symbol_dir.mkdir(parents=True, exist_ok=True) # Save episode data episode_file = symbol_dir / f"{episode.episode_id}.pkl" with open(episode_file, 'wb') as f: pickle.dump(episode, f) # Save episode metadata for quick access metadata = { 'episode_id': episode.episode_id, 'timestamp': episode.input_package.timestamp.isoformat(), 'episode_type': episode.episode_type, 'training_priority': episode.training_priority, 'profitability_score': episode.actual_outcome.profitability_score, 'is_profitable': episode.actual_outcome.is_profitable, 'is_rapid_change': episode.actual_outcome.is_rapid_change, 'data_completeness': episode.input_package.completeness_score } metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json" with open(metadata_file, 'w') as f: json.dump(metadata, f, indent=2) except Exception as e: logger.error(f"Error saving episode to disk: {e}") def get_high_priority_episodes(self, symbol: str, limit: int = 100, min_priority: float = 0.5) -> List[TrainingEpisode]: """Get high-priority training episodes for replay training""" try: if symbol not in self.training_episodes: return [] # Filter and sort by priority high_priority = [ ep for ep in self.training_episodes[symbol] if ep.training_priority >= min_priority ] high_priority.sort(key=lambda x: x.training_priority, reverse=True) return high_priority[:limit] except Exception as e: logger.error(f"Error getting high priority episodes for {symbol}: {e}") return [] def get_collection_statistics(self) -> Dict[str, Any]: """Get comprehensive collection statistics""" stats = self.collection_stats.copy() # Add per-symbol statistics stats['episodes_per_symbol'] = { symbol: len(episodes) for symbol, episodes in self.training_episodes.items() } # Add pending outcomes count stats['pending_outcomes'] = { symbol: len(packages) for symbol, packages in self.pending_outcomes.items() } # Calculate profitability rate if stats['total_episodes'] > 0: stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes'] stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes'] else: stats['profitability_rate'] = 0.0 stats['rapid_change_rate'] = 0.0 return stats def validate_data_integrity(self) -> Dict[str, Any]: """Comprehensive data integrity validation""" validation_results = { 'total_episodes_checked': 0, 'hash_mismatches': 0, 'completeness_issues': 0, 'validation_flag_failures': 0, 'corrupted_episodes': [], 'integrity_score': 1.0 } try: for symbol, episodes in self.training_episodes.items(): for episode in episodes: validation_results['total_episodes_checked'] += 1 # Check data hash expected_hash = episode.input_package._calculate_hash() if expected_hash != episode.input_package.data_hash: validation_results['hash_mismatches'] += 1 validation_results['corrupted_episodes'].append(episode.episode_id) # Check completeness if episode.input_package.completeness_score < 0.7: validation_results['completeness_issues'] += 1 # Check validation flags if not episode.input_package.validation_flags.get('data_consistent', False): validation_results['validation_flag_failures'] += 1 # Calculate integrity score total_issues = ( validation_results['hash_mismatches'] + validation_results['completeness_issues'] + validation_results['validation_flag_failures'] ) if validation_results['total_episodes_checked'] > 0: validation_results['integrity_score'] = 1.0 - ( total_issues / validation_results['total_episodes_checked'] ) logger.info(f"Data integrity validation completed") logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}") except Exception as e: logger.error(f"Error during data integrity validation: {e}") validation_results['validation_error'] = str(e) return validation_results # Global instance for easy access training_data_collector = None def get_training_data_collector() -> TrainingDataCollector: """Get global training data collector instance""" global training_data_collector if training_data_collector is None: training_data_collector = TrainingDataCollector() return training_data_collector