training wip

This commit is contained in:
Dobromir Popov
2025-09-02 19:25:13 +03:00
parent 6dcb82c184
commit 226a6aa047
8 changed files with 241 additions and 771 deletions

View File

@@ -1757,16 +1757,27 @@ class TradingOrchestrator:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
self.training_enabled = False
return
# Initialize unified training manager
from utils.training_integration import get_unified_training_manager
self.training_manager = get_unified_training_manager(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None
)
self.training_manager.initialize()
# Keep backward-compatible attribute
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
# Initialize enhanced training system directly (no external training_integration module needed)
try:
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None
)
logger.info("✅ Enhanced training system initialized successfully")
# Auto-start training by default
logger.info("🚀 Auto-starting enhanced real-time training...")
self.start_enhanced_training()
except ImportError as e:
logger.error(f"Failed to import EnhancedRealtimeTrainingSystem: {e}")
self.training_enabled = False
return
logger.info("Enhanced real-time training system initialized")
logger.info(" - Real-time model training: ENABLED")
@@ -2365,8 +2376,8 @@ class TradingOrchestrator:
logger.info("Initializing ExtremaTrainer with historical context...")
self.extrema_trainer.initialize_context_data()
# CRITICAL: Initialize ALL models with historical data
self._initialize_models_with_historical_data(loaded_data)
# CRITICAL: Initialize ALL models with historical data (using data provider's normalized methods)
self._initialize_models_with_historical_data(symbols_timeframes)
logger.info(f"🎯 Historical data loading complete: {total_candles} total candles loaded")
logger.info(f"📊 Available datasets: {list(loaded_data.keys())}")
@@ -2374,144 +2385,60 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error in historical data loading: {e}")
def _initialize_models_with_historical_data(self, loaded_data: Dict[str, Any]):
"""Initialize all NN models with historical data and multi-symbol support"""
def _initialize_models_with_historical_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize all NN models with historical data using data provider's normalized methods"""
try:
logger.info("Initializing models with historical data and multi-symbol support...")
logger.info("Initializing models with normalized historical data from data provider...")
# Prepare multi-symbol feature matrices
symbol_features = self._prepare_multi_symbol_features(loaded_data)
# Use data provider's multi-symbol feature preparation
symbol_features = self.data_provider.get_multi_symbol_features_for_inference(symbols_timeframes, limit=300)
# Initialize CNN with multi-symbol data
if hasattr(self, 'cnn_model') and self.cnn_model:
logger.info("Initializing CNN with multi-symbol historical features...")
self._initialize_cnn_with_data(symbol_features)
self._initialize_cnn_with_provider_data()
# Initialize DQN with multi-symbol states
if hasattr(self, 'rl_agent') and self.rl_agent:
logger.info("Initializing DQN with multi-symbol state vectors...")
self._initialize_dqn_with_data(symbol_features)
self._initialize_dqn_with_provider_data(symbols_timeframes)
# Initialize Transformer with sequence data
if hasattr(self, 'transformer_model') and self.transformer_model:
logger.info("Initializing Transformer with multi-symbol sequences...")
self._initialize_transformer_with_data(symbol_features)
self._initialize_transformer_with_provider_data(symbols_timeframes)
# Initialize Decision Fusion with comprehensive features
if hasattr(self, 'decision_fusion') and self.decision_fusion:
logger.info("Initializing Decision Fusion with multi-symbol features...")
self._initialize_decision_with_data(symbol_features)
self._initialize_decision_with_provider_data(symbol_features)
logger.info("✅ All models initialized with historical multi-symbol data")
logger.info("✅ All models initialized with data provider's normalized historical data")
except Exception as e:
logger.error(f"Error initializing models with historical data: {e}")
def _prepare_multi_symbol_features(self, loaded_data: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare normalized multi-symbol feature matrices"""
try:
symbol_features = {
'ETH/USDT': {'1m': None, '1h': None, '1d': None},
'BTC/USDT': {'1m': None}
}
# Process each symbol's data with symbol-specific normalization
for data_key, df in loaded_data.items():
if df is None or df.empty:
continue
# Extract symbol and timeframe
if '_1m' in data_key:
symbol = data_key.replace('_1m', '')
timeframe = '1m'
elif '_1h' in data_key:
symbol = data_key.replace('_1h', '')
timeframe = '1h'
elif '_1d' in data_key:
symbol = data_key.replace('_1d', '')
timeframe = '1d'
else:
continue
# Apply symbol-grouped normalization
normalized_df = self._apply_symbol_grouped_normalization(df, symbol)
if normalized_df is not None:
symbol_features[symbol][timeframe] = normalized_df
logger.debug(f"Prepared normalized features for {symbol} {timeframe}")
return symbol_features
except Exception as e:
logger.error(f"Error preparing multi-symbol features: {e}")
return {}
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
try:
df_norm = df.copy()
# Get symbol-specific price ranges for consistent normalization
symbol_price_ranges = {
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
}
if symbol in symbol_price_ranges:
price_range = symbol_price_ranges[symbol]
range_size = price_range['max'] - price_range['min']
# Normalize price columns to [0, 1] range specific to symbol
price_cols = ['open', 'high', 'low', 'close']
for col in price_cols:
if col in df_norm.columns:
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
# Normalize volume to [0, 1] using log scale
if 'volume' in df_norm.columns:
df_norm['volume'] = np.log1p(df_norm['volume'])
vol_max = df_norm['volume'].max()
if vol_max > 0:
df_norm['volume'] = df_norm['volume'] / vol_max
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
# Fill any NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
return df
def _initialize_cnn_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize CNN with multi-symbol feature matrix"""
def _initialize_cnn_with_provider_data(self):
"""Initialize CNN using data provider's normalized feature extraction"""
try:
# Create combined feature matrix: [ETH_1m, ETH_1h, ETH_1d, BTC_1m]
combined_features = []
# ETH features (1m, 1h, 1d)
for timeframe in ['1m', '1h', '1d']:
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
if eth_data is not None and not eth_data.empty:
# Use last 60 candles for CNN input
recent_data = eth_data.tail(60)
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
combined_features.append(features.flatten())
features = self.data_provider.get_cnn_features_for_inference('ETH/USDT', timeframe, window_size=60)
if features is not None:
combined_features.append(features)
# BTC features (1m)
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
if btc_data is not None and not btc_data.empty:
recent_data = btc_data.tail(60)
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
combined_features.append(features.flatten())
btc_features = self.data_provider.get_cnn_features_for_inference('BTC/USDT', '1m', window_size=60)
if btc_features is not None:
combined_features.append(btc_features)
if combined_features:
# Concatenate all features
full_features = np.concatenate(combined_features)
logger.info(f"CNN initialized with {len(full_features)} multi-symbol features")
logger.info(f"CNN initialized with {len(full_features)} multi-symbol normalized features")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2519,39 +2446,16 @@ class TradingOrchestrator:
self.model_historical_features['cnn'] = full_features
except Exception as e:
logger.error(f"Error initializing CNN with historical data: {e}")
logger.error(f"Error initializing CNN with provider data: {e}")
def _initialize_dqn_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize DQN with multi-symbol state vectors"""
def _initialize_dqn_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize DQN using data provider's normalized state vector creation"""
try:
# Create comprehensive state vector combining all symbols and timeframes
state_components = []
# Use data provider's DQN state creation
state_vector = self.data_provider.get_dqn_state_for_inference(symbols_timeframes, target_size=100)
for symbol in ['ETH/USDT', 'BTC/USDT']:
timeframes = ['1m', '1h', '1d'] if symbol == 'ETH/USDT' else ['1m']
for timeframe in timeframes:
data = symbol_features.get(symbol, {}).get(timeframe)
if data is not None and not data.empty:
# Extract key features for state
latest = data.iloc[-1]
state_features = [
latest['close'], # Current price
latest['volume'], # Current volume
data['close'].pct_change().iloc[-1] if len(data) > 1 else 0, # Price change
]
state_components.extend(state_features)
if state_components:
# Pad or truncate to expected DQN state size
target_size = 100 # DQN expects 100-dimensional state
if len(state_components) < target_size:
state_components.extend([0] * (target_size - len(state_components)))
else:
state_components = state_components[:target_size]
state_vector = np.array(state_components, dtype=np.float32)
logger.info(f"DQN initialized with {len(state_vector)} dimensional multi-symbol state")
if state_vector is not None:
logger.info(f"DQN initialized with {len(state_vector)} dimensional normalized multi-symbol state")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2559,30 +2463,16 @@ class TradingOrchestrator:
self.model_historical_features['dqn'] = state_vector
except Exception as e:
logger.error(f"Error initializing DQN with historical data: {e}")
logger.error(f"Error initializing DQN with provider data: {e}")
def _initialize_transformer_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize Transformer with multi-symbol sequence data"""
def _initialize_transformer_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize Transformer using data provider's normalized sequence creation"""
try:
# Prepare sequence data for transformer
sequences = []
# ETH sequences
for timeframe in ['1m', '1h', '1d']:
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
if eth_data is not None and not eth_data.empty:
# Use last 150 points as sequence
sequence = eth_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
sequences.append(sequence)
# BTC sequence
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
if btc_data is not None and not btc_data.empty:
sequence = btc_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
sequences.append(sequence)
# Use data provider's transformer sequence creation
sequences = self.data_provider.get_transformer_sequences_for_inference(symbols_timeframes, seq_length=150)
if sequences:
logger.info(f"Transformer initialized with {len(sequences)} multi-symbol sequences")
logger.info(f"Transformer initialized with {len(sequences)} normalized multi-symbol sequences")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2590,10 +2480,10 @@ class TradingOrchestrator:
self.model_historical_features['transformer'] = sequences
except Exception as e:
logger.error(f"Error initializing Transformer with historical data: {e}")
logger.error(f"Error initializing Transformer with provider data: {e}")
def _initialize_decision_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize Decision Fusion with comprehensive multi-symbol features"""
def _initialize_decision_with_provider_data(self, symbol_features: Dict[str, Dict[str, pd.DataFrame]]):
"""Initialize Decision Fusion using data provider's feature aggregation"""
try:
# Aggregate all available features for decision fusion
all_features = {}
@@ -2611,7 +2501,7 @@ class TradingOrchestrator:
}
if all_features:
logger.info(f"Decision Fusion initialized with {len(all_features)} symbol-timeframe combinations")
logger.info(f"Decision Fusion initialized with {len(all_features)} normalized symbol-timeframe combinations")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2619,7 +2509,7 @@ class TradingOrchestrator:
self.model_historical_features['decision'] = all_features
except Exception as e:
logger.error(f"Error initializing Decision Fusion with historical data: {e}")
logger.error(f"Error initializing Decision Fusion with provider data: {e}")
def get_ohlcv_data(self, symbol: str, timeframe: str, limit: int = 300) -> List:
"""Get OHLCV data for a symbol with specified timeframe and limit."""