show each model's prediction (last inference) and store T model checkpoint
This commit is contained in:
@ -271,15 +271,15 @@
|
||||
],
|
||||
"decision": [
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_031158",
|
||||
"checkpoint_id": "decision_20250702_083032",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_031158.pt",
|
||||
"created_at": "2025-07-02T03:11:58.134610",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_083032.pt",
|
||||
"created_at": "2025-07-02T08:30:32.225869",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 101.79191339107027,
|
||||
"performance_score": 102.79972716525019,
|
||||
"accuracy": null,
|
||||
"loss": 8.087262915050057e-05,
|
||||
"loss": 2.7283549419721e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -291,15 +291,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_031158",
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_031158.pt",
|
||||
"created_at": "2025-07-02T03:11:58.418736",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.899383",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 101.78996565336683,
|
||||
"performance_score": 102.7997148991013,
|
||||
"accuracy": null,
|
||||
"loss": 0.00010035353615320573,
|
||||
"loss": 2.8510171153430164e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -311,15 +311,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_031157",
|
||||
"checkpoint_id": "decision_20250702_082924",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_031157.pt",
|
||||
"created_at": "2025-07-02T03:11:57.126366",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082924.pt",
|
||||
"created_at": "2025-07-02T08:29:24.538886",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 101.78856752244535,
|
||||
"performance_score": 102.79971291710027,
|
||||
"accuracy": null,
|
||||
"loss": 0.00011433784719530295,
|
||||
"loss": 2.8708372390440218e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -331,15 +331,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_031157",
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_031157.pt",
|
||||
"created_at": "2025-07-02T03:11:57.884663",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.218718",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 101.78849664377086,
|
||||
"performance_score": 102.79971274601752,
|
||||
"accuracy": null,
|
||||
"loss": 0.00011504679653424116,
|
||||
"loss": 2.87254807635711e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -351,15 +351,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_031156",
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_031156.pt",
|
||||
"created_at": "2025-07-02T03:11:56.934135",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.332228",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 101.7830878040414,
|
||||
"performance_score": 102.79971263447665,
|
||||
"accuracy": null,
|
||||
"loss": 0.00016915056666120008,
|
||||
"loss": 2.873663491419011e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
|
@ -102,7 +102,7 @@ class CleanTradingDashboard:
|
||||
enhanced_rl_training=True,
|
||||
model_registry={}
|
||||
)
|
||||
logger.info("Using unified Trading Orchestrator with full ML capabilities")
|
||||
logger.debug("Using unified Trading Orchestrator with full ML capabilities")
|
||||
else:
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
@ -125,8 +125,8 @@ class CleanTradingDashboard:
|
||||
callback=self._handle_unified_stream_data,
|
||||
data_types=['ticks', 'ohlcv', 'training_data', 'ui_data']
|
||||
)
|
||||
logger.info(f"Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}")
|
||||
logger.info("Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)")
|
||||
logger.debug(f"Universal Data Stream initialized with consumer ID: {self.stream_consumer_id}")
|
||||
logger.debug("Subscribed to Universal 5 Timeseries: ETH(ticks,1m,1h,1d) + BTC(ticks)")
|
||||
else:
|
||||
self.unified_stream = None
|
||||
self.stream_consumer_id = None
|
||||
@ -205,7 +205,7 @@ class CleanTradingDashboard:
|
||||
# Start Universal Data Stream
|
||||
if self.unified_stream:
|
||||
# threading.Thread(target=self._start_unified_stream, daemon=True).start() # Temporarily disabled
|
||||
logger.info("Universal Data Stream starting...")
|
||||
logger.debug("Universal Data Stream starting...")
|
||||
|
||||
# Initialize COB integration with high-frequency data handling
|
||||
self._initialize_cob_integration()
|
||||
@ -216,7 +216,7 @@ class CleanTradingDashboard:
|
||||
# Start training sessions if models are showing FRESH status
|
||||
threading.Thread(target=self._delayed_training_check, daemon=True).start()
|
||||
|
||||
logger.info("Clean Trading Dashboard initialized with HIGH-FREQUENCY COB integration and signal generation")
|
||||
logger.debug("Clean Trading Dashboard initialized with HIGH-FREQUENCY COB integration and signal generation")
|
||||
|
||||
def _handle_unified_stream_data(self, data):
|
||||
"""Placeholder for unified stream data handling."""
|
||||
@ -444,7 +444,7 @@ class CleanTradingDashboard:
|
||||
# Log COB signal activity
|
||||
cob_signals = [d for d in filtered_decisions if d.get('type') == 'cob_liquidity_imbalance']
|
||||
if cob_signals:
|
||||
logger.info(f"COB signals active: {len(cob_signals)} recent COB signals")
|
||||
logger.debug(f"COB signals active: {len(cob_signals)} recent COB signals")
|
||||
|
||||
return self.component_manager.format_trading_signals(filtered_decisions)
|
||||
except Exception as e:
|
||||
@ -1926,7 +1926,8 @@ class CleanTradingDashboard:
|
||||
'decision': {'initial_loss': 0.2980, 'current_loss': 0.0089, 'best_loss': 0.0065, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
# Get CNN predictions if available
|
||||
# Get latest predictions from all models
|
||||
latest_predictions = self._get_latest_model_predictions()
|
||||
cnn_prediction = self._get_cnn_pivot_prediction()
|
||||
|
||||
# Helper function to safely calculate improvement percentage
|
||||
@ -1997,21 +1998,31 @@ class CleanTradingDashboard:
|
||||
dqn_active = dqn_checkpoint_loaded and dqn_inference_enabled and dqn_model_available
|
||||
dqn_prediction_count = len(self.recent_decisions) if signal_generation_active else 0
|
||||
|
||||
# Get latest DQN prediction
|
||||
dqn_latest = latest_predictions.get('dqn', {})
|
||||
if dqn_latest:
|
||||
last_action = dqn_latest.get('action', 'NONE')
|
||||
last_confidence = dqn_latest.get('confidence', 0.72)
|
||||
last_timestamp = dqn_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
|
||||
else:
|
||||
if signal_generation_active and len(self.recent_decisions) > 0:
|
||||
recent_signal = self.recent_decisions[-1]
|
||||
last_action = self._get_signal_attribute(recent_signal, 'action', 'SIGNAL_GEN')
|
||||
last_confidence = self._get_signal_attribute(recent_signal, 'confidence', 0.72)
|
||||
last_timestamp = datetime.now().strftime('%H:%M:%S')
|
||||
else:
|
||||
last_action = dqn_training_status['status']
|
||||
last_confidence = 0.68
|
||||
last_timestamp = datetime.now().strftime('%H:%M:%S')
|
||||
|
||||
dqn_model_info = {
|
||||
'active': dqn_active,
|
||||
'parameters': 5000000, # ~5M params for DQN
|
||||
'last_prediction': {
|
||||
'timestamp': datetime.now().strftime('%H:%M:%S'),
|
||||
'timestamp': last_timestamp,
|
||||
'action': last_action,
|
||||
'confidence': last_confidence
|
||||
'confidence': last_confidence,
|
||||
'type': dqn_latest.get('type', 'dqn_signal') if dqn_latest else 'dqn_signal'
|
||||
},
|
||||
# FIXED: Get REAL loss values from orchestrator model, not placeholders
|
||||
'loss_5ma': self._get_real_model_loss('dqn'),
|
||||
@ -2060,13 +2071,28 @@ class CleanTradingDashboard:
|
||||
cnn_timing = get_model_timing_info('CNN')
|
||||
cnn_active = True
|
||||
|
||||
# Get latest CNN prediction
|
||||
cnn_latest = latest_predictions.get('cnn', {})
|
||||
if cnn_latest:
|
||||
cnn_action = cnn_latest.get('action', 'PATTERN_ANALYSIS')
|
||||
cnn_confidence = cnn_latest.get('confidence', 0.68)
|
||||
cnn_timestamp = cnn_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
|
||||
cnn_predicted_price = cnn_latest.get('predicted_price', 0)
|
||||
else:
|
||||
cnn_action = 'PATTERN_ANALYSIS'
|
||||
cnn_confidence = 0.68
|
||||
cnn_timestamp = datetime.now().strftime('%H:%M:%S')
|
||||
cnn_predicted_price = 0
|
||||
|
||||
cnn_model_info = {
|
||||
'active': cnn_active,
|
||||
'parameters': 50000000, # ~50M params
|
||||
'last_prediction': {
|
||||
'timestamp': datetime.now().strftime('%H:%M:%S'),
|
||||
'action': 'PATTERN_ANALYSIS',
|
||||
'confidence': 0.68
|
||||
'timestamp': cnn_timestamp,
|
||||
'action': cnn_action,
|
||||
'confidence': cnn_confidence,
|
||||
'predicted_price': cnn_predicted_price,
|
||||
'type': cnn_latest.get('type', 'cnn_pivot') if cnn_latest else 'cnn_pivot'
|
||||
},
|
||||
'loss_5ma': cnn_state.get('current_loss', 0.0187),
|
||||
'initial_loss': cnn_state.get('initial_loss', 0.4120),
|
||||
@ -2103,6 +2129,68 @@ class CleanTradingDashboard:
|
||||
transformer_timing = get_model_timing_info('TRANSFORMER')
|
||||
transformer_active = True
|
||||
|
||||
# Get transformer checkpoint info if available
|
||||
transformer_checkpoint_info = {}
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
|
||||
transformer_checkpoint_info = self.orchestrator.transformer_checkpoint_info
|
||||
|
||||
# Get latest transformer prediction
|
||||
transformer_latest = latest_predictions.get('transformer', {})
|
||||
if transformer_latest:
|
||||
transformer_action = transformer_latest.get('action', 'PRICE_PREDICTION')
|
||||
transformer_confidence = transformer_latest.get('confidence', 0.75)
|
||||
transformer_timestamp = transformer_latest.get('timestamp', datetime.now()).strftime('%H:%M:%S')
|
||||
transformer_predicted_price = transformer_latest.get('predicted_price', 0)
|
||||
transformer_price_change = transformer_latest.get('price_change', 0)
|
||||
else:
|
||||
transformer_action = 'PRICE_PREDICTION'
|
||||
transformer_confidence = 0.75
|
||||
transformer_timestamp = datetime.now().strftime('%H:%M:%S')
|
||||
transformer_predicted_price = 0
|
||||
transformer_price_change = 0
|
||||
|
||||
transformer_last_prediction = {
|
||||
'timestamp': transformer_timestamp,
|
||||
'action': transformer_action,
|
||||
'confidence': transformer_confidence,
|
||||
'predicted_price': transformer_predicted_price,
|
||||
'price_change': transformer_price_change,
|
||||
'type': transformer_latest.get('type', 'transformer_prediction') if transformer_latest else 'transformer_prediction'
|
||||
}
|
||||
|
||||
transformer_model_info = {
|
||||
'active': transformer_active,
|
||||
'parameters': 46000000, # ~46M params for transformer
|
||||
'last_prediction': transformer_last_prediction,
|
||||
'loss_5ma': transformer_state.get('current_loss', 0.0123),
|
||||
'initial_loss': transformer_state.get('initial_loss', 0.2980),
|
||||
'best_loss': transformer_state.get('best_loss', 0.0089),
|
||||
'improvement': safe_improvement_calc(
|
||||
transformer_state.get('initial_loss', 0.2980),
|
||||
transformer_state.get('current_loss', 0.0123),
|
||||
95.9 # Default improvement percentage
|
||||
),
|
||||
'checkpoint_loaded': bool(transformer_checkpoint_info),
|
||||
'model_type': 'TRANSFORMER',
|
||||
'description': 'Advanced Transformer (Price Prediction)',
|
||||
'checkpoint_info': {
|
||||
'filename': transformer_checkpoint_info.get('checkpoint_id', 'none'),
|
||||
'created_at': transformer_checkpoint_info.get('created_at', 'Unknown'),
|
||||
'performance_score': transformer_checkpoint_info.get('performance_score', 0.0),
|
||||
'loss': transformer_checkpoint_info.get('loss', 0.0),
|
||||
'accuracy': transformer_checkpoint_info.get('accuracy', 0.0)
|
||||
},
|
||||
'timing': {
|
||||
'last_inference': transformer_timing['last_inference'].strftime('%H:%M:%S') if transformer_timing['last_inference'] else 'None',
|
||||
'last_training': transformer_timing['last_training'].strftime('%H:%M:%S') if transformer_timing['last_training'] else 'None',
|
||||
'inferences_per_second': f"{transformer_timing['inferences_per_second']:.2f}",
|
||||
'predictions_24h': transformer_timing['prediction_count_24h']
|
||||
},
|
||||
'performance': self.get_model_performance_metrics().get('transformer', {})
|
||||
}
|
||||
loaded_models['transformer'] = transformer_model_info
|
||||
transformer_active = True
|
||||
|
||||
# Check if transformer model is available
|
||||
transformer_model_available = self.orchestrator and hasattr(self.orchestrator, 'primary_transformer')
|
||||
|
||||
@ -2495,14 +2583,76 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting CNN pivot prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_latest_model_predictions(self) -> Dict[str, Dict]:
|
||||
"""Get the latest predictions from each model"""
|
||||
try:
|
||||
latest_predictions = {}
|
||||
|
||||
# Get latest DQN prediction
|
||||
if self.recent_decisions:
|
||||
latest_dqn = self.recent_decisions[-1]
|
||||
latest_predictions['dqn'] = {
|
||||
'timestamp': latest_dqn.get('timestamp', datetime.now()),
|
||||
'action': latest_dqn.get('action', 'NONE'),
|
||||
'confidence': latest_dqn.get('confidence', 0),
|
||||
'type': latest_dqn.get('type', 'dqn_signal')
|
||||
}
|
||||
|
||||
# Get latest CNN prediction
|
||||
cnn_prediction = self._get_cnn_pivot_prediction()
|
||||
if cnn_prediction:
|
||||
latest_predictions['cnn'] = {
|
||||
'timestamp': datetime.now(),
|
||||
'action': cnn_prediction.get('pivot_type', 'PATTERN_ANALYSIS'),
|
||||
'confidence': cnn_prediction.get('confidence', 0),
|
||||
'predicted_price': cnn_prediction.get('predicted_price', 0),
|
||||
'type': 'cnn_pivot'
|
||||
}
|
||||
|
||||
# Get latest Transformer prediction
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'get_latest_transformer_prediction'):
|
||||
transformer_pred = self.orchestrator.get_latest_transformer_prediction()
|
||||
if transformer_pred:
|
||||
latest_predictions['transformer'] = {
|
||||
'timestamp': transformer_pred.get('timestamp', datetime.now()),
|
||||
'action': transformer_pred.get('action', 'PRICE_PREDICTION'),
|
||||
'confidence': transformer_pred.get('confidence', 0),
|
||||
'predicted_price': transformer_pred.get('predicted_price', 0),
|
||||
'price_change': transformer_pred.get('price_change', 0),
|
||||
'type': 'transformer_prediction'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting transformer prediction: {e}")
|
||||
|
||||
# Get latest COB RL prediction
|
||||
if hasattr(self, 'cob_data_history') and 'ETH/USDT' in self.cob_data_history:
|
||||
cob_history = self.cob_data_history['ETH/USDT']
|
||||
if cob_history:
|
||||
latest_cob = cob_history[-1]
|
||||
latest_predictions['cob_rl'] = {
|
||||
'timestamp': datetime.fromtimestamp(latest_cob.get('timestamp', time.time())),
|
||||
'action': 'COB_ANALYSIS',
|
||||
'confidence': abs(latest_cob.get('stats', {}).get('imbalance', 0)) * 100,
|
||||
'imbalance': latest_cob.get('stats', {}).get('imbalance', 0),
|
||||
'type': 'cob_imbalance'
|
||||
}
|
||||
|
||||
return latest_predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting latest model predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _start_signal_generation_loop(self):
|
||||
"""Start continuous signal generation loop"""
|
||||
try:
|
||||
def signal_worker():
|
||||
logger.info("Starting continuous signal generation loop")
|
||||
logger.debug("Starting continuous signal generation loop")
|
||||
|
||||
# Unified orchestrator with full ML pipeline and decision-making model
|
||||
logger.info("Using unified ML pipeline: Data Bus -> Models -> Decision Model -> Trading Signals")
|
||||
logger.debug("Using unified ML pipeline: Data Bus -> Models -> Decision Model -> Trading Signals")
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -2535,7 +2685,7 @@ class CleanTradingDashboard:
|
||||
# Start signal generation thread
|
||||
signal_thread = threading.Thread(target=signal_worker, daemon=True)
|
||||
signal_thread.start()
|
||||
logger.info("Signal generation loop started")
|
||||
logger.debug("Signal generation loop started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting signal generation loop: {e}")
|
||||
@ -3543,7 +3693,7 @@ class CleanTradingDashboard:
|
||||
if not hasattr(self.orchestrator, 'recent_cnn_predictions'):
|
||||
self.orchestrator.recent_cnn_predictions = {}
|
||||
|
||||
logger.info("Enhanced training system initialized for model predictions")
|
||||
logger.debug("Enhanced training system initialized for model predictions")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Enhanced training system not available - using mock predictions")
|
||||
@ -3555,7 +3705,7 @@ class CleanTradingDashboard:
|
||||
def _initialize_cob_integration(self):
|
||||
"""Initialize simple COB integration that works without async event loops"""
|
||||
try:
|
||||
logger.info("Initializing simple COB integration for model feeding")
|
||||
logger.debug("Initializing simple COB integration for model feeding")
|
||||
|
||||
# Initialize COB data storage
|
||||
self.cob_data_history = {
|
||||
@ -3574,7 +3724,7 @@ class CleanTradingDashboard:
|
||||
# Start simple COB data collection
|
||||
self._start_simple_cob_collection()
|
||||
|
||||
logger.info("Simple COB integration initialized successfully")
|
||||
logger.debug("Simple COB integration initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing COB integration: {e}")
|
||||
@ -3790,39 +3940,271 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error generating COB signal for {symbol}: {e}")
|
||||
|
||||
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
|
||||
"""Feed COB data to models for training and inference"""
|
||||
"""Feed COB data to ALL models for training and inference - Enhanced integration"""
|
||||
try:
|
||||
# Calculate cumulative imbalance for model feeding
|
||||
cumulative_imbalance = self._calculate_cumulative_imbalance(symbol)
|
||||
|
||||
# Create 15-second history for model feeding
|
||||
history_data = {
|
||||
# Create comprehensive COB data package for all models
|
||||
cob_data_package = {
|
||||
'symbol': symbol,
|
||||
'current_snapshot': cob_snapshot,
|
||||
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
|
||||
'bucketed_data': self.cob_bucketed_data[symbol],
|
||||
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||
'timestamp': cob_snapshot['timestamp']
|
||||
'cumulative_imbalance': cumulative_imbalance,
|
||||
'timestamp': cob_snapshot['timestamp'],
|
||||
'stats': cob_snapshot.get('stats', {}),
|
||||
'bids': cob_snapshot.get('bids', []),
|
||||
'asks': cob_snapshot.get('asks', []),
|
||||
'mid_price': cob_snapshot.get('mid_price', 0),
|
||||
'spread': cob_snapshot.get('spread', 0),
|
||||
'liquidity_imbalance': cob_snapshot.get('stats', {}).get('imbalance', 0)
|
||||
}
|
||||
|
||||
# Feed to orchestrator models if available
|
||||
# 1. Feed to orchestrator models (if available)
|
||||
if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
|
||||
try:
|
||||
self.orchestrator._on_cob_dashboard_data(symbol, history_data)
|
||||
logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}")
|
||||
self.orchestrator._on_cob_dashboard_data(symbol, cob_data_package)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to orchestrator: {e}")
|
||||
|
||||
# Store for training system
|
||||
# 2. Feed to DQN model specifically
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
# Create DQN-specific COB features
|
||||
dqn_cob_features = self._create_dqn_cob_features(symbol, cob_data_package)
|
||||
if hasattr(self.orchestrator.rl_agent, 'update_cob_features'):
|
||||
self.orchestrator.rl_agent.update_cob_features(symbol, dqn_cob_features)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to DQN: {e}")
|
||||
|
||||
# 3. Feed to CNN model specifically
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
try:
|
||||
# Create CNN-specific COB features
|
||||
cnn_cob_features = self._create_cnn_cob_features(symbol, cob_data_package)
|
||||
if hasattr(self.orchestrator.cnn_model, 'update_cob_features'):
|
||||
self.orchestrator.cnn_model.update_cob_features(symbol, cnn_cob_features)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to CNN: {e}")
|
||||
|
||||
# 4. Feed to Transformer model specifically
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||
try:
|
||||
# Create Transformer-specific COB features
|
||||
transformer_cob_features = self._create_transformer_cob_features(symbol, cob_data_package)
|
||||
if hasattr(self.orchestrator.primary_transformer, 'update_cob_features'):
|
||||
self.orchestrator.primary_transformer.update_cob_features(symbol, transformer_cob_features)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to Transformer: {e}")
|
||||
|
||||
# 5. Feed to COB RL model specifically
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
try:
|
||||
# Create COB RL-specific features
|
||||
cob_rl_features = self._create_cob_rl_features(symbol, cob_data_package)
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'update_cob_features'):
|
||||
self.orchestrator.cob_rl_agent.update_cob_features(symbol, cob_rl_features)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to COB RL: {e}")
|
||||
|
||||
# 6. Store for training system
|
||||
if hasattr(self, 'training_system') and self.training_system:
|
||||
if hasattr(self.training_system, 'real_time_data'):
|
||||
self.training_system.real_time_data['cob_snapshots'].append(history_data)
|
||||
self.training_system.real_time_data['cob_snapshots'].append(cob_data_package)
|
||||
|
||||
logger.debug(f"COB data fed to models for {symbol}")
|
||||
# 7. Update latest COB features for all models
|
||||
if not hasattr(self, 'latest_cob_features'):
|
||||
self.latest_cob_features = {}
|
||||
self.latest_cob_features[symbol] = cob_data_package
|
||||
|
||||
# 8. Store in model-specific COB memory
|
||||
if not hasattr(self, 'model_cob_memory'):
|
||||
self.model_cob_memory = {}
|
||||
if symbol not in self.model_cob_memory:
|
||||
self.model_cob_memory[symbol] = {}
|
||||
|
||||
# Store for each model type
|
||||
for model_type in ['dqn', 'cnn', 'transformer', 'cob_rl']:
|
||||
if model_type not in self.model_cob_memory[symbol]:
|
||||
self.model_cob_memory[symbol][model_type] = []
|
||||
self.model_cob_memory[symbol][model_type].append(cob_data_package)
|
||||
|
||||
# Keep only last 100 snapshots per model
|
||||
if len(self.model_cob_memory[symbol][model_type]) > 100:
|
||||
self.model_cob_memory[symbol][model_type] = self.model_cob_memory[symbol][model_type][-100:]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to models: {e}")
|
||||
|
||||
def _create_dqn_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
|
||||
"""Create COB features specifically for DQN model"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Basic COB features
|
||||
features.append(cob_data.get('mid_price', 0) / 10000) # Normalized price
|
||||
features.append(cob_data.get('spread', 0) / 100) # Normalized spread
|
||||
features.append(cob_data.get('liquidity_imbalance', 0)) # Raw imbalance
|
||||
|
||||
# Cumulative imbalance features
|
||||
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
|
||||
features.extend([
|
||||
cumulative_imbalance.get('1s', 0.0),
|
||||
cumulative_imbalance.get('5s', 0.0),
|
||||
cumulative_imbalance.get('15s', 0.0),
|
||||
cumulative_imbalance.get('60s', 0.0)
|
||||
])
|
||||
|
||||
# Order book depth features
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Top 5 levels for each side
|
||||
for i in range(5):
|
||||
if i < len(bids):
|
||||
features.append(bids[i].get('price', 0) / 10000)
|
||||
features.append(bids[i].get('size', 0) / 1000000)
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
for i in range(5):
|
||||
if i < len(asks):
|
||||
features.append(asks[i].get('price', 0) / 10000)
|
||||
features.append(asks[i].get('size', 0) / 1000000)
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating DQN COB features: {e}")
|
||||
return [0.0] * 20 # Default feature vector
|
||||
|
||||
def _create_cnn_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
|
||||
"""Create COB features specifically for CNN model"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# CNN focuses on pattern recognition - use more granular features
|
||||
features.append(cob_data.get('mid_price', 0) / 10000)
|
||||
features.append(cob_data.get('liquidity_imbalance', 0))
|
||||
|
||||
# Order book imbalance at different levels
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Calculate imbalance at different price levels
|
||||
for level in [1, 2, 3, 5, 10]:
|
||||
bid_vol = sum(bid.get('size', 0) for bid in bids[:level])
|
||||
ask_vol = sum(ask.get('size', 0) for ask in asks[:level])
|
||||
total_vol = bid_vol + ask_vol
|
||||
if total_vol > 0:
|
||||
imbalance = (bid_vol - ask_vol) / total_vol
|
||||
else:
|
||||
imbalance = 0.0
|
||||
features.append(imbalance)
|
||||
|
||||
# Cumulative imbalance features
|
||||
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
|
||||
features.extend([
|
||||
cumulative_imbalance.get('1s', 0.0),
|
||||
cumulative_imbalance.get('5s', 0.0),
|
||||
cumulative_imbalance.get('15s', 0.0)
|
||||
])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating CNN COB features: {e}")
|
||||
return [0.0] * 10 # Default feature vector
|
||||
|
||||
def _create_transformer_cob_features(self, symbol: str, cob_data: dict) -> List[float]:
|
||||
"""Create COB features specifically for Transformer model"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Transformer can handle more complex features
|
||||
features.append(cob_data.get('mid_price', 0) / 10000)
|
||||
features.append(cob_data.get('spread', 0) / 100)
|
||||
features.append(cob_data.get('liquidity_imbalance', 0))
|
||||
|
||||
# Order book features
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Top 10 levels for each side (more granular for transformer)
|
||||
for i in range(10):
|
||||
if i < len(bids):
|
||||
features.append(bids[i].get('price', 0) / 10000)
|
||||
features.append(bids[i].get('size', 0) / 1000000)
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
for i in range(10):
|
||||
if i < len(asks):
|
||||
features.append(asks[i].get('price', 0) / 10000)
|
||||
features.append(asks[i].get('size', 0) / 1000000)
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Cumulative imbalance features
|
||||
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
|
||||
features.extend([
|
||||
cumulative_imbalance.get('1s', 0.0),
|
||||
cumulative_imbalance.get('5s', 0.0),
|
||||
cumulative_imbalance.get('15s', 0.0),
|
||||
cumulative_imbalance.get('60s', 0.0)
|
||||
])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating Transformer COB features: {e}")
|
||||
return [0.0] * 50 # Default feature vector
|
||||
|
||||
def _create_cob_rl_features(self, symbol: str, cob_data: dict) -> List[float]:
|
||||
"""Create COB features specifically for COB RL model"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# COB RL focuses on order book dynamics
|
||||
features.append(cob_data.get('mid_price', 0) / 10000)
|
||||
features.append(cob_data.get('liquidity_imbalance', 0))
|
||||
|
||||
# Order book pressure indicators
|
||||
bids = cob_data.get('bids', [])
|
||||
asks = cob_data.get('asks', [])
|
||||
|
||||
# Calculate pressure at different levels
|
||||
for level in [1, 2, 3, 5]:
|
||||
bid_pressure = sum(bid.get('size', 0) for bid in bids[:level])
|
||||
ask_pressure = sum(ask.get('size', 0) for ask in asks[:level])
|
||||
features.append(bid_pressure / 1000000) # Normalized
|
||||
features.append(ask_pressure / 1000000) # Normalized
|
||||
|
||||
# Pressure ratio
|
||||
if ask_pressure > 0:
|
||||
pressure_ratio = bid_pressure / ask_pressure
|
||||
else:
|
||||
pressure_ratio = 1.0
|
||||
features.append(pressure_ratio)
|
||||
|
||||
# Cumulative imbalance features
|
||||
cumulative_imbalance = cob_data.get('cumulative_imbalance', {})
|
||||
features.extend([
|
||||
cumulative_imbalance.get('1s', 0.0),
|
||||
cumulative_imbalance.get('5s', 0.0),
|
||||
cumulative_imbalance.get('15s', 0.0),
|
||||
cumulative_imbalance.get('60s', 0.0)
|
||||
])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating COB RL features: {e}")
|
||||
return [0.0] * 20 # Default feature vector
|
||||
|
||||
def get_cob_data_summary(self) -> dict:
|
||||
"""Get COB data summary for dashboard display"""
|
||||
try:
|
||||
@ -4693,8 +5075,63 @@ class CleanTradingDashboard:
|
||||
if hasattr(self.orchestrator, 'primary_transformer_trainer'):
|
||||
transformer_trainer = self.orchestrator.primary_transformer_trainer
|
||||
|
||||
# Create transformer if not exists
|
||||
# Try to load existing transformer checkpoint first
|
||||
if transformer_model is None or transformer_trainer is None:
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Try to load the best transformer checkpoint
|
||||
checkpoint_metadata = load_best_checkpoint("transformer", "transformer")
|
||||
|
||||
if checkpoint_metadata and checkpoint_metadata.checkpoint_path:
|
||||
logger.info(f"Loading existing transformer checkpoint: {checkpoint_metadata.checkpoint_id}")
|
||||
|
||||
# Load the checkpoint data
|
||||
checkpoint_data = torch.load(checkpoint_metadata.checkpoint_path, map_location='cpu')
|
||||
|
||||
# Recreate config from checkpoint
|
||||
config = TradingTransformerConfig(
|
||||
d_model=checkpoint_data.get('config', {}).get('d_model', 512),
|
||||
n_heads=checkpoint_data.get('config', {}).get('n_heads', 8),
|
||||
n_layers=checkpoint_data.get('config', {}).get('n_layers', 8),
|
||||
seq_len=checkpoint_data.get('config', {}).get('seq_len', 100),
|
||||
n_actions=3,
|
||||
use_multi_scale_attention=True,
|
||||
use_market_regime_detection=True,
|
||||
use_uncertainty_estimation=True,
|
||||
use_deep_attention=True,
|
||||
use_residual_connections=True,
|
||||
use_layer_norm_variants=True
|
||||
)
|
||||
|
||||
# Create model and trainer
|
||||
transformer_model, transformer_trainer = create_trading_transformer(config)
|
||||
|
||||
# Load state dict
|
||||
transformer_model.load_state_dict(checkpoint_data['model_state_dict'])
|
||||
|
||||
# Restore training history
|
||||
if 'training_history' in checkpoint_data:
|
||||
transformer_trainer.training_history = checkpoint_data['training_history']
|
||||
|
||||
# Store in orchestrator
|
||||
if self.orchestrator:
|
||||
self.orchestrator.primary_transformer = transformer_model
|
||||
self.orchestrator.primary_transformer_trainer = transformer_trainer
|
||||
self.orchestrator.transformer_checkpoint_info = {
|
||||
'checkpoint_id': checkpoint_metadata.checkpoint_id,
|
||||
'checkpoint_path': checkpoint_metadata.checkpoint_path,
|
||||
'performance_score': checkpoint_metadata.performance_score,
|
||||
'created_at': checkpoint_metadata.created_at.isoformat(),
|
||||
'loss': checkpoint_metadata.performance_metrics.get('loss', 0.0),
|
||||
'accuracy': checkpoint_metadata.performance_metrics.get('accuracy', 0.0)
|
||||
}
|
||||
|
||||
logger.info(f"TRANSFORMER: Loaded checkpoint successfully - Loss: {checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}, Accuracy: {checkpoint_metadata.performance_metrics.get('accuracy', 0.0):.4f}")
|
||||
|
||||
else:
|
||||
# Create new transformer if no checkpoint available
|
||||
logger.info("No transformer checkpoint found, creating new model")
|
||||
config = TradingTransformerConfig(
|
||||
d_model=512, # Optimized for 46M parameters
|
||||
n_heads=8, # Optimized
|
||||
@ -4718,6 +5155,32 @@ class CleanTradingDashboard:
|
||||
|
||||
logger.info("Created new advanced transformer model for training")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading transformer checkpoint: {e}")
|
||||
# Fallback to creating new model
|
||||
config = TradingTransformerConfig(
|
||||
d_model=512, # Optimized for 46M parameters
|
||||
n_heads=8, # Optimized
|
||||
n_layers=8, # Optimized
|
||||
seq_len=100, # Optimized
|
||||
n_actions=3,
|
||||
use_multi_scale_attention=True,
|
||||
use_market_regime_detection=True,
|
||||
use_uncertainty_estimation=True,
|
||||
use_deep_attention=True,
|
||||
use_residual_connections=True,
|
||||
use_layer_norm_variants=True
|
||||
)
|
||||
|
||||
transformer_model, transformer_trainer = create_trading_transformer(config)
|
||||
|
||||
# Store in orchestrator
|
||||
if self.orchestrator:
|
||||
self.orchestrator.primary_transformer = transformer_model
|
||||
self.orchestrator.primary_transformer_trainer = transformer_trainer
|
||||
|
||||
logger.info("Created new advanced transformer model for training (fallback)")
|
||||
|
||||
# Prepare training data from market data
|
||||
training_samples = []
|
||||
|
||||
@ -4847,10 +5310,67 @@ class CleanTradingDashboard:
|
||||
self.training_performance_metrics['transformer']['total_calls'] += 1
|
||||
self.training_performance_metrics['transformer']['frequency'] = len(training_samples)
|
||||
|
||||
# Save checkpoint periodically
|
||||
# Save checkpoint periodically with proper checkpoint management
|
||||
if transformer_trainer.training_history['train_loss']:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'model_state_dict': transformer_model.state_dict(),
|
||||
'training_history': transformer_trainer.training_history,
|
||||
'training_samples': len(training_samples),
|
||||
'config': {
|
||||
'd_model': transformer_model.config.d_model,
|
||||
'n_heads': transformer_model.config.n_heads,
|
||||
'n_layers': transformer_model.config.n_layers,
|
||||
'seq_len': transformer_model.config.seq_len
|
||||
}
|
||||
}
|
||||
|
||||
performance_metrics = {
|
||||
'loss': training_metrics['total_loss'],
|
||||
'accuracy': training_metrics['accuracy'],
|
||||
'training_samples': len(training_samples),
|
||||
'model_parameters': sum(p.numel() for p in transformer_model.parameters())
|
||||
}
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data,
|
||||
model_name="transformer",
|
||||
model_type="transformer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'training_iterations': len(transformer_trainer.training_history['train_loss']),
|
||||
'last_training_time': datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"TRANSFORMER: Checkpoint saved successfully: {metadata.checkpoint_id}")
|
||||
|
||||
# Update orchestrator with checkpoint info
|
||||
if self.orchestrator:
|
||||
if not hasattr(self.orchestrator, 'transformer_checkpoint_info'):
|
||||
self.orchestrator.transformer_checkpoint_info = {}
|
||||
self.orchestrator.transformer_checkpoint_info = {
|
||||
'checkpoint_id': metadata.checkpoint_id,
|
||||
'checkpoint_path': metadata.checkpoint_path,
|
||||
'performance_score': metadata.performance_score,
|
||||
'created_at': metadata.created_at.isoformat(),
|
||||
'loss': training_metrics['total_loss'],
|
||||
'accuracy': training_metrics['accuracy']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transformer checkpoint: {e}")
|
||||
# Fallback to direct save
|
||||
try:
|
||||
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
transformer_trainer.save_model(checkpoint_path)
|
||||
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
|
||||
|
||||
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")
|
||||
|
||||
|
@ -677,15 +677,31 @@ class DashboardComponentManager:
|
||||
|
||||
# Model metrics
|
||||
html.Div([
|
||||
# Last prediction
|
||||
# Last prediction with enhanced details
|
||||
html.Div([
|
||||
html.Span("Last: ", className="text-muted small"),
|
||||
html.Span(f"{pred_action}",
|
||||
className=f"small fw-bold {'text-success' if pred_action == 'BUY' else 'text-danger' if pred_action == 'SELL' else 'text-muted'}"),
|
||||
className=f"small fw-bold {'text-success' if pred_action == 'BUY' else 'text-danger' if pred_action == 'SELL' else 'text-warning' if 'PREDICTION' in pred_action else 'text-info'}"),
|
||||
html.Span(f" ({pred_confidence:.1f}%)", className="text-muted small"),
|
||||
html.Span(f" @ {pred_time}", className="text-muted small")
|
||||
], className="mb-1"),
|
||||
|
||||
# Additional prediction details if available
|
||||
*([
|
||||
html.Div([
|
||||
html.Span("Price: ", className="text-muted small"),
|
||||
html.Span(f"${last_prediction.get('predicted_price', 0):.2f}", className="text-warning small fw-bold")
|
||||
], className="mb-1")
|
||||
] if last_prediction.get('predicted_price', 0) > 0 else []),
|
||||
|
||||
*([
|
||||
html.Div([
|
||||
html.Span("Change: ", className="text-muted small"),
|
||||
html.Span(f"{last_prediction.get('price_change', 0):+.2f}%",
|
||||
className=f"small fw-bold {'text-success' if last_prediction.get('price_change', 0) > 0 else 'text-danger'}")
|
||||
], className="mb-1")
|
||||
] if last_prediction.get('price_change', 0) != 0 else []),
|
||||
|
||||
# Timing information (NEW)
|
||||
html.Div([
|
||||
html.Span("Timing: ", className="text-muted small"),
|
||||
|
Reference in New Issue
Block a user