Files
gogo2/core/training_integration.py
Dobromir Popov 12865fd3ef replay system
2025-07-20 12:37:02 +03:00

675 lines
26 KiB
Python

"""
Training Integration Module
This module integrates the comprehensive training data collection system
with the existing data provider and model infrastructure. It provides:
1. Real-time data collection from DataProvider
2. Integration with existing CNN and RL models
3. Automatic training data package creation
4. Rapid price change detection and collection
5. Training pipeline coordination
Key Features:
- Seamless integration with existing DataProvider
- Automatic model input package creation
- Real-time training data validation
- Coordinated training across all models
- Performance monitoring and optimization
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass
import threading
import time
from collections import deque
from .training_data_collector import (
TrainingDataCollector,
ModelInputPackage,
get_training_data_collector
)
from .cnn_training_pipeline import (
CNNTrainer,
CNNPivotPredictor,
get_cnn_trainer
)
from .data_provider import DataProvider
logger = logging.getLogger(__name__)
@dataclass
class TrainingIntegrationConfig:
"""Configuration for training integration"""
# Data collection settings
collection_interval: float = 1.0 # seconds
min_data_completeness: float = 0.7
# Rapid change detection
enable_rapid_change_detection: bool = True
price_change_threshold: float = 0.5 # % per minute
# Training settings
enable_real_time_training: bool = True
training_batch_size: int = 32
min_episodes_for_training: int = 50
# Performance settings
max_concurrent_collections: int = 4
data_validation_enabled: bool = True
class TrainingIntegration:
"""Main integration class for training data collection and model training"""
def __init__(self,
data_provider: DataProvider,
config: TrainingIntegrationConfig = None):
self.data_provider = data_provider
self.config = config or TrainingIntegrationConfig()
# Get training components
self.data_collector = get_training_data_collector()
# Initialize CNN components
self.cnn_model = CNNPivotPredictor()
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
# Integration state
self.is_running = False
self.collection_thread = None
self.training_threads = {}
# Data buffers for real-time processing
self.data_buffers = {}
self.last_collection_time = {}
# Performance tracking
self.integration_stats = {
'data_packages_created': 0,
'training_sessions_triggered': 0,
'rapid_changes_detected': 0,
'validation_failures': 0,
'average_collection_time': 0.0,
'last_update': datetime.now()
}
# Initialize data buffers for each symbol
for symbol in self.data_provider.symbols:
self.data_buffers[symbol] = {
'ohlcv_data': {},
'tick_data': deque(maxlen=1000),
'cob_data': {},
'technical_indicators': {},
'pivot_points': []
}
self.last_collection_time[symbol] = datetime.now()
logger.info("Training Integration initialized")
logger.info(f"Symbols: {self.data_provider.symbols}")
logger.info(f"Real-time training: {self.config.enable_real_time_training}")
logger.info(f"Rapid change detection: {self.config.enable_rapid_change_detection}")
def start_integration(self):
"""Start the training integration system"""
if self.is_running:
logger.warning("Training integration already running")
return
self.is_running = True
# Start data collection
self.data_collector.start_collection()
# Start real-time data collection thread
self.collection_thread = threading.Thread(
target=self._data_collection_worker,
daemon=True
)
self.collection_thread.start()
# Start training threads for each symbol
if self.config.enable_real_time_training:
for symbol in self.data_provider.symbols:
training_thread = threading.Thread(
target=self._training_worker,
args=(symbol,),
daemon=True
)
self.training_threads[symbol] = training_thread
training_thread.start()
logger.info("Training integration started")
def stop_integration(self):
"""Stop the training integration system"""
self.is_running = False
# Stop data collection
self.data_collector.stop_collection()
# Stop CNN training
self.cnn_trainer.stop_training()
# Wait for threads to finish
if self.collection_thread:
self.collection_thread.join(timeout=10)
for thread in self.training_threads.values():
thread.join(timeout=5)
logger.info("Training integration stopped")
def _data_collection_worker(self):
"""Main data collection worker"""
logger.info("Data collection worker started")
while self.is_running:
try:
start_time = time.time()
# Collect data for each symbol
for symbol in self.data_provider.symbols:
self._collect_symbol_data(symbol)
# Update performance stats
collection_time = time.time() - start_time
self._update_collection_stats(collection_time)
# Wait for next collection cycle
time.sleep(self.config.collection_interval)
except Exception as e:
logger.error(f"Error in data collection worker: {e}")
time.sleep(5) # Wait before retrying
logger.info("Data collection worker stopped")
def _collect_symbol_data(self, symbol: str):
"""Collect comprehensive training data for a symbol"""
try:
# Get current market data from data provider
ohlcv_data = self._get_ohlcv_data(symbol)
tick_data = self._get_tick_data(symbol)
cob_data = self._get_cob_data(symbol)
technical_indicators = self._get_technical_indicators(symbol)
pivot_points = self._get_pivot_points(symbol)
# Validate data availability
if not self._validate_data_availability(symbol, ohlcv_data, tick_data):
return
# Create model input features
cnn_features = self._create_cnn_features(symbol, ohlcv_data, technical_indicators)
rl_state = self._create_rl_state(symbol, ohlcv_data, cob_data, technical_indicators)
orchestrator_context = self._create_orchestrator_context(symbol)
# Get model predictions if available
model_predictions = self._get_current_model_predictions(symbol)
# Collect training data package
episode_id = self.data_collector.collect_training_data(
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context,
model_predictions=model_predictions
)
if episode_id:
self.integration_stats['data_packages_created'] += 1
logger.debug(f"Created training data package for {symbol}: {episode_id}")
except Exception as e:
logger.error(f"Error collecting data for {symbol}: {e}")
self.integration_stats['validation_failures'] += 1
def _get_ohlcv_data(self, symbol: str) -> Dict[str, pd.DataFrame]:
"""Get OHLCV data for all timeframes"""
ohlcv_data = {}
try:
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=300, # Get 300 bars as specified in requirements
refresh=True # Get fresh data
)
if df is not None and not df.empty:
ohlcv_data[timeframe] = df
return ohlcv_data
except Exception as e:
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
return {}
def _get_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
"""Get recent tick data"""
try:
# Get tick data from data provider's tick buffers
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol in self.data_provider.tick_buffers:
# Get last 300 seconds of tick data
current_time = datetime.now()
cutoff_time = current_time - timedelta(seconds=300)
tick_buffer = self.data_provider.tick_buffers[binance_symbol]
recent_ticks = []
# Convert deque to list and filter by time
for tick in list(tick_buffer):
if hasattr(tick, 'timestamp') and tick.timestamp >= cutoff_time:
recent_ticks.append({
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'side': tick.side,
'trade_id': tick.trade_id
})
return recent_ticks
return []
except Exception as e:
logger.warning(f"Error getting tick data for {symbol}: {e}")
return []
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
"""Get Consolidated Order Book data"""
try:
# Get COB data from data provider's COB cache
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol in self.data_provider.cob_data_cache:
cob_buffer = self.data_provider.cob_data_cache[binance_symbol]
if cob_buffer:
# Get the most recent COB data
latest_cob = list(cob_buffer)[-1] if cob_buffer else None
if latest_cob:
return {
'timestamp': latest_cob[0] if isinstance(latest_cob, tuple) else datetime.now(),
'cob_features': latest_cob[1] if isinstance(latest_cob, tuple) else latest_cob,
'feature_count': len(latest_cob[1]) if isinstance(latest_cob, tuple) else 0
}
return {}
except Exception as e:
logger.warning(f"Error getting COB data for {symbol}: {e}")
return {}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators from OHLCV data"""
try:
# Get the most recent 1m data with indicators
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe='1m',
limit=50,
refresh=True
)
if df is not None and not df.empty:
# Extract indicators from the latest row
latest_row = df.iloc[-1]
indicators = {}
# Extract common indicators
for col in df.columns:
if col not in ['open', 'high', 'low', 'close', 'volume', 'timestamp']:
try:
value = float(latest_row[col])
if not np.isnan(value):
indicators[col] = value
except (ValueError, TypeError):
continue
return indicators
return {}
except Exception as e:
logger.warning(f"Error getting technical indicators for {symbol}: {e}")
return {}
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
"""Get recent pivot points"""
try:
# Get pivot points from Williams Market Structure
if symbol in self.data_provider.williams_structure:
williams = self.data_provider.williams_structure[symbol]
# Get recent pivot points
pivot_points = []
# This would integrate with the Williams Market Structure
# For now, return empty list as placeholder
return pivot_points
return []
except Exception as e:
logger.warning(f"Error getting pivot points for {symbol}: {e}")
return []
def _create_cnn_features(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
technical_indicators: Dict[str, float]) -> np.ndarray:
"""Create CNN input features from market data"""
try:
# This is a simplified feature creation
# In practice, you'd create multi-timeframe features
features = []
# Add OHLCV features from multiple timeframes
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
if timeframe in ohlcv_data:
df = ohlcv_data[timeframe]
if not df.empty:
# Normalize OHLCV data
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
if len(ohlcv_values) > 0:
# Take last 60 values and flatten
recent_values = ohlcv_values[-60:].flatten()
features.extend(recent_values)
# Add technical indicators
for indicator_name, value in technical_indicators.items():
features.append(value)
# Pad or truncate to fixed size
target_size = 2000 # Match CNN input size
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating CNN features for {symbol}: {e}")
return np.zeros(2000, dtype=np.float32)
def _create_rl_state(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
cob_data: Dict[str, Any],
technical_indicators: Dict[str, float]) -> np.ndarray:
"""Create RL state representation"""
try:
state_features = []
# Add market state features
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
latest_candle = ohlcv_data['1m'].iloc[-1]
state_features.extend([
latest_candle['open'],
latest_candle['high'],
latest_candle['low'],
latest_candle['close'],
latest_candle['volume']
])
# Add COB features
if 'cob_features' in cob_data:
cob_features = cob_data['cob_features']
if isinstance(cob_features, (list, np.ndarray)):
state_features.extend(cob_features[:100]) # Limit COB features
# Add technical indicators
for indicator_name, value in technical_indicators.items():
state_features.append(value)
# Pad or truncate to fixed size
target_size = 2000 # Match RL input size
if len(state_features) < target_size:
state_features.extend([0.0] * (target_size - len(state_features)))
else:
state_features = state_features[:target_size]
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating RL state for {symbol}: {e}")
return np.zeros(2000, dtype=np.float32)
def _create_orchestrator_context(self, symbol: str) -> Dict[str, Any]:
"""Create orchestrator context"""
try:
return {
'symbol': symbol,
'timestamp': datetime.now(),
'market_session': self._determine_market_session(),
'volatility_regime': self._determine_volatility_regime(symbol),
'trend_direction': self._determine_trend_direction(symbol)
}
except Exception as e:
logger.warning(f"Error creating orchestrator context for {symbol}: {e}")
return {'symbol': symbol, 'timestamp': datetime.now()}
def _determine_market_session(self) -> str:
"""Determine current market session"""
# Simplified market session detection
current_hour = datetime.now().hour
if 0 <= current_hour < 8:
return 'asian'
elif 8 <= current_hour < 16:
return 'european'
else:
return 'american'
def _determine_volatility_regime(self, symbol: str) -> str:
"""Determine volatility regime for symbol"""
try:
# Get recent volatility data
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
if df is not None and not df.empty:
returns = df['close'].pct_change().dropna()
volatility = returns.std()
if volatility > 0.02:
return 'high'
elif volatility > 0.01:
return 'medium'
else:
return 'low'
return 'unknown'
except Exception:
return 'unknown'
def _determine_trend_direction(self, symbol: str) -> str:
"""Determine trend direction for symbol"""
try:
# Simple trend detection using moving averages
df = self.data_provider.get_historical_data(symbol, '1h', limit=50)
if df is not None and not df.empty:
if 'sma_20' in df.columns and 'sma_50' in df.columns:
latest_sma20 = df['sma_20'].iloc[-1]
latest_sma50 = df['sma_50'].iloc[-1]
if latest_sma20 > latest_sma50:
return 'uptrend'
elif latest_sma20 < latest_sma50:
return 'downtrend'
else:
return 'sideways'
return 'unknown'
except Exception:
return 'unknown'
def _get_current_model_predictions(self, symbol: str) -> Dict[str, Any]:
"""Get current predictions from all models"""
predictions = {}
try:
# This would integrate with existing model predictions
# For now, return empty dict as placeholder
return predictions
except Exception as e:
logger.warning(f"Error getting model predictions for {symbol}: {e}")
return {}
def _validate_data_availability(self,
symbol: str,
ohlcv_data: Dict[str, pd.DataFrame],
tick_data: List[Dict[str, Any]]) -> bool:
"""Validate that sufficient data is available for training"""
try:
# Check OHLCV data availability
required_timeframes = ['1m', '5m', '1h']
available_timeframes = 0
for timeframe in required_timeframes:
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
available_timeframes += 1
# Check minimum data requirements
if available_timeframes < 2: # Need at least 2 timeframes
return False
# Check tick data availability (optional but preferred)
has_tick_data = len(tick_data) > 0
# Calculate completeness score
completeness = available_timeframes / len(required_timeframes)
if has_tick_data:
completeness += 0.1 # Bonus for tick data
return completeness >= self.config.min_data_completeness
except Exception as e:
logger.warning(f"Error validating data availability for {symbol}: {e}")
return False
def _training_worker(self, symbol: str):
"""Training worker for a specific symbol"""
logger.info(f"Training worker started for {symbol}")
while self.is_running:
try:
# Check if we have enough episodes for training
episodes = self.data_collector.get_high_priority_episodes(
symbol=symbol,
limit=self.config.training_batch_size * 2,
min_priority=0.3
)
if len(episodes) >= self.config.min_episodes_for_training:
# Trigger CNN training
results = self.cnn_trainer.train_on_profitable_episodes(
symbol=symbol,
min_profitability=0.6,
max_episodes=len(episodes)
)
if results.get('status') == 'success':
self.integration_stats['training_sessions_triggered'] += 1
logger.info(f"Training session completed for {symbol}")
# Wait before next training check
time.sleep(300) # Check every 5 minutes
except Exception as e:
logger.error(f"Error in training worker for {symbol}: {e}")
time.sleep(60) # Wait before retrying
logger.info(f"Training worker stopped for {symbol}")
def _update_collection_stats(self, collection_time: float):
"""Update collection performance statistics"""
try:
# Update average collection time
alpha = 0.1 # Exponential moving average factor
if self.integration_stats['average_collection_time'] == 0:
self.integration_stats['average_collection_time'] = collection_time
else:
self.integration_stats['average_collection_time'] = (
alpha * collection_time +
(1 - alpha) * self.integration_stats['average_collection_time']
)
self.integration_stats['last_update'] = datetime.now()
except Exception as e:
logger.warning(f"Error updating collection stats: {e}")
def get_integration_statistics(self) -> Dict[str, Any]:
"""Get comprehensive integration statistics"""
stats = self.integration_stats.copy()
# Add data collector statistics
collector_stats = self.data_collector.get_collection_statistics()
stats.update(collector_stats)
# Add CNN trainer statistics
trainer_stats = self.cnn_trainer.get_training_statistics()
stats['cnn_training'] = trainer_stats
# Add performance metrics
stats['is_running'] = self.is_running
stats['active_symbols'] = len(self.data_provider.symbols)
stats['collection_frequency'] = self.config.collection_interval
return stats
def trigger_manual_training(self, symbol: str, training_type: str = 'profitable') -> Dict[str, Any]:
"""Manually trigger training for a symbol"""
try:
if training_type == 'profitable':
results = self.cnn_trainer.train_on_profitable_episodes(
symbol=symbol,
min_profitability=0.7,
max_episodes=200
)
elif training_type == 'high_value_replay':
results = self.cnn_trainer.replay_high_value_sessions(
symbol=symbol,
min_session_value=0.8,
max_sessions=10
)
else:
return {'status': 'error', 'error': f'Unknown training type: {training_type}'}
if results.get('status') == 'success':
self.integration_stats['training_sessions_triggered'] += 1
return results
except Exception as e:
logger.error(f"Error in manual training trigger: {e}")
return {'status': 'error', 'error': str(e)}
# Global instance
training_integration = None
def get_training_integration(data_provider: DataProvider = None) -> TrainingIntegration:
"""Get global training integration instance"""
global training_integration
if training_integration is None:
if data_provider is None:
raise ValueError("DataProvider required for first initialization")
training_integration = TrainingIntegration(data_provider)
return training_integration