display predictions

This commit is contained in:
Dobromir Popov
2025-06-27 01:12:55 +03:00
parent 63f26a6749
commit 97ea27ea84
3 changed files with 1229 additions and 16 deletions

View File

@ -71,6 +71,34 @@ class EnhancedRealtimeTrainingSystem:
'validation': 0.0
}
# Model prediction tracking - NEW for dashboard visualization
self.recent_dqn_predictions = {
'ETH/USDT': deque(maxlen=100),
'BTC/USDT': deque(maxlen=100)
}
self.recent_cnn_predictions = {
'ETH/USDT': deque(maxlen=50),
'BTC/USDT': deque(maxlen=50)
}
self.prediction_accuracy_history = {
'ETH/USDT': deque(maxlen=200),
'BTC/USDT': deque(maxlen=200)
}
# FIXED: Forward-looking prediction system
self.pending_predictions = {
'ETH/USDT': deque(maxlen=100), # Predictions waiting for validation
'BTC/USDT': deque(maxlen=100)
}
self.last_prediction_time = {
'ETH/USDT': 0,
'BTC/USDT': 0
}
self.prediction_intervals = {
'dqn': 30, # Make DQN prediction every 30 seconds
'cnn': 60 # Make CNN prediction every 60 seconds
}
# Real-time data streams
self.real_time_data = {
'ticks': deque(maxlen=1000),
@ -146,24 +174,27 @@ class EnhancedRealtimeTrainingSystem:
current_time = time.time()
self.training_iteration += 1
# 1. DQN Training (every 5 seconds with enough data)
# 1. FORWARD-LOOKING PREDICTIONS - Generate real predictions for future validation
self.generate_forward_looking_predictions()
# 2. DQN Training (every 5 seconds with enough data)
if (current_time - self.last_training_times['dqn'] > self.training_config['dqn_training_interval']
and len(self.experience_buffer) >= self.training_config['min_training_samples']):
self._perform_enhanced_dqn_training()
self.last_training_times['dqn'] = current_time
# 2. CNN Training (every 10 seconds)
# 3. CNN Training (every 10 seconds)
if (current_time - self.last_training_times['cnn'] > self.training_config['cnn_training_interval']
and len(self.real_time_data['ohlcv_1m']) >= 20):
self._perform_enhanced_cnn_training()
self.last_training_times['cnn'] = current_time
# 3. Validation (every minute)
# 4. Validation (every minute)
if current_time - self.last_training_times['validation'] > self.training_config['validation_interval']:
self._perform_validation()
self.last_training_times['validation'] = current_time
# 4. Adaptive learning rate adjustment
# 5. Adaptive learning rate adjustment
if self.training_iteration % 100 == 0:
self._adapt_learning_parameters()
@ -911,6 +942,11 @@ class EnhancedRealtimeTrainingSystem:
'dqn_loss_count': len(self.performance_history['dqn_losses']),
'cnn_loss_count': len(self.performance_history['cnn_losses']),
'validation_count': len(self.performance_history['validation_scores'])
},
'prediction_stats': {
'dqn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_dqn_predictions.items()},
'cnn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_cnn_predictions.items()},
'accuracy_history': {symbol: len(history) for symbol, history in self.prediction_accuracy_history.items()}
}
}
@ -927,4 +963,492 @@ class EnhancedRealtimeTrainingSystem:
except Exception as e:
logger.error(f"Error getting training statistics: {e}")
return {'error': str(e)}
return {'error': str(e)}
def capture_dqn_prediction(self, symbol: str, state: np.ndarray, q_values: List[float], action: int, confidence: float, price: float):
"""Capture DQN prediction for dashboard visualization"""
try:
prediction = {
'timestamp': datetime.now(),
'symbol': symbol,
'state': state.tolist() if hasattr(state, 'tolist') else state,
'q_values': q_values,
'action': action, # 0=BUY, 1=SELL, 2=HOLD
'confidence': confidence,
'price': price
}
if symbol in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol].append(prediction)
logger.debug(f"DQN prediction captured: {symbol} action={action} confidence={confidence:.2f}")
except Exception as e:
logger.debug(f"Error capturing DQN prediction: {e}")
def capture_cnn_prediction(self, symbol: str, current_price: float, predicted_price: float, direction: int, confidence: float, features: Optional[np.ndarray] = None):
"""Capture CNN prediction for dashboard visualization"""
try:
prediction = {
'timestamp': datetime.now(),
'symbol': symbol,
'current_price': current_price,
'predicted_price': predicted_price,
'direction': direction, # 0=DOWN, 1=SAME, 2=UP
'confidence': confidence,
'features': features.tolist() if features is not None and hasattr(features, 'tolist') else None
}
if symbol in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol].append(prediction)
logger.debug(f"CNN prediction captured: {symbol} direction={direction} confidence={confidence:.2f}")
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")
def validate_prediction_accuracy(self, symbol: str, prediction_type: str, predicted_action: int, actual_price_change: float, confidence: float):
"""Validate prediction accuracy and store results"""
try:
# Determine if prediction was correct
was_correct = False
if prediction_type == 'DQN':
# For DQN: BUY (0) should be followed by price increase, SELL (1) by decrease
if predicted_action == 0 and actual_price_change > 0.001: # BUY + price up
was_correct = True
elif predicted_action == 1 and actual_price_change < -0.001: # SELL + price down
was_correct = True
elif predicted_action == 2 and abs(actual_price_change) <= 0.001: # HOLD + no change
was_correct = True
elif prediction_type == 'CNN':
# For CNN: direction prediction accuracy
if predicted_action == 2 and actual_price_change > 0.001: # UP + price up
was_correct = True
elif predicted_action == 0 and actual_price_change < -0.001: # DOWN + price down
was_correct = True
elif predicted_action == 1 and abs(actual_price_change) <= 0.001: # SAME + no change
was_correct = True
# Calculate accuracy score based on confidence and correctness
accuracy_score = confidence if was_correct else (1.0 - confidence)
accuracy_data = {
'timestamp': datetime.now(),
'symbol': symbol,
'prediction_type': prediction_type,
'correct': was_correct,
'accuracy_score': accuracy_score,
'confidence': confidence,
'actual_price_change': actual_price_change,
'predicted_action': predicted_action
}
if symbol in self.prediction_accuracy_history:
self.prediction_accuracy_history[symbol].append(accuracy_data)
logger.debug(f"Prediction accuracy validated: {symbol} {prediction_type} correct={was_correct} score={accuracy_score:.2f}")
except Exception as e:
logger.debug(f"Error validating prediction accuracy: {e}")
def get_prediction_summary(self, symbol: str) -> Dict[str, Any]:
"""Get prediction summary for a symbol"""
try:
summary = {
'symbol': symbol,
'dqn_predictions': len(self.recent_dqn_predictions.get(symbol, [])),
'cnn_predictions': len(self.recent_cnn_predictions.get(symbol, [])),
'accuracy_history': len(self.prediction_accuracy_history.get(symbol, [])),
'pending_predictions': len(self.pending_predictions.get(symbol, []))
}
# Calculate accuracy statistics
if symbol in self.prediction_accuracy_history and self.prediction_accuracy_history[symbol]:
accuracy_data = list(self.prediction_accuracy_history[symbol])
total_predictions = len(accuracy_data)
correct_predictions = sum(1 for acc in accuracy_data if acc['correct'])
summary['total_predictions'] = total_predictions
summary['correct_predictions'] = correct_predictions
summary['accuracy_rate'] = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# Calculate accuracy by prediction type
dqn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'DQN']
cnn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'CNN']
if dqn_accuracy_data:
dqn_correct = sum(1 for acc in dqn_accuracy_data if acc['correct'])
summary['dqn_accuracy_rate'] = dqn_correct / len(dqn_accuracy_data)
else:
summary['dqn_accuracy_rate'] = 0.0
if cnn_accuracy_data:
cnn_correct = sum(1 for acc in cnn_accuracy_data if acc['correct'])
summary['cnn_accuracy_rate'] = cnn_correct / len(cnn_accuracy_data)
else:
summary['cnn_accuracy_rate'] = 0.0
return summary
except Exception as e:
logger.error(f"Error getting prediction summary: {e}")
return {'error': str(e)}
def generate_forward_looking_predictions(self):
"""Generate forward-looking predictions based on current market data"""
try:
current_time = time.time()
for symbol in ['ETH/USDT', 'BTC/USDT']:
# Check if it's time to make new predictions
time_since_last = current_time - self.last_prediction_time.get(symbol, 0)
# Generate DQN prediction every 30 seconds
if time_since_last >= self.prediction_intervals['dqn']:
self._generate_forward_dqn_prediction(symbol, current_time)
# Generate CNN prediction every 60 seconds
if time_since_last >= self.prediction_intervals['cnn']:
self._generate_forward_cnn_prediction(symbol, current_time)
# Validate pending predictions
self._validate_pending_predictions(symbol, current_time)
except Exception as e:
logger.error(f"Error generating forward-looking predictions: {e}")
def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
"""Generate a DQN prediction for future price movement"""
try:
# Get current market state (only historical data)
current_state = self._build_comprehensive_state()
current_price = self._get_current_price_from_data(symbol)
if current_price is None:
return
# Use DQN model to predict action (if available)
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
and self.orchestrator.rl_agent):
# Get Q-values from model
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
if isinstance(q_values, tuple):
action, q_vals = q_values
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
else:
action = q_values
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
else:
# Fallback to technical analysis-based prediction
action, q_values, confidence = self._technical_analysis_prediction(symbol)
# Create forward-looking prediction
prediction_time = datetime.now()
target_time = prediction_time + timedelta(minutes=5) # Predict 5 minutes ahead
prediction = {
'id': f"dqn_{symbol}_{int(current_time)}",
'type': 'DQN',
'symbol': symbol,
'prediction_time': prediction_time,
'target_time': target_time,
'current_price': current_price,
'predicted_action': action,
'q_values': q_values,
'confidence': confidence,
'state': current_state.tolist() if hasattr(current_state, 'tolist') else current_state,
'validated': False
}
# Add to pending predictions for future validation
if symbol in self.pending_predictions:
self.pending_predictions[symbol].append(prediction)
# Add to recent predictions for display (only if confident enough)
if confidence > 0.4:
display_prediction = {
'timestamp': prediction_time,
'price': current_price,
'action': action,
'confidence': confidence,
'q_values': q_values
}
if symbol in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol].append(display_prediction)
self.last_prediction_time[symbol] = current_time
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
except Exception as e:
logger.error(f"Error generating forward DQN prediction: {e}")
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
"""Generate a CNN prediction for future price direction"""
try:
# Get current price and historical sequence (only past data)
current_price = self._get_current_price_from_data(symbol)
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
if current_price is None or len(price_sequence) < 15:
return
# Use CNN model to predict direction (if available)
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
and self.orchestrator.cnn_model):
# Prepare features for CNN
features = self._prepare_cnn_features(price_sequence)
try:
# Get prediction from CNN model
prediction_output = self.orchestrator.cnn_model.predict(features)
if hasattr(prediction_output, 'tolist'):
pred_probs = prediction_output.tolist()
else:
pred_probs = [0.33, 0.33, 0.34] # Default
direction = int(np.argmax(pred_probs)) # 0=DOWN, 1=SAME, 2=UP
confidence = max(pred_probs)
except Exception as e:
logger.debug(f"CNN model prediction failed: {e}")
direction, confidence = self._technical_direction_prediction(symbol)
else:
# Fallback to technical analysis
direction, confidence = self._technical_direction_prediction(symbol)
# Calculate predicted price based on direction
price_change_percent = self._estimate_price_change(direction, confidence)
predicted_price = current_price * (1 + price_change_percent)
# Create forward-looking prediction
prediction_time = datetime.now()
target_time = prediction_time + timedelta(minutes=10) # Predict 10 minutes ahead
prediction = {
'id': f"cnn_{symbol}_{int(current_time)}",
'type': 'CNN',
'symbol': symbol,
'prediction_time': prediction_time,
'target_time': target_time,
'current_price': current_price,
'predicted_price': predicted_price,
'direction': direction,
'confidence': confidence,
'features': features.tolist() if hasattr(features, 'tolist') else None,
'validated': False
}
# Add to pending predictions for future validation
if symbol in self.pending_predictions:
self.pending_predictions[symbol].append(prediction)
# Add to recent predictions for display (only if confident enough)
if confidence > 0.5:
display_prediction = {
'timestamp': prediction_time,
'current_price': current_price,
'predicted_price': predicted_price,
'direction': direction,
'confidence': confidence
}
if symbol in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol].append(display_prediction)
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
except Exception as e:
logger.error(f"Error generating forward CNN prediction: {e}")
def _validate_pending_predictions(self, symbol: str, current_time: float):
"""Validate pending predictions when their target time arrives"""
try:
if symbol not in self.pending_predictions:
return
current_datetime = datetime.now()
validated_predictions = []
# Check each pending prediction
for prediction in list(self.pending_predictions[symbol]):
target_time = prediction['target_time']
# If target time has passed, validate the prediction
if current_datetime >= target_time:
actual_price = self._get_current_price_from_data(symbol)
if actual_price is not None:
# Calculate actual price change
predicted_price = prediction.get('predicted_price', prediction['current_price'])
actual_change = (actual_price - prediction['current_price']) / prediction['current_price']
predicted_change = (predicted_price - prediction['current_price']) / prediction['current_price']
# Validate based on prediction type
if prediction['type'] == 'DQN':
was_correct = self._validate_dqn_prediction(prediction, actual_change)
else: # CNN
was_correct = self._validate_cnn_prediction(prediction, actual_change)
# Store accuracy result
accuracy_data = {
'timestamp': current_datetime,
'symbol': symbol,
'prediction_type': prediction['type'],
'correct': was_correct,
'accuracy_score': prediction['confidence'] if was_correct else (1.0 - prediction['confidence']),
'confidence': prediction['confidence'],
'actual_price_change': actual_change,
'predicted_action': prediction.get('predicted_action', prediction.get('direction', 0)),
'actual_price': actual_price
}
if symbol in self.prediction_accuracy_history:
self.prediction_accuracy_history[symbol].append(accuracy_data)
validated_predictions.append(prediction['id'])
logger.info(f"Validated {prediction['type']} prediction: {symbol} correct={was_correct} confidence={prediction['confidence']:.2f}")
# Remove validated predictions from pending list
if validated_predictions:
self.pending_predictions[symbol] = deque([
p for p in self.pending_predictions[symbol]
if p['id'] not in validated_predictions
], maxlen=100)
except Exception as e:
logger.error(f"Error validating pending predictions: {e}")
def _validate_dqn_prediction(self, prediction: Dict, actual_change: float) -> bool:
"""Validate DQN action prediction"""
predicted_action = prediction['predicted_action']
threshold = 0.005 # 0.5% threshold for significant movement
if predicted_action == 0: # BUY prediction
return actual_change > threshold
elif predicted_action == 1: # SELL prediction
return actual_change < -threshold
else: # HOLD prediction
return abs(actual_change) <= threshold
def _validate_cnn_prediction(self, prediction: Dict, actual_change: float) -> bool:
"""Validate CNN direction prediction"""
predicted_direction = prediction['direction']
threshold = 0.002 # 0.2% threshold for direction
if predicted_direction == 2: # UP prediction
return actual_change > threshold
elif predicted_direction == 0: # DOWN prediction
return actual_change < -threshold
else: # SAME prediction
return abs(actual_change) <= threshold
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
"""Get current price from real-time data streams"""
try:
if len(self.real_time_data['ohlcv_1m']) > 0:
return self.real_time_data['ohlcv_1m'][-1]['close']
return None
except Exception as e:
logger.debug(f"Error getting current price: {e}")
return None
def _get_historical_price_sequence(self, symbol: str, periods: int = 15) -> List[float]:
"""Get historical price sequence for CNN features"""
try:
if len(self.real_time_data['ohlcv_1m']) >= periods:
return [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-periods:]]
return []
except Exception as e:
logger.debug(f"Error getting price sequence: {e}")
return []
def _technical_analysis_prediction(self, symbol: str) -> Tuple[int, List[float], float]:
"""Fallback technical analysis prediction for DQN"""
try:
# Simple momentum-based prediction
if len(self.real_time_data['ohlcv_1m']) >= 5:
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]]
momentum = (recent_prices[-1] - recent_prices[0]) / recent_prices[0]
if momentum > 0.01: # 1% upward momentum
return 0, [0.6, 0.2, 0.2], 0.6 # BUY
elif momentum < -0.01: # 1% downward momentum
return 1, [0.2, 0.6, 0.2], 0.6 # SELL
else:
return 2, [0.2, 0.2, 0.6], 0.6 # HOLD
return 2, [0.33, 0.33, 0.34], 0.33 # Default HOLD
except Exception as e:
logger.debug(f"Error in technical analysis prediction: {e}")
return 2, [0.33, 0.33, 0.34], 0.33
def _technical_direction_prediction(self, symbol: str) -> Tuple[int, float]:
"""Fallback technical analysis for CNN direction"""
try:
if len(self.real_time_data['ohlcv_1m']) >= 3:
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-3:]]
short_momentum = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2]
if short_momentum > 0.005: # 0.5% short-term up
return 2, 0.65 # UP
elif short_momentum < -0.005: # 0.5% short-term down
return 0, 0.65 # DOWN
else:
return 1, 0.55 # SAME
return 1, 0.5 # Default SAME
except Exception as e:
logger.debug(f"Error in technical direction prediction: {e}")
return 1, 0.5
def _prepare_cnn_features(self, price_sequence: List[float]) -> np.ndarray:
"""Prepare features for CNN model"""
try:
# Normalize prices relative to first price
if len(price_sequence) >= 15:
base_price = price_sequence[0]
normalized = [(p - base_price) / base_price for p in price_sequence]
# Create feature matrix (15 x 20, flattened)
features = np.zeros((15, 20))
for i, norm_price in enumerate(normalized):
features[i, 0] = norm_price # Normalized price
if i > 0:
features[i, 1] = normalized[i] - normalized[i-1] # Price change
return features.flatten()
return np.zeros(300) # Default feature vector
except Exception as e:
logger.debug(f"Error preparing CNN features: {e}")
return np.zeros(300)
def _estimate_price_change(self, direction: int, confidence: float) -> float:
"""Estimate price change percentage based on direction and confidence"""
try:
# Base change scaled by confidence
base_change = 0.01 * confidence # Up to 1% change
if direction == 2: # UP
return base_change
elif direction == 0: # DOWN
return -base_change
else: # SAME
return 0.0
except Exception as e:
logger.debug(f"Error estimating price change: {e}")
return 0.0

View File

@ -0,0 +1,309 @@
#!/usr/bin/env python3
"""
Test Model Predictions Visualization
This script demonstrates the enhanced model prediction visualization system
that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart.
Features tested:
- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing
- CNN price direction predictions as trend lines with target markers
- Prediction accuracy feedback with color-coded results
- Real-time prediction tracking and storage
- Mock prediction generation for demonstration
"""
import asyncio
import logging
import time
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from core.config import get_config
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import create_clean_dashboard
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelPredictionTester:
"""Test model prediction visualization capabilities"""
def __init__(self):
self.config = get_config()
self.data_provider = DataProvider()
self.trading_executor = TradingExecutor()
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True,
model_registry={}
)
# Initialize enhanced training system
self.training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=None # Will be set after dashboard creation
)
# Create dashboard with enhanced prediction visualization
self.dashboard = create_clean_dashboard(
data_provider=self.data_provider,
orchestrator=self.orchestrator,
trading_executor=self.trading_executor
)
# Connect training system to dashboard
self.training_system.dashboard = self.dashboard
self.dashboard.training_system = self.training_system
# Test data
self.test_symbols = ['ETH/USDT', 'BTC/USDT']
self.prediction_count = 0
logger.info("Model Prediction Tester initialized")
def generate_mock_dqn_predictions(self, symbol: str, count: int = 10):
"""Generate mock DQN predictions for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Generate realistic state vector
state = np.random.random(100) # 100-dimensional state
# Generate Q-values with some logic
q_values = [np.random.random(), np.random.random(), np.random.random()]
action = np.argmax(q_values) # Best action
confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution
# Add some price variation
pred_price = current_price + np.random.normal(0, 20)
# Capture prediction
self.training_system.capture_dqn_prediction(
symbol=symbol,
state=state,
q_values=q_values,
action=action,
confidence=confidence,
price=pred_price
)
self.prediction_count += 1
logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}")
# Small delay between predictions
time.sleep(0.1)
except Exception as e:
logger.error(f"Error generating DQN predictions: {e}")
def generate_mock_cnn_predictions(self, symbol: str, count: int = 8):
"""Generate mock CNN predictions for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Generate direction with some logic
direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish
confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence
# Calculate predicted price based on direction
if direction == 2: # UP
price_change = np.random.uniform(5, 50)
elif direction == 0: # DOWN
price_change = -np.random.uniform(5, 50)
else: # SAME
price_change = np.random.uniform(-5, 5)
predicted_price = current_price + price_change
# Generate features
features = np.random.random((15, 20)).flatten() # Flattened CNN features
# Capture prediction
self.training_system.capture_cnn_prediction(
symbol=symbol,
current_price=current_price,
predicted_price=predicted_price,
direction=direction,
confidence=confidence,
features=features
)
self.prediction_count += 1
logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}")
# Small delay between predictions
time.sleep(0.2)
except Exception as e:
logger.error(f"Error generating CNN predictions: {e}")
def generate_mock_accuracy_data(self, symbol: str, count: int = 15):
"""Generate mock prediction accuracy data for testing"""
try:
current_price = self.data_provider.get_current_price(symbol) or 2400.0
for i in range(count):
# Randomly choose prediction type
prediction_type = np.random.choice(['DQN', 'CNN'])
predicted_action = np.random.choice([0, 1, 2])
confidence = 0.3 + np.random.random() * 0.6
# Generate realistic price change
actual_price_change = np.random.normal(0, 0.01) # ±1% typical change
# Validate accuracy
self.training_system.validate_prediction_accuracy(
symbol=symbol,
prediction_type=prediction_type,
predicted_action=predicted_action,
actual_price_change=actual_price_change,
confidence=confidence
)
logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}")
# Small delay
time.sleep(0.1)
except Exception as e:
logger.error(f"Error generating accuracy data: {e}")
def run_prediction_generation_test(self):
"""Run comprehensive prediction generation test"""
try:
logger.info("Starting Model Prediction Visualization Test")
logger.info("=" * 60)
# Test for each symbol
for symbol in self.test_symbols:
logger.info(f"\nGenerating predictions for {symbol}...")
# Generate DQN predictions
logger.info(f"Generating DQN predictions for {symbol}...")
self.generate_mock_dqn_predictions(symbol, count=12)
# Generate CNN predictions
logger.info(f"Generating CNN predictions for {symbol}...")
self.generate_mock_cnn_predictions(symbol, count=8)
# Generate accuracy data
logger.info(f"Generating accuracy data for {symbol}...")
self.generate_mock_accuracy_data(symbol, count=20)
# Get prediction summary
summary = self.training_system.get_prediction_summary(symbol)
logger.info(f"Prediction summary for {symbol}: {summary}")
# Log total statistics
training_stats = self.training_system.get_training_statistics()
logger.info("\nTraining System Statistics:")
logger.info(f"Total predictions generated: {self.prediction_count}")
logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}")
logger.info("\n" + "=" * 60)
logger.info("Prediction generation test completed successfully!")
logger.info("Dashboard should now show enhanced model predictions on the price chart:")
logger.info("- Green/Red arrows for DQN BUY/SELL predictions")
logger.info("- Gray circles for DQN HOLD predictions")
logger.info("- Colored trend lines for CNN price direction predictions")
logger.info("- Diamond markers for CNN prediction targets")
logger.info("- Green/Red X marks for correct/incorrect prediction feedback")
logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores")
except Exception as e:
logger.error(f"Error in prediction generation test: {e}")
def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051):
"""Start dashboard with enhanced prediction visualization"""
try:
logger.info(f"Starting dashboard with model predictions at http://{host}:{port}")
# Run prediction generation in background
import threading
pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True)
pred_thread.start()
# Start training system
self.training_system.start_training()
# Start dashboard
self.dashboard.run_server(host=host, port=port, debug=False)
except Exception as e:
logger.error(f"Error starting dashboard with predictions: {e}")
def test_prediction_accuracy_validation(self):
"""Test prediction accuracy validation logic"""
try:
logger.info("Testing prediction accuracy validation...")
# Test DQN accuracy validation
test_cases = [
('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct
('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct
('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct
('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect
('CNN', 2, 0.01, 0.9, True), # UP + price up = correct
('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct
('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct
('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect
]
for prediction_type, action, price_change, confidence, expected_correct in test_cases:
self.training_system.validate_prediction_accuracy(
symbol='ETH/USDT',
prediction_type=prediction_type,
predicted_action=action,
actual_price_change=price_change,
confidence=confidence
)
# Check if validation worked correctly
if self.training_system.prediction_accuracy_history['ETH/USDT']:
latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1]
actual_correct = latest['correct']
status = "" if actual_correct == expected_correct else ""
logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}")
logger.info("Prediction accuracy validation test completed")
except Exception as e:
logger.error(f"Error testing prediction accuracy validation: {e}")
def main():
"""Main test function"""
try:
# Create tester
tester = ModelPredictionTester()
# Run accuracy validation test first
tester.test_prediction_accuracy_validation()
# Start dashboard with enhanced predictions
logger.info("\nStarting dashboard with enhanced model prediction visualization...")
logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions")
tester.start_dashboard_with_predictions()
except KeyboardInterrupt:
logger.info("Test interrupted by user")
except Exception as e:
logger.error(f"Error in main test: {e}")
if __name__ == "__main__":
main()

View File

@ -106,6 +106,10 @@ class CleanTradingDashboard:
else:
self.orchestrator = orchestrator
# Initialize enhanced training system for predictions
self.training_system = None
self._initialize_enhanced_training_system()
# Initialize layout and component managers
self.layout_manager = DashboardLayoutManager(
starting_balance=self._get_initial_balance(),
@ -711,9 +715,9 @@ class CleanTradingDashboard:
x=0.5, y=0.5, showarrow=False)
def _add_model_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add model predictions to the chart - ONLY EXECUTED TRADES on main chart"""
"""Add enhanced model predictions to the chart with real-time feedback"""
try:
# Only show EXECUTED TRADES on the main 1m chart
# 1. Add executed trades (existing functionality)
executed_signals = [signal for signal in self.recent_decisions if self._get_signal_attribute(signal, 'executed', False)]
if executed_signals:
@ -721,8 +725,7 @@ class CleanTradingDashboard:
buy_trades = []
sell_trades = []
for signal in executed_signals[-50:]: # Last 50 executed trades (increased from 20)
# Try to get full timestamp first, fall back to string timestamp
for signal in executed_signals[-50:]: # Last 50 executed trades
signal_time = self._get_signal_attribute(signal, 'full_timestamp')
if not signal_time:
signal_time = self._get_signal_attribute(signal, 'timestamp')
@ -732,10 +735,9 @@ class CleanTradingDashboard:
signal_confidence = self._get_signal_attribute(signal, 'confidence', 0)
if signal_time and signal_price and signal_confidence > 0:
# FIXED: Better timestamp conversion to prevent race conditions
# Enhanced timestamp handling
if isinstance(signal_time, str):
try:
# Handle time-only format with current date
if ':' in signal_time and len(signal_time.split(':')) == 3:
now = datetime.now()
time_parts = signal_time.split(':')
@ -745,7 +747,6 @@ class CleanTradingDashboard:
second=int(time_parts[2]),
microsecond=0
)
# Handle day boundary issues - if signal seems from future, subtract a day
if signal_time > now + timedelta(minutes=5):
signal_time -= timedelta(days=1)
else:
@ -754,7 +755,6 @@ class CleanTradingDashboard:
logger.debug(f"Error parsing timestamp {signal_time}: {e}")
continue
elif not isinstance(signal_time, datetime):
# Convert other timestamp formats to datetime
try:
signal_time = pd.to_datetime(signal_time)
except Exception as e:
@ -766,7 +766,7 @@ class CleanTradingDashboard:
elif signal_action == 'SELL':
sell_trades.append({'x': signal_time, 'y': signal_price, 'confidence': signal_confidence})
# Add EXECUTED BUY trades (large green circles)
# Add executed trades with enhanced visualization
if buy_trades:
fig.add_trace(
go.Scatter(
@ -790,7 +790,6 @@ class CleanTradingDashboard:
row=row, col=1
)
# Add EXECUTED SELL trades (large red circles)
if sell_trades:
fig.add_trace(
go.Scatter(
@ -813,9 +812,363 @@ class CleanTradingDashboard:
),
row=row, col=1
)
# 2. NEW: Add real-time model predictions overlay
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
except Exception as e:
logger.warning(f"Error adding executed trades to main chart: {e}")
logger.warning(f"Error adding model predictions to chart: {e}")
def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add DQN action predictions as directional arrows"""
try:
# Get recent DQN predictions from orchestrator
dqn_predictions = self._get_recent_dqn_predictions(symbol)
if not dqn_predictions:
return
# Separate predictions by action
buy_predictions = []
sell_predictions = []
hold_predictions = []
for pred in dqn_predictions[-30:]: # Last 30 DQN predictions
action = pred.get('action', 2) # 0=BUY, 1=SELL, 2=HOLD
confidence = pred.get('confidence', 0)
timestamp = pred.get('timestamp', datetime.now())
price = pred.get('price', 0)
if confidence > 0.3: # Only show predictions with reasonable confidence
pred_data = {
'x': timestamp,
'y': price,
'confidence': confidence,
'q_values': pred.get('q_values', [0, 0, 0])
}
if action == 0: # BUY
buy_predictions.append(pred_data)
elif action == 1: # SELL
sell_predictions.append(pred_data)
else: # HOLD
hold_predictions.append(pred_data)
# Add DQN BUY predictions (green arrows pointing up)
if buy_predictions:
fig.add_trace(
go.Scatter(
x=[p['x'] for p in buy_predictions],
y=[p['y'] for p in buy_predictions],
mode='markers',
marker=dict(
symbol='triangle-up',
size=[8 + p['confidence'] * 12 for p in buy_predictions], # Size based on confidence
color=[f'rgba(0, 200, 0, {0.3 + p["confidence"] * 0.7})' for p in buy_predictions], # Opacity based on confidence
line=dict(width=1, color='darkgreen')
),
name='DQN BUY Prediction',
showlegend=True,
hovertemplate="<b>DQN BUY PREDICTION</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata[0]:.1%}<br>" +
"Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]<extra></extra>",
customdata=[[p['confidence']] + p['q_values'] for p in buy_predictions]
),
row=row, col=1
)
# Add DQN SELL predictions (red arrows pointing down)
if sell_predictions:
fig.add_trace(
go.Scatter(
x=[p['x'] for p in sell_predictions],
y=[p['y'] for p in sell_predictions],
mode='markers',
marker=dict(
symbol='triangle-down',
size=[8 + p['confidence'] * 12 for p in sell_predictions],
color=[f'rgba(200, 0, 0, {0.3 + p["confidence"] * 0.7})' for p in sell_predictions],
line=dict(width=1, color='darkred')
),
name='DQN SELL Prediction',
showlegend=True,
hovertemplate="<b>DQN SELL PREDICTION</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata[0]:.1%}<br>" +
"Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]<extra></extra>",
customdata=[[p['confidence']] + p['q_values'] for p in sell_predictions]
),
row=row, col=1
)
# Add DQN HOLD predictions (small gray circles)
if hold_predictions:
fig.add_trace(
go.Scatter(
x=[p['x'] for p in hold_predictions],
y=[p['y'] for p in hold_predictions],
mode='markers',
marker=dict(
symbol='circle',
size=[4 + p['confidence'] * 6 for p in hold_predictions],
color=[f'rgba(128, 128, 128, {0.2 + p["confidence"] * 0.5})' for p in hold_predictions],
line=dict(width=1, color='gray')
),
name='DQN HOLD Prediction',
showlegend=True,
hovertemplate="<b>DQN HOLD PREDICTION</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata[0]:.1%}<br>" +
"Q-Values: [%{customdata[1]:.3f}, %{customdata[2]:.3f}, %{customdata[3]:.3f}]<extra></extra>",
customdata=[[p['confidence']] + p['q_values'] for p in hold_predictions]
),
row=row, col=1
)
except Exception as e:
logger.debug(f"Error adding DQN predictions to chart: {e}")
def _add_cnn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add CNN price direction predictions as trend lines"""
try:
# Get recent CNN predictions from orchestrator
cnn_predictions = self._get_recent_cnn_predictions(symbol)
if not cnn_predictions:
return
# Create trend prediction lines
prediction_lines = []
for i, pred in enumerate(cnn_predictions[-20:]): # Last 20 CNN predictions
direction = pred.get('direction', 1) # 0=DOWN, 1=SAME, 2=UP
confidence = pred.get('confidence', 0)
timestamp = pred.get('timestamp', datetime.now())
current_price = pred.get('current_price', 0)
predicted_price = pred.get('predicted_price', current_price)
if confidence > 0.4 and current_price > 0: # Only show confident predictions
# Calculate prediction end point (5 minutes ahead)
end_time = timestamp + timedelta(minutes=5)
# Determine color based on direction
if direction == 2: # UP
color = f'rgba(0, 255, 0, {0.3 + confidence * 0.4})'
line_color = 'green'
prediction_name = 'CNN UP'
elif direction == 0: # DOWN
color = f'rgba(255, 0, 0, {0.3 + confidence * 0.4})'
line_color = 'red'
prediction_name = 'CNN DOWN'
else: # SAME
color = f'rgba(128, 128, 128, {0.2 + confidence * 0.3})'
line_color = 'gray'
prediction_name = 'CNN FLAT'
# Add prediction line
fig.add_trace(
go.Scatter(
x=[timestamp, end_time],
y=[current_price, predicted_price],
mode='lines',
line=dict(
color=line_color,
width=2 + confidence * 3, # Line width based on confidence
dash='dot' if direction == 1 else 'solid'
),
name=f'{prediction_name} Prediction',
showlegend=i == 0, # Only show legend for first instance
hovertemplate=f"<b>{prediction_name} PREDICTION</b><br>" +
"From: $%{y[0]:.2f}<br>" +
"To: $%{y[1]:.2f}<br>" +
"Time: %{x[0]}%{x[1]}<br>" +
f"Confidence: {confidence:.1%}<br>" +
f"Direction: {['DOWN', 'SAME', 'UP'][direction]}<extra></extra>"
),
row=row, col=1
)
# Add prediction end point marker
fig.add_trace(
go.Scatter(
x=[end_time],
y=[predicted_price],
mode='markers',
marker=dict(
symbol='diamond',
size=6 + confidence * 8,
color=color,
line=dict(width=1, color=line_color)
),
name=f'{prediction_name} Target',
showlegend=False,
hovertemplate=f"<b>{prediction_name} TARGET</b><br>" +
"Target Price: $%{y:.2f}<br>" +
"Target Time: %{x}<br>" +
f"Confidence: {confidence:.1%}<extra></extra>"
),
row=row, col=1
)
except Exception as e:
logger.debug(f"Error adding CNN predictions to chart: {e}")
def _add_prediction_accuracy_feedback(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add prediction accuracy feedback with color-coded results"""
try:
# Get prediction accuracy history
accuracy_data = self._get_prediction_accuracy_history(symbol)
if not accuracy_data:
return
# Add accuracy feedback markers
correct_predictions = []
incorrect_predictions = []
for acc in accuracy_data[-50:]: # Last 50 accuracy points
timestamp = acc.get('timestamp', datetime.now())
price = acc.get('actual_price', 0)
was_correct = acc.get('correct', False)
prediction_type = acc.get('prediction_type', 'unknown')
accuracy_score = acc.get('accuracy_score', 0)
if price > 0:
acc_data = {
'x': timestamp,
'y': price,
'type': prediction_type,
'score': accuracy_score
}
if was_correct:
correct_predictions.append(acc_data)
else:
incorrect_predictions.append(acc_data)
# Add correct prediction markers (green checkmarks)
if correct_predictions:
fig.add_trace(
go.Scatter(
x=[p['x'] for p in correct_predictions],
y=[p['y'] for p in correct_predictions],
mode='markers',
marker=dict(
symbol='x',
size=8,
color='rgba(0, 255, 0, 0.8)',
line=dict(width=2, color='darkgreen')
),
name='Correct Predictions',
showlegend=True,
hovertemplate="<b>CORRECT PREDICTION</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Type: %{customdata[0]}<br>" +
"Accuracy: %{customdata[1]:.1%}<extra></extra>",
customdata=[[p['type'], p['score']] for p in correct_predictions]
),
row=row, col=1
)
# Add incorrect prediction markers (red X marks)
if incorrect_predictions:
fig.add_trace(
go.Scatter(
x=[p['x'] for p in incorrect_predictions],
y=[p['y'] for p in incorrect_predictions],
mode='markers',
marker=dict(
symbol='x',
size=8,
color='rgba(255, 0, 0, 0.8)',
line=dict(width=2, color='darkred')
),
name='Incorrect Predictions',
showlegend=True,
hovertemplate="<b>INCORRECT PREDICTION</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Type: %{customdata[0]}<br>" +
"Accuracy: %{customdata[1]:.1%}<extra></extra>",
customdata=[[p['type'], p['score']] for p in incorrect_predictions]
),
row=row, col=1
)
except Exception as e:
logger.debug(f"Error adding prediction accuracy feedback to chart: {e}")
def _get_recent_dqn_predictions(self, symbol: str) -> List[Dict]:
"""Get recent DQN predictions from enhanced training system (forward-looking only)"""
try:
predictions = []
# Get REAL forward-looking predictions from enhanced training system
if hasattr(self, 'training_system') and self.training_system:
if hasattr(self.training_system, 'recent_dqn_predictions'):
predictions.extend(self.training_system.recent_dqn_predictions.get(symbol, []))
# Get from orchestrator as fallback
if hasattr(self.orchestrator, 'recent_dqn_predictions'):
predictions.extend(self.orchestrator.recent_dqn_predictions.get(symbol, []))
# REMOVED: Mock prediction generation - now using REAL predictions only
# No more artificial past predictions or random data
return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now()))
except Exception as e:
logger.debug(f"Error getting DQN predictions: {e}")
return []
def _get_recent_cnn_predictions(self, symbol: str) -> List[Dict]:
"""Get recent CNN predictions from enhanced training system (forward-looking only)"""
try:
predictions = []
# Get REAL forward-looking predictions from enhanced training system
if hasattr(self, 'training_system') and self.training_system:
if hasattr(self.training_system, 'recent_cnn_predictions'):
predictions.extend(self.training_system.recent_cnn_predictions.get(symbol, []))
# Get from orchestrator as fallback
if hasattr(self.orchestrator, 'recent_cnn_predictions'):
predictions.extend(self.orchestrator.recent_cnn_predictions.get(symbol, []))
# REMOVED: Mock prediction generation - now using REAL predictions only
# No more artificial past predictions or random data
return sorted(predictions, key=lambda x: x.get('timestamp', datetime.now()))
except Exception as e:
logger.debug(f"Error getting CNN predictions: {e}")
return []
def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
"""Get REAL prediction accuracy history from validated forward-looking predictions"""
try:
accuracy_data = []
# Get REAL accuracy data from training system validation
if hasattr(self, 'training_system') and self.training_system:
if hasattr(self.training_system, 'prediction_accuracy_history'):
accuracy_data.extend(self.training_system.prediction_accuracy_history.get(symbol, []))
# REMOVED: Mock accuracy data generation - now using REAL validation results only
# Accuracy is now based on actual prediction outcomes, not random data
return sorted(accuracy_data, key=lambda x: x.get('timestamp', datetime.now()))
except Exception as e:
logger.debug(f"Error getting prediction accuracy history: {e}")
return []
def _add_signals_to_mini_chart(self, fig: go.Figure, symbol: str, ws_data_1s: pd.DataFrame, row: int = 2):
"""Add ALL signals (executed and non-executed) to the 1s mini chart"""
@ -2566,6 +2919,33 @@ class CleanTradingDashboard:
except Exception as e:
logger.warning(f"Error clearing old signals: {e}")
def _initialize_enhanced_training_system(self):
"""Initialize enhanced training system for model predictions"""
try:
# Try to import and initialize enhanced training system
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=self
)
# Initialize prediction storage
if not hasattr(self.orchestrator, 'recent_dqn_predictions'):
self.orchestrator.recent_dqn_predictions = {}
if not hasattr(self.orchestrator, 'recent_cnn_predictions'):
self.orchestrator.recent_cnn_predictions = {}
logger.info("Enhanced training system initialized for model predictions")
except ImportError:
logger.warning("Enhanced training system not available - using mock predictions")
self.training_system = None
except Exception as e:
logger.error(f"Error initializing enhanced training system: {e}")
self.training_system = None
def _initialize_cob_integration(self):
"""Initialize COB integration with high-frequency data handling"""
try: