From 97ea27ea84f186234b6c1d5712f62ff59891a486 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Fri, 27 Jun 2025 01:12:55 +0300 Subject: [PATCH] display predictions --- enhanced_realtime_training.py | 534 +++++++++++++++++++++++- test_model_predictions_visualization.py | 309 ++++++++++++++ web/clean_dashboard.py | 402 +++++++++++++++++- 3 files changed, 1229 insertions(+), 16 deletions(-) create mode 100644 test_model_predictions_visualization.py diff --git a/enhanced_realtime_training.py b/enhanced_realtime_training.py index 88dd23d..5cf281c 100644 --- a/enhanced_realtime_training.py +++ b/enhanced_realtime_training.py @@ -71,6 +71,34 @@ class EnhancedRealtimeTrainingSystem: 'validation': 0.0 } + # Model prediction tracking - NEW for dashboard visualization + self.recent_dqn_predictions = { + 'ETH/USDT': deque(maxlen=100), + 'BTC/USDT': deque(maxlen=100) + } + self.recent_cnn_predictions = { + 'ETH/USDT': deque(maxlen=50), + 'BTC/USDT': deque(maxlen=50) + } + self.prediction_accuracy_history = { + 'ETH/USDT': deque(maxlen=200), + 'BTC/USDT': deque(maxlen=200) + } + + # FIXED: Forward-looking prediction system + self.pending_predictions = { + 'ETH/USDT': deque(maxlen=100), # Predictions waiting for validation + 'BTC/USDT': deque(maxlen=100) + } + self.last_prediction_time = { + 'ETH/USDT': 0, + 'BTC/USDT': 0 + } + self.prediction_intervals = { + 'dqn': 30, # Make DQN prediction every 30 seconds + 'cnn': 60 # Make CNN prediction every 60 seconds + } + # Real-time data streams self.real_time_data = { 'ticks': deque(maxlen=1000), @@ -146,24 +174,27 @@ class EnhancedRealtimeTrainingSystem: current_time = time.time() self.training_iteration += 1 - # 1. DQN Training (every 5 seconds with enough data) + # 1. FORWARD-LOOKING PREDICTIONS - Generate real predictions for future validation + self.generate_forward_looking_predictions() + + # 2. DQN Training (every 5 seconds with enough data) if (current_time - self.last_training_times['dqn'] > self.training_config['dqn_training_interval'] and len(self.experience_buffer) >= self.training_config['min_training_samples']): self._perform_enhanced_dqn_training() self.last_training_times['dqn'] = current_time - # 2. CNN Training (every 10 seconds) + # 3. CNN Training (every 10 seconds) if (current_time - self.last_training_times['cnn'] > self.training_config['cnn_training_interval'] and len(self.real_time_data['ohlcv_1m']) >= 20): self._perform_enhanced_cnn_training() self.last_training_times['cnn'] = current_time - # 3. Validation (every minute) + # 4. Validation (every minute) if current_time - self.last_training_times['validation'] > self.training_config['validation_interval']: self._perform_validation() self.last_training_times['validation'] = current_time - # 4. Adaptive learning rate adjustment + # 5. Adaptive learning rate adjustment if self.training_iteration % 100 == 0: self._adapt_learning_parameters() @@ -911,6 +942,11 @@ class EnhancedRealtimeTrainingSystem: 'dqn_loss_count': len(self.performance_history['dqn_losses']), 'cnn_loss_count': len(self.performance_history['cnn_losses']), 'validation_count': len(self.performance_history['validation_scores']) + }, + 'prediction_stats': { + 'dqn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_dqn_predictions.items()}, + 'cnn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_cnn_predictions.items()}, + 'accuracy_history': {symbol: len(history) for symbol, history in self.prediction_accuracy_history.items()} } } @@ -927,4 +963,492 @@ class EnhancedRealtimeTrainingSystem: except Exception as e: logger.error(f"Error getting training statistics: {e}") - return {'error': str(e)} \ No newline at end of file + return {'error': str(e)} + + def capture_dqn_prediction(self, symbol: str, state: np.ndarray, q_values: List[float], action: int, confidence: float, price: float): + """Capture DQN prediction for dashboard visualization""" + try: + prediction = { + 'timestamp': datetime.now(), + 'symbol': symbol, + 'state': state.tolist() if hasattr(state, 'tolist') else state, + 'q_values': q_values, + 'action': action, # 0=BUY, 1=SELL, 2=HOLD + 'confidence': confidence, + 'price': price + } + + if symbol in self.recent_dqn_predictions: + self.recent_dqn_predictions[symbol].append(prediction) + + logger.debug(f"DQN prediction captured: {symbol} action={action} confidence={confidence:.2f}") + + except Exception as e: + logger.debug(f"Error capturing DQN prediction: {e}") + + def capture_cnn_prediction(self, symbol: str, current_price: float, predicted_price: float, direction: int, confidence: float, features: Optional[np.ndarray] = None): + """Capture CNN prediction for dashboard visualization""" + try: + prediction = { + 'timestamp': datetime.now(), + 'symbol': symbol, + 'current_price': current_price, + 'predicted_price': predicted_price, + 'direction': direction, # 0=DOWN, 1=SAME, 2=UP + 'confidence': confidence, + 'features': features.tolist() if features is not None and hasattr(features, 'tolist') else None + } + + if symbol in self.recent_cnn_predictions: + self.recent_cnn_predictions[symbol].append(prediction) + + logger.debug(f"CNN prediction captured: {symbol} direction={direction} confidence={confidence:.2f}") + + except Exception as e: + logger.debug(f"Error capturing CNN prediction: {e}") + + def validate_prediction_accuracy(self, symbol: str, prediction_type: str, predicted_action: int, actual_price_change: float, confidence: float): + """Validate prediction accuracy and store results""" + try: + # Determine if prediction was correct + was_correct = False + + if prediction_type == 'DQN': + # For DQN: BUY (0) should be followed by price increase, SELL (1) by decrease + if predicted_action == 0 and actual_price_change > 0.001: # BUY + price up + was_correct = True + elif predicted_action == 1 and actual_price_change < -0.001: # SELL + price down + was_correct = True + elif predicted_action == 2 and abs(actual_price_change) <= 0.001: # HOLD + no change + was_correct = True + + elif prediction_type == 'CNN': + # For CNN: direction prediction accuracy + if predicted_action == 2 and actual_price_change > 0.001: # UP + price up + was_correct = True + elif predicted_action == 0 and actual_price_change < -0.001: # DOWN + price down + was_correct = True + elif predicted_action == 1 and abs(actual_price_change) <= 0.001: # SAME + no change + was_correct = True + + # Calculate accuracy score based on confidence and correctness + accuracy_score = confidence if was_correct else (1.0 - confidence) + + accuracy_data = { + 'timestamp': datetime.now(), + 'symbol': symbol, + 'prediction_type': prediction_type, + 'correct': was_correct, + 'accuracy_score': accuracy_score, + 'confidence': confidence, + 'actual_price_change': actual_price_change, + 'predicted_action': predicted_action + } + + if symbol in self.prediction_accuracy_history: + self.prediction_accuracy_history[symbol].append(accuracy_data) + + logger.debug(f"Prediction accuracy validated: {symbol} {prediction_type} correct={was_correct} score={accuracy_score:.2f}") + + except Exception as e: + logger.debug(f"Error validating prediction accuracy: {e}") + + def get_prediction_summary(self, symbol: str) -> Dict[str, Any]: + """Get prediction summary for a symbol""" + try: + summary = { + 'symbol': symbol, + 'dqn_predictions': len(self.recent_dqn_predictions.get(symbol, [])), + 'cnn_predictions': len(self.recent_cnn_predictions.get(symbol, [])), + 'accuracy_history': len(self.prediction_accuracy_history.get(symbol, [])), + 'pending_predictions': len(self.pending_predictions.get(symbol, [])) + } + + # Calculate accuracy statistics + if symbol in self.prediction_accuracy_history and self.prediction_accuracy_history[symbol]: + accuracy_data = list(self.prediction_accuracy_history[symbol]) + + total_predictions = len(accuracy_data) + correct_predictions = sum(1 for acc in accuracy_data if acc['correct']) + + summary['total_predictions'] = total_predictions + summary['correct_predictions'] = correct_predictions + summary['accuracy_rate'] = correct_predictions / total_predictions if total_predictions > 0 else 0.0 + + # Calculate accuracy by prediction type + dqn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'DQN'] + cnn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'CNN'] + + if dqn_accuracy_data: + dqn_correct = sum(1 for acc in dqn_accuracy_data if acc['correct']) + summary['dqn_accuracy_rate'] = dqn_correct / len(dqn_accuracy_data) + else: + summary['dqn_accuracy_rate'] = 0.0 + + if cnn_accuracy_data: + cnn_correct = sum(1 for acc in cnn_accuracy_data if acc['correct']) + summary['cnn_accuracy_rate'] = cnn_correct / len(cnn_accuracy_data) + else: + summary['cnn_accuracy_rate'] = 0.0 + + return summary + + except Exception as e: + logger.error(f"Error getting prediction summary: {e}") + return {'error': str(e)} + + def generate_forward_looking_predictions(self): + """Generate forward-looking predictions based on current market data""" + try: + current_time = time.time() + + for symbol in ['ETH/USDT', 'BTC/USDT']: + # Check if it's time to make new predictions + time_since_last = current_time - self.last_prediction_time.get(symbol, 0) + + # Generate DQN prediction every 30 seconds + if time_since_last >= self.prediction_intervals['dqn']: + self._generate_forward_dqn_prediction(symbol, current_time) + + # Generate CNN prediction every 60 seconds + if time_since_last >= self.prediction_intervals['cnn']: + self._generate_forward_cnn_prediction(symbol, current_time) + + # Validate pending predictions + self._validate_pending_predictions(symbol, current_time) + + except Exception as e: + logger.error(f"Error generating forward-looking predictions: {e}") + + def _generate_forward_dqn_prediction(self, symbol: str, current_time: float): + """Generate a DQN prediction for future price movement""" + try: + # Get current market state (only historical data) + current_state = self._build_comprehensive_state() + current_price = self._get_current_price_from_data(symbol) + + if current_price is None: + return + + # Use DQN model to predict action (if available) + if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent') + and self.orchestrator.rl_agent): + + # Get Q-values from model + q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True) + if isinstance(q_values, tuple): + action, q_vals = q_values + q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0] + else: + action = q_values + q_values = [0.33, 0.33, 0.34] # Default uniform distribution + + confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33 + + else: + # Fallback to technical analysis-based prediction + action, q_values, confidence = self._technical_analysis_prediction(symbol) + + # Create forward-looking prediction + prediction_time = datetime.now() + target_time = prediction_time + timedelta(minutes=5) # Predict 5 minutes ahead + + prediction = { + 'id': f"dqn_{symbol}_{int(current_time)}", + 'type': 'DQN', + 'symbol': symbol, + 'prediction_time': prediction_time, + 'target_time': target_time, + 'current_price': current_price, + 'predicted_action': action, + 'q_values': q_values, + 'confidence': confidence, + 'state': current_state.tolist() if hasattr(current_state, 'tolist') else current_state, + 'validated': False + } + + # Add to pending predictions for future validation + if symbol in self.pending_predictions: + self.pending_predictions[symbol].append(prediction) + + # Add to recent predictions for display (only if confident enough) + if confidence > 0.4: + display_prediction = { + 'timestamp': prediction_time, + 'price': current_price, + 'action': action, + 'confidence': confidence, + 'q_values': q_values + } + if symbol in self.recent_dqn_predictions: + self.recent_dqn_predictions[symbol].append(display_prediction) + + self.last_prediction_time[symbol] = current_time + + logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}") + + except Exception as e: + logger.error(f"Error generating forward DQN prediction: {e}") + + def _generate_forward_cnn_prediction(self, symbol: str, current_time: float): + """Generate a CNN prediction for future price direction""" + try: + # Get current price and historical sequence (only past data) + current_price = self._get_current_price_from_data(symbol) + price_sequence = self._get_historical_price_sequence(symbol, periods=15) + + if current_price is None or len(price_sequence) < 15: + return + + # Use CNN model to predict direction (if available) + if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model') + and self.orchestrator.cnn_model): + + # Prepare features for CNN + features = self._prepare_cnn_features(price_sequence) + + try: + # Get prediction from CNN model + prediction_output = self.orchestrator.cnn_model.predict(features) + if hasattr(prediction_output, 'tolist'): + pred_probs = prediction_output.tolist() + else: + pred_probs = [0.33, 0.33, 0.34] # Default + + direction = int(np.argmax(pred_probs)) # 0=DOWN, 1=SAME, 2=UP + confidence = max(pred_probs) + + except Exception as e: + logger.debug(f"CNN model prediction failed: {e}") + direction, confidence = self._technical_direction_prediction(symbol) + + else: + # Fallback to technical analysis + direction, confidence = self._technical_direction_prediction(symbol) + + # Calculate predicted price based on direction + price_change_percent = self._estimate_price_change(direction, confidence) + predicted_price = current_price * (1 + price_change_percent) + + # Create forward-looking prediction + prediction_time = datetime.now() + target_time = prediction_time + timedelta(minutes=10) # Predict 10 minutes ahead + + prediction = { + 'id': f"cnn_{symbol}_{int(current_time)}", + 'type': 'CNN', + 'symbol': symbol, + 'prediction_time': prediction_time, + 'target_time': target_time, + 'current_price': current_price, + 'predicted_price': predicted_price, + 'direction': direction, + 'confidence': confidence, + 'features': features.tolist() if hasattr(features, 'tolist') else None, + 'validated': False + } + + # Add to pending predictions for future validation + if symbol in self.pending_predictions: + self.pending_predictions[symbol].append(prediction) + + # Add to recent predictions for display (only if confident enough) + if confidence > 0.5: + display_prediction = { + 'timestamp': prediction_time, + 'current_price': current_price, + 'predicted_price': predicted_price, + 'direction': direction, + 'confidence': confidence + } + if symbol in self.recent_cnn_predictions: + self.recent_cnn_predictions[symbol].append(display_prediction) + + logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}") + + except Exception as e: + logger.error(f"Error generating forward CNN prediction: {e}") + + def _validate_pending_predictions(self, symbol: str, current_time: float): + """Validate pending predictions when their target time arrives""" + try: + if symbol not in self.pending_predictions: + return + + current_datetime = datetime.now() + validated_predictions = [] + + # Check each pending prediction + for prediction in list(self.pending_predictions[symbol]): + target_time = prediction['target_time'] + + # If target time has passed, validate the prediction + if current_datetime >= target_time: + actual_price = self._get_current_price_from_data(symbol) + + if actual_price is not None: + # Calculate actual price change + predicted_price = prediction.get('predicted_price', prediction['current_price']) + actual_change = (actual_price - prediction['current_price']) / prediction['current_price'] + predicted_change = (predicted_price - prediction['current_price']) / prediction['current_price'] + + # Validate based on prediction type + if prediction['type'] == 'DQN': + was_correct = self._validate_dqn_prediction(prediction, actual_change) + else: # CNN + was_correct = self._validate_cnn_prediction(prediction, actual_change) + + # Store accuracy result + accuracy_data = { + 'timestamp': current_datetime, + 'symbol': symbol, + 'prediction_type': prediction['type'], + 'correct': was_correct, + 'accuracy_score': prediction['confidence'] if was_correct else (1.0 - prediction['confidence']), + 'confidence': prediction['confidence'], + 'actual_price_change': actual_change, + 'predicted_action': prediction.get('predicted_action', prediction.get('direction', 0)), + 'actual_price': actual_price + } + + if symbol in self.prediction_accuracy_history: + self.prediction_accuracy_history[symbol].append(accuracy_data) + + validated_predictions.append(prediction['id']) + + logger.info(f"Validated {prediction['type']} prediction: {symbol} correct={was_correct} confidence={prediction['confidence']:.2f}") + + # Remove validated predictions from pending list + if validated_predictions: + self.pending_predictions[symbol] = deque([ + p for p in self.pending_predictions[symbol] + if p['id'] not in validated_predictions + ], maxlen=100) + + except Exception as e: + logger.error(f"Error validating pending predictions: {e}") + + def _validate_dqn_prediction(self, prediction: Dict, actual_change: float) -> bool: + """Validate DQN action prediction""" + predicted_action = prediction['predicted_action'] + threshold = 0.005 # 0.5% threshold for significant movement + + if predicted_action == 0: # BUY prediction + return actual_change > threshold + elif predicted_action == 1: # SELL prediction + return actual_change < -threshold + else: # HOLD prediction + return abs(actual_change) <= threshold + + def _validate_cnn_prediction(self, prediction: Dict, actual_change: float) -> bool: + """Validate CNN direction prediction""" + predicted_direction = prediction['direction'] + threshold = 0.002 # 0.2% threshold for direction + + if predicted_direction == 2: # UP prediction + return actual_change > threshold + elif predicted_direction == 0: # DOWN prediction + return actual_change < -threshold + else: # SAME prediction + return abs(actual_change) <= threshold + + def _get_current_price_from_data(self, symbol: str) -> Optional[float]: + """Get current price from real-time data streams""" + try: + if len(self.real_time_data['ohlcv_1m']) > 0: + return self.real_time_data['ohlcv_1m'][-1]['close'] + return None + except Exception as e: + logger.debug(f"Error getting current price: {e}") + return None + + def _get_historical_price_sequence(self, symbol: str, periods: int = 15) -> List[float]: + """Get historical price sequence for CNN features""" + try: + if len(self.real_time_data['ohlcv_1m']) >= periods: + return [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-periods:]] + return [] + except Exception as e: + logger.debug(f"Error getting price sequence: {e}") + return [] + + def _technical_analysis_prediction(self, symbol: str) -> Tuple[int, List[float], float]: + """Fallback technical analysis prediction for DQN""" + try: + # Simple momentum-based prediction + if len(self.real_time_data['ohlcv_1m']) >= 5: + recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]] + momentum = (recent_prices[-1] - recent_prices[0]) / recent_prices[0] + + if momentum > 0.01: # 1% upward momentum + return 0, [0.6, 0.2, 0.2], 0.6 # BUY + elif momentum < -0.01: # 1% downward momentum + return 1, [0.2, 0.6, 0.2], 0.6 # SELL + else: + return 2, [0.2, 0.2, 0.6], 0.6 # HOLD + + return 2, [0.33, 0.33, 0.34], 0.33 # Default HOLD + + except Exception as e: + logger.debug(f"Error in technical analysis prediction: {e}") + return 2, [0.33, 0.33, 0.34], 0.33 + + def _technical_direction_prediction(self, symbol: str) -> Tuple[int, float]: + """Fallback technical analysis for CNN direction""" + try: + if len(self.real_time_data['ohlcv_1m']) >= 3: + recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-3:]] + short_momentum = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2] + + if short_momentum > 0.005: # 0.5% short-term up + return 2, 0.65 # UP + elif short_momentum < -0.005: # 0.5% short-term down + return 0, 0.65 # DOWN + else: + return 1, 0.55 # SAME + + return 1, 0.5 # Default SAME + + except Exception as e: + logger.debug(f"Error in technical direction prediction: {e}") + return 1, 0.5 + + def _prepare_cnn_features(self, price_sequence: List[float]) -> np.ndarray: + """Prepare features for CNN model""" + try: + # Normalize prices relative to first price + if len(price_sequence) >= 15: + base_price = price_sequence[0] + normalized = [(p - base_price) / base_price for p in price_sequence] + + # Create feature matrix (15 x 20, flattened) + features = np.zeros((15, 20)) + for i, norm_price in enumerate(normalized): + features[i, 0] = norm_price # Normalized price + if i > 0: + features[i, 1] = normalized[i] - normalized[i-1] # Price change + + return features.flatten() + + return np.zeros(300) # Default feature vector + + except Exception as e: + logger.debug(f"Error preparing CNN features: {e}") + return np.zeros(300) + + def _estimate_price_change(self, direction: int, confidence: float) -> float: + """Estimate price change percentage based on direction and confidence""" + try: + # Base change scaled by confidence + base_change = 0.01 * confidence # Up to 1% change + + if direction == 2: # UP + return base_change + elif direction == 0: # DOWN + return -base_change + else: # SAME + return 0.0 + + except Exception as e: + logger.debug(f"Error estimating price change: {e}") + return 0.0 \ No newline at end of file diff --git a/test_model_predictions_visualization.py b/test_model_predictions_visualization.py new file mode 100644 index 0000000..40a9a3e --- /dev/null +++ b/test_model_predictions_visualization.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +Test Model Predictions Visualization + +This script demonstrates the enhanced model prediction visualization system +that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart. + +Features tested: +- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing +- CNN price direction predictions as trend lines with target markers +- Prediction accuracy feedback with color-coded results +- Real-time prediction tracking and storage +- Mock prediction generation for demonstration +""" + +import asyncio +import logging +import time +import numpy as np +from datetime import datetime, timedelta +from typing import Dict, List, Optional + +from core.config import get_config +from core.data_provider import DataProvider +from core.orchestrator import TradingOrchestrator +from core.trading_executor import TradingExecutor +from web.clean_dashboard import create_clean_dashboard +from enhanced_realtime_training import EnhancedRealtimeTrainingSystem + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class ModelPredictionTester: + """Test model prediction visualization capabilities""" + + def __init__(self): + self.config = get_config() + self.data_provider = DataProvider() + self.trading_executor = TradingExecutor() + self.orchestrator = TradingOrchestrator( + data_provider=self.data_provider, + enhanced_rl_training=True, + model_registry={} + ) + + # Initialize enhanced training system + self.training_system = EnhancedRealtimeTrainingSystem( + orchestrator=self.orchestrator, + data_provider=self.data_provider, + dashboard=None # Will be set after dashboard creation + ) + + # Create dashboard with enhanced prediction visualization + self.dashboard = create_clean_dashboard( + data_provider=self.data_provider, + orchestrator=self.orchestrator, + trading_executor=self.trading_executor + ) + + # Connect training system to dashboard + self.training_system.dashboard = self.dashboard + self.dashboard.training_system = self.training_system + + # Test data + self.test_symbols = ['ETH/USDT', 'BTC/USDT'] + self.prediction_count = 0 + + logger.info("Model Prediction Tester initialized") + + def generate_mock_dqn_predictions(self, symbol: str, count: int = 10): + """Generate mock DQN predictions for testing""" + try: + current_price = self.data_provider.get_current_price(symbol) or 2400.0 + + for i in range(count): + # Generate realistic state vector + state = np.random.random(100) # 100-dimensional state + + # Generate Q-values with some logic + q_values = [np.random.random(), np.random.random(), np.random.random()] + action = np.argmax(q_values) # Best action + confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution + + # Add some price variation + pred_price = current_price + np.random.normal(0, 20) + + # Capture prediction + self.training_system.capture_dqn_prediction( + symbol=symbol, + state=state, + q_values=q_values, + action=action, + confidence=confidence, + price=pred_price + ) + + self.prediction_count += 1 + + logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}") + + # Small delay between predictions + time.sleep(0.1) + + except Exception as e: + logger.error(f"Error generating DQN predictions: {e}") + + def generate_mock_cnn_predictions(self, symbol: str, count: int = 8): + """Generate mock CNN predictions for testing""" + try: + current_price = self.data_provider.get_current_price(symbol) or 2400.0 + + for i in range(count): + # Generate direction with some logic + direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish + confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence + + # Calculate predicted price based on direction + if direction == 2: # UP + price_change = np.random.uniform(5, 50) + elif direction == 0: # DOWN + price_change = -np.random.uniform(5, 50) + else: # SAME + price_change = np.random.uniform(-5, 5) + + predicted_price = current_price + price_change + + # Generate features + features = np.random.random((15, 20)).flatten() # Flattened CNN features + + # Capture prediction + self.training_system.capture_cnn_prediction( + symbol=symbol, + current_price=current_price, + predicted_price=predicted_price, + direction=direction, + confidence=confidence, + features=features + ) + + self.prediction_count += 1 + + logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}") + + # Small delay between predictions + time.sleep(0.2) + + except Exception as e: + logger.error(f"Error generating CNN predictions: {e}") + + def generate_mock_accuracy_data(self, symbol: str, count: int = 15): + """Generate mock prediction accuracy data for testing""" + try: + current_price = self.data_provider.get_current_price(symbol) or 2400.0 + + for i in range(count): + # Randomly choose prediction type + prediction_type = np.random.choice(['DQN', 'CNN']) + predicted_action = np.random.choice([0, 1, 2]) + confidence = 0.3 + np.random.random() * 0.6 + + # Generate realistic price change + actual_price_change = np.random.normal(0, 0.01) # ±1% typical change + + # Validate accuracy + self.training_system.validate_prediction_accuracy( + symbol=symbol, + prediction_type=prediction_type, + predicted_action=predicted_action, + actual_price_change=actual_price_change, + confidence=confidence + ) + + logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}") + + # Small delay + time.sleep(0.1) + + except Exception as e: + logger.error(f"Error generating accuracy data: {e}") + + def run_prediction_generation_test(self): + """Run comprehensive prediction generation test""" + try: + logger.info("Starting Model Prediction Visualization Test") + logger.info("=" * 60) + + # Test for each symbol + for symbol in self.test_symbols: + logger.info(f"\nGenerating predictions for {symbol}...") + + # Generate DQN predictions + logger.info(f"Generating DQN predictions for {symbol}...") + self.generate_mock_dqn_predictions(symbol, count=12) + + # Generate CNN predictions + logger.info(f"Generating CNN predictions for {symbol}...") + self.generate_mock_cnn_predictions(symbol, count=8) + + # Generate accuracy data + logger.info(f"Generating accuracy data for {symbol}...") + self.generate_mock_accuracy_data(symbol, count=20) + + # Get prediction summary + summary = self.training_system.get_prediction_summary(symbol) + logger.info(f"Prediction summary for {symbol}: {summary}") + + # Log total statistics + training_stats = self.training_system.get_training_statistics() + logger.info("\nTraining System Statistics:") + logger.info(f"Total predictions generated: {self.prediction_count}") + logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}") + + logger.info("\n" + "=" * 60) + logger.info("Prediction generation test completed successfully!") + logger.info("Dashboard should now show enhanced model predictions on the price chart:") + logger.info("- Green/Red arrows for DQN BUY/SELL predictions") + logger.info("- Gray circles for DQN HOLD predictions") + logger.info("- Colored trend lines for CNN price direction predictions") + logger.info("- Diamond markers for CNN prediction targets") + logger.info("- Green/Red X marks for correct/incorrect prediction feedback") + logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores") + + except Exception as e: + logger.error(f"Error in prediction generation test: {e}") + + def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051): + """Start dashboard with enhanced prediction visualization""" + try: + logger.info(f"Starting dashboard with model predictions at http://{host}:{port}") + + # Run prediction generation in background + import threading + pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True) + pred_thread.start() + + # Start training system + self.training_system.start_training() + + # Start dashboard + self.dashboard.run_server(host=host, port=port, debug=False) + + except Exception as e: + logger.error(f"Error starting dashboard with predictions: {e}") + + def test_prediction_accuracy_validation(self): + """Test prediction accuracy validation logic""" + try: + logger.info("Testing prediction accuracy validation...") + + # Test DQN accuracy validation + test_cases = [ + ('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct + ('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct + ('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct + ('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect + ('CNN', 2, 0.01, 0.9, True), # UP + price up = correct + ('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct + ('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct + ('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect + ] + + for prediction_type, action, price_change, confidence, expected_correct in test_cases: + self.training_system.validate_prediction_accuracy( + symbol='ETH/USDT', + prediction_type=prediction_type, + predicted_action=action, + actual_price_change=price_change, + confidence=confidence + ) + + # Check if validation worked correctly + if self.training_system.prediction_accuracy_history['ETH/USDT']: + latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1] + actual_correct = latest['correct'] + + status = "✓" if actual_correct == expected_correct else "✗" + logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}") + + logger.info("Prediction accuracy validation test completed") + + except Exception as e: + logger.error(f"Error testing prediction accuracy validation: {e}") + +def main(): + """Main test function""" + try: + # Create tester + tester = ModelPredictionTester() + + # Run accuracy validation test first + tester.test_prediction_accuracy_validation() + + # Start dashboard with enhanced predictions + logger.info("\nStarting dashboard with enhanced model prediction visualization...") + logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions") + + tester.start_dashboard_with_predictions() + + except KeyboardInterrupt: + logger.info("Test interrupted by user") + except Exception as e: + logger.error(f"Error in main test: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 2c6da58..808e81c 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -106,6 +106,10 @@ class CleanTradingDashboard: else: self.orchestrator = orchestrator + # Initialize enhanced training system for predictions + self.training_system = None + self._initialize_enhanced_training_system() + # Initialize layout and component managers self.layout_manager = DashboardLayoutManager( starting_balance=self._get_initial_balance(), @@ -711,9 +715,9 @@ class CleanTradingDashboard: x=0.5, y=0.5, showarrow=False) def _add_model_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): - """Add model predictions to the chart - ONLY EXECUTED TRADES on main chart""" + """Add enhanced model predictions to the chart with real-time feedback""" try: - # Only show EXECUTED TRADES on the main 1m chart + # 1. Add executed trades (existing functionality) executed_signals = [signal for signal in self.recent_decisions if self._get_signal_attribute(signal, 'executed', False)] if executed_signals: @@ -721,8 +725,7 @@ class CleanTradingDashboard: buy_trades = [] sell_trades = [] - for signal in executed_signals[-50:]: # Last 50 executed trades (increased from 20) - # Try to get full timestamp first, fall back to string timestamp + for signal in executed_signals[-50:]: # Last 50 executed trades signal_time = self._get_signal_attribute(signal, 'full_timestamp') if not signal_time: signal_time = self._get_signal_attribute(signal, 'timestamp') @@ -732,10 +735,9 @@ class CleanTradingDashboard: signal_confidence = self._get_signal_attribute(signal, 'confidence', 0) if signal_time and signal_price and signal_confidence > 0: - # FIXED: Better timestamp conversion to prevent race conditions + # Enhanced timestamp handling if isinstance(signal_time, str): try: - # Handle time-only format with current date if ':' in signal_time and len(signal_time.split(':')) == 3: now = datetime.now() time_parts = signal_time.split(':') @@ -745,7 +747,6 @@ class CleanTradingDashboard: second=int(time_parts[2]), microsecond=0 ) - # Handle day boundary issues - if signal seems from future, subtract a day if signal_time > now + timedelta(minutes=5): signal_time -= timedelta(days=1) else: @@ -754,7 +755,6 @@ class CleanTradingDashboard: logger.debug(f"Error parsing timestamp {signal_time}: {e}") continue elif not isinstance(signal_time, datetime): - # Convert other timestamp formats to datetime try: signal_time = pd.to_datetime(signal_time) except Exception as e: @@ -766,7 +766,7 @@ class CleanTradingDashboard: elif signal_action == 'SELL': sell_trades.append({'x': signal_time, 'y': signal_price, 'confidence': signal_confidence}) - # Add EXECUTED BUY trades (large green circles) + # Add executed trades with enhanced visualization if buy_trades: fig.add_trace( go.Scatter( @@ -790,7 +790,6 @@ class CleanTradingDashboard: row=row, col=1 ) - # Add EXECUTED SELL trades (large red circles) if sell_trades: fig.add_trace( go.Scatter( @@ -813,9 +812,363 @@ class CleanTradingDashboard: ), row=row, col=1 ) + + # 2. NEW: Add real-time model predictions overlay + self._add_dqn_predictions_to_chart(fig, symbol, df_main, row) + self._add_cnn_predictions_to_chart(fig, symbol, df_main, row) + self._add_prediction_accuracy_feedback(fig, symbol, df_main, row) except Exception as e: - logger.warning(f"Error adding executed trades to main chart: {e}") + logger.warning(f"Error adding model predictions to chart: {e}") + + def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): + """Add DQN action predictions as directional arrows""" + try: + # Get recent DQN predictions from orchestrator + dqn_predictions = self._get_recent_dqn_predictions(symbol) + + if not dqn_predictions: + return + + # Separate predictions by action + buy_predictions = [] + sell_predictions = [] + hold_predictions = [] + + for pred in dqn_predictions[-30:]: # Last 30 DQN predictions + action = pred.get('action', 2) # 0=BUY, 1=SELL, 2=HOLD + confidence = pred.get('confidence', 0) + timestamp = pred.get('timestamp', datetime.now()) + price = pred.get('price', 0) + + if confidence > 0.3: # Only show predictions with reasonable confidence + pred_data = { + 'x': timestamp, + 'y': price, + 'confidence': confidence, + 'q_values': pred.get('q_values', [0, 0, 0]) + } + + if action == 0: # BUY + buy_predictions.append(pred_data) + elif action == 1: # SELL + sell_predictions.append(pred_data) + else: # HOLD + hold_predictions.append(pred_data) + + # Add DQN BUY predictions (green arrows pointing up) + if buy_predictions: + fig.add_trace( + go.Scatter( + x=[p['x'] for p in buy_predictions], + y=[p['y'] for p in buy_predictions], + mode='markers', + marker=dict( + symbol='triangle-up', + size=[8 + p['confidence'] * 12 for p in buy_predictions], # Size based on confidence + color=[f'rgba(0, 200, 0, {0.3 + p["confidence"] * 0.7})' for p in buy_predictions], # Opacity based on confidence + line=dict(width=1, color='darkgreen') + ), + name='DQN BUY Prediction', + showlegend=True, + hovertemplate="DQN BUY PREDICTION
" + + "Price: $%{y:.2f}
" + + "Time: %{x}
" + + "Confidence: %{customdata[0]:.1%}
" + + "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]", + customdata=[[p['confidence']] + p['q_values'] for p in buy_predictions] + ), + row=row, col=1 + ) + + # Add DQN SELL predictions (red arrows pointing down) + if sell_predictions: + fig.add_trace( + go.Scatter( + x=[p['x'] for p in sell_predictions], + y=[p['y'] for p in sell_predictions], + mode='markers', + marker=dict( + symbol='triangle-down', + size=[8 + p['confidence'] * 12 for p in sell_predictions], + color=[f'rgba(200, 0, 0, {0.3 + p["confidence"] * 0.7})' for p in sell_predictions], + line=dict(width=1, color='darkred') + ), + name='DQN SELL Prediction', + showlegend=True, + hovertemplate="DQN SELL PREDICTION
" + + "Price: $%{y:.2f}
" + + "Time: %{x}
" + + "Confidence: %{customdata[0]:.1%}
" + + "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]", + customdata=[[p['confidence']] + p['q_values'] for p in sell_predictions] + ), + row=row, col=1 + ) + + # Add DQN HOLD predictions (small gray circles) + if hold_predictions: + fig.add_trace( + go.Scatter( + x=[p['x'] for p in hold_predictions], + y=[p['y'] for p in hold_predictions], + mode='markers', + marker=dict( + symbol='circle', + size=[4 + p['confidence'] * 6 for p in hold_predictions], + color=[f'rgba(128, 128, 128, {0.2 + p["confidence"] * 0.5})' for p in hold_predictions], + line=dict(width=1, color='gray') + ), + name='DQN HOLD Prediction', + showlegend=True, + hovertemplate="DQN HOLD PREDICTION
" + + "Price: $%{y:.2f}
" + + "Time: %{x}
" + + "Confidence: %{customdata[0]:.1%}
" + + "Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]", + customdata=[[p['confidence']] + p['q_values'] for p in hold_predictions] + ), + row=row, col=1 + ) + + except Exception as e: + logger.debug(f"Error adding DQN predictions to chart: {e}") + + def _add_cnn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): + """Add CNN price direction predictions as trend lines""" + try: + # Get recent CNN predictions from orchestrator + cnn_predictions = self._get_recent_cnn_predictions(symbol) + + if not cnn_predictions: + return + + # Create trend prediction lines + prediction_lines = [] + + for i, pred in enumerate(cnn_predictions[-20:]): # Last 20 CNN predictions + direction = pred.get('direction', 1) # 0=DOWN, 1=SAME, 2=UP + confidence = pred.get('confidence', 0) + timestamp = pred.get('timestamp', datetime.now()) + current_price = pred.get('current_price', 0) + predicted_price = pred.get('predicted_price', current_price) + + if confidence > 0.4 and current_price > 0: # Only show confident predictions + # Calculate prediction end point (5 minutes ahead) + end_time = timestamp + timedelta(minutes=5) + + # Determine color based on direction + if direction == 2: # UP + color = f'rgba(0, 255, 0, {0.3 + confidence * 0.4})' + line_color = 'green' + prediction_name = 'CNN UP' + elif direction == 0: # DOWN + color = f'rgba(255, 0, 0, {0.3 + confidence * 0.4})' + line_color = 'red' + prediction_name = 'CNN DOWN' + else: # SAME + color = f'rgba(128, 128, 128, {0.2 + confidence * 0.3})' + line_color = 'gray' + prediction_name = 'CNN FLAT' + + # Add prediction line + fig.add_trace( + go.Scatter( + x=[timestamp, end_time], + y=[current_price, predicted_price], + mode='lines', + line=dict( + color=line_color, + width=2 + confidence * 3, # Line width based on confidence + dash='dot' if direction == 1 else 'solid' + ), + name=f'{prediction_name} Prediction', + showlegend=i == 0, # Only show legend for first instance + hovertemplate=f"{prediction_name} PREDICTION
" + + "From: $%{y[0]:.2f}
" + + "To: $%{y[1]:.2f}
" + + "Time: %{x[0]} → %{x[1]}
" + + f"Confidence: {confidence:.1%}
" + + f"Direction: {['DOWN', 'SAME', 'UP'][direction]}" + ), + row=row, col=1 + ) + + # Add prediction end point marker + fig.add_trace( + go.Scatter( + x=[end_time], + y=[predicted_price], + mode='markers', + marker=dict( + symbol='diamond', + size=6 + confidence * 8, + color=color, + line=dict(width=1, color=line_color) + ), + name=f'{prediction_name} Target', + showlegend=False, + hovertemplate=f"{prediction_name} TARGET
" + + "Target Price: $%{y:.2f}
" + + "Target Time: %{x}
" + + f"Confidence: {confidence:.1%}" + ), + row=row, col=1 + ) + + except Exception as e: + logger.debug(f"Error adding CNN predictions to chart: {e}") + + def _add_prediction_accuracy_feedback(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): + """Add prediction accuracy feedback with color-coded results""" + try: + # Get prediction accuracy history + accuracy_data = self._get_prediction_accuracy_history(symbol) + + if not accuracy_data: + return + + # Add accuracy feedback markers + correct_predictions = [] + incorrect_predictions = [] + + for acc in accuracy_data[-50:]: # Last 50 accuracy points + timestamp = acc.get('timestamp', datetime.now()) + price = acc.get('actual_price', 0) + was_correct = acc.get('correct', False) + prediction_type = acc.get('prediction_type', 'unknown') + accuracy_score = acc.get('accuracy_score', 0) + + if price > 0: + acc_data = { + 'x': timestamp, + 'y': price, + 'type': prediction_type, + 'score': accuracy_score + } + + if was_correct: + correct_predictions.append(acc_data) + else: + incorrect_predictions.append(acc_data) + + # Add correct prediction markers (green checkmarks) + if correct_predictions: + fig.add_trace( + go.Scatter( + x=[p['x'] for p in correct_predictions], + y=[p['y'] for p in correct_predictions], + mode='markers', + marker=dict( + symbol='x', + size=8, + color='rgba(0, 255, 0, 0.8)', + line=dict(width=2, color='darkgreen') + ), + name='Correct Predictions', + showlegend=True, + hovertemplate="CORRECT PREDICTION
" + + "Price: $%{y:.2f}
" + + "Time: %{x}
" + + "Type: %{customdata[0]}
" + + "Accuracy: %{customdata[1]:.1%}", + customdata=[[p['type'], p['score']] for p in correct_predictions] + ), + row=row, col=1 + ) + + # Add incorrect prediction markers (red X marks) + if incorrect_predictions: + fig.add_trace( + go.Scatter( + x=[p['x'] for p in incorrect_predictions], + y=[p['y'] for p in incorrect_predictions], + mode='markers', + marker=dict( + symbol='x', + size=8, + color='rgba(255, 0, 0, 0.8)', + line=dict(width=2, color='darkred') + ), + name='Incorrect Predictions', + showlegend=True, + hovertemplate="INCORRECT PREDICTION
" + + "Price: $%{y:.2f}
" + + "Time: %{x}
" + + "Type: %{customdata[0]}
" + + "Accuracy: %{customdata[1]:.1%}", + customdata=[[p['type'], p['score']] for p in incorrect_predictions] + ), + row=row, col=1 + ) + + except Exception as e: + logger.debug(f"Error adding prediction accuracy feedback to chart: {e}") + + def _get_recent_dqn_predictions(self, symbol: str) -> List[Dict]: + """Get recent DQN predictions from enhanced training system (forward-looking only)""" + try: + predictions = [] + + # Get REAL forward-looking predictions from enhanced training system + if hasattr(self, 'training_system') and self.training_system: + if hasattr(self.training_system, 'recent_dqn_predictions'): + predictions.extend(self.training_system.recent_dqn_predictions.get(symbol, [])) + + # Get from orchestrator as fallback + if hasattr(self.orchestrator, 'recent_dqn_predictions'): + predictions.extend(self.orchestrator.recent_dqn_predictions.get(symbol, [])) + + # REMOVED: Mock prediction generation - now using REAL predictions only + # No more artificial past predictions or random data + + return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now())) + + except Exception as e: + logger.debug(f"Error getting DQN predictions: {e}") + return [] + + def _get_recent_cnn_predictions(self, symbol: str) -> List[Dict]: + """Get recent CNN predictions from enhanced training system (forward-looking only)""" + try: + predictions = [] + + # Get REAL forward-looking predictions from enhanced training system + if hasattr(self, 'training_system') and self.training_system: + if hasattr(self.training_system, 'recent_cnn_predictions'): + predictions.extend(self.training_system.recent_cnn_predictions.get(symbol, [])) + + # Get from orchestrator as fallback + if hasattr(self.orchestrator, 'recent_cnn_predictions'): + predictions.extend(self.orchestrator.recent_cnn_predictions.get(symbol, [])) + + # REMOVED: Mock prediction generation - now using REAL predictions only + # No more artificial past predictions or random data + + return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now())) + + except Exception as e: + logger.debug(f"Error getting CNN predictions: {e}") + return [] + + def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]: + """Get REAL prediction accuracy history from validated forward-looking predictions""" + try: + accuracy_data = [] + + # Get REAL accuracy data from training system validation + if hasattr(self, 'training_system') and self.training_system: + if hasattr(self.training_system, 'prediction_accuracy_history'): + accuracy_data.extend(self.training_system.prediction_accuracy_history.get(symbol, [])) + + # REMOVED: Mock accuracy data generation - now using REAL validation results only + # Accuracy is now based on actual prediction outcomes, not random data + + return sorted(accuracy_data, key=lambda x: x.get('timestamp', datetime.now())) + + except Exception as e: + logger.debug(f"Error getting prediction accuracy history: {e}") + return [] def _add_signals_to_mini_chart(self, fig: go.Figure, symbol: str, ws_data_1s: pd.DataFrame, row: int = 2): """Add ALL signals (executed and non-executed) to the 1s mini chart""" @@ -2566,6 +2919,33 @@ class CleanTradingDashboard: except Exception as e: logger.warning(f"Error clearing old signals: {e}") + def _initialize_enhanced_training_system(self): + """Initialize enhanced training system for model predictions""" + try: + # Try to import and initialize enhanced training system + from enhanced_realtime_training import EnhancedRealtimeTrainingSystem + + self.training_system = EnhancedRealtimeTrainingSystem( + orchestrator=self.orchestrator, + data_provider=self.data_provider, + dashboard=self + ) + + # Initialize prediction storage + if not hasattr(self.orchestrator, 'recent_dqn_predictions'): + self.orchestrator.recent_dqn_predictions = {} + if not hasattr(self.orchestrator, 'recent_cnn_predictions'): + self.orchestrator.recent_cnn_predictions = {} + + logger.info("Enhanced training system initialized for model predictions") + + except ImportError: + logger.warning("Enhanced training system not available - using mock predictions") + self.training_system = None + except Exception as e: + logger.error(f"Error initializing enhanced training system: {e}") + self.training_system = None + def _initialize_cob_integration(self): """Initialize COB integration with high-frequency data handling""" try: