diff --git a/enhanced_realtime_training.py b/enhanced_realtime_training.py
index 88dd23d..5cf281c 100644
--- a/enhanced_realtime_training.py
+++ b/enhanced_realtime_training.py
@@ -71,6 +71,34 @@ class EnhancedRealtimeTrainingSystem:
'validation': 0.0
}
+ # Model prediction tracking - NEW for dashboard visualization
+ self.recent_dqn_predictions = {
+ 'ETH/USDT': deque(maxlen=100),
+ 'BTC/USDT': deque(maxlen=100)
+ }
+ self.recent_cnn_predictions = {
+ 'ETH/USDT': deque(maxlen=50),
+ 'BTC/USDT': deque(maxlen=50)
+ }
+ self.prediction_accuracy_history = {
+ 'ETH/USDT': deque(maxlen=200),
+ 'BTC/USDT': deque(maxlen=200)
+ }
+
+ # FIXED: Forward-looking prediction system
+ self.pending_predictions = {
+ 'ETH/USDT': deque(maxlen=100), # Predictions waiting for validation
+ 'BTC/USDT': deque(maxlen=100)
+ }
+ self.last_prediction_time = {
+ 'ETH/USDT': 0,
+ 'BTC/USDT': 0
+ }
+ self.prediction_intervals = {
+ 'dqn': 30, # Make DQN prediction every 30 seconds
+ 'cnn': 60 # Make CNN prediction every 60 seconds
+ }
+
# Real-time data streams
self.real_time_data = {
'ticks': deque(maxlen=1000),
@@ -146,24 +174,27 @@ class EnhancedRealtimeTrainingSystem:
current_time = time.time()
self.training_iteration += 1
- # 1. DQN Training (every 5 seconds with enough data)
+ # 1. FORWARD-LOOKING PREDICTIONS - Generate real predictions for future validation
+ self.generate_forward_looking_predictions()
+
+ # 2. DQN Training (every 5 seconds with enough data)
if (current_time - self.last_training_times['dqn'] > self.training_config['dqn_training_interval']
and len(self.experience_buffer) >= self.training_config['min_training_samples']):
self._perform_enhanced_dqn_training()
self.last_training_times['dqn'] = current_time
- # 2. CNN Training (every 10 seconds)
+ # 3. CNN Training (every 10 seconds)
if (current_time - self.last_training_times['cnn'] > self.training_config['cnn_training_interval']
and len(self.real_time_data['ohlcv_1m']) >= 20):
self._perform_enhanced_cnn_training()
self.last_training_times['cnn'] = current_time
- # 3. Validation (every minute)
+ # 4. Validation (every minute)
if current_time - self.last_training_times['validation'] > self.training_config['validation_interval']:
self._perform_validation()
self.last_training_times['validation'] = current_time
- # 4. Adaptive learning rate adjustment
+ # 5. Adaptive learning rate adjustment
if self.training_iteration % 100 == 0:
self._adapt_learning_parameters()
@@ -911,6 +942,11 @@ class EnhancedRealtimeTrainingSystem:
'dqn_loss_count': len(self.performance_history['dqn_losses']),
'cnn_loss_count': len(self.performance_history['cnn_losses']),
'validation_count': len(self.performance_history['validation_scores'])
+ },
+ 'prediction_stats': {
+ 'dqn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_dqn_predictions.items()},
+ 'cnn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_cnn_predictions.items()},
+ 'accuracy_history': {symbol: len(history) for symbol, history in self.prediction_accuracy_history.items()}
}
}
@@ -927,4 +963,492 @@ class EnhancedRealtimeTrainingSystem:
except Exception as e:
logger.error(f"Error getting training statistics: {e}")
- return {'error': str(e)}
\ No newline at end of file
+ return {'error': str(e)}
+
+ def capture_dqn_prediction(self, symbol: str, state: np.ndarray, q_values: List[float], action: int, confidence: float, price: float):
+ """Capture DQN prediction for dashboard visualization"""
+ try:
+ prediction = {
+ 'timestamp': datetime.now(),
+ 'symbol': symbol,
+ 'state': state.tolist() if hasattr(state, 'tolist') else state,
+ 'q_values': q_values,
+ 'action': action, # 0=BUY, 1=SELL, 2=HOLD
+ 'confidence': confidence,
+ 'price': price
+ }
+
+ if symbol in self.recent_dqn_predictions:
+ self.recent_dqn_predictions[symbol].append(prediction)
+
+ logger.debug(f"DQN prediction captured: {symbol} action={action} confidence={confidence:.2f}")
+
+ except Exception as e:
+ logger.debug(f"Error capturing DQN prediction: {e}")
+
+ def capture_cnn_prediction(self, symbol: str, current_price: float, predicted_price: float, direction: int, confidence: float, features: Optional[np.ndarray] = None):
+ """Capture CNN prediction for dashboard visualization"""
+ try:
+ prediction = {
+ 'timestamp': datetime.now(),
+ 'symbol': symbol,
+ 'current_price': current_price,
+ 'predicted_price': predicted_price,
+ 'direction': direction, # 0=DOWN, 1=SAME, 2=UP
+ 'confidence': confidence,
+ 'features': features.tolist() if features is not None and hasattr(features, 'tolist') else None
+ }
+
+ if symbol in self.recent_cnn_predictions:
+ self.recent_cnn_predictions[symbol].append(prediction)
+
+ logger.debug(f"CNN prediction captured: {symbol} direction={direction} confidence={confidence:.2f}")
+
+ except Exception as e:
+ logger.debug(f"Error capturing CNN prediction: {e}")
+
+ def validate_prediction_accuracy(self, symbol: str, prediction_type: str, predicted_action: int, actual_price_change: float, confidence: float):
+ """Validate prediction accuracy and store results"""
+ try:
+ # Determine if prediction was correct
+ was_correct = False
+
+ if prediction_type == 'DQN':
+ # For DQN: BUY (0) should be followed by price increase, SELL (1) by decrease
+ if predicted_action == 0 and actual_price_change > 0.001: # BUY + price up
+ was_correct = True
+ elif predicted_action == 1 and actual_price_change < -0.001: # SELL + price down
+ was_correct = True
+ elif predicted_action == 2 and abs(actual_price_change) <= 0.001: # HOLD + no change
+ was_correct = True
+
+ elif prediction_type == 'CNN':
+ # For CNN: direction prediction accuracy
+ if predicted_action == 2 and actual_price_change > 0.001: # UP + price up
+ was_correct = True
+ elif predicted_action == 0 and actual_price_change < -0.001: # DOWN + price down
+ was_correct = True
+ elif predicted_action == 1 and abs(actual_price_change) <= 0.001: # SAME + no change
+ was_correct = True
+
+ # Calculate accuracy score based on confidence and correctness
+ accuracy_score = confidence if was_correct else (1.0 - confidence)
+
+ accuracy_data = {
+ 'timestamp': datetime.now(),
+ 'symbol': symbol,
+ 'prediction_type': prediction_type,
+ 'correct': was_correct,
+ 'accuracy_score': accuracy_score,
+ 'confidence': confidence,
+ 'actual_price_change': actual_price_change,
+ 'predicted_action': predicted_action
+ }
+
+ if symbol in self.prediction_accuracy_history:
+ self.prediction_accuracy_history[symbol].append(accuracy_data)
+
+ logger.debug(f"Prediction accuracy validated: {symbol} {prediction_type} correct={was_correct} score={accuracy_score:.2f}")
+
+ except Exception as e:
+ logger.debug(f"Error validating prediction accuracy: {e}")
+
+ def get_prediction_summary(self, symbol: str) -> Dict[str, Any]:
+ """Get prediction summary for a symbol"""
+ try:
+ summary = {
+ 'symbol': symbol,
+ 'dqn_predictions': len(self.recent_dqn_predictions.get(symbol, [])),
+ 'cnn_predictions': len(self.recent_cnn_predictions.get(symbol, [])),
+ 'accuracy_history': len(self.prediction_accuracy_history.get(symbol, [])),
+ 'pending_predictions': len(self.pending_predictions.get(symbol, []))
+ }
+
+ # Calculate accuracy statistics
+ if symbol in self.prediction_accuracy_history and self.prediction_accuracy_history[symbol]:
+ accuracy_data = list(self.prediction_accuracy_history[symbol])
+
+ total_predictions = len(accuracy_data)
+ correct_predictions = sum(1 for acc in accuracy_data if acc['correct'])
+
+ summary['total_predictions'] = total_predictions
+ summary['correct_predictions'] = correct_predictions
+ summary['accuracy_rate'] = correct_predictions / total_predictions if total_predictions > 0 else 0.0
+
+ # Calculate accuracy by prediction type
+ dqn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'DQN']
+ cnn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'CNN']
+
+ if dqn_accuracy_data:
+ dqn_correct = sum(1 for acc in dqn_accuracy_data if acc['correct'])
+ summary['dqn_accuracy_rate'] = dqn_correct / len(dqn_accuracy_data)
+ else:
+ summary['dqn_accuracy_rate'] = 0.0
+
+ if cnn_accuracy_data:
+ cnn_correct = sum(1 for acc in cnn_accuracy_data if acc['correct'])
+ summary['cnn_accuracy_rate'] = cnn_correct / len(cnn_accuracy_data)
+ else:
+ summary['cnn_accuracy_rate'] = 0.0
+
+ return summary
+
+ except Exception as e:
+ logger.error(f"Error getting prediction summary: {e}")
+ return {'error': str(e)}
+
+ def generate_forward_looking_predictions(self):
+ """Generate forward-looking predictions based on current market data"""
+ try:
+ current_time = time.time()
+
+ for symbol in ['ETH/USDT', 'BTC/USDT']:
+ # Check if it's time to make new predictions
+ time_since_last = current_time - self.last_prediction_time.get(symbol, 0)
+
+ # Generate DQN prediction every 30 seconds
+ if time_since_last >= self.prediction_intervals['dqn']:
+ self._generate_forward_dqn_prediction(symbol, current_time)
+
+ # Generate CNN prediction every 60 seconds
+ if time_since_last >= self.prediction_intervals['cnn']:
+ self._generate_forward_cnn_prediction(symbol, current_time)
+
+ # Validate pending predictions
+ self._validate_pending_predictions(symbol, current_time)
+
+ except Exception as e:
+ logger.error(f"Error generating forward-looking predictions: {e}")
+
+ def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
+ """Generate a DQN prediction for future price movement"""
+ try:
+ # Get current market state (only historical data)
+ current_state = self._build_comprehensive_state()
+ current_price = self._get_current_price_from_data(symbol)
+
+ if current_price is None:
+ return
+
+ # Use DQN model to predict action (if available)
+ if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
+ and self.orchestrator.rl_agent):
+
+ # Get Q-values from model
+ q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
+ if isinstance(q_values, tuple):
+ action, q_vals = q_values
+ q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
+ else:
+ action = q_values
+ q_values = [0.33, 0.33, 0.34] # Default uniform distribution
+
+ confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
+
+ else:
+ # Fallback to technical analysis-based prediction
+ action, q_values, confidence = self._technical_analysis_prediction(symbol)
+
+ # Create forward-looking prediction
+ prediction_time = datetime.now()
+ target_time = prediction_time + timedelta(minutes=5) # Predict 5 minutes ahead
+
+ prediction = {
+ 'id': f"dqn_{symbol}_{int(current_time)}",
+ 'type': 'DQN',
+ 'symbol': symbol,
+ 'prediction_time': prediction_time,
+ 'target_time': target_time,
+ 'current_price': current_price,
+ 'predicted_action': action,
+ 'q_values': q_values,
+ 'confidence': confidence,
+ 'state': current_state.tolist() if hasattr(current_state, 'tolist') else current_state,
+ 'validated': False
+ }
+
+ # Add to pending predictions for future validation
+ if symbol in self.pending_predictions:
+ self.pending_predictions[symbol].append(prediction)
+
+ # Add to recent predictions for display (only if confident enough)
+ if confidence > 0.4:
+ display_prediction = {
+ 'timestamp': prediction_time,
+ 'price': current_price,
+ 'action': action,
+ 'confidence': confidence,
+ 'q_values': q_values
+ }
+ if symbol in self.recent_dqn_predictions:
+ self.recent_dqn_predictions[symbol].append(display_prediction)
+
+ self.last_prediction_time[symbol] = current_time
+
+ logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
+
+ except Exception as e:
+ logger.error(f"Error generating forward DQN prediction: {e}")
+
+ def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
+ """Generate a CNN prediction for future price direction"""
+ try:
+ # Get current price and historical sequence (only past data)
+ current_price = self._get_current_price_from_data(symbol)
+ price_sequence = self._get_historical_price_sequence(symbol, periods=15)
+
+ if current_price is None or len(price_sequence) < 15:
+ return
+
+ # Use CNN model to predict direction (if available)
+ if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
+ and self.orchestrator.cnn_model):
+
+ # Prepare features for CNN
+ features = self._prepare_cnn_features(price_sequence)
+
+ try:
+ # Get prediction from CNN model
+ prediction_output = self.orchestrator.cnn_model.predict(features)
+ if hasattr(prediction_output, 'tolist'):
+ pred_probs = prediction_output.tolist()
+ else:
+ pred_probs = [0.33, 0.33, 0.34] # Default
+
+ direction = int(np.argmax(pred_probs)) # 0=DOWN, 1=SAME, 2=UP
+ confidence = max(pred_probs)
+
+ except Exception as e:
+ logger.debug(f"CNN model prediction failed: {e}")
+ direction, confidence = self._technical_direction_prediction(symbol)
+
+ else:
+ # Fallback to technical analysis
+ direction, confidence = self._technical_direction_prediction(symbol)
+
+ # Calculate predicted price based on direction
+ price_change_percent = self._estimate_price_change(direction, confidence)
+ predicted_price = current_price * (1 + price_change_percent)
+
+ # Create forward-looking prediction
+ prediction_time = datetime.now()
+ target_time = prediction_time + timedelta(minutes=10) # Predict 10 minutes ahead
+
+ prediction = {
+ 'id': f"cnn_{symbol}_{int(current_time)}",
+ 'type': 'CNN',
+ 'symbol': symbol,
+ 'prediction_time': prediction_time,
+ 'target_time': target_time,
+ 'current_price': current_price,
+ 'predicted_price': predicted_price,
+ 'direction': direction,
+ 'confidence': confidence,
+ 'features': features.tolist() if hasattr(features, 'tolist') else None,
+ 'validated': False
+ }
+
+ # Add to pending predictions for future validation
+ if symbol in self.pending_predictions:
+ self.pending_predictions[symbol].append(prediction)
+
+ # Add to recent predictions for display (only if confident enough)
+ if confidence > 0.5:
+ display_prediction = {
+ 'timestamp': prediction_time,
+ 'current_price': current_price,
+ 'predicted_price': predicted_price,
+ 'direction': direction,
+ 'confidence': confidence
+ }
+ if symbol in self.recent_cnn_predictions:
+ self.recent_cnn_predictions[symbol].append(display_prediction)
+
+ logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
+
+ except Exception as e:
+ logger.error(f"Error generating forward CNN prediction: {e}")
+
+ def _validate_pending_predictions(self, symbol: str, current_time: float):
+ """Validate pending predictions when their target time arrives"""
+ try:
+ if symbol not in self.pending_predictions:
+ return
+
+ current_datetime = datetime.now()
+ validated_predictions = []
+
+ # Check each pending prediction
+ for prediction in list(self.pending_predictions[symbol]):
+ target_time = prediction['target_time']
+
+ # If target time has passed, validate the prediction
+ if current_datetime >= target_time:
+ actual_price = self._get_current_price_from_data(symbol)
+
+ if actual_price is not None:
+ # Calculate actual price change
+ predicted_price = prediction.get('predicted_price', prediction['current_price'])
+ actual_change = (actual_price - prediction['current_price']) / prediction['current_price']
+ predicted_change = (predicted_price - prediction['current_price']) / prediction['current_price']
+
+ # Validate based on prediction type
+ if prediction['type'] == 'DQN':
+ was_correct = self._validate_dqn_prediction(prediction, actual_change)
+ else: # CNN
+ was_correct = self._validate_cnn_prediction(prediction, actual_change)
+
+ # Store accuracy result
+ accuracy_data = {
+ 'timestamp': current_datetime,
+ 'symbol': symbol,
+ 'prediction_type': prediction['type'],
+ 'correct': was_correct,
+ 'accuracy_score': prediction['confidence'] if was_correct else (1.0 - prediction['confidence']),
+ 'confidence': prediction['confidence'],
+ 'actual_price_change': actual_change,
+ 'predicted_action': prediction.get('predicted_action', prediction.get('direction', 0)),
+ 'actual_price': actual_price
+ }
+
+ if symbol in self.prediction_accuracy_history:
+ self.prediction_accuracy_history[symbol].append(accuracy_data)
+
+ validated_predictions.append(prediction['id'])
+
+ logger.info(f"Validated {prediction['type']} prediction: {symbol} correct={was_correct} confidence={prediction['confidence']:.2f}")
+
+ # Remove validated predictions from pending list
+ if validated_predictions:
+ self.pending_predictions[symbol] = deque([
+ p for p in self.pending_predictions[symbol]
+ if p['id'] not in validated_predictions
+ ], maxlen=100)
+
+ except Exception as e:
+ logger.error(f"Error validating pending predictions: {e}")
+
+ def _validate_dqn_prediction(self, prediction: Dict, actual_change: float) -> bool:
+ """Validate DQN action prediction"""
+ predicted_action = prediction['predicted_action']
+ threshold = 0.005 # 0.5% threshold for significant movement
+
+ if predicted_action == 0: # BUY prediction
+ return actual_change > threshold
+ elif predicted_action == 1: # SELL prediction
+ return actual_change < -threshold
+ else: # HOLD prediction
+ return abs(actual_change) <= threshold
+
+ def _validate_cnn_prediction(self, prediction: Dict, actual_change: float) -> bool:
+ """Validate CNN direction prediction"""
+ predicted_direction = prediction['direction']
+ threshold = 0.002 # 0.2% threshold for direction
+
+ if predicted_direction == 2: # UP prediction
+ return actual_change > threshold
+ elif predicted_direction == 0: # DOWN prediction
+ return actual_change < -threshold
+ else: # SAME prediction
+ return abs(actual_change) <= threshold
+
+ def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
+ """Get current price from real-time data streams"""
+ try:
+ if len(self.real_time_data['ohlcv_1m']) > 0:
+ return self.real_time_data['ohlcv_1m'][-1]['close']
+ return None
+ except Exception as e:
+ logger.debug(f"Error getting current price: {e}")
+ return None
+
+ def _get_historical_price_sequence(self, symbol: str, periods: int = 15) -> List[float]:
+ """Get historical price sequence for CNN features"""
+ try:
+ if len(self.real_time_data['ohlcv_1m']) >= periods:
+ return [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-periods:]]
+ return []
+ except Exception as e:
+ logger.debug(f"Error getting price sequence: {e}")
+ return []
+
+ def _technical_analysis_prediction(self, symbol: str) -> Tuple[int, List[float], float]:
+ """Fallback technical analysis prediction for DQN"""
+ try:
+ # Simple momentum-based prediction
+ if len(self.real_time_data['ohlcv_1m']) >= 5:
+ recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]]
+ momentum = (recent_prices[-1] - recent_prices[0]) / recent_prices[0]
+
+ if momentum > 0.01: # 1% upward momentum
+ return 0, [0.6, 0.2, 0.2], 0.6 # BUY
+ elif momentum < -0.01: # 1% downward momentum
+ return 1, [0.2, 0.6, 0.2], 0.6 # SELL
+ else:
+ return 2, [0.2, 0.2, 0.6], 0.6 # HOLD
+
+ return 2, [0.33, 0.33, 0.34], 0.33 # Default HOLD
+
+ except Exception as e:
+ logger.debug(f"Error in technical analysis prediction: {e}")
+ return 2, [0.33, 0.33, 0.34], 0.33
+
+ def _technical_direction_prediction(self, symbol: str) -> Tuple[int, float]:
+ """Fallback technical analysis for CNN direction"""
+ try:
+ if len(self.real_time_data['ohlcv_1m']) >= 3:
+ recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-3:]]
+ short_momentum = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2]
+
+ if short_momentum > 0.005: # 0.5% short-term up
+ return 2, 0.65 # UP
+ elif short_momentum < -0.005: # 0.5% short-term down
+ return 0, 0.65 # DOWN
+ else:
+ return 1, 0.55 # SAME
+
+ return 1, 0.5 # Default SAME
+
+ except Exception as e:
+ logger.debug(f"Error in technical direction prediction: {e}")
+ return 1, 0.5
+
+ def _prepare_cnn_features(self, price_sequence: List[float]) -> np.ndarray:
+ """Prepare features for CNN model"""
+ try:
+ # Normalize prices relative to first price
+ if len(price_sequence) >= 15:
+ base_price = price_sequence[0]
+ normalized = [(p - base_price) / base_price for p in price_sequence]
+
+ # Create feature matrix (15 x 20, flattened)
+ features = np.zeros((15, 20))
+ for i, norm_price in enumerate(normalized):
+ features[i, 0] = norm_price # Normalized price
+ if i > 0:
+ features[i, 1] = normalized[i] - normalized[i-1] # Price change
+
+ return features.flatten()
+
+ return np.zeros(300) # Default feature vector
+
+ except Exception as e:
+ logger.debug(f"Error preparing CNN features: {e}")
+ return np.zeros(300)
+
+ def _estimate_price_change(self, direction: int, confidence: float) -> float:
+ """Estimate price change percentage based on direction and confidence"""
+ try:
+ # Base change scaled by confidence
+ base_change = 0.01 * confidence # Up to 1% change
+
+ if direction == 2: # UP
+ return base_change
+ elif direction == 0: # DOWN
+ return -base_change
+ else: # SAME
+ return 0.0
+
+ except Exception as e:
+ logger.debug(f"Error estimating price change: {e}")
+ return 0.0
\ No newline at end of file
diff --git a/test_model_predictions_visualization.py b/test_model_predictions_visualization.py
new file mode 100644
index 0000000..40a9a3e
--- /dev/null
+++ b/test_model_predictions_visualization.py
@@ -0,0 +1,309 @@
+#!/usr/bin/env python3
+"""
+Test Model Predictions Visualization
+
+This script demonstrates the enhanced model prediction visualization system
+that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart.
+
+Features tested:
+- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing
+- CNN price direction predictions as trend lines with target markers
+- Prediction accuracy feedback with color-coded results
+- Real-time prediction tracking and storage
+- Mock prediction generation for demonstration
+"""
+
+import asyncio
+import logging
+import time
+import numpy as np
+from datetime import datetime, timedelta
+from typing import Dict, List, Optional
+
+from core.config import get_config
+from core.data_provider import DataProvider
+from core.orchestrator import TradingOrchestrator
+from core.trading_executor import TradingExecutor
+from web.clean_dashboard import create_clean_dashboard
+from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
+
+# Setup logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+class ModelPredictionTester:
+ """Test model prediction visualization capabilities"""
+
+ def __init__(self):
+ self.config = get_config()
+ self.data_provider = DataProvider()
+ self.trading_executor = TradingExecutor()
+ self.orchestrator = TradingOrchestrator(
+ data_provider=self.data_provider,
+ enhanced_rl_training=True,
+ model_registry={}
+ )
+
+ # Initialize enhanced training system
+ self.training_system = EnhancedRealtimeTrainingSystem(
+ orchestrator=self.orchestrator,
+ data_provider=self.data_provider,
+ dashboard=None # Will be set after dashboard creation
+ )
+
+ # Create dashboard with enhanced prediction visualization
+ self.dashboard = create_clean_dashboard(
+ data_provider=self.data_provider,
+ orchestrator=self.orchestrator,
+ trading_executor=self.trading_executor
+ )
+
+ # Connect training system to dashboard
+ self.training_system.dashboard = self.dashboard
+ self.dashboard.training_system = self.training_system
+
+ # Test data
+ self.test_symbols = ['ETH/USDT', 'BTC/USDT']
+ self.prediction_count = 0
+
+ logger.info("Model Prediction Tester initialized")
+
+ def generate_mock_dqn_predictions(self, symbol: str, count: int = 10):
+ """Generate mock DQN predictions for testing"""
+ try:
+ current_price = self.data_provider.get_current_price(symbol) or 2400.0
+
+ for i in range(count):
+ # Generate realistic state vector
+ state = np.random.random(100) # 100-dimensional state
+
+ # Generate Q-values with some logic
+ q_values = [np.random.random(), np.random.random(), np.random.random()]
+ action = np.argmax(q_values) # Best action
+ confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution
+
+ # Add some price variation
+ pred_price = current_price + np.random.normal(0, 20)
+
+ # Capture prediction
+ self.training_system.capture_dqn_prediction(
+ symbol=symbol,
+ state=state,
+ q_values=q_values,
+ action=action,
+ confidence=confidence,
+ price=pred_price
+ )
+
+ self.prediction_count += 1
+
+ logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}")
+
+ # Small delay between predictions
+ time.sleep(0.1)
+
+ except Exception as e:
+ logger.error(f"Error generating DQN predictions: {e}")
+
+ def generate_mock_cnn_predictions(self, symbol: str, count: int = 8):
+ """Generate mock CNN predictions for testing"""
+ try:
+ current_price = self.data_provider.get_current_price(symbol) or 2400.0
+
+ for i in range(count):
+ # Generate direction with some logic
+ direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish
+ confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence
+
+ # Calculate predicted price based on direction
+ if direction == 2: # UP
+ price_change = np.random.uniform(5, 50)
+ elif direction == 0: # DOWN
+ price_change = -np.random.uniform(5, 50)
+ else: # SAME
+ price_change = np.random.uniform(-5, 5)
+
+ predicted_price = current_price + price_change
+
+ # Generate features
+ features = np.random.random((15, 20)).flatten() # Flattened CNN features
+
+ # Capture prediction
+ self.training_system.capture_cnn_prediction(
+ symbol=symbol,
+ current_price=current_price,
+ predicted_price=predicted_price,
+ direction=direction,
+ confidence=confidence,
+ features=features
+ )
+
+ self.prediction_count += 1
+
+ logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}")
+
+ # Small delay between predictions
+ time.sleep(0.2)
+
+ except Exception as e:
+ logger.error(f"Error generating CNN predictions: {e}")
+
+ def generate_mock_accuracy_data(self, symbol: str, count: int = 15):
+ """Generate mock prediction accuracy data for testing"""
+ try:
+ current_price = self.data_provider.get_current_price(symbol) or 2400.0
+
+ for i in range(count):
+ # Randomly choose prediction type
+ prediction_type = np.random.choice(['DQN', 'CNN'])
+ predicted_action = np.random.choice([0, 1, 2])
+ confidence = 0.3 + np.random.random() * 0.6
+
+ # Generate realistic price change
+ actual_price_change = np.random.normal(0, 0.01) # ±1% typical change
+
+ # Validate accuracy
+ self.training_system.validate_prediction_accuracy(
+ symbol=symbol,
+ prediction_type=prediction_type,
+ predicted_action=predicted_action,
+ actual_price_change=actual_price_change,
+ confidence=confidence
+ )
+
+ logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}")
+
+ # Small delay
+ time.sleep(0.1)
+
+ except Exception as e:
+ logger.error(f"Error generating accuracy data: {e}")
+
+ def run_prediction_generation_test(self):
+ """Run comprehensive prediction generation test"""
+ try:
+ logger.info("Starting Model Prediction Visualization Test")
+ logger.info("=" * 60)
+
+ # Test for each symbol
+ for symbol in self.test_symbols:
+ logger.info(f"\nGenerating predictions for {symbol}...")
+
+ # Generate DQN predictions
+ logger.info(f"Generating DQN predictions for {symbol}...")
+ self.generate_mock_dqn_predictions(symbol, count=12)
+
+ # Generate CNN predictions
+ logger.info(f"Generating CNN predictions for {symbol}...")
+ self.generate_mock_cnn_predictions(symbol, count=8)
+
+ # Generate accuracy data
+ logger.info(f"Generating accuracy data for {symbol}...")
+ self.generate_mock_accuracy_data(symbol, count=20)
+
+ # Get prediction summary
+ summary = self.training_system.get_prediction_summary(symbol)
+ logger.info(f"Prediction summary for {symbol}: {summary}")
+
+ # Log total statistics
+ training_stats = self.training_system.get_training_statistics()
+ logger.info("\nTraining System Statistics:")
+ logger.info(f"Total predictions generated: {self.prediction_count}")
+ logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}")
+
+ logger.info("\n" + "=" * 60)
+ logger.info("Prediction generation test completed successfully!")
+ logger.info("Dashboard should now show enhanced model predictions on the price chart:")
+ logger.info("- Green/Red arrows for DQN BUY/SELL predictions")
+ logger.info("- Gray circles for DQN HOLD predictions")
+ logger.info("- Colored trend lines for CNN price direction predictions")
+ logger.info("- Diamond markers for CNN prediction targets")
+ logger.info("- Green/Red X marks for correct/incorrect prediction feedback")
+ logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores")
+
+ except Exception as e:
+ logger.error(f"Error in prediction generation test: {e}")
+
+ def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051):
+ """Start dashboard with enhanced prediction visualization"""
+ try:
+ logger.info(f"Starting dashboard with model predictions at http://{host}:{port}")
+
+ # Run prediction generation in background
+ import threading
+ pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True)
+ pred_thread.start()
+
+ # Start training system
+ self.training_system.start_training()
+
+ # Start dashboard
+ self.dashboard.run_server(host=host, port=port, debug=False)
+
+ except Exception as e:
+ logger.error(f"Error starting dashboard with predictions: {e}")
+
+ def test_prediction_accuracy_validation(self):
+ """Test prediction accuracy validation logic"""
+ try:
+ logger.info("Testing prediction accuracy validation...")
+
+ # Test DQN accuracy validation
+ test_cases = [
+ ('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct
+ ('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct
+ ('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct
+ ('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect
+ ('CNN', 2, 0.01, 0.9, True), # UP + price up = correct
+ ('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct
+ ('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct
+ ('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect
+ ]
+
+ for prediction_type, action, price_change, confidence, expected_correct in test_cases:
+ self.training_system.validate_prediction_accuracy(
+ symbol='ETH/USDT',
+ prediction_type=prediction_type,
+ predicted_action=action,
+ actual_price_change=price_change,
+ confidence=confidence
+ )
+
+ # Check if validation worked correctly
+ if self.training_system.prediction_accuracy_history['ETH/USDT']:
+ latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1]
+ actual_correct = latest['correct']
+
+ status = "✓" if actual_correct == expected_correct else "✗"
+ logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}")
+
+ logger.info("Prediction accuracy validation test completed")
+
+ except Exception as e:
+ logger.error(f"Error testing prediction accuracy validation: {e}")
+
+def main():
+ """Main test function"""
+ try:
+ # Create tester
+ tester = ModelPredictionTester()
+
+ # Run accuracy validation test first
+ tester.test_prediction_accuracy_validation()
+
+ # Start dashboard with enhanced predictions
+ logger.info("\nStarting dashboard with enhanced model prediction visualization...")
+ logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions")
+
+ tester.start_dashboard_with_predictions()
+
+ except KeyboardInterrupt:
+ logger.info("Test interrupted by user")
+ except Exception as e:
+ logger.error(f"Error in main test: {e}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py
index 2c6da58..808e81c 100644
--- a/web/clean_dashboard.py
+++ b/web/clean_dashboard.py
@@ -106,6 +106,10 @@ class CleanTradingDashboard:
else:
self.orchestrator = orchestrator
+ # Initialize enhanced training system for predictions
+ self.training_system = None
+ self._initialize_enhanced_training_system()
+
# Initialize layout and component managers
self.layout_manager = DashboardLayoutManager(
starting_balance=self._get_initial_balance(),
@@ -711,9 +715,9 @@ class CleanTradingDashboard:
x=0.5, y=0.5, showarrow=False)
def _add_model_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
- """Add model predictions to the chart - ONLY EXECUTED TRADES on main chart"""
+ """Add enhanced model predictions to the chart with real-time feedback"""
try:
- # Only show EXECUTED TRADES on the main 1m chart
+ # 1. Add executed trades (existing functionality)
executed_signals = [signal for signal in self.recent_decisions if self._get_signal_attribute(signal, 'executed', False)]
if executed_signals:
@@ -721,8 +725,7 @@ class CleanTradingDashboard:
buy_trades = []
sell_trades = []
- for signal in executed_signals[-50:]: # Last 50 executed trades (increased from 20)
- # Try to get full timestamp first, fall back to string timestamp
+ for signal in executed_signals[-50:]: # Last 50 executed trades
signal_time = self._get_signal_attribute(signal, 'full_timestamp')
if not signal_time:
signal_time = self._get_signal_attribute(signal, 'timestamp')
@@ -732,10 +735,9 @@ class CleanTradingDashboard:
signal_confidence = self._get_signal_attribute(signal, 'confidence', 0)
if signal_time and signal_price and signal_confidence > 0:
- # FIXED: Better timestamp conversion to prevent race conditions
+ # Enhanced timestamp handling
if isinstance(signal_time, str):
try:
- # Handle time-only format with current date
if ':' in signal_time and len(signal_time.split(':')) == 3:
now = datetime.now()
time_parts = signal_time.split(':')
@@ -745,7 +747,6 @@ class CleanTradingDashboard:
second=int(time_parts[2]),
microsecond=0
)
- # Handle day boundary issues - if signal seems from future, subtract a day
if signal_time > now + timedelta(minutes=5):
signal_time -= timedelta(days=1)
else:
@@ -754,7 +755,6 @@ class CleanTradingDashboard:
logger.debug(f"Error parsing timestamp {signal_time}: {e}")
continue
elif not isinstance(signal_time, datetime):
- # Convert other timestamp formats to datetime
try:
signal_time = pd.to_datetime(signal_time)
except Exception as e:
@@ -766,7 +766,7 @@ class CleanTradingDashboard:
elif signal_action == 'SELL':
sell_trades.append({'x': signal_time, 'y': signal_price, 'confidence': signal_confidence})
- # Add EXECUTED BUY trades (large green circles)
+ # Add executed trades with enhanced visualization
if buy_trades:
fig.add_trace(
go.Scatter(
@@ -790,7 +790,6 @@ class CleanTradingDashboard:
row=row, col=1
)
- # Add EXECUTED SELL trades (large red circles)
if sell_trades:
fig.add_trace(
go.Scatter(
@@ -813,9 +812,363 @@ class CleanTradingDashboard:
),
row=row, col=1
)
+
+ # 2. NEW: Add real-time model predictions overlay
+ self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
+ self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
+ self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
except Exception as e:
- logger.warning(f"Error adding executed trades to main chart: {e}")
+ logger.warning(f"Error adding model predictions to chart: {e}")
+
+ def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
+ """Add DQN action predictions as directional arrows"""
+ try:
+ # Get recent DQN predictions from orchestrator
+ dqn_predictions = self._get_recent_dqn_predictions(symbol)
+
+ if not dqn_predictions:
+ return
+
+ # Separate predictions by action
+ buy_predictions = []
+ sell_predictions = []
+ hold_predictions = []
+
+ for pred in dqn_predictions[-30:]: # Last 30 DQN predictions
+ action = pred.get('action', 2) # 0=BUY, 1=SELL, 2=HOLD
+ confidence = pred.get('confidence', 0)
+ timestamp = pred.get('timestamp', datetime.now())
+ price = pred.get('price', 0)
+
+ if confidence > 0.3: # Only show predictions with reasonable confidence
+ pred_data = {
+ 'x': timestamp,
+ 'y': price,
+ 'confidence': confidence,
+ 'q_values': pred.get('q_values', [0, 0, 0])
+ }
+
+ if action == 0: # BUY
+ buy_predictions.append(pred_data)
+ elif action == 1: # SELL
+ sell_predictions.append(pred_data)
+ else: # HOLD
+ hold_predictions.append(pred_data)
+
+ # Add DQN BUY predictions (green arrows pointing up)
+ if buy_predictions:
+ fig.add_trace(
+ go.Scatter(
+ x=[p['x'] for p in buy_predictions],
+ y=[p['y'] for p in buy_predictions],
+ mode='markers',
+ marker=dict(
+ symbol='triangle-up',
+ size=[8 + p['confidence'] * 12 for p in buy_predictions], # Size based on confidence
+ color=[f'rgba(0, 200, 0, {0.3 + p["confidence"] * 0.7})' for p in buy_predictions], # Opacity based on confidence
+ line=dict(width=1, color='darkgreen')
+ ),
+ name='DQN BUY Prediction',
+ showlegend=True,
+ hovertemplate="DQN BUY PREDICTION
" +
+ "Price: $%{y:.2f}
" +
+ "Time: %{x}
" +
+ "Confidence: %{customdata[0]:.1%}
" +
+ "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]",
+ customdata=[[p['confidence']] + p['q_values'] for p in buy_predictions]
+ ),
+ row=row, col=1
+ )
+
+ # Add DQN SELL predictions (red arrows pointing down)
+ if sell_predictions:
+ fig.add_trace(
+ go.Scatter(
+ x=[p['x'] for p in sell_predictions],
+ y=[p['y'] for p in sell_predictions],
+ mode='markers',
+ marker=dict(
+ symbol='triangle-down',
+ size=[8 + p['confidence'] * 12 for p in sell_predictions],
+ color=[f'rgba(200, 0, 0, {0.3 + p["confidence"] * 0.7})' for p in sell_predictions],
+ line=dict(width=1, color='darkred')
+ ),
+ name='DQN SELL Prediction',
+ showlegend=True,
+ hovertemplate="DQN SELL PREDICTION
" +
+ "Price: $%{y:.2f}
" +
+ "Time: %{x}
" +
+ "Confidence: %{customdata[0]:.1%}
" +
+ "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]",
+ customdata=[[p['confidence']] + p['q_values'] for p in sell_predictions]
+ ),
+ row=row, col=1
+ )
+
+ # Add DQN HOLD predictions (small gray circles)
+ if hold_predictions:
+ fig.add_trace(
+ go.Scatter(
+ x=[p['x'] for p in hold_predictions],
+ y=[p['y'] for p in hold_predictions],
+ mode='markers',
+ marker=dict(
+ symbol='circle',
+ size=[4 + p['confidence'] * 6 for p in hold_predictions],
+ color=[f'rgba(128, 128, 128, {0.2 + p["confidence"] * 0.5})' for p in hold_predictions],
+ line=dict(width=1, color='gray')
+ ),
+ name='DQN HOLD Prediction',
+ showlegend=True,
+ hovertemplate="DQN HOLD PREDICTION
" +
+ "Price: $%{y:.2f}
" +
+ "Time: %{x}
" +
+ "Confidence: %{customdata[0]:.1%}
" +
+ "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]",
+ customdata=[[p['confidence']] + p['q_values'] for p in hold_predictions]
+ ),
+ row=row, col=1
+ )
+
+ except Exception as e:
+ logger.debug(f"Error adding DQN predictions to chart: {e}")
+
+ def _add_cnn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
+ """Add CNN price direction predictions as trend lines"""
+ try:
+ # Get recent CNN predictions from orchestrator
+ cnn_predictions = self._get_recent_cnn_predictions(symbol)
+
+ if not cnn_predictions:
+ return
+
+ # Create trend prediction lines
+ prediction_lines = []
+
+ for i, pred in enumerate(cnn_predictions[-20:]): # Last 20 CNN predictions
+ direction = pred.get('direction', 1) # 0=DOWN, 1=SAME, 2=UP
+ confidence = pred.get('confidence', 0)
+ timestamp = pred.get('timestamp', datetime.now())
+ current_price = pred.get('current_price', 0)
+ predicted_price = pred.get('predicted_price', current_price)
+
+ if confidence > 0.4 and current_price > 0: # Only show confident predictions
+ # Calculate prediction end point (5 minutes ahead)
+ end_time = timestamp + timedelta(minutes=5)
+
+ # Determine color based on direction
+ if direction == 2: # UP
+ color = f'rgba(0, 255, 0, {0.3 + confidence * 0.4})'
+ line_color = 'green'
+ prediction_name = 'CNN UP'
+ elif direction == 0: # DOWN
+ color = f'rgba(255, 0, 0, {0.3 + confidence * 0.4})'
+ line_color = 'red'
+ prediction_name = 'CNN DOWN'
+ else: # SAME
+ color = f'rgba(128, 128, 128, {0.2 + confidence * 0.3})'
+ line_color = 'gray'
+ prediction_name = 'CNN FLAT'
+
+ # Add prediction line
+ fig.add_trace(
+ go.Scatter(
+ x=[timestamp, end_time],
+ y=[current_price, predicted_price],
+ mode='lines',
+ line=dict(
+ color=line_color,
+ width=2 + confidence * 3, # Line width based on confidence
+ dash='dot' if direction == 1 else 'solid'
+ ),
+ name=f'{prediction_name} Prediction',
+ showlegend=i == 0, # Only show legend for first instance
+ hovertemplate=f"{prediction_name} PREDICTION
" +
+ "From: $%{y[0]:.2f}
" +
+ "To: $%{y[1]:.2f}
" +
+ "Time: %{x[0]} → %{x[1]}
" +
+ f"Confidence: {confidence:.1%}
" +
+ f"Direction: {['DOWN', 'SAME', 'UP'][direction]}"
+ ),
+ row=row, col=1
+ )
+
+ # Add prediction end point marker
+ fig.add_trace(
+ go.Scatter(
+ x=[end_time],
+ y=[predicted_price],
+ mode='markers',
+ marker=dict(
+ symbol='diamond',
+ size=6 + confidence * 8,
+ color=color,
+ line=dict(width=1, color=line_color)
+ ),
+ name=f'{prediction_name} Target',
+ showlegend=False,
+ hovertemplate=f"{prediction_name} TARGET
" +
+ "Target Price: $%{y:.2f}
" +
+ "Target Time: %{x}
" +
+ f"Confidence: {confidence:.1%}"
+ ),
+ row=row, col=1
+ )
+
+ except Exception as e:
+ logger.debug(f"Error adding CNN predictions to chart: {e}")
+
+ def _add_prediction_accuracy_feedback(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
+ """Add prediction accuracy feedback with color-coded results"""
+ try:
+ # Get prediction accuracy history
+ accuracy_data = self._get_prediction_accuracy_history(symbol)
+
+ if not accuracy_data:
+ return
+
+ # Add accuracy feedback markers
+ correct_predictions = []
+ incorrect_predictions = []
+
+ for acc in accuracy_data[-50:]: # Last 50 accuracy points
+ timestamp = acc.get('timestamp', datetime.now())
+ price = acc.get('actual_price', 0)
+ was_correct = acc.get('correct', False)
+ prediction_type = acc.get('prediction_type', 'unknown')
+ accuracy_score = acc.get('accuracy_score', 0)
+
+ if price > 0:
+ acc_data = {
+ 'x': timestamp,
+ 'y': price,
+ 'type': prediction_type,
+ 'score': accuracy_score
+ }
+
+ if was_correct:
+ correct_predictions.append(acc_data)
+ else:
+ incorrect_predictions.append(acc_data)
+
+ # Add correct prediction markers (green checkmarks)
+ if correct_predictions:
+ fig.add_trace(
+ go.Scatter(
+ x=[p['x'] for p in correct_predictions],
+ y=[p['y'] for p in correct_predictions],
+ mode='markers',
+ marker=dict(
+ symbol='x',
+ size=8,
+ color='rgba(0, 255, 0, 0.8)',
+ line=dict(width=2, color='darkgreen')
+ ),
+ name='Correct Predictions',
+ showlegend=True,
+ hovertemplate="CORRECT PREDICTION
" +
+ "Price: $%{y:.2f}
" +
+ "Time: %{x}
" +
+ "Type: %{customdata[0]}
" +
+ "Accuracy: %{customdata[1]:.1%}",
+ customdata=[[p['type'], p['score']] for p in correct_predictions]
+ ),
+ row=row, col=1
+ )
+
+ # Add incorrect prediction markers (red X marks)
+ if incorrect_predictions:
+ fig.add_trace(
+ go.Scatter(
+ x=[p['x'] for p in incorrect_predictions],
+ y=[p['y'] for p in incorrect_predictions],
+ mode='markers',
+ marker=dict(
+ symbol='x',
+ size=8,
+ color='rgba(255, 0, 0, 0.8)',
+ line=dict(width=2, color='darkred')
+ ),
+ name='Incorrect Predictions',
+ showlegend=True,
+ hovertemplate="INCORRECT PREDICTION
" +
+ "Price: $%{y:.2f}
" +
+ "Time: %{x}
" +
+ "Type: %{customdata[0]}
" +
+ "Accuracy: %{customdata[1]:.1%}",
+ customdata=[[p['type'], p['score']] for p in incorrect_predictions]
+ ),
+ row=row, col=1
+ )
+
+ except Exception as e:
+ logger.debug(f"Error adding prediction accuracy feedback to chart: {e}")
+
+ def _get_recent_dqn_predictions(self, symbol: str) -> List[Dict]:
+ """Get recent DQN predictions from enhanced training system (forward-looking only)"""
+ try:
+ predictions = []
+
+ # Get REAL forward-looking predictions from enhanced training system
+ if hasattr(self, 'training_system') and self.training_system:
+ if hasattr(self.training_system, 'recent_dqn_predictions'):
+ predictions.extend(self.training_system.recent_dqn_predictions.get(symbol, []))
+
+ # Get from orchestrator as fallback
+ if hasattr(self.orchestrator, 'recent_dqn_predictions'):
+ predictions.extend(self.orchestrator.recent_dqn_predictions.get(symbol, []))
+
+ # REMOVED: Mock prediction generation - now using REAL predictions only
+ # No more artificial past predictions or random data
+
+ return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now()))
+
+ except Exception as e:
+ logger.debug(f"Error getting DQN predictions: {e}")
+ return []
+
+ def _get_recent_cnn_predictions(self, symbol: str) -> List[Dict]:
+ """Get recent CNN predictions from enhanced training system (forward-looking only)"""
+ try:
+ predictions = []
+
+ # Get REAL forward-looking predictions from enhanced training system
+ if hasattr(self, 'training_system') and self.training_system:
+ if hasattr(self.training_system, 'recent_cnn_predictions'):
+ predictions.extend(self.training_system.recent_cnn_predictions.get(symbol, []))
+
+ # Get from orchestrator as fallback
+ if hasattr(self.orchestrator, 'recent_cnn_predictions'):
+ predictions.extend(self.orchestrator.recent_cnn_predictions.get(symbol, []))
+
+ # REMOVED: Mock prediction generation - now using REAL predictions only
+ # No more artificial past predictions or random data
+
+ return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now()))
+
+ except Exception as e:
+ logger.debug(f"Error getting CNN predictions: {e}")
+ return []
+
+ def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
+ """Get REAL prediction accuracy history from validated forward-looking predictions"""
+ try:
+ accuracy_data = []
+
+ # Get REAL accuracy data from training system validation
+ if hasattr(self, 'training_system') and self.training_system:
+ if hasattr(self.training_system, 'prediction_accuracy_history'):
+ accuracy_data.extend(self.training_system.prediction_accuracy_history.get(symbol, []))
+
+ # REMOVED: Mock accuracy data generation - now using REAL validation results only
+ # Accuracy is now based on actual prediction outcomes, not random data
+
+ return sorted(accuracy_data, key=lambda x: x.get('timestamp', datetime.now()))
+
+ except Exception as e:
+ logger.debug(f"Error getting prediction accuracy history: {e}")
+ return []
def _add_signals_to_mini_chart(self, fig: go.Figure, symbol: str, ws_data_1s: pd.DataFrame, row: int = 2):
"""Add ALL signals (executed and non-executed) to the 1s mini chart"""
@@ -2566,6 +2919,33 @@ class CleanTradingDashboard:
except Exception as e:
logger.warning(f"Error clearing old signals: {e}")
+ def _initialize_enhanced_training_system(self):
+ """Initialize enhanced training system for model predictions"""
+ try:
+ # Try to import and initialize enhanced training system
+ from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
+
+ self.training_system = EnhancedRealtimeTrainingSystem(
+ orchestrator=self.orchestrator,
+ data_provider=self.data_provider,
+ dashboard=self
+ )
+
+ # Initialize prediction storage
+ if not hasattr(self.orchestrator, 'recent_dqn_predictions'):
+ self.orchestrator.recent_dqn_predictions = {}
+ if not hasattr(self.orchestrator, 'recent_cnn_predictions'):
+ self.orchestrator.recent_cnn_predictions = {}
+
+ logger.info("Enhanced training system initialized for model predictions")
+
+ except ImportError:
+ logger.warning("Enhanced training system not available - using mock predictions")
+ self.training_system = None
+ except Exception as e:
+ logger.error(f"Error initializing enhanced training system: {e}")
+ self.training_system = None
+
def _initialize_cob_integration(self):
"""Initialize COB integration with high-frequency data handling"""
try: