309 lines
13 KiB
Python
309 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Model Predictions Visualization
|
|
|
|
This script demonstrates the enhanced model prediction visualization system
|
|
that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart.
|
|
|
|
Features tested:
|
|
- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing
|
|
- CNN price direction predictions as trend lines with target markers
|
|
- Prediction accuracy feedback with color-coded results
|
|
- Real-time prediction tracking and storage
|
|
- Mock prediction generation for demonstration
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional
|
|
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.trading_executor import TradingExecutor
|
|
from web.clean_dashboard import create_clean_dashboard
|
|
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelPredictionTester:
|
|
"""Test model prediction visualization capabilities"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.data_provider = DataProvider()
|
|
self.trading_executor = TradingExecutor()
|
|
self.orchestrator = TradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
enhanced_rl_training=True,
|
|
model_registry={}
|
|
)
|
|
|
|
# Initialize enhanced training system
|
|
self.training_system = EnhancedRealtimeTrainingSystem(
|
|
orchestrator=self.orchestrator,
|
|
data_provider=self.data_provider,
|
|
dashboard=None # Will be set after dashboard creation
|
|
)
|
|
|
|
# Create dashboard with enhanced prediction visualization
|
|
self.dashboard = create_clean_dashboard(
|
|
data_provider=self.data_provider,
|
|
orchestrator=self.orchestrator,
|
|
trading_executor=self.trading_executor
|
|
)
|
|
|
|
# Connect training system to dashboard
|
|
self.training_system.dashboard = self.dashboard
|
|
self.dashboard.training_system = self.training_system
|
|
|
|
# Test data
|
|
self.test_symbols = ['ETH/USDT', 'BTC/USDT']
|
|
self.prediction_count = 0
|
|
|
|
logger.info("Model Prediction Tester initialized")
|
|
|
|
def generate_mock_dqn_predictions(self, symbol: str, count: int = 10):
|
|
"""Generate mock DQN predictions for testing"""
|
|
try:
|
|
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
|
|
|
for i in range(count):
|
|
# Generate realistic state vector
|
|
state = np.random.random(100) # 100-dimensional state
|
|
|
|
# Generate Q-values with some logic
|
|
q_values = [np.random.random(), np.random.random(), np.random.random()]
|
|
action = np.argmax(q_values) # Best action
|
|
confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution
|
|
|
|
# Add some price variation
|
|
pred_price = current_price + np.random.normal(0, 20)
|
|
|
|
# Capture prediction
|
|
self.training_system.capture_dqn_prediction(
|
|
symbol=symbol,
|
|
state=state,
|
|
q_values=q_values,
|
|
action=action,
|
|
confidence=confidence,
|
|
price=pred_price
|
|
)
|
|
|
|
self.prediction_count += 1
|
|
|
|
logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}")
|
|
|
|
# Small delay between predictions
|
|
time.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating DQN predictions: {e}")
|
|
|
|
def generate_mock_cnn_predictions(self, symbol: str, count: int = 8):
|
|
"""Generate mock CNN predictions for testing"""
|
|
try:
|
|
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
|
|
|
for i in range(count):
|
|
# Generate direction with some logic
|
|
direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish
|
|
confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence
|
|
|
|
# Calculate predicted price based on direction
|
|
if direction == 2: # UP
|
|
price_change = np.random.uniform(5, 50)
|
|
elif direction == 0: # DOWN
|
|
price_change = -np.random.uniform(5, 50)
|
|
else: # SAME
|
|
price_change = np.random.uniform(-5, 5)
|
|
|
|
predicted_price = current_price + price_change
|
|
|
|
# Generate features
|
|
features = np.random.random((15, 20)).flatten() # Flattened CNN features
|
|
|
|
# Capture prediction
|
|
self.training_system.capture_cnn_prediction(
|
|
symbol=symbol,
|
|
current_price=current_price,
|
|
predicted_price=predicted_price,
|
|
direction=direction,
|
|
confidence=confidence,
|
|
features=features
|
|
)
|
|
|
|
self.prediction_count += 1
|
|
|
|
logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}")
|
|
|
|
# Small delay between predictions
|
|
time.sleep(0.2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating CNN predictions: {e}")
|
|
|
|
def generate_mock_accuracy_data(self, symbol: str, count: int = 15):
|
|
"""Generate mock prediction accuracy data for testing"""
|
|
try:
|
|
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
|
|
|
for i in range(count):
|
|
# Randomly choose prediction type
|
|
prediction_type = np.random.choice(['DQN', 'CNN'])
|
|
predicted_action = np.random.choice([0, 1, 2])
|
|
confidence = 0.3 + np.random.random() * 0.6
|
|
|
|
# Generate realistic price change
|
|
actual_price_change = np.random.normal(0, 0.01) # ±1% typical change
|
|
|
|
# Validate accuracy
|
|
self.training_system.validate_prediction_accuracy(
|
|
symbol=symbol,
|
|
prediction_type=prediction_type,
|
|
predicted_action=predicted_action,
|
|
actual_price_change=actual_price_change,
|
|
confidence=confidence
|
|
)
|
|
|
|
logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}")
|
|
|
|
# Small delay
|
|
time.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating accuracy data: {e}")
|
|
|
|
def run_prediction_generation_test(self):
|
|
"""Run comprehensive prediction generation test"""
|
|
try:
|
|
logger.info("Starting Model Prediction Visualization Test")
|
|
logger.info("=" * 60)
|
|
|
|
# Test for each symbol
|
|
for symbol in self.test_symbols:
|
|
logger.info(f"\nGenerating predictions for {symbol}...")
|
|
|
|
# Generate DQN predictions
|
|
logger.info(f"Generating DQN predictions for {symbol}...")
|
|
self.generate_mock_dqn_predictions(symbol, count=12)
|
|
|
|
# Generate CNN predictions
|
|
logger.info(f"Generating CNN predictions for {symbol}...")
|
|
self.generate_mock_cnn_predictions(symbol, count=8)
|
|
|
|
# Generate accuracy data
|
|
logger.info(f"Generating accuracy data for {symbol}...")
|
|
self.generate_mock_accuracy_data(symbol, count=20)
|
|
|
|
# Get prediction summary
|
|
summary = self.training_system.get_prediction_summary(symbol)
|
|
logger.info(f"Prediction summary for {symbol}: {summary}")
|
|
|
|
# Log total statistics
|
|
training_stats = self.training_system.get_training_statistics()
|
|
logger.info("\nTraining System Statistics:")
|
|
logger.info(f"Total predictions generated: {self.prediction_count}")
|
|
logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}")
|
|
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("Prediction generation test completed successfully!")
|
|
logger.info("Dashboard should now show enhanced model predictions on the price chart:")
|
|
logger.info("- Green/Red arrows for DQN BUY/SELL predictions")
|
|
logger.info("- Gray circles for DQN HOLD predictions")
|
|
logger.info("- Colored trend lines for CNN price direction predictions")
|
|
logger.info("- Diamond markers for CNN prediction targets")
|
|
logger.info("- Green/Red X marks for correct/incorrect prediction feedback")
|
|
logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction generation test: {e}")
|
|
|
|
def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051):
|
|
"""Start dashboard with enhanced prediction visualization"""
|
|
try:
|
|
logger.info(f"Starting dashboard with model predictions at http://{host}:{port}")
|
|
|
|
# Run prediction generation in background
|
|
import threading
|
|
pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True)
|
|
pred_thread.start()
|
|
|
|
# Start training system
|
|
self.training_system.start_training()
|
|
|
|
# Start dashboard
|
|
self.dashboard.run_server(host=host, port=port, debug=False)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting dashboard with predictions: {e}")
|
|
|
|
def test_prediction_accuracy_validation(self):
|
|
"""Test prediction accuracy validation logic"""
|
|
try:
|
|
logger.info("Testing prediction accuracy validation...")
|
|
|
|
# Test DQN accuracy validation
|
|
test_cases = [
|
|
('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct
|
|
('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct
|
|
('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct
|
|
('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect
|
|
('CNN', 2, 0.01, 0.9, True), # UP + price up = correct
|
|
('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct
|
|
('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct
|
|
('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect
|
|
]
|
|
|
|
for prediction_type, action, price_change, confidence, expected_correct in test_cases:
|
|
self.training_system.validate_prediction_accuracy(
|
|
symbol='ETH/USDT',
|
|
prediction_type=prediction_type,
|
|
predicted_action=action,
|
|
actual_price_change=price_change,
|
|
confidence=confidence
|
|
)
|
|
|
|
# Check if validation worked correctly
|
|
if self.training_system.prediction_accuracy_history['ETH/USDT']:
|
|
latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1]
|
|
actual_correct = latest['correct']
|
|
|
|
status = "✓" if actual_correct == expected_correct else "✗"
|
|
logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}")
|
|
|
|
logger.info("Prediction accuracy validation test completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing prediction accuracy validation: {e}")
|
|
|
|
def main():
|
|
"""Main test function"""
|
|
try:
|
|
# Create tester
|
|
tester = ModelPredictionTester()
|
|
|
|
# Run accuracy validation test first
|
|
tester.test_prediction_accuracy_validation()
|
|
|
|
# Start dashboard with enhanced predictions
|
|
logger.info("\nStarting dashboard with enhanced model prediction visualization...")
|
|
logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions")
|
|
|
|
tester.start_dashboard_with_predictions()
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Test interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Error in main test: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |