#!/usr/bin/env python3 """ Test Model Predictions Visualization This script demonstrates the enhanced model prediction visualization system that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart. Features tested: - DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing - CNN price direction predictions as trend lines with target markers - Prediction accuracy feedback with color-coded results - Real-time prediction tracking and storage - Mock prediction generation for demonstration """ import asyncio import logging import time import numpy as np from datetime import datetime, timedelta from typing import Dict, List, Optional from core.config import get_config from core.data_provider import DataProvider from core.orchestrator import TradingOrchestrator from core.trading_executor import TradingExecutor from web.clean_dashboard import create_clean_dashboard from enhanced_realtime_training import EnhancedRealtimeTrainingSystem # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ModelPredictionTester: """Test model prediction visualization capabilities""" def __init__(self): self.config = get_config() self.data_provider = DataProvider() self.trading_executor = TradingExecutor() self.orchestrator = TradingOrchestrator( data_provider=self.data_provider, enhanced_rl_training=True, model_registry={} ) # Initialize enhanced training system self.training_system = EnhancedRealtimeTrainingSystem( orchestrator=self.orchestrator, data_provider=self.data_provider, dashboard=None # Will be set after dashboard creation ) # Create dashboard with enhanced prediction visualization self.dashboard = create_clean_dashboard( data_provider=self.data_provider, orchestrator=self.orchestrator, trading_executor=self.trading_executor ) # Connect training system to dashboard self.training_system.dashboard = self.dashboard self.dashboard.training_system = self.training_system # Test data self.test_symbols = ['ETH/USDT', 'BTC/USDT'] self.prediction_count = 0 logger.info("Model Prediction Tester initialized") def generate_mock_dqn_predictions(self, symbol: str, count: int = 10): """Generate mock DQN predictions for testing""" try: current_price = self.data_provider.get_current_price(symbol) or 2400.0 for i in range(count): # Generate realistic state vector state = np.random.random(100) # 100-dimensional state # Generate Q-values with some logic q_values = [np.random.random(), np.random.random(), np.random.random()] action = np.argmax(q_values) # Best action confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution # Add some price variation pred_price = current_price + np.random.normal(0, 20) # Capture prediction self.training_system.capture_dqn_prediction( symbol=symbol, state=state, q_values=q_values, action=action, confidence=confidence, price=pred_price ) self.prediction_count += 1 logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}") # Small delay between predictions time.sleep(0.1) except Exception as e: logger.error(f"Error generating DQN predictions: {e}") def generate_mock_cnn_predictions(self, symbol: str, count: int = 8): """Generate mock CNN predictions for testing""" try: current_price = self.data_provider.get_current_price(symbol) or 2400.0 for i in range(count): # Generate direction with some logic direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence # Calculate predicted price based on direction if direction == 2: # UP price_change = np.random.uniform(5, 50) elif direction == 0: # DOWN price_change = -np.random.uniform(5, 50) else: # SAME price_change = np.random.uniform(-5, 5) predicted_price = current_price + price_change # Generate features features = np.random.random((15, 20)).flatten() # Flattened CNN features # Capture prediction self.training_system.capture_cnn_prediction( symbol=symbol, current_price=current_price, predicted_price=predicted_price, direction=direction, confidence=confidence, features=features ) self.prediction_count += 1 logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}") # Small delay between predictions time.sleep(0.2) except Exception as e: logger.error(f"Error generating CNN predictions: {e}") def generate_mock_accuracy_data(self, symbol: str, count: int = 15): """Generate mock prediction accuracy data for testing""" try: current_price = self.data_provider.get_current_price(symbol) or 2400.0 for i in range(count): # Randomly choose prediction type prediction_type = np.random.choice(['DQN', 'CNN']) predicted_action = np.random.choice([0, 1, 2]) confidence = 0.3 + np.random.random() * 0.6 # Generate realistic price change actual_price_change = np.random.normal(0, 0.01) # ±1% typical change # Validate accuracy self.training_system.validate_prediction_accuracy( symbol=symbol, prediction_type=prediction_type, predicted_action=predicted_action, actual_price_change=actual_price_change, confidence=confidence ) logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}") # Small delay time.sleep(0.1) except Exception as e: logger.error(f"Error generating accuracy data: {e}") def run_prediction_generation_test(self): """Run comprehensive prediction generation test""" try: logger.info("Starting Model Prediction Visualization Test") logger.info("=" * 60) # Test for each symbol for symbol in self.test_symbols: logger.info(f"\nGenerating predictions for {symbol}...") # Generate DQN predictions logger.info(f"Generating DQN predictions for {symbol}...") self.generate_mock_dqn_predictions(symbol, count=12) # Generate CNN predictions logger.info(f"Generating CNN predictions for {symbol}...") self.generate_mock_cnn_predictions(symbol, count=8) # Generate accuracy data logger.info(f"Generating accuracy data for {symbol}...") self.generate_mock_accuracy_data(symbol, count=20) # Get prediction summary summary = self.training_system.get_prediction_summary(symbol) logger.info(f"Prediction summary for {symbol}: {summary}") # Log total statistics training_stats = self.training_system.get_training_statistics() logger.info("\nTraining System Statistics:") logger.info(f"Total predictions generated: {self.prediction_count}") logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}") logger.info("\n" + "=" * 60) logger.info("Prediction generation test completed successfully!") logger.info("Dashboard should now show enhanced model predictions on the price chart:") logger.info("- Green/Red arrows for DQN BUY/SELL predictions") logger.info("- Gray circles for DQN HOLD predictions") logger.info("- Colored trend lines for CNN price direction predictions") logger.info("- Diamond markers for CNN prediction targets") logger.info("- Green/Red X marks for correct/incorrect prediction feedback") logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores") except Exception as e: logger.error(f"Error in prediction generation test: {e}") def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051): """Start dashboard with enhanced prediction visualization""" try: logger.info(f"Starting dashboard with model predictions at http://{host}:{port}") # Run prediction generation in background import threading pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True) pred_thread.start() # Start training system self.training_system.start_training() # Start dashboard self.dashboard.run_server(host=host, port=port, debug=False) except Exception as e: logger.error(f"Error starting dashboard with predictions: {e}") def test_prediction_accuracy_validation(self): """Test prediction accuracy validation logic""" try: logger.info("Testing prediction accuracy validation...") # Test DQN accuracy validation test_cases = [ ('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct ('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct ('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct ('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect ('CNN', 2, 0.01, 0.9, True), # UP + price up = correct ('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct ('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct ('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect ] for prediction_type, action, price_change, confidence, expected_correct in test_cases: self.training_system.validate_prediction_accuracy( symbol='ETH/USDT', prediction_type=prediction_type, predicted_action=action, actual_price_change=price_change, confidence=confidence ) # Check if validation worked correctly if self.training_system.prediction_accuracy_history['ETH/USDT']: latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1] actual_correct = latest['correct'] status = "✓" if actual_correct == expected_correct else "✗" logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}") logger.info("Prediction accuracy validation test completed") except Exception as e: logger.error(f"Error testing prediction accuracy validation: {e}") def main(): """Main test function""" try: # Create tester tester = ModelPredictionTester() # Run accuracy validation test first tester.test_prediction_accuracy_validation() # Start dashboard with enhanced predictions logger.info("\nStarting dashboard with enhanced model prediction visualization...") logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions") tester.start_dashboard_with_predictions() except KeyboardInterrupt: logger.info("Test interrupted by user") except Exception as e: logger.error(f"Error in main test: {e}") if __name__ == "__main__": main()