Files
gogo2/core/training_data_collector.py
Dobromir Popov 12865fd3ef replay system
2025-07-20 12:37:02 +03:00

795 lines
32 KiB
Python

"""
Comprehensive Training Data Collection System
This module implements a robust training data collection system that:
1. Captures all model inputs with validation and completeness checks
2. Stores training data packages with future outcome validation
3. Detects rapid price changes for high-value training examples
4. Enables replay and retraining on most profitable setups
5. Maintains data integrity and traceability
Key Features:
- Real-time data package creation with all model inputs
- Future outcome validation (profitable vs unprofitable predictions)
- Rapid price change detection for premium training examples
- Comprehensive data validation and completeness verification
- Backpropagation data storage for gradient replay
- Training episode profitability tracking and ranking
"""
import asyncio
import json
import logging
import numpy as np
import pandas as pd
import pickle
import torch
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass, field, asdict
from collections import deque
import hashlib
import threading
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
@dataclass
class ModelInputPackage:
"""Complete package of all model inputs at a specific timestamp"""
timestamp: datetime
symbol: str
# Market data inputs
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
tick_data: List[Dict[str, Any]] # Raw tick data
cob_data: Dict[str, Any] # Consolidated Order Book data
technical_indicators: Dict[str, float] # All technical indicators
pivot_points: List[Dict[str, Any]] # Detected pivot points
# Model-specific inputs
cnn_features: np.ndarray # CNN input features
rl_state: np.ndarray # RL state representation
orchestrator_context: Dict[str, Any] # Orchestrator context
# Cross-model inputs (outputs from other models)
cnn_predictions: Optional[Dict[str, Any]] = None
rl_predictions: Optional[Dict[str, Any]] = None
orchestrator_decision: Optional[Dict[str, Any]] = None
# Data validation
data_hash: str = ""
completeness_score: float = 0.0
validation_flags: Dict[str, bool] = field(default_factory=dict)
def __post_init__(self):
"""Calculate data hash and completeness after initialization"""
self.data_hash = self._calculate_hash()
self.completeness_score = self._calculate_completeness()
self.validation_flags = self._validate_data()
def _calculate_hash(self) -> str:
"""Calculate hash for data integrity verification"""
try:
# Create a string representation of all data
data_str = f"{self.timestamp}_{self.symbol}"
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
return hashlib.md5(data_str.encode()).hexdigest()
except Exception as e:
logger.warning(f"Error calculating data hash: {e}")
return "invalid_hash"
def _calculate_completeness(self) -> float:
"""Calculate completeness score (0.0 to 1.0)"""
try:
total_fields = 10 # Total expected data fields
complete_fields = 0
# Check each required field
if self.ohlcv_data and len(self.ohlcv_data) > 0:
complete_fields += 1
if self.tick_data and len(self.tick_data) > 0:
complete_fields += 1
if self.cob_data and len(self.cob_data) > 0:
complete_fields += 1
if self.technical_indicators and len(self.technical_indicators) > 0:
complete_fields += 1
if self.pivot_points and len(self.pivot_points) > 0:
complete_fields += 1
if self.cnn_features is not None and self.cnn_features.size > 0:
complete_fields += 1
if self.rl_state is not None and self.rl_state.size > 0:
complete_fields += 1
if self.orchestrator_context and len(self.orchestrator_context) > 0:
complete_fields += 1
if self.cnn_predictions is not None:
complete_fields += 1
if self.rl_predictions is not None:
complete_fields += 1
return complete_fields / total_fields
except Exception as e:
logger.warning(f"Error calculating completeness: {e}")
return 0.0
def _validate_data(self) -> Dict[str, bool]:
"""Validate data integrity and consistency"""
flags = {}
try:
# Validate timestamp
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
# Validate OHLCV data
flags['valid_ohlcv'] = (
self.ohlcv_data is not None and
len(self.ohlcv_data) > 0 and
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
)
# Validate feature arrays
flags['valid_cnn_features'] = (
self.cnn_features is not None and
isinstance(self.cnn_features, np.ndarray) and
self.cnn_features.size > 0
)
flags['valid_rl_state'] = (
self.rl_state is not None and
isinstance(self.rl_state, np.ndarray) and
self.rl_state.size > 0
)
# Validate data consistency
flags['data_consistent'] = self.completeness_score > 0.7
except Exception as e:
logger.warning(f"Error validating data: {e}")
flags['validation_error'] = True
return flags
@dataclass
class TrainingOutcome:
"""Future outcome validation for training data"""
input_package_hash: str
timestamp: datetime
symbol: str
# Price movement outcomes
price_change_1m: float
price_change_5m: float
price_change_15m: float
price_change_1h: float
# Profitability metrics
max_profit_potential: float
max_loss_potential: float
optimal_entry_price: float
optimal_exit_price: float
optimal_holding_time: timedelta
# Classification labels
is_profitable: bool
profitability_score: float # 0.0 to 1.0
risk_reward_ratio: float
# Rapid price change detection
is_rapid_change: bool
change_velocity: float # Price change per minute
volatility_spike: bool
# Validation
outcome_validated: bool = False
validation_timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class TrainingEpisode:
"""Complete training episode with inputs, predictions, and outcomes"""
episode_id: str
input_package: ModelInputPackage
model_predictions: Dict[str, Any] # Predictions from all models
actual_outcome: TrainingOutcome
# Training metadata
episode_type: str # 'normal', 'rapid_change', 'high_profit'
profitability_rank: float # Ranking among all episodes
training_priority: float # Priority for replay training
# Backpropagation data storage
gradient_data: Optional[Dict[str, torch.Tensor]] = None
loss_components: Optional[Dict[str, float]] = None
model_states: Optional[Dict[str, Any]] = None
# Episode statistics
created_timestamp: datetime = field(default_factory=datetime.now)
last_trained_timestamp: Optional[datetime] = None
training_count: int = 0
def calculate_training_priority(self) -> float:
"""Calculate training priority based on profitability and characteristics"""
try:
priority = 0.0
# Base priority from profitability
if self.actual_outcome.is_profitable:
priority += self.actual_outcome.profitability_score * 0.4
# Bonus for rapid changes (high learning value)
if self.actual_outcome.is_rapid_change:
priority += 0.3
# Bonus for high risk-reward ratio
if self.actual_outcome.risk_reward_ratio > 2.0:
priority += 0.2
# Bonus for data completeness
priority += self.input_package.completeness_score * 0.1
# Penalty for frequent training (avoid overfitting)
if self.training_count > 5:
priority *= 0.8
return min(priority, 1.0)
except Exception as e:
logger.warning(f"Error calculating training priority: {e}")
return 0.0
class RapidChangeDetector:
"""Detects rapid price changes for high-value training examples"""
def __init__(self,
velocity_threshold: float = 0.5, # % per minute
volatility_multiplier: float = 3.0,
lookback_minutes: int = 5):
self.velocity_threshold = velocity_threshold
self.volatility_multiplier = volatility_multiplier
self.lookback_minutes = lookback_minutes
# Price history for change detection
self.price_history: Dict[str, deque] = {}
self.volatility_baseline: Dict[str, float] = {}
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
"""Add new price point for change detection"""
if symbol not in self.price_history:
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
self.volatility_baseline[symbol] = 0.0
self.price_history[symbol].append((timestamp, price))
self._update_volatility_baseline(symbol)
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
"""
Detect rapid price changes
Returns:
(is_rapid_change, change_velocity, volatility_spike)
"""
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
return False, 0.0, False
try:
prices = list(self.price_history[symbol])
# Calculate recent velocity (last minute)
recent_prices = prices[-60:] # Last 60 seconds
if len(recent_prices) < 2:
return False, 0.0, False
start_price = recent_prices[0][1]
end_price = recent_prices[-1][1]
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
if time_diff <= 0:
return False, 0.0, False
# Calculate velocity (% change per minute)
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
# Check for rapid change
is_rapid = velocity > self.velocity_threshold
# Check for volatility spike
current_volatility = self._calculate_current_volatility(symbol)
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
volatility_spike = (
baseline_volatility > 0 and
current_volatility > baseline_volatility * self.volatility_multiplier
)
return is_rapid, velocity, volatility_spike
except Exception as e:
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
return False, 0.0, False
def _update_volatility_baseline(self, symbol: str):
"""Update volatility baseline for the symbol"""
try:
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
return
# Calculate rolling volatility over longer period
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
if len(prices) < 2:
return
# Calculate standard deviation of price changes
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
volatility = np.std(price_changes) * 100 # Convert to percentage
# Update baseline with exponential moving average
alpha = 0.1
if self.volatility_baseline[symbol] == 0:
self.volatility_baseline[symbol] = volatility
else:
self.volatility_baseline[symbol] = (
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
)
except Exception as e:
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
def _calculate_current_volatility(self, symbol: str) -> float:
"""Calculate current volatility for the symbol"""
try:
if len(self.price_history[symbol]) < 60:
return 0.0
# Use last minute of data
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
if len(recent_prices) < 2:
return 0.0
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
for i in range(1, len(recent_prices))]
return np.std(price_changes) * 100
except Exception as e:
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
return 0.0
class TrainingDataCollector:
"""Main training data collection system"""
def __init__(self,
storage_dir: str = "training_data",
max_episodes_per_symbol: int = 10000,
outcome_validation_delay: timedelta = timedelta(hours=1)):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.max_episodes_per_symbol = max_episodes_per_symbol
self.outcome_validation_delay = outcome_validation_delay
# Data storage
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
# Rapid change detection
self.rapid_change_detector = RapidChangeDetector()
# Data validation and statistics
self.collection_stats = {
'total_episodes': 0,
'profitable_episodes': 0,
'rapid_change_episodes': 0,
'validation_errors': 0,
'data_completeness_avg': 0.0
}
# Background processing
self.is_collecting = False
self.collection_thread = None
self.outcome_validation_thread = None
# Thread safety
self.data_lock = threading.Lock()
logger.info(f"Training Data Collector initialized")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
def start_collection(self):
"""Start the training data collection system"""
if self.is_collecting:
logger.warning("Training data collection already running")
return
self.is_collecting = True
# Start outcome validation thread
self.outcome_validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.outcome_validation_thread.start()
logger.info("Training data collection started")
def stop_collection(self):
"""Stop the training data collection system"""
self.is_collecting = False
if self.outcome_validation_thread:
self.outcome_validation_thread.join(timeout=5)
logger.info("Training data collection stopped")
def collect_training_data(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
tick_data: List[Dict[str, Any]],
cob_data: Dict[str, Any],
technical_indicators: Dict[str, float],
pivot_points: List[Dict[str, Any]],
cnn_features: np.ndarray,
rl_state: np.ndarray,
orchestrator_context: Dict[str, Any],
model_predictions: Dict[str, Any] = None) -> str:
"""
Collect comprehensive training data package
Returns:
episode_id for tracking
"""
try:
# Create input package
input_package = ModelInputPackage(
timestamp=datetime.now(),
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context
)
# Validate data completeness
if input_package.completeness_score < 0.5:
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
self.collection_stats['validation_errors'] += 1
return None
# Check for rapid price changes
current_price = self._extract_current_price(ohlcv_data)
if current_price:
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
# Add to pending outcomes for future validation
with self.data_lock:
if symbol not in self.pending_outcomes:
self.pending_outcomes[symbol] = []
self.pending_outcomes[symbol].append(input_package)
# Limit pending outcomes to prevent memory issues
if len(self.pending_outcomes[symbol]) > 1000:
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
# Generate episode ID
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Update statistics
self.collection_stats['total_episodes'] += 1
self.collection_stats['data_completeness_avg'] = (
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
input_package.completeness_score) / self.collection_stats['total_episodes']
)
logger.debug(f"Collected training data for {symbol}: {episode_id}")
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
return episode_id
except Exception as e:
logger.error(f"Error collecting training data for {symbol}: {e}")
self.collection_stats['validation_errors'] += 1
return None
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
"""Extract current price from OHLCV data"""
try:
# Try to get price from shortest timeframe first
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
return float(ohlcv_data[timeframe]['close'].iloc[-1])
return None
except Exception as e:
logger.warning(f"Error extracting current price: {e}")
return None
def _outcome_validation_worker(self):
"""Background worker for validating training outcomes"""
logger.info("Outcome validation worker started")
while self.is_collecting:
try:
self._validate_pending_outcomes()
threading.Event().wait(60) # Check every minute
except Exception as e:
logger.error(f"Error in outcome validation worker: {e}")
threading.Event().wait(30) # Wait before retrying
logger.info("Outcome validation worker stopped")
def _validate_pending_outcomes(self):
"""Validate outcomes for pending training data"""
current_time = datetime.now()
with self.data_lock:
for symbol in list(self.pending_outcomes.keys()):
if symbol not in self.pending_outcomes:
continue
validated_packages = []
remaining_packages = []
for package in self.pending_outcomes[symbol]:
# Check if enough time has passed for outcome validation
if current_time - package.timestamp >= self.outcome_validation_delay:
outcome = self._calculate_training_outcome(package)
if outcome:
self._create_training_episode(package, outcome)
validated_packages.append(package)
else:
remaining_packages.append(package)
else:
remaining_packages.append(package)
# Update pending outcomes
self.pending_outcomes[symbol] = remaining_packages
if validated_packages:
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
"""Calculate training outcome based on future price movements"""
try:
# This would typically fetch recent price data to calculate outcomes
# For now, we'll create a placeholder implementation
# Extract base price from input package
base_price = self._extract_current_price(input_package.ohlcv_data)
if not base_price:
return None
# Simulate outcome calculation (in real implementation, fetch actual future prices)
# This is where you would integrate with your data provider to get actual outcomes
# Check for rapid change
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
input_package.symbol
)
# Create outcome (placeholder values - replace with actual calculation)
outcome = TrainingOutcome(
input_package_hash=input_package.data_hash,
timestamp=input_package.timestamp,
symbol=input_package.symbol,
price_change_1m=0.0, # Calculate from actual future data
price_change_5m=0.0,
price_change_15m=0.0,
price_change_1h=0.0,
max_profit_potential=0.0,
max_loss_potential=0.0,
optimal_entry_price=base_price,
optimal_exit_price=base_price,
optimal_holding_time=timedelta(minutes=5),
is_profitable=False, # Determine from actual outcomes
profitability_score=0.0,
risk_reward_ratio=1.0,
is_rapid_change=is_rapid,
change_velocity=velocity,
volatility_spike=volatility_spike,
outcome_validated=True
)
return outcome
except Exception as e:
logger.error(f"Error calculating training outcome: {e}")
return None
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
"""Create complete training episode"""
try:
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
# Determine episode type
episode_type = 'normal'
if outcome.is_rapid_change:
episode_type = 'rapid_change'
self.collection_stats['rapid_change_episodes'] += 1
elif outcome.profitability_score > 0.8:
episode_type = 'high_profit'
if outcome.is_profitable:
self.collection_stats['profitable_episodes'] += 1
# Create training episode
episode = TrainingEpisode(
episode_id=episode_id,
input_package=input_package,
model_predictions={}, # Will be filled when models make predictions
actual_outcome=outcome,
episode_type=episode_type,
profitability_rank=0.0, # Will be calculated later
training_priority=0.0
)
# Calculate training priority
episode.training_priority = episode.calculate_training_priority()
# Store episode
symbol = input_package.symbol
if symbol not in self.training_episodes:
self.training_episodes[symbol] = []
self.training_episodes[symbol].append(episode)
# Limit episodes per symbol
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
# Keep highest priority episodes
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
# Save episode to disk
self._save_episode_to_disk(episode)
logger.debug(f"Created training episode: {episode_id}")
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
except Exception as e:
logger.error(f"Error creating training episode: {e}")
def _save_episode_to_disk(self, episode: TrainingEpisode):
"""Save training episode to disk for persistence"""
try:
symbol_dir = self.storage_dir / episode.input_package.symbol
symbol_dir.mkdir(parents=True, exist_ok=True)
# Save episode data
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
with open(episode_file, 'wb') as f:
pickle.dump(episode, f)
# Save episode metadata for quick access
metadata = {
'episode_id': episode.episode_id,
'timestamp': episode.input_package.timestamp.isoformat(),
'episode_type': episode.episode_type,
'training_priority': episode.training_priority,
'profitability_score': episode.actual_outcome.profitability_score,
'is_profitable': episode.actual_outcome.is_profitable,
'is_rapid_change': episode.actual_outcome.is_rapid_change,
'data_completeness': episode.input_package.completeness_score
}
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving episode to disk: {e}")
def get_high_priority_episodes(self,
symbol: str,
limit: int = 100,
min_priority: float = 0.5) -> List[TrainingEpisode]:
"""Get high-priority training episodes for replay training"""
try:
if symbol not in self.training_episodes:
return []
# Filter and sort by priority
high_priority = [
ep for ep in self.training_episodes[symbol]
if ep.training_priority >= min_priority
]
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
return high_priority[:limit]
except Exception as e:
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
return []
def get_collection_statistics(self) -> Dict[str, Any]:
"""Get comprehensive collection statistics"""
stats = self.collection_stats.copy()
# Add per-symbol statistics
stats['episodes_per_symbol'] = {
symbol: len(episodes)
for symbol, episodes in self.training_episodes.items()
}
# Add pending outcomes count
stats['pending_outcomes'] = {
symbol: len(packages)
for symbol, packages in self.pending_outcomes.items()
}
# Calculate profitability rate
if stats['total_episodes'] > 0:
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
else:
stats['profitability_rate'] = 0.0
stats['rapid_change_rate'] = 0.0
return stats
def validate_data_integrity(self) -> Dict[str, Any]:
"""Comprehensive data integrity validation"""
validation_results = {
'total_episodes_checked': 0,
'hash_mismatches': 0,
'completeness_issues': 0,
'validation_flag_failures': 0,
'corrupted_episodes': [],
'integrity_score': 1.0
}
try:
for symbol, episodes in self.training_episodes.items():
for episode in episodes:
validation_results['total_episodes_checked'] += 1
# Check data hash
expected_hash = episode.input_package._calculate_hash()
if expected_hash != episode.input_package.data_hash:
validation_results['hash_mismatches'] += 1
validation_results['corrupted_episodes'].append(episode.episode_id)
# Check completeness
if episode.input_package.completeness_score < 0.7:
validation_results['completeness_issues'] += 1
# Check validation flags
if not episode.input_package.validation_flags.get('data_consistent', False):
validation_results['validation_flag_failures'] += 1
# Calculate integrity score
total_issues = (
validation_results['hash_mismatches'] +
validation_results['completeness_issues'] +
validation_results['validation_flag_failures']
)
if validation_results['total_episodes_checked'] > 0:
validation_results['integrity_score'] = 1.0 - (
total_issues / validation_results['total_episodes_checked']
)
logger.info(f"Data integrity validation completed")
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
except Exception as e:
logger.error(f"Error during data integrity validation: {e}")
validation_results['validation_error'] = str(e)
return validation_results
# Global instance for easy access
training_data_collector = None
def get_training_data_collector() -> TrainingDataCollector:
"""Get global training data collector instance"""
global training_data_collector
if training_data_collector is None:
training_data_collector = TrainingDataCollector()
return training_data_collector