Files
gogo2/core/enhanced_training_integration.py
Dobromir Popov 12865fd3ef replay system
2025-07-20 12:37:02 +03:00

775 lines
32 KiB
Python

"""
Enhanced Training Integration Module
This module provides comprehensive integration between the training data collection system,
CNN training pipeline, RL training pipeline, and your existing infrastructure.
Key Features:
- Real-time integration with existing DataProvider
- Coordinated training across CNN and RL models
- Automatic outcome validation and profitability tracking
- Integration with existing COB RL model
- Performance monitoring and optimization
- Seamless connection to existing orchestrator and trading executor
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import torch
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Callable
from dataclasses import dataclass
import threading
import time
from pathlib import Path
# Import existing components
from .data_provider import DataProvider
from .orchestrator import Orchestrator
from .trading_executor import TradingExecutor
# Import our training system components
from .training_data_collector import (
TrainingDataCollector,
get_training_data_collector
)
from .cnn_training_pipeline import (
CNNPivotPredictor,
CNNTrainer,
get_cnn_trainer
)
from .rl_training_pipeline import (
RLTradingAgent,
RLTrainer,
get_rl_trainer
)
from .training_integration import TrainingIntegration
# Import existing RL model
try:
from NN.models.cob_rl_model import COBRLModelInterface
except ImportError:
logger.warning("Could not import COBRLModelInterface - using fallback")
COBRLModelInterface = None
logger = logging.getLogger(__name__)
@dataclass
class EnhancedTrainingConfig:
"""Enhanced configuration for comprehensive training integration"""
# Data collection
collection_interval: float = 1.0
min_data_completeness: float = 0.8
# Training triggers
min_episodes_for_cnn_training: int = 100
min_experiences_for_rl_training: int = 200
training_frequency_minutes: int = 30
# Profitability thresholds
min_profitability_for_replay: float = 0.1
high_profitability_threshold: float = 0.5
# Model integration
use_existing_cob_rl_model: bool = True
enable_cross_model_learning: bool = True
# Performance optimization
max_concurrent_training_sessions: int = 2
enable_background_validation: bool = True
class EnhancedTrainingIntegration:
"""Enhanced training integration with existing infrastructure"""
def __init__(self,
data_provider: DataProvider,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None,
config: EnhancedTrainingConfig = None):
self.data_provider = data_provider
self.orchestrator = orchestrator
self.trading_executor = trading_executor
self.config = config or EnhancedTrainingConfig()
# Initialize training components
self.data_collector = get_training_data_collector()
# Initialize CNN components
self.cnn_model = CNNPivotPredictor()
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
# Initialize RL components
if self.config.use_existing_cob_rl_model and COBRLModelInterface:
self.existing_rl_model = COBRLModelInterface()
logger.info("Using existing COB RL model")
else:
self.existing_rl_model = None
self.rl_agent = RLTradingAgent()
self.rl_trainer = get_rl_trainer(self.rl_agent)
# Integration state
self.is_running = False
self.training_threads = {}
self.validation_thread = None
# Performance tracking
self.integration_stats = {
'total_data_packages': 0,
'cnn_training_sessions': 0,
'rl_training_sessions': 0,
'profitable_predictions': 0,
'total_predictions': 0,
'cross_model_improvements': 0,
'last_update': datetime.now()
}
# Model prediction tracking
self.recent_predictions = {}
self.prediction_outcomes = {}
# Cross-model learning
self.model_performance_history = {
'cnn': [],
'rl': [],
'orchestrator': []
}
logger.info("Enhanced Training Integration initialized")
logger.info(f"CNN model parameters: {sum(p.numel() for p in self.cnn_model.parameters()):,}")
logger.info(f"RL agent parameters: {sum(p.numel() for p in self.rl_agent.parameters()):,}")
logger.info(f"Using existing COB RL model: {self.existing_rl_model is not None}")
def start_enhanced_integration(self):
"""Start the enhanced training integration system"""
if self.is_running:
logger.warning("Enhanced training integration already running")
return
self.is_running = True
# Start data collection
self.data_collector.start_collection()
# Start CNN training
if self.config.min_episodes_for_cnn_training > 0:
for symbol in self.data_provider.symbols:
self.cnn_trainer.start_real_time_training(symbol)
# Start coordinated training thread
self.training_threads['coordinator'] = threading.Thread(
target=self._training_coordinator_worker,
daemon=True
)
self.training_threads['coordinator'].start()
# Start data collection and validation
self.training_threads['data_collector'] = threading.Thread(
target=self._enhanced_data_collection_worker,
daemon=True
)
self.training_threads['data_collector'].start()
# Start outcome validation if enabled
if self.config.enable_background_validation:
self.validation_thread = threading.Thread(
target=self._outcome_validation_worker,
daemon=True
)
self.validation_thread.start()
logger.info("Enhanced training integration started")
def stop_enhanced_integration(self):
"""Stop the enhanced training integration system"""
self.is_running = False
# Stop data collection
self.data_collector.stop_collection()
# Stop CNN training
self.cnn_trainer.stop_training()
# Wait for threads to finish
for thread_name, thread in self.training_threads.items():
thread.join(timeout=10)
logger.info(f"Stopped {thread_name} thread")
if self.validation_thread:
self.validation_thread.join(timeout=5)
logger.info("Enhanced training integration stopped")
def _enhanced_data_collection_worker(self):
"""Enhanced data collection with real-time model integration"""
logger.info("Enhanced data collection worker started")
while self.is_running:
try:
for symbol in self.data_provider.symbols:
self._collect_enhanced_training_data(symbol)
time.sleep(self.config.collection_interval)
except Exception as e:
logger.error(f"Error in enhanced data collection: {e}")
time.sleep(5)
logger.info("Enhanced data collection worker stopped")
def _collect_enhanced_training_data(self, symbol: str):
"""Collect enhanced training data with model predictions"""
try:
# Get comprehensive market data
market_data = self._get_comprehensive_market_data(symbol)
if not market_data or not self._validate_market_data(market_data):
return
# Get current model predictions
model_predictions = self._get_all_model_predictions(symbol, market_data)
# Create enhanced features
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
rl_state = self._create_enhanced_rl_state(symbol, market_data, model_predictions)
# Collect training data with predictions
episode_id = self.data_collector.collect_training_data(
symbol=symbol,
ohlcv_data=market_data['ohlcv'],
tick_data=market_data['ticks'],
cob_data=market_data['cob'],
technical_indicators=market_data['indicators'],
pivot_points=market_data['pivots'],
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=market_data['context'],
model_predictions=model_predictions
)
if episode_id:
# Store predictions for outcome validation
self.recent_predictions[episode_id] = {
'timestamp': datetime.now(),
'symbol': symbol,
'predictions': model_predictions,
'market_data': market_data
}
# Add RL experience if we have action
if 'rl_action' in model_predictions:
self._add_rl_experience(symbol, market_data, model_predictions, episode_id)
self.integration_stats['total_data_packages'] += 1
except Exception as e:
logger.error(f"Error collecting enhanced training data for {symbol}: {e}")
def _get_comprehensive_market_data(self, symbol: str) -> Dict[str, Any]:
"""Get comprehensive market data from all sources"""
try:
market_data = {}
# OHLCV data
ohlcv_data = {}
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
df = self.data_provider.get_historical_data(symbol, timeframe, limit=300, refresh=True)
if df is not None and not df.empty:
ohlcv_data[timeframe] = df
market_data['ohlcv'] = ohlcv_data
# Tick data
market_data['ticks'] = self._get_recent_tick_data(symbol)
# COB data
market_data['cob'] = self._get_cob_data(symbol)
# Technical indicators
market_data['indicators'] = self._get_technical_indicators(symbol)
# Pivot points
market_data['pivots'] = self._get_pivot_points(symbol)
# Market context
market_data['context'] = self._get_market_context(symbol)
return market_data
except Exception as e:
logger.error(f"Error getting comprehensive market data: {e}")
return {}
def _get_all_model_predictions(self, symbol: str, market_data: Dict[str, Any]) -> Dict[str, Any]:
"""Get predictions from all available models"""
predictions = {}
try:
# CNN predictions
if self.cnn_model and market_data.get('ohlcv'):
cnn_features = self._create_enhanced_cnn_features(symbol, market_data)
if cnn_features is not None:
cnn_input = torch.from_numpy(cnn_features).float().unsqueeze(0)
# Reshape for CNN (add channel dimension)
cnn_input = cnn_input.view(1, 10, -1) # Assuming 10 channels
with torch.no_grad():
cnn_outputs = self.cnn_model(cnn_input)
predictions['cnn'] = {
'pivot_logits': cnn_outputs['pivot_logits'].cpu().numpy(),
'pivot_price': cnn_outputs['pivot_price'].cpu().numpy(),
'confidence': cnn_outputs['confidence'].cpu().numpy(),
'timestamp': datetime.now()
}
# RL predictions
if self.rl_agent and market_data.get('cob'):
rl_state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if rl_state is not None:
action, confidence = self.rl_agent.select_action(rl_state, epsilon=0.1)
predictions['rl'] = {
'action': action,
'confidence': confidence,
'timestamp': datetime.now()
}
predictions['rl_action'] = action
# Existing COB RL model predictions
if self.existing_rl_model and market_data.get('cob'):
cob_features = market_data['cob'].get('cob_features', [])
if cob_features and len(cob_features) >= 2000:
cob_array = np.array(cob_features[:2000], dtype=np.float32)
cob_prediction = self.existing_rl_model.predict(cob_array)
predictions['cob_rl'] = {
'predicted_direction': cob_prediction.get('predicted_direction', 1),
'confidence': cob_prediction.get('confidence', 0.5),
'value': cob_prediction.get('value', 0.0),
'timestamp': datetime.now()
}
# Orchestrator predictions (if available)
if self.orchestrator:
try:
# This would integrate with your orchestrator's prediction method
orchestrator_prediction = self._get_orchestrator_prediction(symbol, market_data, predictions)
if orchestrator_prediction:
predictions['orchestrator'] = orchestrator_prediction
except Exception as e:
logger.debug(f"Could not get orchestrator prediction: {e}")
return predictions
except Exception as e:
logger.error(f"Error getting model predictions: {e}")
return {}
def _add_rl_experience(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any], episode_id: str):
"""Add RL experience to the training buffer"""
try:
# Create RL state
state = self._create_enhanced_rl_state(symbol, market_data, predictions)
if state is None:
return
# Get action from predictions
action = predictions.get('rl_action', 1) # Default to HOLD
# Calculate immediate reward (placeholder - would be updated with actual outcome)
reward = 0.0
# Create next state (same as current for now - would be updated)
next_state = state.copy()
# Market context
market_context = {
'symbol': symbol,
'episode_id': episode_id,
'timestamp': datetime.now(),
'market_session': market_data['context'].get('market_session', 'unknown'),
'volatility_regime': market_data['context'].get('volatility_regime', 'unknown')
}
# Add experience
experience_id = self.rl_trainer.add_experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=False,
market_context=market_context,
cnn_predictions=predictions.get('cnn'),
confidence_score=predictions.get('rl', {}).get('confidence', 0.0)
)
if experience_id:
logger.debug(f"Added RL experience: {experience_id}")
except Exception as e:
logger.error(f"Error adding RL experience: {e}")
def _training_coordinator_worker(self):
"""Coordinate training across all models"""
logger.info("Training coordinator worker started")
while self.is_running:
try:
# Check if we should trigger training
for symbol in self.data_provider.symbols:
self._check_and_trigger_training(symbol)
# Wait before next check
time.sleep(self.config.training_frequency_minutes * 60)
except Exception as e:
logger.error(f"Error in training coordinator: {e}")
time.sleep(60)
logger.info("Training coordinator worker stopped")
def _check_and_trigger_training(self, symbol: str):
"""Check conditions and trigger training if needed"""
try:
# Get training episodes and experiences
episodes = self.data_collector.get_high_priority_episodes(symbol, limit=1000)
# Check CNN training conditions
if len(episodes) >= self.config.min_episodes_for_cnn_training:
profitable_episodes = [ep for ep in episodes if ep.actual_outcome.is_profitable]
if len(profitable_episodes) >= 20: # Minimum profitable episodes
logger.info(f"Triggering CNN training for {symbol} with {len(profitable_episodes)} profitable episodes")
results = self.cnn_trainer.train_on_profitable_episodes(
symbol=symbol,
min_profitability=self.config.min_profitability_for_replay,
max_episodes=len(profitable_episodes)
)
if results.get('status') == 'success':
self.integration_stats['cnn_training_sessions'] += 1
logger.info(f"CNN training completed for {symbol}")
# Check RL training conditions
buffer_stats = self.rl_trainer.experience_buffer.get_buffer_statistics()
total_experiences = buffer_stats.get('total_experiences', 0)
if total_experiences >= self.config.min_experiences_for_rl_training:
profitable_experiences = buffer_stats.get('profitable_experiences', 0)
if profitable_experiences >= 50: # Minimum profitable experiences
logger.info(f"Triggering RL training with {profitable_experiences} profitable experiences")
results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=self.config.min_profitability_for_replay,
max_experiences=min(profitable_experiences, 500),
batch_size=32
)
if results.get('status') == 'success':
self.integration_stats['rl_training_sessions'] += 1
logger.info("RL training completed")
except Exception as e:
logger.error(f"Error checking training conditions for {symbol}: {e}")
def _outcome_validation_worker(self):
"""Background worker for validating prediction outcomes"""
logger.info("Outcome validation worker started")
while self.is_running:
try:
self._validate_recent_predictions()
time.sleep(300) # Check every 5 minutes
except Exception as e:
logger.error(f"Error in outcome validation: {e}")
time.sleep(60)
logger.info("Outcome validation worker stopped")
def _validate_recent_predictions(self):
"""Validate recent predictions against actual outcomes"""
try:
current_time = datetime.now()
validation_delay = timedelta(hours=1) # Wait 1 hour to validate
validated_predictions = []
for episode_id, prediction_data in self.recent_predictions.items():
prediction_time = prediction_data['timestamp']
if current_time - prediction_time >= validation_delay:
# Validate this prediction
outcome = self._calculate_prediction_outcome(prediction_data)
if outcome:
self.prediction_outcomes[episode_id] = outcome
# Update RL experience if exists
if 'rl_action' in prediction_data['predictions']:
self._update_rl_experience_outcome(episode_id, outcome)
# Update statistics
if outcome['is_profitable']:
self.integration_stats['profitable_predictions'] += 1
self.integration_stats['total_predictions'] += 1
validated_predictions.append(episode_id)
# Remove validated predictions
for episode_id in validated_predictions:
del self.recent_predictions[episode_id]
if validated_predictions:
logger.info(f"Validated {len(validated_predictions)} predictions")
except Exception as e:
logger.error(f"Error validating predictions: {e}")
def _calculate_prediction_outcome(self, prediction_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Calculate actual outcome for a prediction"""
try:
symbol = prediction_data['symbol']
prediction_time = prediction_data['timestamp']
# Get price data after prediction
current_df = self.data_provider.get_historical_data(symbol, '1m', limit=100, refresh=True)
if current_df is None or current_df.empty:
return None
# Find price at prediction time and current price
prediction_price = prediction_data['market_data']['ohlcv'].get('1m', pd.DataFrame())
if prediction_price.empty:
return None
base_price = float(prediction_price['close'].iloc[-1])
current_price = float(current_df['close'].iloc[-1])
# Calculate outcome
price_change = (current_price - base_price) / base_price
is_profitable = abs(price_change) > 0.005 # 0.5% threshold
return {
'episode_id': prediction_data.get('episode_id'),
'base_price': base_price,
'current_price': current_price,
'price_change': price_change,
'is_profitable': is_profitable,
'profitability_score': abs(price_change) * 10, # Scale to 0-1 range
'validation_time': datetime.now()
}
except Exception as e:
logger.error(f"Error calculating prediction outcome: {e}")
return None
def _update_rl_experience_outcome(self, episode_id: str, outcome: Dict[str, Any]):
"""Update RL experience with actual outcome"""
try:
# Find the experience ID associated with this episode
# This is a simplified approach - in practice you'd maintain better mapping
actual_profit = outcome['price_change']
# Determine optimal action based on outcome
if outcome['price_change'] > 0.01:
optimal_action = 2 # BUY
elif outcome['price_change'] < -0.01:
optimal_action = 0 # SELL
else:
optimal_action = 1 # HOLD
# Update experience (this would need proper experience ID mapping)
# For now, we'll update the most recent experience
# In practice, you'd maintain a mapping between episodes and experiences
except Exception as e:
logger.error(f"Error updating RL experience outcome: {e}")
def get_integration_statistics(self) -> Dict[str, Any]:
"""Get comprehensive integration statistics"""
stats = self.integration_stats.copy()
# Add component statistics
stats['data_collector'] = self.data_collector.get_collection_statistics()
stats['cnn_trainer'] = self.cnn_trainer.get_training_statistics()
stats['rl_trainer'] = self.rl_trainer.get_training_statistics()
# Add performance metrics
stats['is_running'] = self.is_running
stats['active_symbols'] = len(self.data_provider.symbols)
stats['recent_predictions_count'] = len(self.recent_predictions)
stats['validated_outcomes_count'] = len(self.prediction_outcomes)
# Calculate profitability rate
if stats['total_predictions'] > 0:
stats['overall_profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
else:
stats['overall_profitability_rate'] = 0.0
return stats
def trigger_manual_training(self, training_type: str = 'all', symbol: str = None) -> Dict[str, Any]:
"""Manually trigger training"""
results = {}
try:
if training_type in ['all', 'cnn']:
symbols = [symbol] if symbol else self.data_provider.symbols
for sym in symbols:
cnn_results = self.cnn_trainer.train_on_profitable_episodes(
symbol=sym,
min_profitability=0.1,
max_episodes=200
)
results[f'cnn_{sym}'] = cnn_results
if training_type in ['all', 'rl']:
rl_results = self.rl_trainer.train_on_profitable_experiences(
min_profitability=0.1,
max_experiences=500,
batch_size=32
)
results['rl'] = rl_results
return {'status': 'success', 'results': results}
except Exception as e:
logger.error(f"Error in manual training trigger: {e}")
return {'status': 'error', 'error': str(e)}
# Helper methods (simplified implementations)
def _get_recent_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
"""Get recent tick data"""
# Implementation would get tick data from data provider
return []
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
"""Get COB data"""
# Implementation would get COB data from data provider
return {}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators"""
# Implementation would get indicators from data provider
return {}
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
"""Get pivot points"""
# Implementation would get pivot points from data provider
return []
def _get_market_context(self, symbol: str) -> Dict[str, Any]:
"""Get market context"""
return {
'symbol': symbol,
'timestamp': datetime.now(),
'market_session': 'unknown',
'volatility_regime': 'unknown'
}
def _validate_market_data(self, market_data: Dict[str, Any]) -> bool:
"""Validate market data completeness"""
required_fields = ['ohlcv', 'indicators']
return all(field in market_data for field in required_fields)
def _create_enhanced_cnn_features(self, symbol: str, market_data: Dict[str, Any]) -> Optional[np.ndarray]:
"""Create enhanced CNN features"""
try:
# Simplified feature creation
features = []
# Add OHLCV features
for timeframe in ['1m', '5m', '15m', '1h']:
if timeframe in market_data.get('ohlcv', {}):
df = market_data['ohlcv'][timeframe]
if not df.empty:
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
if len(ohlcv_values) > 0:
recent_values = ohlcv_values[-60:].flatten()
features.extend(recent_values)
# Pad to target size
target_size = 3000 # 10 channels * 300 sequence length
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating CNN features: {e}")
return None
def _create_enhanced_rl_state(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any] = None) -> Optional[np.ndarray]:
"""Create enhanced RL state"""
try:
state_features = []
# Add market features
if '1m' in market_data.get('ohlcv', {}):
df = market_data['ohlcv']['1m']
if not df.empty:
latest = df.iloc[-1]
state_features.extend([
latest['open'], latest['high'],
latest['low'], latest['close'], latest['volume']
])
# Add technical indicators
indicators = market_data.get('indicators', {})
for value in indicators.values():
state_features.append(value)
# Add model predictions as features
if predictions:
if 'cnn' in predictions:
cnn_pred = predictions['cnn']
state_features.extend(cnn_pred.get('pivot_logits', [0, 0, 0]))
state_features.append(cnn_pred.get('confidence', [0.0])[0])
if 'cob_rl' in predictions:
cob_pred = predictions['cob_rl']
state_features.append(cob_pred.get('predicted_direction', 1))
state_features.append(cob_pred.get('confidence', 0.5))
# Pad to target size
target_size = 2000
if len(state_features) < target_size:
state_features.extend([0.0] * (target_size - len(state_features)))
else:
state_features = state_features[:target_size]
return np.array(state_features, dtype=np.float32)
except Exception as e:
logger.warning(f"Error creating RL state: {e}")
return None
def _get_orchestrator_prediction(self, symbol: str, market_data: Dict[str, Any],
predictions: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Get orchestrator prediction"""
# This would integrate with your orchestrator
return None
# Global instance
enhanced_training_integration = None
def get_enhanced_training_integration(data_provider: DataProvider = None,
orchestrator: Orchestrator = None,
trading_executor: TradingExecutor = None) -> EnhancedTrainingIntegration:
"""Get global enhanced training integration instance"""
global enhanced_training_integration
if enhanced_training_integration is None:
if data_provider is None:
raise ValueError("DataProvider required for first initialization")
enhanced_training_integration = EnhancedTrainingIntegration(
data_provider, orchestrator, trading_executor
)
return enhanced_training_integration