show each model's prediction (last inference) and store T model checkpoint

This commit is contained in:
Dobromir Popov
2025-07-02 09:52:45 +03:00
parent b47805dafc
commit 488fbacf67
3 changed files with 625 additions and 89 deletions

View File

@ -102,7 +102,7 @@ class CleanTradingDashboard:
enhanced_rl_training=True,
model_registry={}
)
logger.info("Using unified Trading Orchestrator with full ML capabilities")
logger.debug("Using unified Trading Orchestrator with full ML capabilities")
else:
self.orchestrator = orchestrator
@ -123,10 +123,10 @@ class CleanTradingDashboard:
self.stream_consumer_id = self.unified_stream.register_consumer(
consumer_name="CleanTradingDashboard",
callback=self._handle_unified_stream_data,
data_types=['ticks', 'ohlcv', 'training_data', 'ui_data']
data_types=['ticks', 'ohlcv', 'training_data', 'ui_data']
)
logger.info(f"Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}")
logger.info("Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)")
logger.debug(f"Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}")
logger.debug("Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)")
else:
self.unified_stream = None
self.stream_consumer_id = None
@ -205,7 +205,7 @@ class CleanTradingDashboard:
# Start Universal Data Stream
if self.unified_stream:
# threading.Thread(target=self._start_unified_stream, daemon=True).start() # Temporarily disabled
logger.info("Universal Data Stream starting...")
logger.debug("Universal Data Stream starting...")
# Initialize COB integration with high-frequency data handling
self._initialize_cob_integration()
@ -216,7 +216,7 @@ class CleanTradingDashboard:
# Start training sessions if models are showing FRESH status
threading.Thread(target=self._delayed_training_check, daemon=True).start()
logger.info("Clean Trading Dashboard initialized with HIGH-FREQUENCY COB integration and signal generation")
logger.debug("Clean Trading Dashboard initialized with HIGH-FREQUENCY COB integration and signal generation")
def _handle_unified_stream_data(self, data):
"""Placeholder for unified stream data handling."""
@ -444,7 +444,7 @@ class CleanTradingDashboard:
# Log COB signal activity
cob_signals = [d for d in filtered_decisions if d.get('type') == 'cob_liquidity_imbalance']
if cob_signals:
logger.info(f"COB signals active: {len(cob_signals)} recent COB signals")
logger.debug(f"COB signals active: {len(cob_signals)} recent COB signals")
return self.component_manager.format_trading_signals(filtered_decisions)
except Exception as e:
@ -1926,7 +1926,8 @@ class CleanTradingDashboard:
'decision': {'initial_loss': 0.2980, 'current_loss': 0.0089, 'best_loss': 0.0065, 'checkpoint_loaded': False}
}
# Get CNN predictions if available
# Get latest predictions from all models
latest_predictions = self._get_latest_model_predictions()
cnn_prediction = self._get_cnn_pivot_prediction()
# Helper function to safely calculate improvement percentage
@ -1997,21 +1998,31 @@ class CleanTradingDashboard:
dqn_active = dqn_checkpoint_loaded and dqn_inference_enabled and dqn_model_available
dqn_prediction_count = len(self.recent_decisions) if signal_generation_active else 0
if signal_generation_active and len(self.recent_decisions) > 0:
recent_signal = self.recent_decisions[-1]
last_action = self._get_signal_attribute(recent_signal, 'action', 'SIGNAL_GEN')
last_confidence = self._get_signal_attribute(recent_signal, 'confidence', 0.72)
# Get latest DQN prediction
dqn_latest = latest_predictions.get('dqn', {})
if dqn_latest:
last_action = dqn_latest.get('action', 'NONE')
last_confidence = dqn_latest.get('confidence', 0.72)
last_timestamp = dqn_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
else:
last_action = dqn_training_status['status']
last_confidence = 0.68
if signal_generation_active and len(self.recent_decisions) > 0:
recent_signal = self.recent_decisions[-1]
last_action = self._get_signal_attribute(recent_signal, 'action', 'SIGNAL_GEN')
last_confidence = self._get_signal_attribute(recent_signal, 'confidence', 0.72)
last_timestamp = datetime.now().strftime('%H:%M:%S')
else:
last_action = dqn_training_status['status']
last_confidence = 0.68
last_timestamp = datetime.now().strftime('%H:%M:%S')
dqn_model_info = {
'active': dqn_active,
'parameters': 5000000, # ~5M params for DQN
'last_prediction': {
'timestamp': datetime.now().strftime('%H:%M:%S'),
'timestamp': last_timestamp,
'action': last_action,
'confidence': last_confidence
'confidence': last_confidence,
'type': dqn_latest.get('type', 'dqn_signal') if dqn_latest else 'dqn_signal'
},
# FIXED: Get REAL loss values from orchestrator model, not placeholders
'loss_5ma': self._get_real_model_loss('dqn'),
@ -2060,13 +2071,28 @@ class CleanTradingDashboard:
cnn_timing = get_model_timing_info('CNN')
cnn_active = True
# Get latest CNN prediction
cnn_latest = latest_predictions.get('cnn', {})
if cnn_latest:
cnn_action = cnn_latest.get('action', 'PATTERN_ANALYSIS')
cnn_confidence = cnn_latest.get('confidence', 0.68)
cnn_timestamp = cnn_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
cnn_predicted_price = cnn_latest.get('predicted_price', 0)
else:
cnn_action = 'PATTERN_ANALYSIS'
cnn_confidence = 0.68
cnn_timestamp = datetime.now().strftime('%H:%M:%S')
cnn_predicted_price = 0
cnn_model_info = {
'active': cnn_active,
'parameters': 50000000, # ~50M params
'last_prediction': {
'timestamp': datetime.now().strftime('%H:%M:%S'),
'action': 'PATTERN_ANALYSIS',
'confidence': 0.68
'timestamp': cnn_timestamp,
'action': cnn_action,
'confidence': cnn_confidence,
'predicted_price': cnn_predicted_price,
'type': cnn_latest.get('type', 'cnn_pivot') if cnn_latest else 'cnn_pivot'
},
'loss_5ma': cnn_state.get('current_loss', 0.0187),
'initial_loss': cnn_state.get('initial_loss', 0.4120),
@ -2103,6 +2129,68 @@ class CleanTradingDashboard:
transformer_timing = get_model_timing_info('TRANSFORMER')
transformer_active = True
# Get transformer checkpoint info if available
transformer_checkpoint_info = {}
if self.orchestrator and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
transformer_checkpoint_info = self.orchestrator.transformer_checkpoint_info
# Get latest transformer prediction
transformer_latest = latest_predictions.get('transformer', {})
if transformer_latest:
transformer_action = transformer_latest.get('action', 'PRICE_PREDICTION')
transformer_confidence = transformer_latest.get('confidence', 0.75)
transformer_timestamp = transformer_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
transformer_predicted_price = transformer_latest.get('predicted_price', 0)
transformer_price_change = transformer_latest.get('price_change', 0)
else:
transformer_action = 'PRICE_PREDICTION'
transformer_confidence = 0.75
transformer_timestamp = datetime.now().strftime('%H:%M:%S')
transformer_predicted_price = 0
transformer_price_change = 0
transformer_last_prediction = {
'timestamp': transformer_timestamp,
'action': transformer_action,
'confidence': transformer_confidence,
'predicted_price': transformer_predicted_price,
'price_change': transformer_price_change,
'type': transformer_latest.get('type', 'transformer_prediction') if transformer_latest else 'transformer_prediction'
}
transformer_model_info = {
'active': transformer_active,
'parameters': 46000000, # ~46M params for transformer
'last_prediction': transformer_last_prediction,
'loss_5ma': transformer_state.get('current_loss', 0.0123),
'initial_loss': transformer_state.get('initial_loss', 0.2980),
'best_loss': transformer_state.get('best_loss', 0.0089),
'improvement': safe_improvement_calc(
transformer_state.get('initial_loss', 0.2980),
transformer_state.get('current_loss', 0.0123),
95.9 # Default improvement percentage
),
'checkpoint_loaded': bool(transformer_checkpoint_info),
'model_type': 'TRANSFORMER',
'description': 'Advanced Transformer (Price Prediction)',
'checkpoint_info': {
'filename': transformer_checkpoint_info.get('checkpoint_id', 'none'),
'created_at': transformer_checkpoint_info.get('created_at', 'Unknown'),
'performance_score': transformer_checkpoint_info.get('performance_score', 0.0),
'loss': transformer_checkpoint_info.get('loss', 0.0),
'accuracy': transformer_checkpoint_info.get('accuracy', 0.0)
},
'timing': {
'last_inference': transformer_timing['last_inference'].strftime('%H:%M:%S') if transformer_timing['last_inference'] else 'None',
'last_training': transformer_timing['last_training'].strftime('%H:%M:%S') if transformer_timing['last_training'] else 'None',
'inferences_per_second': f"{transformer_timing['inferences_per_second']:.2f}",
'predictions_24h': transformer_timing['prediction_count_24h']
},
'performance': self.get_model_performance_metrics().get('transformer', {})
}
loaded_models['transformer'] = transformer_model_info
transformer_active = True
# Check if transformer model is available
transformer_model_available = self.orchestrator and hasattr(self.orchestrator, 'primary_transformer')
@ -2495,14 +2583,76 @@ class CleanTradingDashboard:
logger.debug(f"Error getting CNN pivot prediction: {e}")
return None
def _get_latest_model_predictions(self) -> Dict[str, Dict]:
"""Get the latest predictions from each model"""
try:
latest_predictions = {}
# Get latest DQN prediction
if self.recent_decisions:
latest_dqn = self.recent_decisions[-1]
latest_predictions['dqn'] = {
'timestamp': latest_dqn.get('timestamp', datetime.now()),
'action': latest_dqn.get('action', 'NONE'),
'confidence': latest_dqn.get('confidence', 0),
'type': latest_dqn.get('type', 'dqn_signal')
}
# Get latest CNN prediction
cnn_prediction = self._get_cnn_pivot_prediction()
if cnn_prediction:
latest_predictions['cnn'] = {
'timestamp': datetime.now(),
'action': cnn_prediction.get('pivot_type', 'PATTERN_ANALYSIS'),
'confidence': cnn_prediction.get('confidence', 0),
'predicted_price': cnn_prediction.get('predicted_price', 0),
'type': 'cnn_pivot'
}
# Get latest Transformer prediction
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
try:
if hasattr(self.orchestrator, 'get_latest_transformer_prediction'):
transformer_pred = self.orchestrator.get_latest_transformer_prediction()
if transformer_pred:
latest_predictions['transformer'] = {
'timestamp': transformer_pred.get('timestamp', datetime.now()),
'action': transformer_pred.get('action', 'PRICE_PREDICTION'),
'confidence': transformer_pred.get('confidence', 0),
'predicted_price': transformer_pred.get('predicted_price', 0),
'price_change': transformer_pred.get('price_change', 0),
'type': 'transformer_prediction'
}
except Exception as e:
logger.debug(f"Error getting transformer prediction: {e}")
# Get latest COB RL prediction
if hasattr(self, 'cob_data_history') and 'ETH/USDT' in self.cob_data_history:
cob_history = self.cob_data_history['ETH/USDT']
if cob_history:
latest_cob = cob_history[-1]
latest_predictions['cob_rl'] = {
'timestamp': datetime.fromtimestamp(latest_cob.get('timestamp', time.time())),
'action': 'COB_ANALYSIS',
'confidence': abs(latest_cob.get('stats', {}).get('imbalance', 0)) * 100,
'imbalance': latest_cob.get('stats', {}).get('imbalance', 0),
'type': 'cob_imbalance'
}
return latest_predictions
except Exception as e:
logger.debug(f"Error getting latest model predictions: {e}")
return {}
def _start_signal_generation_loop(self):
"""Start continuous signal generation loop"""
try:
def signal_worker():
logger.info("Starting continuous signal generation loop")
logger.debug("Starting continuous signal generation loop")
# Unified orchestrator with full ML pipeline and decision-making model
logger.info("Using unified ML pipeline: Data Bus -> Models -> Decision Model -> Trading Signals")
logger.debug("Using unified ML pipeline: Data Bus -> Models -> Decision Model -> Trading Signals")
while True:
try:
@ -2535,7 +2685,7 @@ class CleanTradingDashboard:
# Start signal generation thread
signal_thread = threading.Thread(target=signal_worker, daemon=True)
signal_thread.start()
logger.info("Signal generation loop started")
logger.debug("Signal generation loop started")
except Exception as e:
logger.error(f"Error starting signal generation loop: {e}")
@ -3543,7 +3693,7 @@ class CleanTradingDashboard:
if not hasattr(self.orchestrator, 'recent_cnn_predictions'):
self.orchestrator.recent_cnn_predictions = {}
logger.info("Enhanced training system initialized for model predictions")
logger.debug("Enhanced training system initialized for model predictions")
except ImportError:
logger.warning("Enhanced training system not available - using mock predictions")
@ -3555,7 +3705,7 @@ class CleanTradingDashboard:
def _initialize_cob_integration(self):
"""Initialize simple COB integration that works without async event loops"""
try:
logger.info("Initializing simple COB integration for model feeding")
logger.debug("Initializing simple COB integration for model feeding")
# Initialize COB data storage
self.cob_data_history = {
@ -3574,7 +3724,7 @@ class CleanTradingDashboard:
# Start simple COB data collection
self._start_simple_cob_collection()
logger.info("Simple COB integration initialized successfully")
logger.debug("Simple COB integration initialized successfully")
except Exception as e:
logger.error(f"Error initializing COB integration: {e}")
@ -3790,39 +3940,271 @@ class CleanTradingDashboard:
logger.debug(f"Error generating COB signal for {symbol}: {e}")
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
"""Feed COB data to models for training and inference"""
"""Feed COB data to ALL models for training and inference - Enhanced integration"""
try:
# Calculate cumulative imbalance for model feeding
cumulative_imbalance = self._calculate_cumulative_imbalance(symbol)
# Create 15-second history for model feeding
history_data = {
# Create comprehensive COB data package for all models
cob_data_package = {
'symbol': symbol,
'current_snapshot': cob_snapshot,
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
'bucketed_data': self.cob_bucketed_data[symbol],
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
'timestamp': cob_snapshot['timestamp']
'cumulative_imbalance': cumulative_imbalance,
'timestamp': cob_snapshot['timestamp'],
'stats': cob_snapshot.get('stats', {}),
'bids': cob_snapshot.get('bids', []),
'asks': cob_snapshot.get('asks', []),
'mid_price': cob_snapshot.get('mid_price', 0),
'spread': cob_snapshot.get('spread', 0),
'liquidity_imbalance': cob_snapshot.get('stats', {}).get('imbalance', 0)
}
# Feed to orchestrator models if available
# 1. Feed to orchestrator models (if available)
if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
try:
self.orchestrator._on_cob_dashboard_data(symbol, history_data)
logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}")
self.orchestrator._on_cob_dashboard_data(symbol, cob_data_package)
except Exception as e:
logger.debug(f"Error feeding COB data to orchestrator: {e}")
# Store for training system
# 2. Feed to DQN model specifically
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
# Create DQN-specific COB features
dqn_cob_features = self._create_dqn_cob_features(symbol, cob_data_package)
if hasattr(self.orchestrator.rl_agent, 'update_cob_features'):
self.orchestrator.rl_agent.update_cob_features(symbol, dqn_cob_features)
except Exception as e:
logger.debug(f"Error feeding COB data to DQN: {e}")
# 3. Feed to CNN model specifically
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try:
# Create CNN-specific COB features
cnn_cob_features = self._create_cnn_cob_features(symbol, cob_data_package)
if hasattr(self.orchestrator.cnn_model, 'update_cob_features'):
self.orchestrator.cnn_model.update_cob_features(symbol, cnn_cob_features)
except Exception as e:
logger.debug(f"Error feeding COB data to CNN: {e}")
# 4. Feed to Transformer model specifically
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
try:
# Create Transformer-specific COB features
transformer_cob_features = self._create_transformer_cob_features(symbol, cob_data_package)
if hasattr(self.orchestrator.primary_transformer, 'update_cob_features'):
self.orchestrator.primary_transformer.update_cob_features(symbol, transformer_cob_features)
except Exception as e:
logger.debug(f"Error feeding COB data to Transformer: {e}")
# 5. Feed to COB RL model specifically
if self.orchestrator and hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
try:
# Create COB RL-specific features
cob_rl_features = self._create_cob_rl_features(symbol, cob_data_package)
if hasattr(self.orchestrator.cob_rl_agent, 'update_cob_features'):
self.orchestrator.cob_rl_agent.update_cob_features(symbol, cob_rl_features)
except Exception as e:
logger.debug(f"Error feeding COB data to COB RL: {e}")
# 6. Store for training system
if hasattr(self, 'training_system') and self.training_system:
if hasattr(self.training_system, 'real_time_data'):
self.training_system.real_time_data['cob_snapshots'].append(history_data)
self.training_system.real_time_data['cob_snapshots'].append(cob_data_package)
logger.debug(f"COB data fed to models for {symbol}")
# 7. Update latest COB features for all models
if not hasattr(self, 'latest_cob_features'):
self.latest_cob_features = {}
self.latest_cob_features[symbol] = cob_data_package
# 8. Store in model-specific COB memory
if not hasattr(self, 'model_cob_memory'):
self.model_cob_memory = {}
if symbol not in self.model_cob_memory:
self.model_cob_memory[symbol] = {}
# Store for each model type
for model_type in ['dqn', 'cnn', 'transformer', 'cob_rl']:
if model_type not in self.model_cob_memory[symbol]:
self.model_cob_memory[symbol][model_type] = []
self.model_cob_memory[symbol][model_type].append(cob_data_package)
# Keep only last 100 snapshots per model
if len(self.model_cob_memory[symbol][model_type]) > 100:
self.model_cob_memory[symbol][model_type] = self.model_cob_memory[symbol][model_type][-100:]
except Exception as e:
logger.debug(f"Error feeding COB data to models: {e}")
def _create_dqn_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
"""Create COB features specifically for DQN model"""
try:
features = []
# Basic COB features
features.append(cob_data.get('mid_price', 0) / 10000) # Normalized price
features.append(cob_data.get('spread', 0) / 100) # Normalized spread
features.append(cob_data.get('liquidity_imbalance', 0)) # Raw imbalance
# Cumulative imbalance features
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
features.extend([
cumulative_imbalance.get('1s', 0.0),
cumulative_imbalance.get('5s', 0.0),
cumulative_imbalance.get('15s', 0.0),
cumulative_imbalance.get('60s', 0.0)
])
# Order book depth features
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Top 5 levels for each side
for i in range(5):
if i < len(bids):
features.append(bids[i].get('price', 0) / 10000)
features.append(bids[i].get('size', 0) / 1000000)
else:
features.extend([0.0, 0.0])
for i in range(5):
if i < len(asks):
features.append(asks[i].get('price', 0) / 10000)
features.append(asks[i].get('size', 0) / 1000000)
else:
features.extend([0.0, 0.0])
return features
except Exception as e:
logger.debug(f"Error creating DQN COB features: {e}")
return [0.0] * 20 # Default feature vector
def _create_cnn_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
"""Create COB features specifically for CNN model"""
try:
features = []
# CNN focuses on pattern recognition - use more granular features
features.append(cob_data.get('mid_price', 0) / 10000)
features.append(cob_data.get('liquidity_imbalance', 0))
# Order book imbalance at different levels
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Calculate imbalance at different price levels
for level in [1, 2, 3, 5, 10]:
bid_vol = sum(bid.get('size', 0) for bid in bids[:level])
ask_vol = sum(ask.get('size', 0) for ask in asks[:level])
total_vol = bid_vol + ask_vol
if total_vol > 0:
imbalance = (bid_vol - ask_vol) / total_vol
else:
imbalance = 0.0
features.append(imbalance)
# Cumulative imbalance features
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
features.extend([
cumulative_imbalance.get('1s', 0.0),
cumulative_imbalance.get('5s', 0.0),
cumulative_imbalance.get('15s', 0.0)
])
return features
except Exception as e:
logger.debug(f"Error creating CNN COB features: {e}")
return [0.0] * 10 # Default feature vector
def _create_transformer_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
"""Create COB features specifically for Transformer model"""
try:
features = []
# Transformer can handle more complex features
features.append(cob_data.get('mid_price', 0) / 10000)
features.append(cob_data.get('spread', 0) / 100)
features.append(cob_data.get('liquidity_imbalance', 0))
# Order book features
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Top 10 levels for each side (more granular for transformer)
for i in range(10):
if i < len(bids):
features.append(bids[i].get('price', 0) / 10000)
features.append(bids[i].get('size', 0) / 1000000)
else:
features.extend([0.0, 0.0])
for i in range(10):
if i < len(asks):
features.append(asks[i].get('price', 0) / 10000)
features.append(asks[i].get('size', 0) / 1000000)
else:
features.extend([0.0, 0.0])
# Cumulative imbalance features
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
features.extend([
cumulative_imbalance.get('1s', 0.0),
cumulative_imbalance.get('5s', 0.0),
cumulative_imbalance.get('15s', 0.0),
cumulative_imbalance.get('60s', 0.0)
])
return features
except Exception as e:
logger.debug(f"Error creating Transformer COB features: {e}")
return [0.0] * 50 # Default feature vector
def _create_cob_rl_features(self, symbol: str, cob_data: dict) -> List[float]:
"""Create COB features specifically for COB RL model"""
try:
features = []
# COB RL focuses on order book dynamics
features.append(cob_data.get('mid_price', 0) / 10000)
features.append(cob_data.get('liquidity_imbalance', 0))
# Order book pressure indicators
bids = cob_data.get('bids', [])
asks = cob_data.get('asks', [])
# Calculate pressure at different levels
for level in [1, 2, 3, 5]:
bid_pressure = sum(bid.get('size', 0) for bid in bids[:level])
ask_pressure = sum(ask.get('size', 0) for ask in asks[:level])
features.append(bid_pressure / 1000000) # Normalized
features.append(ask_pressure / 1000000) # Normalized
# Pressure ratio
if ask_pressure > 0:
pressure_ratio = bid_pressure / ask_pressure
else:
pressure_ratio = 1.0
features.append(pressure_ratio)
# Cumulative imbalance features
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
features.extend([
cumulative_imbalance.get('1s', 0.0),
cumulative_imbalance.get('5s', 0.0),
cumulative_imbalance.get('15s', 0.0),
cumulative_imbalance.get('60s', 0.0)
])
return features
except Exception as e:
logger.debug(f"Error creating COB RL features: {e}")
return [0.0] * 20 # Default feature vector
def get_cob_data_summary(self) -> dict:
"""Get COB data summary for dashboard display"""
try:
@ -4693,30 +5075,111 @@ class CleanTradingDashboard:
if hasattr(self.orchestrator, 'primary_transformer_trainer'):
transformer_trainer = self.orchestrator.primary_transformer_trainer
# Create transformer if not exists
# Try to load existing transformer checkpoint first
if transformer_model is None or transformer_trainer is None:
config = TradingTransformerConfig(
d_model=512, # Optimized for 46M parameters
n_heads=8, # Optimized
n_layers=8, # Optimized
seq_len=100, # Optimized
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True,
use_deep_attention=True,
use_residual_connections=True,
use_layer_norm_variants=True
)
transformer_model, transformer_trainer = create_trading_transformer(config)
# Store in orchestrator
if self.orchestrator:
self.orchestrator.primary_transformer = transformer_model
self.orchestrator.primary_transformer_trainer = transformer_trainer
logger.info("Created new advanced transformer model for training")
try:
from utils.checkpoint_manager import load_best_checkpoint
# Try to load the best transformer checkpoint
checkpoint_metadata = load_best_checkpoint("transformer", "transformer")
if checkpoint_metadata and checkpoint_metadata.checkpoint_path:
logger.info(f"Loading existing transformer checkpoint: {checkpoint_metadata.checkpoint_id}")
# Load the checkpoint data
checkpoint_data = torch.load(checkpoint_metadata.checkpoint_path, map_location='cpu')
# Recreate config from checkpoint
config = TradingTransformerConfig(
d_model=checkpoint_data.get('config', {}).get('d_model', 512),
n_heads=checkpoint_data.get('config', {}).get('n_heads', 8),
n_layers=checkpoint_data.get('config', {}).get('n_layers', 8),
seq_len=checkpoint_data.get('config', {}).get('seq_len', 100),
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True,
use_deep_attention=True,
use_residual_connections=True,
use_layer_norm_variants=True
)
# Create model and trainer
transformer_model, transformer_trainer = create_trading_transformer(config)
# Load state dict
transformer_model.load_state_dict(checkpoint_data['model_state_dict'])
# Restore training history
if 'training_history' in checkpoint_data:
transformer_trainer.training_history = checkpoint_data['training_history']
# Store in orchestrator
if self.orchestrator:
self.orchestrator.primary_transformer = transformer_model
self.orchestrator.primary_transformer_trainer = transformer_trainer
self.orchestrator.transformer_checkpoint_info = {
'checkpoint_id': checkpoint_metadata.checkpoint_id,
'checkpoint_path': checkpoint_metadata.checkpoint_path,
'performance_score': checkpoint_metadata.performance_score,
'created_at': checkpoint_metadata.created_at.isoformat(),
'loss': checkpoint_metadata.performance_metrics.get('loss', 0.0),
'accuracy': checkpoint_metadata.performance_metrics.get('accuracy', 0.0)
}
logger.info(f"TRANSFORMER: Loaded checkpoint successfully - Loss: {checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}, Accuracy: {checkpoint_metadata.performance_metrics.get('accuracy', 0.0):.4f}")
else:
# Create new transformer if no checkpoint available
logger.info("No transformer checkpoint found, creating new model")
config = TradingTransformerConfig(
d_model=512, # Optimized for 46M parameters
n_heads=8, # Optimized
n_layers=8, # Optimized
seq_len=100, # Optimized
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True,
use_deep_attention=True,
use_residual_connections=True,
use_layer_norm_variants=True
)
transformer_model, transformer_trainer = create_trading_transformer(config)
# Store in orchestrator
if self.orchestrator:
self.orchestrator.primary_transformer = transformer_model
self.orchestrator.primary_transformer_trainer = transformer_trainer
logger.info("Created new advanced transformer model for training")
except Exception as e:
logger.error(f"Error loading transformer checkpoint: {e}")
# Fallback to creating new model
config = TradingTransformerConfig(
d_model=512, # Optimized for 46M parameters
n_heads=8, # Optimized
n_layers=8, # Optimized
seq_len=100, # Optimized
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True,
use_deep_attention=True,
use_residual_connections=True,
use_layer_norm_variants=True
)
transformer_model, transformer_trainer = create_trading_transformer(config)
# Store in orchestrator
if self.orchestrator:
self.orchestrator.primary_transformer = transformer_model
self.orchestrator.primary_transformer_trainer = transformer_trainer
logger.info("Created new advanced transformer model for training (fallback)")
# Prepare training data from market data
training_samples = []
@ -4847,10 +5310,67 @@ class CleanTradingDashboard:
self.training_performance_metrics['transformer']['total_calls'] += 1
self.training_performance_metrics['transformer']['frequency'] = len(training_samples)
# Save checkpoint periodically
# Save checkpoint periodically with proper checkpoint management
if transformer_trainer.training_history['train_loss']:
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
transformer_trainer.save_model(checkpoint_path)
try:
from utils.checkpoint_manager import save_checkpoint
# Prepare checkpoint data
checkpoint_data = {
'model_state_dict': transformer_model.state_dict(),
'training_history': transformer_trainer.training_history,
'training_samples': len(training_samples),
'config': {
'd_model': transformer_model.config.d_model,
'n_heads': transformer_model.config.n_heads,
'n_layers': transformer_model.config.n_layers,
'seq_len': transformer_model.config.seq_len
}
}
performance_metrics = {
'loss': training_metrics['total_loss'],
'accuracy': training_metrics['accuracy'],
'training_samples': len(training_samples),
'model_parameters': sum(p.numel() for p in transformer_model.parameters())
}
metadata = save_checkpoint(
model=checkpoint_data,
model_name="transformer",
model_type="transformer",
performance_metrics=performance_metrics,
training_metadata={
'training_iterations': len(transformer_trainer.training_history['train_loss']),
'last_training_time': datetime.now().isoformat()
}
)
if metadata:
logger.info(f"TRANSFORMER: Checkpoint saved successfully: {metadata.checkpoint_id}")
# Update orchestrator with checkpoint info
if self.orchestrator:
if not hasattr(self.orchestrator, 'transformer_checkpoint_info'):
self.orchestrator.transformer_checkpoint_info = {}
self.orchestrator.transformer_checkpoint_info = {
'checkpoint_id': metadata.checkpoint_id,
'checkpoint_path': metadata.checkpoint_path,
'performance_score': metadata.performance_score,
'created_at': metadata.created_at.isoformat(),
'loss': training_metrics['total_loss'],
'accuracy': training_metrics['accuracy']
}
except Exception as e:
logger.error(f"Error saving transformer checkpoint: {e}")
# Fallback to direct save
try:
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
transformer_trainer.save_model(checkpoint_path)
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
except Exception as fallback_error:
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")