675 lines
26 KiB
Python
675 lines
26 KiB
Python
"""
|
|
Training Integration Module
|
|
|
|
This module integrates the comprehensive training data collection system
|
|
with the existing data provider and model infrastructure. It provides:
|
|
|
|
1. Real-time data collection from DataProvider
|
|
2. Integration with existing CNN and RL models
|
|
3. Automatic training data package creation
|
|
4. Rapid price change detection and collection
|
|
5. Training pipeline coordination
|
|
|
|
Key Features:
|
|
- Seamless integration with existing DataProvider
|
|
- Automatic model input package creation
|
|
- Real-time training data validation
|
|
- Coordinated training across all models
|
|
- Performance monitoring and optimization
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
from dataclasses import dataclass
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
|
|
from .training_data_collector import (
|
|
TrainingDataCollector,
|
|
ModelInputPackage,
|
|
get_training_data_collector
|
|
)
|
|
from .cnn_training_pipeline import (
|
|
CNNTrainer,
|
|
CNNPivotPredictor,
|
|
get_cnn_trainer
|
|
)
|
|
from .data_provider import DataProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TrainingIntegrationConfig:
|
|
"""Configuration for training integration"""
|
|
# Data collection settings
|
|
collection_interval: float = 1.0 # seconds
|
|
min_data_completeness: float = 0.7
|
|
|
|
# Rapid change detection
|
|
enable_rapid_change_detection: bool = True
|
|
price_change_threshold: float = 0.5 # % per minute
|
|
|
|
# Training settings
|
|
enable_real_time_training: bool = True
|
|
training_batch_size: int = 32
|
|
min_episodes_for_training: int = 50
|
|
|
|
# Performance settings
|
|
max_concurrent_collections: int = 4
|
|
data_validation_enabled: bool = True
|
|
|
|
class TrainingIntegration:
|
|
"""Main integration class for training data collection and model training"""
|
|
|
|
def __init__(self,
|
|
data_provider: DataProvider,
|
|
config: TrainingIntegrationConfig = None):
|
|
|
|
self.data_provider = data_provider
|
|
self.config = config or TrainingIntegrationConfig()
|
|
|
|
# Get training components
|
|
self.data_collector = get_training_data_collector()
|
|
|
|
# Initialize CNN components
|
|
self.cnn_model = CNNPivotPredictor()
|
|
self.cnn_trainer = get_cnn_trainer(self.cnn_model)
|
|
|
|
# Integration state
|
|
self.is_running = False
|
|
self.collection_thread = None
|
|
self.training_threads = {}
|
|
|
|
# Data buffers for real-time processing
|
|
self.data_buffers = {}
|
|
self.last_collection_time = {}
|
|
|
|
# Performance tracking
|
|
self.integration_stats = {
|
|
'data_packages_created': 0,
|
|
'training_sessions_triggered': 0,
|
|
'rapid_changes_detected': 0,
|
|
'validation_failures': 0,
|
|
'average_collection_time': 0.0,
|
|
'last_update': datetime.now()
|
|
}
|
|
|
|
# Initialize data buffers for each symbol
|
|
for symbol in self.data_provider.symbols:
|
|
self.data_buffers[symbol] = {
|
|
'ohlcv_data': {},
|
|
'tick_data': deque(maxlen=1000),
|
|
'cob_data': {},
|
|
'technical_indicators': {},
|
|
'pivot_points': []
|
|
}
|
|
self.last_collection_time[symbol] = datetime.now()
|
|
|
|
logger.info("Training Integration initialized")
|
|
logger.info(f"Symbols: {self.data_provider.symbols}")
|
|
logger.info(f"Real-time training: {self.config.enable_real_time_training}")
|
|
logger.info(f"Rapid change detection: {self.config.enable_rapid_change_detection}")
|
|
|
|
def start_integration(self):
|
|
"""Start the training integration system"""
|
|
if self.is_running:
|
|
logger.warning("Training integration already running")
|
|
return
|
|
|
|
self.is_running = True
|
|
|
|
# Start data collection
|
|
self.data_collector.start_collection()
|
|
|
|
# Start real-time data collection thread
|
|
self.collection_thread = threading.Thread(
|
|
target=self._data_collection_worker,
|
|
daemon=True
|
|
)
|
|
self.collection_thread.start()
|
|
|
|
# Start training threads for each symbol
|
|
if self.config.enable_real_time_training:
|
|
for symbol in self.data_provider.symbols:
|
|
training_thread = threading.Thread(
|
|
target=self._training_worker,
|
|
args=(symbol,),
|
|
daemon=True
|
|
)
|
|
self.training_threads[symbol] = training_thread
|
|
training_thread.start()
|
|
|
|
logger.info("Training integration started")
|
|
|
|
def stop_integration(self):
|
|
"""Stop the training integration system"""
|
|
self.is_running = False
|
|
|
|
# Stop data collection
|
|
self.data_collector.stop_collection()
|
|
|
|
# Stop CNN training
|
|
self.cnn_trainer.stop_training()
|
|
|
|
# Wait for threads to finish
|
|
if self.collection_thread:
|
|
self.collection_thread.join(timeout=10)
|
|
|
|
for thread in self.training_threads.values():
|
|
thread.join(timeout=5)
|
|
|
|
logger.info("Training integration stopped")
|
|
|
|
def _data_collection_worker(self):
|
|
"""Main data collection worker"""
|
|
logger.info("Data collection worker started")
|
|
|
|
while self.is_running:
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Collect data for each symbol
|
|
for symbol in self.data_provider.symbols:
|
|
self._collect_symbol_data(symbol)
|
|
|
|
# Update performance stats
|
|
collection_time = time.time() - start_time
|
|
self._update_collection_stats(collection_time)
|
|
|
|
# Wait for next collection cycle
|
|
time.sleep(self.config.collection_interval)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in data collection worker: {e}")
|
|
time.sleep(5) # Wait before retrying
|
|
|
|
logger.info("Data collection worker stopped")
|
|
|
|
def _collect_symbol_data(self, symbol: str):
|
|
"""Collect comprehensive training data for a symbol"""
|
|
try:
|
|
# Get current market data from data provider
|
|
ohlcv_data = self._get_ohlcv_data(symbol)
|
|
tick_data = self._get_tick_data(symbol)
|
|
cob_data = self._get_cob_data(symbol)
|
|
technical_indicators = self._get_technical_indicators(symbol)
|
|
pivot_points = self._get_pivot_points(symbol)
|
|
|
|
# Validate data availability
|
|
if not self._validate_data_availability(symbol, ohlcv_data, tick_data):
|
|
return
|
|
|
|
# Create model input features
|
|
cnn_features = self._create_cnn_features(symbol, ohlcv_data, technical_indicators)
|
|
rl_state = self._create_rl_state(symbol, ohlcv_data, cob_data, technical_indicators)
|
|
orchestrator_context = self._create_orchestrator_context(symbol)
|
|
|
|
# Get model predictions if available
|
|
model_predictions = self._get_current_model_predictions(symbol)
|
|
|
|
# Collect training data package
|
|
episode_id = self.data_collector.collect_training_data(
|
|
symbol=symbol,
|
|
ohlcv_data=ohlcv_data,
|
|
tick_data=tick_data,
|
|
cob_data=cob_data,
|
|
technical_indicators=technical_indicators,
|
|
pivot_points=pivot_points,
|
|
cnn_features=cnn_features,
|
|
rl_state=rl_state,
|
|
orchestrator_context=orchestrator_context,
|
|
model_predictions=model_predictions
|
|
)
|
|
|
|
if episode_id:
|
|
self.integration_stats['data_packages_created'] += 1
|
|
logger.debug(f"Created training data package for {symbol}: {episode_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error collecting data for {symbol}: {e}")
|
|
self.integration_stats['validation_failures'] += 1
|
|
|
|
def _get_ohlcv_data(self, symbol: str) -> Dict[str, pd.DataFrame]:
|
|
"""Get OHLCV data for all timeframes"""
|
|
ohlcv_data = {}
|
|
|
|
try:
|
|
for timeframe in ['1s', '1m', '5m', '15m', '1h', '1d']:
|
|
df = self.data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=300, # Get 300 bars as specified in requirements
|
|
refresh=True # Get fresh data
|
|
)
|
|
|
|
if df is not None and not df.empty:
|
|
ohlcv_data[timeframe] = df
|
|
|
|
return ohlcv_data
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
|
return {}
|
|
|
|
def _get_tick_data(self, symbol: str) -> List[Dict[str, Any]]:
|
|
"""Get recent tick data"""
|
|
try:
|
|
# Get tick data from data provider's tick buffers
|
|
binance_symbol = symbol.replace('/', '').upper()
|
|
|
|
if binance_symbol in self.data_provider.tick_buffers:
|
|
# Get last 300 seconds of tick data
|
|
current_time = datetime.now()
|
|
cutoff_time = current_time - timedelta(seconds=300)
|
|
|
|
tick_buffer = self.data_provider.tick_buffers[binance_symbol]
|
|
recent_ticks = []
|
|
|
|
# Convert deque to list and filter by time
|
|
for tick in list(tick_buffer):
|
|
if hasattr(tick, 'timestamp') and tick.timestamp >= cutoff_time:
|
|
recent_ticks.append({
|
|
'timestamp': tick.timestamp,
|
|
'price': tick.price,
|
|
'volume': tick.volume,
|
|
'side': tick.side,
|
|
'trade_id': tick.trade_id
|
|
})
|
|
|
|
return recent_ticks
|
|
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
|
return []
|
|
|
|
def _get_cob_data(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get Consolidated Order Book data"""
|
|
try:
|
|
# Get COB data from data provider's COB cache
|
|
binance_symbol = symbol.replace('/', '').upper()
|
|
|
|
if binance_symbol in self.data_provider.cob_data_cache:
|
|
cob_buffer = self.data_provider.cob_data_cache[binance_symbol]
|
|
|
|
if cob_buffer:
|
|
# Get the most recent COB data
|
|
latest_cob = list(cob_buffer)[-1] if cob_buffer else None
|
|
|
|
if latest_cob:
|
|
return {
|
|
'timestamp': latest_cob[0] if isinstance(latest_cob, tuple) else datetime.now(),
|
|
'cob_features': latest_cob[1] if isinstance(latest_cob, tuple) else latest_cob,
|
|
'feature_count': len(latest_cob[1]) if isinstance(latest_cob, tuple) else 0
|
|
}
|
|
|
|
return {}
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting COB data for {symbol}: {e}")
|
|
return {}
|
|
|
|
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
|
"""Get technical indicators from OHLCV data"""
|
|
try:
|
|
# Get the most recent 1m data with indicators
|
|
df = self.data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe='1m',
|
|
limit=50,
|
|
refresh=True
|
|
)
|
|
|
|
if df is not None and not df.empty:
|
|
# Extract indicators from the latest row
|
|
latest_row = df.iloc[-1]
|
|
indicators = {}
|
|
|
|
# Extract common indicators
|
|
for col in df.columns:
|
|
if col not in ['open', 'high', 'low', 'close', 'volume', 'timestamp']:
|
|
try:
|
|
value = float(latest_row[col])
|
|
if not np.isnan(value):
|
|
indicators[col] = value
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
return indicators
|
|
|
|
return {}
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting technical indicators for {symbol}: {e}")
|
|
return {}
|
|
|
|
def _get_pivot_points(self, symbol: str) -> List[Dict[str, Any]]:
|
|
"""Get recent pivot points"""
|
|
try:
|
|
# Get pivot points from Williams Market Structure
|
|
if symbol in self.data_provider.williams_structure:
|
|
williams = self.data_provider.williams_structure[symbol]
|
|
|
|
# Get recent pivot points
|
|
pivot_points = []
|
|
|
|
# This would integrate with the Williams Market Structure
|
|
# For now, return empty list as placeholder
|
|
return pivot_points
|
|
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting pivot points for {symbol}: {e}")
|
|
return []
|
|
|
|
def _create_cnn_features(self,
|
|
symbol: str,
|
|
ohlcv_data: Dict[str, pd.DataFrame],
|
|
technical_indicators: Dict[str, float]) -> np.ndarray:
|
|
"""Create CNN input features from market data"""
|
|
try:
|
|
# This is a simplified feature creation
|
|
# In practice, you'd create multi-timeframe features
|
|
|
|
features = []
|
|
|
|
# Add OHLCV features from multiple timeframes
|
|
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
|
|
if timeframe in ohlcv_data:
|
|
df = ohlcv_data[timeframe]
|
|
if not df.empty:
|
|
# Normalize OHLCV data
|
|
ohlcv_values = df[['open', 'high', 'low', 'close', 'volume']].values
|
|
if len(ohlcv_values) > 0:
|
|
# Take last 60 values and flatten
|
|
recent_values = ohlcv_values[-60:].flatten()
|
|
features.extend(recent_values)
|
|
|
|
# Add technical indicators
|
|
for indicator_name, value in technical_indicators.items():
|
|
features.append(value)
|
|
|
|
# Pad or truncate to fixed size
|
|
target_size = 2000 # Match CNN input size
|
|
if len(features) < target_size:
|
|
features.extend([0.0] * (target_size - len(features)))
|
|
else:
|
|
features = features[:target_size]
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error creating CNN features for {symbol}: {e}")
|
|
return np.zeros(2000, dtype=np.float32)
|
|
|
|
def _create_rl_state(self,
|
|
symbol: str,
|
|
ohlcv_data: Dict[str, pd.DataFrame],
|
|
cob_data: Dict[str, Any],
|
|
technical_indicators: Dict[str, float]) -> np.ndarray:
|
|
"""Create RL state representation"""
|
|
try:
|
|
state_features = []
|
|
|
|
# Add market state features
|
|
if '1m' in ohlcv_data and not ohlcv_data['1m'].empty:
|
|
latest_candle = ohlcv_data['1m'].iloc[-1]
|
|
state_features.extend([
|
|
latest_candle['open'],
|
|
latest_candle['high'],
|
|
latest_candle['low'],
|
|
latest_candle['close'],
|
|
latest_candle['volume']
|
|
])
|
|
|
|
# Add COB features
|
|
if 'cob_features' in cob_data:
|
|
cob_features = cob_data['cob_features']
|
|
if isinstance(cob_features, (list, np.ndarray)):
|
|
state_features.extend(cob_features[:100]) # Limit COB features
|
|
|
|
# Add technical indicators
|
|
for indicator_name, value in technical_indicators.items():
|
|
state_features.append(value)
|
|
|
|
# Pad or truncate to fixed size
|
|
target_size = 2000 # Match RL input size
|
|
if len(state_features) < target_size:
|
|
state_features.extend([0.0] * (target_size - len(state_features)))
|
|
else:
|
|
state_features = state_features[:target_size]
|
|
|
|
return np.array(state_features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error creating RL state for {symbol}: {e}")
|
|
return np.zeros(2000, dtype=np.float32)
|
|
|
|
def _create_orchestrator_context(self, symbol: str) -> Dict[str, Any]:
|
|
"""Create orchestrator context"""
|
|
try:
|
|
return {
|
|
'symbol': symbol,
|
|
'timestamp': datetime.now(),
|
|
'market_session': self._determine_market_session(),
|
|
'volatility_regime': self._determine_volatility_regime(symbol),
|
|
'trend_direction': self._determine_trend_direction(symbol)
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Error creating orchestrator context for {symbol}: {e}")
|
|
return {'symbol': symbol, 'timestamp': datetime.now()}
|
|
|
|
def _determine_market_session(self) -> str:
|
|
"""Determine current market session"""
|
|
# Simplified market session detection
|
|
current_hour = datetime.now().hour
|
|
|
|
if 0 <= current_hour < 8:
|
|
return 'asian'
|
|
elif 8 <= current_hour < 16:
|
|
return 'european'
|
|
else:
|
|
return 'american'
|
|
|
|
def _determine_volatility_regime(self, symbol: str) -> str:
|
|
"""Determine volatility regime for symbol"""
|
|
try:
|
|
# Get recent volatility data
|
|
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
|
|
if df is not None and not df.empty:
|
|
returns = df['close'].pct_change().dropna()
|
|
volatility = returns.std()
|
|
|
|
if volatility > 0.02:
|
|
return 'high'
|
|
elif volatility > 0.01:
|
|
return 'medium'
|
|
else:
|
|
return 'low'
|
|
|
|
return 'unknown'
|
|
except Exception:
|
|
return 'unknown'
|
|
|
|
def _determine_trend_direction(self, symbol: str) -> str:
|
|
"""Determine trend direction for symbol"""
|
|
try:
|
|
# Simple trend detection using moving averages
|
|
df = self.data_provider.get_historical_data(symbol, '1h', limit=50)
|
|
if df is not None and not df.empty:
|
|
if 'sma_20' in df.columns and 'sma_50' in df.columns:
|
|
latest_sma20 = df['sma_20'].iloc[-1]
|
|
latest_sma50 = df['sma_50'].iloc[-1]
|
|
|
|
if latest_sma20 > latest_sma50:
|
|
return 'uptrend'
|
|
elif latest_sma20 < latest_sma50:
|
|
return 'downtrend'
|
|
else:
|
|
return 'sideways'
|
|
|
|
return 'unknown'
|
|
except Exception:
|
|
return 'unknown'
|
|
|
|
def _get_current_model_predictions(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get current predictions from all models"""
|
|
predictions = {}
|
|
|
|
try:
|
|
# This would integrate with existing model predictions
|
|
# For now, return empty dict as placeholder
|
|
return predictions
|
|
except Exception as e:
|
|
logger.warning(f"Error getting model predictions for {symbol}: {e}")
|
|
return {}
|
|
|
|
def _validate_data_availability(self,
|
|
symbol: str,
|
|
ohlcv_data: Dict[str, pd.DataFrame],
|
|
tick_data: List[Dict[str, Any]]) -> bool:
|
|
"""Validate that sufficient data is available for training"""
|
|
try:
|
|
# Check OHLCV data availability
|
|
required_timeframes = ['1m', '5m', '1h']
|
|
available_timeframes = 0
|
|
|
|
for timeframe in required_timeframes:
|
|
if timeframe in ohlcv_data and not ohlcv_data[timeframe].empty:
|
|
available_timeframes += 1
|
|
|
|
# Check minimum data requirements
|
|
if available_timeframes < 2: # Need at least 2 timeframes
|
|
return False
|
|
|
|
# Check tick data availability (optional but preferred)
|
|
has_tick_data = len(tick_data) > 0
|
|
|
|
# Calculate completeness score
|
|
completeness = available_timeframes / len(required_timeframes)
|
|
if has_tick_data:
|
|
completeness += 0.1 # Bonus for tick data
|
|
|
|
return completeness >= self.config.min_data_completeness
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error validating data availability for {symbol}: {e}")
|
|
return False
|
|
|
|
def _training_worker(self, symbol: str):
|
|
"""Training worker for a specific symbol"""
|
|
logger.info(f"Training worker started for {symbol}")
|
|
|
|
while self.is_running:
|
|
try:
|
|
# Check if we have enough episodes for training
|
|
episodes = self.data_collector.get_high_priority_episodes(
|
|
symbol=symbol,
|
|
limit=self.config.training_batch_size * 2,
|
|
min_priority=0.3
|
|
)
|
|
|
|
if len(episodes) >= self.config.min_episodes_for_training:
|
|
# Trigger CNN training
|
|
results = self.cnn_trainer.train_on_profitable_episodes(
|
|
symbol=symbol,
|
|
min_profitability=0.6,
|
|
max_episodes=len(episodes)
|
|
)
|
|
|
|
if results.get('status') == 'success':
|
|
self.integration_stats['training_sessions_triggered'] += 1
|
|
logger.info(f"Training session completed for {symbol}")
|
|
|
|
# Wait before next training check
|
|
time.sleep(300) # Check every 5 minutes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training worker for {symbol}: {e}")
|
|
time.sleep(60) # Wait before retrying
|
|
|
|
logger.info(f"Training worker stopped for {symbol}")
|
|
|
|
def _update_collection_stats(self, collection_time: float):
|
|
"""Update collection performance statistics"""
|
|
try:
|
|
# Update average collection time
|
|
alpha = 0.1 # Exponential moving average factor
|
|
if self.integration_stats['average_collection_time'] == 0:
|
|
self.integration_stats['average_collection_time'] = collection_time
|
|
else:
|
|
self.integration_stats['average_collection_time'] = (
|
|
alpha * collection_time +
|
|
(1 - alpha) * self.integration_stats['average_collection_time']
|
|
)
|
|
|
|
self.integration_stats['last_update'] = datetime.now()
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error updating collection stats: {e}")
|
|
|
|
def get_integration_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive integration statistics"""
|
|
stats = self.integration_stats.copy()
|
|
|
|
# Add data collector statistics
|
|
collector_stats = self.data_collector.get_collection_statistics()
|
|
stats.update(collector_stats)
|
|
|
|
# Add CNN trainer statistics
|
|
trainer_stats = self.cnn_trainer.get_training_statistics()
|
|
stats['cnn_training'] = trainer_stats
|
|
|
|
# Add performance metrics
|
|
stats['is_running'] = self.is_running
|
|
stats['active_symbols'] = len(self.data_provider.symbols)
|
|
stats['collection_frequency'] = self.config.collection_interval
|
|
|
|
return stats
|
|
|
|
def trigger_manual_training(self, symbol: str, training_type: str = 'profitable') -> Dict[str, Any]:
|
|
"""Manually trigger training for a symbol"""
|
|
try:
|
|
if training_type == 'profitable':
|
|
results = self.cnn_trainer.train_on_profitable_episodes(
|
|
symbol=symbol,
|
|
min_profitability=0.7,
|
|
max_episodes=200
|
|
)
|
|
elif training_type == 'high_value_replay':
|
|
results = self.cnn_trainer.replay_high_value_sessions(
|
|
symbol=symbol,
|
|
min_session_value=0.8,
|
|
max_sessions=10
|
|
)
|
|
else:
|
|
return {'status': 'error', 'error': f'Unknown training type: {training_type}'}
|
|
|
|
if results.get('status') == 'success':
|
|
self.integration_stats['training_sessions_triggered'] += 1
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in manual training trigger: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
# Global instance
|
|
training_integration = None
|
|
|
|
def get_training_integration(data_provider: DataProvider = None) -> TrainingIntegration:
|
|
"""Get global training integration instance"""
|
|
global training_integration
|
|
if training_integration is None:
|
|
if data_provider is None:
|
|
raise ValueError("DataProvider required for first initialization")
|
|
training_integration = TrainingIntegration(data_provider)
|
|
return training_integration |