Files
gogo2/test_model_predictions_visualization.py
2025-06-27 01:12:55 +03:00

309 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Test Model Predictions Visualization
This script demonstrates the enhanced model prediction visualization system
that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart.
Features tested:
- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing
- CNN price direction predictions as trend lines with target markers
- Prediction accuracy feedback with color-coded results
- Real-time prediction tracking and storage
- Mock prediction generation for demonstration
"""
import asyncio
import logging
import time
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from core.config import get_config
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import create_clean_dashboard
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelPredictionTester:
"""Test model prediction visualization capabilities"""
def __init__(self):
self.config = get_config()
self.data_provider = DataProvider()
self.trading_executor = TradingExecutor()
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True,
model_registry={}
)
# Initialize enhanced training system
self.training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=None # Will be set after dashboard creation
)
# Create dashboard with enhanced prediction visualization
self.dashboard = create_clean_dashboard(
data_provider=self.data_provider,
orchestrator=self.orchestrator,
trading_executor=self.trading_executor
)
# Connect training system to dashboard
self.training_system.dashboard = self.dashboard
self.dashboard.training_system = self.training_system
# Test data
self.test_symbols = ['ETH/USDT', 'BTC/USDT']
self.prediction_count = 0
logger.info("Model Prediction Tester initialized")
def generate_mock_dqn_predictions(self, symbol: str, count: int = 10):
"""Generate mock DQN predictions for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Generate realistic state vector
state = np.random.random(100) # 100-dimensional state
# Generate Q-values with some logic
q_values = [np.random.random(), np.random.random(), np.random.random()]
action = np.argmax(q_values) # Best action
confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution
# Add some price variation
pred_price = current_price + np.random.normal(0, 20)
# Capture prediction
self.training_system.capture_dqn_prediction(
symbol=symbol,
state=state,
q_values=q_values,
action=action,
confidence=confidence,
price=pred_price
)
self.prediction_count += 1
logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}")
# Small delay between predictions
time.sleep(0.1)
except Exception as e:
logger.error(f"Error generating DQN predictions: {e}")
def generate_mock_cnn_predictions(self, symbol: str, count: int = 8):
"""Generate mock CNN predictions for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Generate direction with some logic
direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish
confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence
# Calculate predicted price based on direction
if direction == 2: # UP
price_change = np.random.uniform(5, 50)
elif direction == 0: # DOWN
price_change = -np.random.uniform(5, 50)
else: # SAME
price_change = np.random.uniform(-5, 5)
predicted_price = current_price + price_change
# Generate features
features = np.random.random((15, 20)).flatten() # Flattened CNN features
# Capture prediction
self.training_system.capture_cnn_prediction(
symbol=symbol,
current_price=current_price,
predicted_price=predicted_price,
direction=direction,
confidence=confidence,
features=features
)
self.prediction_count += 1
logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}")
# Small delay between predictions
time.sleep(0.2)
except Exception as e:
logger.error(f"Error generating CNN predictions: {e}")
def generate_mock_accuracy_data(self, symbol: str, count: int = 15):
"""Generate mock prediction accuracy data for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Randomly choose prediction type
prediction_type = np.random.choice(['DQN', 'CNN'])
predicted_action = np.random.choice([0, 1, 2])
confidence = 0.3 + np.random.random() * 0.6
# Generate realistic price change
actual_price_change = np.random.normal(0, 0.01) # ±1% typical change
# Validate accuracy
self.training_system.validate_prediction_accuracy(
symbol=symbol,
prediction_type=prediction_type,
predicted_action=predicted_action,
actual_price_change=actual_price_change,
confidence=confidence
)
logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}")
# Small delay
time.sleep(0.1)
except Exception as e:
logger.error(f"Error generating accuracy data: {e}")
def run_prediction_generation_test(self):
"""Run comprehensive prediction generation test"""
try:
logger.info("Starting Model Prediction Visualization Test")
logger.info("=" * 60)
# Test for each symbol
for symbol in self.test_symbols:
logger.info(f"\nGenerating predictions for {symbol}...")
# Generate DQN predictions
logger.info(f"Generating DQN predictions for {symbol}...")
self.generate_mock_dqn_predictions(symbol, count=12)
# Generate CNN predictions
logger.info(f"Generating CNN predictions for {symbol}...")
self.generate_mock_cnn_predictions(symbol, count=8)
# Generate accuracy data
logger.info(f"Generating accuracy data for {symbol}...")
self.generate_mock_accuracy_data(symbol, count=20)
# Get prediction summary
summary = self.training_system.get_prediction_summary(symbol)
logger.info(f"Prediction summary for {symbol}: {summary}")
# Log total statistics
training_stats = self.training_system.get_training_statistics()
logger.info("\nTraining System Statistics:")
logger.info(f"Total predictions generated: {self.prediction_count}")
logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}")
logger.info("\n" + "=" * 60)
logger.info("Prediction generation test completed successfully!")
logger.info("Dashboard should now show enhanced model predictions on the price chart:")
logger.info("- Green/Red arrows for DQN BUY/SELL predictions")
logger.info("- Gray circles for DQN HOLD predictions")
logger.info("- Colored trend lines for CNN price direction predictions")
logger.info("- Diamond markers for CNN prediction targets")
logger.info("- Green/Red X marks for correct/incorrect prediction feedback")
logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores")
except Exception as e:
logger.error(f"Error in prediction generation test: {e}")
def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051):
"""Start dashboard with enhanced prediction visualization"""
try:
logger.info(f"Starting dashboard with model predictions at http://{host}:{port}")
# Run prediction generation in background
import threading
pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True)
pred_thread.start()
# Start training system
self.training_system.start_training()
# Start dashboard
self.dashboard.run_server(host=host, port=port, debug=False)
except Exception as e:
logger.error(f"Error starting dashboard with predictions: {e}")
def test_prediction_accuracy_validation(self):
"""Test prediction accuracy validation logic"""
try:
logger.info("Testing prediction accuracy validation...")
# Test DQN accuracy validation
test_cases = [
('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct
('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct
('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct
('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect
('CNN', 2, 0.01, 0.9, True), # UP + price up = correct
('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct
('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct
('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect
]
for prediction_type, action, price_change, confidence, expected_correct in test_cases:
self.training_system.validate_prediction_accuracy(
symbol='ETH/USDT',
prediction_type=prediction_type,
predicted_action=action,
actual_price_change=price_change,
confidence=confidence
)
# Check if validation worked correctly
if self.training_system.prediction_accuracy_history['ETH/USDT']:
latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1]
actual_correct = latest['correct']
status = "" if actual_correct == expected_correct else ""
logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}")
logger.info("Prediction accuracy validation test completed")
except Exception as e:
logger.error(f"Error testing prediction accuracy validation: {e}")
def main():
"""Main test function"""
try:
# Create tester
tester = ModelPredictionTester()
# Run accuracy validation test first
tester.test_prediction_accuracy_validation()
# Start dashboard with enhanced predictions
logger.info("\nStarting dashboard with enhanced model prediction visualization...")
logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions")
tester.start_dashboard_with_predictions()
except KeyboardInterrupt:
logger.info("Test interrupted by user")
except Exception as e:
logger.error(f"Error in main test: {e}")
if __name__ == "__main__":
main()