data normalizations
This commit is contained in:
@@ -16,6 +16,7 @@ import logging
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
@@ -2364,12 +2365,262 @@ class TradingOrchestrator:
|
||||
logger.info("Initializing ExtremaTrainer with historical context...")
|
||||
self.extrema_trainer.initialize_context_data()
|
||||
|
||||
# CRITICAL: Initialize ALL models with historical data
|
||||
self._initialize_models_with_historical_data(loaded_data)
|
||||
|
||||
logger.info(f"🎯 Historical data loading complete: {total_candles} total candles loaded")
|
||||
logger.info(f"📊 Available datasets: {list(loaded_data.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in historical data loading: {e}")
|
||||
|
||||
def _initialize_models_with_historical_data(self, loaded_data: Dict[str, Any]):
|
||||
"""Initialize all NN models with historical data and multi-symbol support"""
|
||||
try:
|
||||
logger.info("Initializing models with historical data and multi-symbol support...")
|
||||
|
||||
# Prepare multi-symbol feature matrices
|
||||
symbol_features = self._prepare_multi_symbol_features(loaded_data)
|
||||
|
||||
# Initialize CNN with multi-symbol data
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model:
|
||||
logger.info("Initializing CNN with multi-symbol historical features...")
|
||||
self._initialize_cnn_with_data(symbol_features)
|
||||
|
||||
# Initialize DQN with multi-symbol states
|
||||
if hasattr(self, 'rl_agent') and self.rl_agent:
|
||||
logger.info("Initializing DQN with multi-symbol state vectors...")
|
||||
self._initialize_dqn_with_data(symbol_features)
|
||||
|
||||
# Initialize Transformer with sequence data
|
||||
if hasattr(self, 'transformer_model') and self.transformer_model:
|
||||
logger.info("Initializing Transformer with multi-symbol sequences...")
|
||||
self._initialize_transformer_with_data(symbol_features)
|
||||
|
||||
# Initialize Decision Fusion with comprehensive features
|
||||
if hasattr(self, 'decision_fusion') and self.decision_fusion:
|
||||
logger.info("Initializing Decision Fusion with multi-symbol features...")
|
||||
self._initialize_decision_with_data(symbol_features)
|
||||
|
||||
logger.info("✅ All models initialized with historical multi-symbol data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing models with historical data: {e}")
|
||||
|
||||
def _prepare_multi_symbol_features(self, loaded_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare normalized multi-symbol feature matrices"""
|
||||
try:
|
||||
symbol_features = {
|
||||
'ETH/USDT': {'1m': None, '1h': None, '1d': None},
|
||||
'BTC/USDT': {'1m': None}
|
||||
}
|
||||
|
||||
# Process each symbol's data with symbol-specific normalization
|
||||
for data_key, df in loaded_data.items():
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
# Extract symbol and timeframe
|
||||
if '_1m' in data_key:
|
||||
symbol = data_key.replace('_1m', '')
|
||||
timeframe = '1m'
|
||||
elif '_1h' in data_key:
|
||||
symbol = data_key.replace('_1h', '')
|
||||
timeframe = '1h'
|
||||
elif '_1d' in data_key:
|
||||
symbol = data_key.replace('_1d', '')
|
||||
timeframe = '1d'
|
||||
else:
|
||||
continue
|
||||
|
||||
# Apply symbol-grouped normalization
|
||||
normalized_df = self._apply_symbol_grouped_normalization(df, symbol)
|
||||
|
||||
if normalized_df is not None:
|
||||
symbol_features[symbol][timeframe] = normalized_df
|
||||
logger.debug(f"Prepared normalized features for {symbol} {timeframe}")
|
||||
|
||||
return symbol_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing multi-symbol features: {e}")
|
||||
return {}
|
||||
|
||||
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
|
||||
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Get symbol-specific price ranges for consistent normalization
|
||||
symbol_price_ranges = {
|
||||
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
|
||||
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
|
||||
}
|
||||
|
||||
if symbol in symbol_price_ranges:
|
||||
price_range = symbol_price_ranges[symbol]
|
||||
range_size = price_range['max'] - price_range['min']
|
||||
|
||||
# Normalize price columns to [0, 1] range specific to symbol
|
||||
price_cols = ['open', 'high', 'low', 'close']
|
||||
for col in price_cols:
|
||||
if col in df_norm.columns:
|
||||
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
|
||||
|
||||
# Normalize volume to [0, 1] using log scale
|
||||
if 'volume' in df_norm.columns:
|
||||
df_norm['volume'] = np.log1p(df_norm['volume'])
|
||||
vol_max = df_norm['volume'].max()
|
||||
if vol_max > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / vol_max
|
||||
|
||||
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
|
||||
|
||||
# Fill any NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
def _initialize_cnn_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize CNN with multi-symbol feature matrix"""
|
||||
try:
|
||||
# Create combined feature matrix: [ETH_1m, ETH_1h, ETH_1d, BTC_1m]
|
||||
combined_features = []
|
||||
|
||||
# ETH features (1m, 1h, 1d)
|
||||
for timeframe in ['1m', '1h', '1d']:
|
||||
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
# Use last 60 candles for CNN input
|
||||
recent_data = eth_data.tail(60)
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
combined_features.append(features.flatten())
|
||||
|
||||
# BTC features (1m)
|
||||
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
|
||||
if btc_data is not None and not btc_data.empty:
|
||||
recent_data = btc_data.tail(60)
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
combined_features.append(features.flatten())
|
||||
|
||||
if combined_features:
|
||||
# Concatenate all features
|
||||
full_features = np.concatenate(combined_features)
|
||||
logger.info(f"CNN initialized with {len(full_features)} multi-symbol features")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
self.model_historical_features = {}
|
||||
self.model_historical_features['cnn'] = full_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing CNN with historical data: {e}")
|
||||
|
||||
def _initialize_dqn_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize DQN with multi-symbol state vectors"""
|
||||
try:
|
||||
# Create comprehensive state vector combining all symbols and timeframes
|
||||
state_components = []
|
||||
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
timeframes = ['1m', '1h', '1d'] if symbol == 'ETH/USDT' else ['1m']
|
||||
|
||||
for timeframe in timeframes:
|
||||
data = symbol_features.get(symbol, {}).get(timeframe)
|
||||
if data is not None and not data.empty:
|
||||
# Extract key features for state
|
||||
latest = data.iloc[-1]
|
||||
state_features = [
|
||||
latest['close'], # Current price
|
||||
latest['volume'], # Current volume
|
||||
data['close'].pct_change().iloc[-1] if len(data) > 1 else 0, # Price change
|
||||
]
|
||||
state_components.extend(state_features)
|
||||
|
||||
if state_components:
|
||||
# Pad or truncate to expected DQN state size
|
||||
target_size = 100 # DQN expects 100-dimensional state
|
||||
if len(state_components) < target_size:
|
||||
state_components.extend([0] * (target_size - len(state_components)))
|
||||
else:
|
||||
state_components = state_components[:target_size]
|
||||
|
||||
state_vector = np.array(state_components, dtype=np.float32)
|
||||
logger.info(f"DQN initialized with {len(state_vector)} dimensional multi-symbol state")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
self.model_historical_features = {}
|
||||
self.model_historical_features['dqn'] = state_vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing DQN with historical data: {e}")
|
||||
|
||||
def _initialize_transformer_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize Transformer with multi-symbol sequence data"""
|
||||
try:
|
||||
# Prepare sequence data for transformer
|
||||
sequences = []
|
||||
|
||||
# ETH sequences
|
||||
for timeframe in ['1m', '1h', '1d']:
|
||||
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
# Use last 150 points as sequence
|
||||
sequence = eth_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
|
||||
# BTC sequence
|
||||
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
|
||||
if btc_data is not None and not btc_data.empty:
|
||||
sequence = btc_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
|
||||
if sequences:
|
||||
logger.info(f"Transformer initialized with {len(sequences)} multi-symbol sequences")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
self.model_historical_features = {}
|
||||
self.model_historical_features['transformer'] = sequences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Transformer with historical data: {e}")
|
||||
|
||||
def _initialize_decision_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize Decision Fusion with comprehensive multi-symbol features"""
|
||||
try:
|
||||
# Aggregate all available features for decision fusion
|
||||
all_features = {}
|
||||
|
||||
for symbol in symbol_features:
|
||||
for timeframe in symbol_features[symbol]:
|
||||
data = symbol_features[symbol][timeframe]
|
||||
if data is not None and not data.empty:
|
||||
key = f"{symbol}_{timeframe}"
|
||||
all_features[key] = {
|
||||
'latest_price': data['close'].iloc[-1],
|
||||
'volume': data['volume'].iloc[-1],
|
||||
'price_change': data['close'].pct_change().iloc[-1] if len(data) > 1 else 0,
|
||||
'volatility': data['close'].std() if len(data) > 1 else 0
|
||||
}
|
||||
|
||||
if all_features:
|
||||
logger.info(f"Decision Fusion initialized with {len(all_features)} symbol-timeframe combinations")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
self.model_historical_features = {}
|
||||
self.model_historical_features['decision'] = all_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Decision Fusion with historical data: {e}")
|
||||
|
||||
def get_ohlcv_data(self, symbol: str, timeframe: str, limit: int = 300) -> List:
|
||||
"""Get OHLCV data for a symbol with specified timeframe and limit."""
|
||||
try:
|
||||
|
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Demo: Data Stream Monitor for Model Input Capture
|
||||
|
||||
This script demonstrates how to use the DataStreamMonitor to capture
|
||||
and stream all model input data in console-friendly text format.
|
||||
|
||||
Run this while the dashboard is running to see real-time data streaming.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def main():
|
||||
print("=" * 80)
|
||||
print("DATA STREAM MONITOR DEMO")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
print("This demo shows how to control the data streaming system.")
|
||||
print("Make sure the dashboard is running first with:")
|
||||
print(" source venv/bin/activate && python run_clean_dashboard.py")
|
||||
print()
|
||||
|
||||
print("Available commands:")
|
||||
print("1. Start streaming: python data_stream_control.py start")
|
||||
print("2. Stop streaming: python data_stream_control.py stop")
|
||||
print("3. Save snapshot: python data_stream_control.py snapshot")
|
||||
print("4. Switch to compact: python data_stream_control.py compact")
|
||||
print("5. Switch to detailed: python data_stream_control.py detailed")
|
||||
print("6. Check status: python data_stream_control.py status")
|
||||
print()
|
||||
|
||||
print("Data streams captured:")
|
||||
print("• OHLCV data (1m, 5m, 15m timeframes)")
|
||||
print("• Real-time tick data")
|
||||
print("• COB (Consolidated Order Book) data")
|
||||
print("• Technical indicators")
|
||||
print("• Model state vectors for each model")
|
||||
print("• Recent predictions from all models")
|
||||
print("• Training experiences and rewards")
|
||||
print()
|
||||
|
||||
print("Output formats:")
|
||||
print("• Detailed: Human-readable format with sections")
|
||||
print("• Compact: JSON format for programmatic processing")
|
||||
print()
|
||||
|
||||
print("""
|
||||
================================================================================
|
||||
DATA STREAM DEMO
|
||||
================================================================================
|
||||
|
||||
The data stream is now managed by the TradingOrchestrator and starts
|
||||
automatically when you run the dashboard:
|
||||
|
||||
python run_clean_dashboard.py
|
||||
|
||||
You should see periodic data samples in the dashboard console.
|
||||
|
||||
================================================================================
|
||||
DATA STREAM SAMPLE - 14:30:15
|
||||
================================================================================
|
||||
OHLCV (1m): ETH/USDT | O:4335.67 H:4338.92 L:4334.21 C:4336.67 V:125.8
|
||||
TICK: ETH/USDT | Price:4336.67 Vol:0.0456 Side:buy
|
||||
COB: ETH/USDT | Imbalance:0.234 Spread:2.3bps Mid:4336.67
|
||||
DQN State: 15 features | Price:4336.67
|
||||
DQN Prediction: BUY (conf:0.78)
|
||||
Training Exp: Action:1 Reward:0.0234 Done:False
|
||||
================================================================================
|
||||
""")
|
||||
|
||||
print("Example console output (Compact format):")
|
||||
print('DATA_STREAM: {"timestamp":"2024-01-15T14:30:15","ohlcv_count":5,"ticks_count":12,"cob_count":8,"predictions_count":3,"experiences_count":7,"price":4336.67,"volume":125.8,"imbalance":0.234,"spread_bps":2.3}')
|
||||
print()
|
||||
|
||||
print("To start streaming, run:")
|
||||
print(" python data_stream_control.py start")
|
||||
print()
|
||||
print("The streaming will continue until you stop it with:")
|
||||
print(" python data_stream_control.py stop")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,233 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Integration for Checkpoint Management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = False):
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.enable_wandb = enable_wandb
|
||||
|
||||
if self.enable_wandb:
|
||||
self._init_wandb()
|
||||
|
||||
def _init_wandb(self):
|
||||
# Disabled by default to avoid CLI prompts
|
||||
pass
|
||||
|
||||
def save_cnn_checkpoint(self,
|
||||
cnn_model,
|
||||
model_name: str,
|
||||
epoch: int,
|
||||
train_accuracy: float,
|
||||
val_accuracy: float,
|
||||
train_loss: float,
|
||||
val_loss: float,
|
||||
training_time_hours: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'accuracy': train_accuracy,
|
||||
'val_accuracy': val_accuracy,
|
||||
'loss': train_loss,
|
||||
'val_loss': val_loss
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': epoch,
|
||||
'training_time_hours': training_time_hours,
|
||||
'total_parameters': self._count_parameters(cnn_model)
|
||||
}
|
||||
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def save_rl_checkpoint(self,
|
||||
rl_agent,
|
||||
model_name: str,
|
||||
episode: int,
|
||||
avg_reward: float,
|
||||
best_reward: float,
|
||||
epsilon: float,
|
||||
total_pnl: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'reward': avg_reward,
|
||||
'best_reward': best_reward
|
||||
}
|
||||
|
||||
if total_pnl is not None:
|
||||
performance_metrics['pnl'] = total_pnl
|
||||
|
||||
training_metadata = {
|
||||
'episode': episode,
|
||||
'epsilon': epsilon,
|
||||
'total_parameters': self._count_parameters(rl_agent)
|
||||
}
|
||||
|
||||
# W&B disabled
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"RL checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving RL checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
||||
file_path, metadata = result
|
||||
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
logger.info(f"Loaded best checkpoint for {model_name}:")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
logger.info(f" Created: {metadata.created_at}")
|
||||
|
||||
if model_class and 'model_state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
return model
|
||||
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _count_parameters(self, model) -> int:
|
||||
try:
|
||||
if hasattr(model, 'parameters'):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
elif hasattr(model, 'policy_net'):
|
||||
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
||||
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
||||
return policy_params + target_params
|
||||
else:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
_training_integration = None
|
||||
|
||||
def get_training_integration() -> TrainingIntegration:
|
||||
global _training_integration
|
||||
if _training_integration is None:
|
||||
_training_integration = TrainingIntegration()
|
||||
return _training_integration
|
||||
|
||||
# ---------------- Unified Training Manager ----------------
|
||||
|
||||
class UnifiedTrainingManager:
|
||||
"""Single entry point to manage all training in the system.
|
||||
|
||||
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, dashboard=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.dashboard = dashboard
|
||||
self.training_system = None
|
||||
self.started = False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
try:
|
||||
# Import via project root shim to avoid path issues
|
||||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
self.training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=self.dashboard
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
|
||||
self.training_system = None
|
||||
return False
|
||||
|
||||
def start(self) -> bool:
|
||||
try:
|
||||
if self.training_system is None:
|
||||
if not self.initialize():
|
||||
return False
|
||||
self.training_system.start_training()
|
||||
self.started = True
|
||||
logger.info("UnifiedTrainingManager: training started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
|
||||
return False
|
||||
|
||||
def stop(self) -> bool:
|
||||
try:
|
||||
if self.training_system and self.started:
|
||||
self.training_system.stop_training()
|
||||
self.started = False
|
||||
logger.info("UnifiedTrainingManager: training stopped")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
try:
|
||||
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
|
||||
return self.training_system.get_training_stats()
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
_unified_training_manager = None
|
||||
|
||||
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
|
||||
global _unified_training_manager
|
||||
if _unified_training_manager is None:
|
||||
if orchestrator is None or data_provider is None:
|
||||
raise ValueError("orchestrator and data_provider are required for first-time initialization")
|
||||
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
|
||||
return _unified_training_manager
|
Reference in New Issue
Block a user