795 lines
32 KiB
Python
795 lines
32 KiB
Python
"""
|
|
Comprehensive Training Data Collection System
|
|
|
|
This module implements a robust training data collection system that:
|
|
1. Captures all model inputs with validation and completeness checks
|
|
2. Stores training data packages with future outcome validation
|
|
3. Detects rapid price changes for high-value training examples
|
|
4. Enables replay and retraining on most profitable setups
|
|
5. Maintains data integrity and traceability
|
|
|
|
Key Features:
|
|
- Real-time data package creation with all model inputs
|
|
- Future outcome validation (profitable vs unprofitable predictions)
|
|
- Rapid price change detection for premium training examples
|
|
- Comprehensive data validation and completeness verification
|
|
- Backpropagation data storage for gradient replay
|
|
- Training episode profitability tracking and ranking
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pickle
|
|
import torch
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
from dataclasses import dataclass, field, asdict
|
|
from collections import deque
|
|
import hashlib
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ModelInputPackage:
|
|
"""Complete package of all model inputs at a specific timestamp"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
|
|
# Market data inputs
|
|
ohlcv_data: Dict[str, pd.DataFrame] # {timeframe: DataFrame}
|
|
tick_data: List[Dict[str, Any]] # Raw tick data
|
|
cob_data: Dict[str, Any] # Consolidated Order Book data
|
|
technical_indicators: Dict[str, float] # All technical indicators
|
|
pivot_points: List[Dict[str, Any]] # Detected pivot points
|
|
|
|
# Model-specific inputs
|
|
cnn_features: np.ndarray # CNN input features
|
|
rl_state: np.ndarray # RL state representation
|
|
orchestrator_context: Dict[str, Any] # Orchestrator context
|
|
|
|
# Cross-model inputs (outputs from other models)
|
|
cnn_predictions: Optional[Dict[str, Any]] = None
|
|
rl_predictions: Optional[Dict[str, Any]] = None
|
|
orchestrator_decision: Optional[Dict[str, Any]] = None
|
|
|
|
# Data validation
|
|
data_hash: str = ""
|
|
completeness_score: float = 0.0
|
|
validation_flags: Dict[str, bool] = field(default_factory=dict)
|
|
|
|
def __post_init__(self):
|
|
"""Calculate data hash and completeness after initialization"""
|
|
self.data_hash = self._calculate_hash()
|
|
self.completeness_score = self._calculate_completeness()
|
|
self.validation_flags = self._validate_data()
|
|
|
|
def _calculate_hash(self) -> str:
|
|
"""Calculate hash for data integrity verification"""
|
|
try:
|
|
# Create a string representation of all data
|
|
data_str = f"{self.timestamp}_{self.symbol}"
|
|
data_str += f"_{len(self.ohlcv_data)}_{len(self.tick_data)}"
|
|
data_str += f"_{self.cnn_features.shape if self.cnn_features is not None else 'None'}"
|
|
data_str += f"_{self.rl_state.shape if self.rl_state is not None else 'None'}"
|
|
|
|
return hashlib.md5(data_str.encode()).hexdigest()
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating data hash: {e}")
|
|
return "invalid_hash"
|
|
|
|
def _calculate_completeness(self) -> float:
|
|
"""Calculate completeness score (0.0 to 1.0)"""
|
|
try:
|
|
total_fields = 10 # Total expected data fields
|
|
complete_fields = 0
|
|
|
|
# Check each required field
|
|
if self.ohlcv_data and len(self.ohlcv_data) > 0:
|
|
complete_fields += 1
|
|
if self.tick_data and len(self.tick_data) > 0:
|
|
complete_fields += 1
|
|
if self.cob_data and len(self.cob_data) > 0:
|
|
complete_fields += 1
|
|
if self.technical_indicators and len(self.technical_indicators) > 0:
|
|
complete_fields += 1
|
|
if self.pivot_points and len(self.pivot_points) > 0:
|
|
complete_fields += 1
|
|
if self.cnn_features is not None and self.cnn_features.size > 0:
|
|
complete_fields += 1
|
|
if self.rl_state is not None and self.rl_state.size > 0:
|
|
complete_fields += 1
|
|
if self.orchestrator_context and len(self.orchestrator_context) > 0:
|
|
complete_fields += 1
|
|
if self.cnn_predictions is not None:
|
|
complete_fields += 1
|
|
if self.rl_predictions is not None:
|
|
complete_fields += 1
|
|
|
|
return complete_fields / total_fields
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating completeness: {e}")
|
|
return 0.0
|
|
|
|
def _validate_data(self) -> Dict[str, bool]:
|
|
"""Validate data integrity and consistency"""
|
|
flags = {}
|
|
|
|
try:
|
|
# Validate timestamp
|
|
flags['valid_timestamp'] = isinstance(self.timestamp, datetime)
|
|
|
|
# Validate OHLCV data
|
|
flags['valid_ohlcv'] = (
|
|
self.ohlcv_data is not None and
|
|
len(self.ohlcv_data) > 0 and
|
|
all(isinstance(df, pd.DataFrame) for df in self.ohlcv_data.values())
|
|
)
|
|
|
|
# Validate feature arrays
|
|
flags['valid_cnn_features'] = (
|
|
self.cnn_features is not None and
|
|
isinstance(self.cnn_features, np.ndarray) and
|
|
self.cnn_features.size > 0
|
|
)
|
|
|
|
flags['valid_rl_state'] = (
|
|
self.rl_state is not None and
|
|
isinstance(self.rl_state, np.ndarray) and
|
|
self.rl_state.size > 0
|
|
)
|
|
|
|
# Validate data consistency
|
|
flags['data_consistent'] = self.completeness_score > 0.7
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error validating data: {e}")
|
|
flags['validation_error'] = True
|
|
|
|
return flags
|
|
|
|
@dataclass
|
|
class TrainingOutcome:
|
|
"""Future outcome validation for training data"""
|
|
input_package_hash: str
|
|
timestamp: datetime
|
|
symbol: str
|
|
|
|
# Price movement outcomes
|
|
price_change_1m: float
|
|
price_change_5m: float
|
|
price_change_15m: float
|
|
price_change_1h: float
|
|
|
|
# Profitability metrics
|
|
max_profit_potential: float
|
|
max_loss_potential: float
|
|
optimal_entry_price: float
|
|
optimal_exit_price: float
|
|
optimal_holding_time: timedelta
|
|
|
|
# Classification labels
|
|
is_profitable: bool
|
|
profitability_score: float # 0.0 to 1.0
|
|
risk_reward_ratio: float
|
|
|
|
# Rapid price change detection
|
|
is_rapid_change: bool
|
|
change_velocity: float # Price change per minute
|
|
volatility_spike: bool
|
|
|
|
# Validation
|
|
outcome_validated: bool = False
|
|
validation_timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
@dataclass
|
|
class TrainingEpisode:
|
|
"""Complete training episode with inputs, predictions, and outcomes"""
|
|
episode_id: str
|
|
input_package: ModelInputPackage
|
|
model_predictions: Dict[str, Any] # Predictions from all models
|
|
actual_outcome: TrainingOutcome
|
|
|
|
# Training metadata
|
|
episode_type: str # 'normal', 'rapid_change', 'high_profit'
|
|
profitability_rank: float # Ranking among all episodes
|
|
training_priority: float # Priority for replay training
|
|
|
|
# Backpropagation data storage
|
|
gradient_data: Optional[Dict[str, torch.Tensor]] = None
|
|
loss_components: Optional[Dict[str, float]] = None
|
|
model_states: Optional[Dict[str, Any]] = None
|
|
|
|
# Episode statistics
|
|
created_timestamp: datetime = field(default_factory=datetime.now)
|
|
last_trained_timestamp: Optional[datetime] = None
|
|
training_count: int = 0
|
|
|
|
def calculate_training_priority(self) -> float:
|
|
"""Calculate training priority based on profitability and characteristics"""
|
|
try:
|
|
priority = 0.0
|
|
|
|
# Base priority from profitability
|
|
if self.actual_outcome.is_profitable:
|
|
priority += self.actual_outcome.profitability_score * 0.4
|
|
|
|
# Bonus for rapid changes (high learning value)
|
|
if self.actual_outcome.is_rapid_change:
|
|
priority += 0.3
|
|
|
|
# Bonus for high risk-reward ratio
|
|
if self.actual_outcome.risk_reward_ratio > 2.0:
|
|
priority += 0.2
|
|
|
|
# Bonus for data completeness
|
|
priority += self.input_package.completeness_score * 0.1
|
|
|
|
# Penalty for frequent training (avoid overfitting)
|
|
if self.training_count > 5:
|
|
priority *= 0.8
|
|
|
|
return min(priority, 1.0)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating training priority: {e}")
|
|
return 0.0
|
|
|
|
class RapidChangeDetector:
|
|
"""Detects rapid price changes for high-value training examples"""
|
|
|
|
def __init__(self,
|
|
velocity_threshold: float = 0.5, # % per minute
|
|
volatility_multiplier: float = 3.0,
|
|
lookback_minutes: int = 5):
|
|
self.velocity_threshold = velocity_threshold
|
|
self.volatility_multiplier = volatility_multiplier
|
|
self.lookback_minutes = lookback_minutes
|
|
|
|
# Price history for change detection
|
|
self.price_history: Dict[str, deque] = {}
|
|
self.volatility_baseline: Dict[str, float] = {}
|
|
|
|
def add_price_point(self, symbol: str, timestamp: datetime, price: float):
|
|
"""Add new price point for change detection"""
|
|
if symbol not in self.price_history:
|
|
self.price_history[symbol] = deque(maxlen=self.lookback_minutes * 60) # 1 second resolution
|
|
self.volatility_baseline[symbol] = 0.0
|
|
|
|
self.price_history[symbol].append((timestamp, price))
|
|
self._update_volatility_baseline(symbol)
|
|
|
|
def detect_rapid_change(self, symbol: str) -> Tuple[bool, float, bool]:
|
|
"""
|
|
Detect rapid price changes
|
|
|
|
Returns:
|
|
(is_rapid_change, change_velocity, volatility_spike)
|
|
"""
|
|
if symbol not in self.price_history or len(self.price_history[symbol]) < 60:
|
|
return False, 0.0, False
|
|
|
|
try:
|
|
prices = list(self.price_history[symbol])
|
|
|
|
# Calculate recent velocity (last minute)
|
|
recent_prices = prices[-60:] # Last 60 seconds
|
|
if len(recent_prices) < 2:
|
|
return False, 0.0, False
|
|
|
|
start_price = recent_prices[0][1]
|
|
end_price = recent_prices[-1][1]
|
|
time_diff = (recent_prices[-1][0] - recent_prices[0][0]).total_seconds() / 60.0 # minutes
|
|
|
|
if time_diff <= 0:
|
|
return False, 0.0, False
|
|
|
|
# Calculate velocity (% change per minute)
|
|
velocity = abs((end_price - start_price) / start_price * 100) / time_diff
|
|
|
|
# Check for rapid change
|
|
is_rapid = velocity > self.velocity_threshold
|
|
|
|
# Check for volatility spike
|
|
current_volatility = self._calculate_current_volatility(symbol)
|
|
baseline_volatility = self.volatility_baseline.get(symbol, 0.0)
|
|
volatility_spike = (
|
|
baseline_volatility > 0 and
|
|
current_volatility > baseline_volatility * self.volatility_multiplier
|
|
)
|
|
|
|
return is_rapid, velocity, volatility_spike
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error detecting rapid change for {symbol}: {e}")
|
|
return False, 0.0, False
|
|
|
|
def _update_volatility_baseline(self, symbol: str):
|
|
"""Update volatility baseline for the symbol"""
|
|
try:
|
|
if len(self.price_history[symbol]) < 120: # Need at least 2 minutes of data
|
|
return
|
|
|
|
# Calculate rolling volatility over longer period
|
|
prices = [p[1] for p in list(self.price_history[symbol])[-300:]] # Last 5 minutes
|
|
if len(prices) < 2:
|
|
return
|
|
|
|
# Calculate standard deviation of price changes
|
|
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
|
volatility = np.std(price_changes) * 100 # Convert to percentage
|
|
|
|
# Update baseline with exponential moving average
|
|
alpha = 0.1
|
|
if self.volatility_baseline[symbol] == 0:
|
|
self.volatility_baseline[symbol] = volatility
|
|
else:
|
|
self.volatility_baseline[symbol] = (
|
|
alpha * volatility + (1 - alpha) * self.volatility_baseline[symbol]
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error updating volatility baseline for {symbol}: {e}")
|
|
|
|
def _calculate_current_volatility(self, symbol: str) -> float:
|
|
"""Calculate current volatility for the symbol"""
|
|
try:
|
|
if len(self.price_history[symbol]) < 60:
|
|
return 0.0
|
|
|
|
# Use last minute of data
|
|
recent_prices = [p[1] for p in list(self.price_history[symbol])[-60:]]
|
|
if len(recent_prices) < 2:
|
|
return 0.0
|
|
|
|
price_changes = [abs(recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
|
for i in range(1, len(recent_prices))]
|
|
return np.std(price_changes) * 100
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating current volatility for {symbol}: {e}")
|
|
return 0.0
|
|
|
|
class TrainingDataCollector:
|
|
"""Main training data collection system"""
|
|
|
|
def __init__(self,
|
|
storage_dir: str = "training_data",
|
|
max_episodes_per_symbol: int = 10000,
|
|
outcome_validation_delay: timedelta = timedelta(hours=1)):
|
|
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.max_episodes_per_symbol = max_episodes_per_symbol
|
|
self.outcome_validation_delay = outcome_validation_delay
|
|
|
|
# Data storage
|
|
self.training_episodes: Dict[str, List[TrainingEpisode]] = {} # {symbol: episodes}
|
|
self.pending_outcomes: Dict[str, List[ModelInputPackage]] = {} # Awaiting outcome validation
|
|
|
|
# Rapid change detection
|
|
self.rapid_change_detector = RapidChangeDetector()
|
|
|
|
# Data validation and statistics
|
|
self.collection_stats = {
|
|
'total_episodes': 0,
|
|
'profitable_episodes': 0,
|
|
'rapid_change_episodes': 0,
|
|
'validation_errors': 0,
|
|
'data_completeness_avg': 0.0
|
|
}
|
|
|
|
# Background processing
|
|
self.is_collecting = False
|
|
self.collection_thread = None
|
|
self.outcome_validation_thread = None
|
|
|
|
# Thread safety
|
|
self.data_lock = threading.Lock()
|
|
|
|
logger.info(f"Training Data Collector initialized")
|
|
logger.info(f"Storage directory: {self.storage_dir}")
|
|
logger.info(f"Max episodes per symbol: {self.max_episodes_per_symbol}")
|
|
|
|
def start_collection(self):
|
|
"""Start the training data collection system"""
|
|
if self.is_collecting:
|
|
logger.warning("Training data collection already running")
|
|
return
|
|
|
|
self.is_collecting = True
|
|
|
|
# Start outcome validation thread
|
|
self.outcome_validation_thread = threading.Thread(
|
|
target=self._outcome_validation_worker,
|
|
daemon=True
|
|
)
|
|
self.outcome_validation_thread.start()
|
|
|
|
logger.info("Training data collection started")
|
|
|
|
def stop_collection(self):
|
|
"""Stop the training data collection system"""
|
|
self.is_collecting = False
|
|
|
|
if self.outcome_validation_thread:
|
|
self.outcome_validation_thread.join(timeout=5)
|
|
|
|
logger.info("Training data collection stopped")
|
|
|
|
def collect_training_data(self,
|
|
symbol: str,
|
|
ohlcv_data: Dict[str, pd.DataFrame],
|
|
tick_data: List[Dict[str, Any]],
|
|
cob_data: Dict[str, Any],
|
|
technical_indicators: Dict[str, float],
|
|
pivot_points: List[Dict[str, Any]],
|
|
cnn_features: np.ndarray,
|
|
rl_state: np.ndarray,
|
|
orchestrator_context: Dict[str, Any],
|
|
model_predictions: Dict[str, Any] = None) -> str:
|
|
"""
|
|
Collect comprehensive training data package
|
|
|
|
Returns:
|
|
episode_id for tracking
|
|
"""
|
|
try:
|
|
# Create input package
|
|
input_package = ModelInputPackage(
|
|
timestamp=datetime.now(),
|
|
symbol=symbol,
|
|
ohlcv_data=ohlcv_data,
|
|
tick_data=tick_data,
|
|
cob_data=cob_data,
|
|
technical_indicators=technical_indicators,
|
|
pivot_points=pivot_points,
|
|
cnn_features=cnn_features,
|
|
rl_state=rl_state,
|
|
orchestrator_context=orchestrator_context
|
|
)
|
|
|
|
# Validate data completeness
|
|
if input_package.completeness_score < 0.5:
|
|
logger.warning(f"Low data completeness for {symbol}: {input_package.completeness_score:.2f}")
|
|
self.collection_stats['validation_errors'] += 1
|
|
return None
|
|
|
|
# Check for rapid price changes
|
|
current_price = self._extract_current_price(ohlcv_data)
|
|
if current_price:
|
|
self.rapid_change_detector.add_price_point(symbol, input_package.timestamp, current_price)
|
|
|
|
# Add to pending outcomes for future validation
|
|
with self.data_lock:
|
|
if symbol not in self.pending_outcomes:
|
|
self.pending_outcomes[symbol] = []
|
|
|
|
self.pending_outcomes[symbol].append(input_package)
|
|
|
|
# Limit pending outcomes to prevent memory issues
|
|
if len(self.pending_outcomes[symbol]) > 1000:
|
|
self.pending_outcomes[symbol] = self.pending_outcomes[symbol][-500:]
|
|
|
|
# Generate episode ID
|
|
episode_id = f"{symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
|
|
|
# Update statistics
|
|
self.collection_stats['total_episodes'] += 1
|
|
self.collection_stats['data_completeness_avg'] = (
|
|
(self.collection_stats['data_completeness_avg'] * (self.collection_stats['total_episodes'] - 1) +
|
|
input_package.completeness_score) / self.collection_stats['total_episodes']
|
|
)
|
|
|
|
logger.debug(f"Collected training data for {symbol}: {episode_id}")
|
|
logger.debug(f"Data completeness: {input_package.completeness_score:.2f}")
|
|
|
|
return episode_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error collecting training data for {symbol}: {e}")
|
|
self.collection_stats['validation_errors'] += 1
|
|
return None
|
|
|
|
def _extract_current_price(self, ohlcv_data: Dict[str, pd.DataFrame]) -> Optional[float]:
|
|
"""Extract current price from OHLCV data"""
|
|
try:
|
|
# Try to get price from shortest timeframe first
|
|
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
|
|
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
|
|
return float(ohlcv_data[timeframe]['close'].iloc[-1])
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting current price: {e}")
|
|
return None
|
|
|
|
def _outcome_validation_worker(self):
|
|
"""Background worker for validating training outcomes"""
|
|
logger.info("Outcome validation worker started")
|
|
|
|
while self.is_collecting:
|
|
try:
|
|
self._validate_pending_outcomes()
|
|
threading.Event().wait(60) # Check every minute
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in outcome validation worker: {e}")
|
|
threading.Event().wait(30) # Wait before retrying
|
|
|
|
logger.info("Outcome validation worker stopped")
|
|
|
|
def _validate_pending_outcomes(self):
|
|
"""Validate outcomes for pending training data"""
|
|
current_time = datetime.now()
|
|
|
|
with self.data_lock:
|
|
for symbol in list(self.pending_outcomes.keys()):
|
|
if symbol not in self.pending_outcomes:
|
|
continue
|
|
|
|
validated_packages = []
|
|
remaining_packages = []
|
|
|
|
for package in self.pending_outcomes[symbol]:
|
|
# Check if enough time has passed for outcome validation
|
|
if current_time - package.timestamp >= self.outcome_validation_delay:
|
|
outcome = self._calculate_training_outcome(package)
|
|
if outcome:
|
|
self._create_training_episode(package, outcome)
|
|
validated_packages.append(package)
|
|
else:
|
|
remaining_packages.append(package)
|
|
else:
|
|
remaining_packages.append(package)
|
|
|
|
# Update pending outcomes
|
|
self.pending_outcomes[symbol] = remaining_packages
|
|
|
|
if validated_packages:
|
|
logger.info(f"Validated {len(validated_packages)} outcomes for {symbol}")
|
|
|
|
def _calculate_training_outcome(self, input_package: ModelInputPackage) -> Optional[TrainingOutcome]:
|
|
"""Calculate training outcome based on future price movements"""
|
|
try:
|
|
# This would typically fetch recent price data to calculate outcomes
|
|
# For now, we'll create a placeholder implementation
|
|
|
|
# Extract base price from input package
|
|
base_price = self._extract_current_price(input_package.ohlcv_data)
|
|
if not base_price:
|
|
return None
|
|
|
|
# Simulate outcome calculation (in real implementation, fetch actual future prices)
|
|
# This is where you would integrate with your data provider to get actual outcomes
|
|
|
|
# Check for rapid change
|
|
is_rapid, velocity, volatility_spike = self.rapid_change_detector.detect_rapid_change(
|
|
input_package.symbol
|
|
)
|
|
|
|
# Create outcome (placeholder values - replace with actual calculation)
|
|
outcome = TrainingOutcome(
|
|
input_package_hash=input_package.data_hash,
|
|
timestamp=input_package.timestamp,
|
|
symbol=input_package.symbol,
|
|
price_change_1m=0.0, # Calculate from actual future data
|
|
price_change_5m=0.0,
|
|
price_change_15m=0.0,
|
|
price_change_1h=0.0,
|
|
max_profit_potential=0.0,
|
|
max_loss_potential=0.0,
|
|
optimal_entry_price=base_price,
|
|
optimal_exit_price=base_price,
|
|
optimal_holding_time=timedelta(minutes=5),
|
|
is_profitable=False, # Determine from actual outcomes
|
|
profitability_score=0.0,
|
|
risk_reward_ratio=1.0,
|
|
is_rapid_change=is_rapid,
|
|
change_velocity=velocity,
|
|
volatility_spike=volatility_spike,
|
|
outcome_validated=True
|
|
)
|
|
|
|
return outcome
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating training outcome: {e}")
|
|
return None
|
|
|
|
def _create_training_episode(self, input_package: ModelInputPackage, outcome: TrainingOutcome):
|
|
"""Create complete training episode"""
|
|
try:
|
|
episode_id = f"{input_package.symbol}_{input_package.timestamp.strftime('%Y%m%d_%H%M%S')}_{input_package.data_hash[:8]}"
|
|
|
|
# Determine episode type
|
|
episode_type = 'normal'
|
|
if outcome.is_rapid_change:
|
|
episode_type = 'rapid_change'
|
|
self.collection_stats['rapid_change_episodes'] += 1
|
|
elif outcome.profitability_score > 0.8:
|
|
episode_type = 'high_profit'
|
|
|
|
if outcome.is_profitable:
|
|
self.collection_stats['profitable_episodes'] += 1
|
|
|
|
# Create training episode
|
|
episode = TrainingEpisode(
|
|
episode_id=episode_id,
|
|
input_package=input_package,
|
|
model_predictions={}, # Will be filled when models make predictions
|
|
actual_outcome=outcome,
|
|
episode_type=episode_type,
|
|
profitability_rank=0.0, # Will be calculated later
|
|
training_priority=0.0
|
|
)
|
|
|
|
# Calculate training priority
|
|
episode.training_priority = episode.calculate_training_priority()
|
|
|
|
# Store episode
|
|
symbol = input_package.symbol
|
|
if symbol not in self.training_episodes:
|
|
self.training_episodes[symbol] = []
|
|
|
|
self.training_episodes[symbol].append(episode)
|
|
|
|
# Limit episodes per symbol
|
|
if len(self.training_episodes[symbol]) > self.max_episodes_per_symbol:
|
|
# Keep highest priority episodes
|
|
self.training_episodes[symbol].sort(key=lambda x: x.training_priority, reverse=True)
|
|
self.training_episodes[symbol] = self.training_episodes[symbol][:self.max_episodes_per_symbol]
|
|
|
|
# Save episode to disk
|
|
self._save_episode_to_disk(episode)
|
|
|
|
logger.debug(f"Created training episode: {episode_id}")
|
|
logger.debug(f"Episode type: {episode_type}, Priority: {episode.training_priority:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating training episode: {e}")
|
|
|
|
def _save_episode_to_disk(self, episode: TrainingEpisode):
|
|
"""Save training episode to disk for persistence"""
|
|
try:
|
|
symbol_dir = self.storage_dir / episode.input_package.symbol
|
|
symbol_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save episode data
|
|
episode_file = symbol_dir / f"{episode.episode_id}.pkl"
|
|
with open(episode_file, 'wb') as f:
|
|
pickle.dump(episode, f)
|
|
|
|
# Save episode metadata for quick access
|
|
metadata = {
|
|
'episode_id': episode.episode_id,
|
|
'timestamp': episode.input_package.timestamp.isoformat(),
|
|
'episode_type': episode.episode_type,
|
|
'training_priority': episode.training_priority,
|
|
'profitability_score': episode.actual_outcome.profitability_score,
|
|
'is_profitable': episode.actual_outcome.is_profitable,
|
|
'is_rapid_change': episode.actual_outcome.is_rapid_change,
|
|
'data_completeness': episode.input_package.completeness_score
|
|
}
|
|
|
|
metadata_file = symbol_dir / f"{episode.episode_id}_metadata.json"
|
|
with open(metadata_file, 'w') as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving episode to disk: {e}")
|
|
|
|
def get_high_priority_episodes(self,
|
|
symbol: str,
|
|
limit: int = 100,
|
|
min_priority: float = 0.5) -> List[TrainingEpisode]:
|
|
"""Get high-priority training episodes for replay training"""
|
|
try:
|
|
if symbol not in self.training_episodes:
|
|
return []
|
|
|
|
# Filter and sort by priority
|
|
high_priority = [
|
|
ep for ep in self.training_episodes[symbol]
|
|
if ep.training_priority >= min_priority
|
|
]
|
|
|
|
high_priority.sort(key=lambda x: x.training_priority, reverse=True)
|
|
|
|
return high_priority[:limit]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting high priority episodes for {symbol}: {e}")
|
|
return []
|
|
|
|
def get_collection_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive collection statistics"""
|
|
stats = self.collection_stats.copy()
|
|
|
|
# Add per-symbol statistics
|
|
stats['episodes_per_symbol'] = {
|
|
symbol: len(episodes)
|
|
for symbol, episodes in self.training_episodes.items()
|
|
}
|
|
|
|
# Add pending outcomes count
|
|
stats['pending_outcomes'] = {
|
|
symbol: len(packages)
|
|
for symbol, packages in self.pending_outcomes.items()
|
|
}
|
|
|
|
# Calculate profitability rate
|
|
if stats['total_episodes'] > 0:
|
|
stats['profitability_rate'] = stats['profitable_episodes'] / stats['total_episodes']
|
|
stats['rapid_change_rate'] = stats['rapid_change_episodes'] / stats['total_episodes']
|
|
else:
|
|
stats['profitability_rate'] = 0.0
|
|
stats['rapid_change_rate'] = 0.0
|
|
|
|
return stats
|
|
|
|
def validate_data_integrity(self) -> Dict[str, Any]:
|
|
"""Comprehensive data integrity validation"""
|
|
validation_results = {
|
|
'total_episodes_checked': 0,
|
|
'hash_mismatches': 0,
|
|
'completeness_issues': 0,
|
|
'validation_flag_failures': 0,
|
|
'corrupted_episodes': [],
|
|
'integrity_score': 1.0
|
|
}
|
|
|
|
try:
|
|
for symbol, episodes in self.training_episodes.items():
|
|
for episode in episodes:
|
|
validation_results['total_episodes_checked'] += 1
|
|
|
|
# Check data hash
|
|
expected_hash = episode.input_package._calculate_hash()
|
|
if expected_hash != episode.input_package.data_hash:
|
|
validation_results['hash_mismatches'] += 1
|
|
validation_results['corrupted_episodes'].append(episode.episode_id)
|
|
|
|
# Check completeness
|
|
if episode.input_package.completeness_score < 0.7:
|
|
validation_results['completeness_issues'] += 1
|
|
|
|
# Check validation flags
|
|
if not episode.input_package.validation_flags.get('data_consistent', False):
|
|
validation_results['validation_flag_failures'] += 1
|
|
|
|
# Calculate integrity score
|
|
total_issues = (
|
|
validation_results['hash_mismatches'] +
|
|
validation_results['completeness_issues'] +
|
|
validation_results['validation_flag_failures']
|
|
)
|
|
|
|
if validation_results['total_episodes_checked'] > 0:
|
|
validation_results['integrity_score'] = 1.0 - (
|
|
total_issues / validation_results['total_episodes_checked']
|
|
)
|
|
|
|
logger.info(f"Data integrity validation completed")
|
|
logger.info(f"Integrity score: {validation_results['integrity_score']:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during data integrity validation: {e}")
|
|
validation_results['validation_error'] = str(e)
|
|
|
|
return validation_results
|
|
|
|
# Global instance for easy access
|
|
training_data_collector = None
|
|
|
|
def get_training_data_collector() -> TrainingDataCollector:
|
|
"""Get global training data collector instance"""
|
|
global training_data_collector
|
|
if training_data_collector is None:
|
|
training_data_collector = TrainingDataCollector()
|
|
return training_data_collector |