365 lines
15 KiB
Python
365 lines
15 KiB
Python
"""
|
|
Dashboard CNN Integration
|
|
|
|
This module integrates the EnhancedCNNAdapter with the dashboard system,
|
|
providing real-time training, predictions, and performance metrics display.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import threading
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from collections import deque
|
|
import numpy as np
|
|
|
|
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
from .standardized_data_provider import StandardizedDataProvider
|
|
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DashboardCNNIntegration:
|
|
"""
|
|
CNN integration for the dashboard system
|
|
|
|
This class:
|
|
1. Manages CNN model lifecycle in the dashboard
|
|
2. Provides real-time training and inference
|
|
3. Tracks performance metrics for dashboard display
|
|
4. Handles model predictions for chart overlay
|
|
"""
|
|
|
|
def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None):
|
|
"""
|
|
Initialize the dashboard CNN integration
|
|
|
|
Args:
|
|
data_provider: Standardized data provider
|
|
symbols: List of symbols to process
|
|
"""
|
|
self.data_provider = data_provider
|
|
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
|
|
|
# Initialize CNN adapter
|
|
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
|
|
# Load best checkpoint if available
|
|
self.cnn_adapter.load_best_checkpoint()
|
|
|
|
# Performance tracking
|
|
self.performance_metrics = {
|
|
'total_predictions': 0,
|
|
'total_training_samples': 0,
|
|
'last_training_time': None,
|
|
'last_inference_time': None,
|
|
'training_loss_history': deque(maxlen=100),
|
|
'accuracy_history': deque(maxlen=100),
|
|
'inference_times': deque(maxlen=100),
|
|
'training_times': deque(maxlen=100),
|
|
'predictions_per_second': 0.0,
|
|
'training_per_second': 0.0,
|
|
'model_status': 'FRESH',
|
|
'confidence_history': deque(maxlen=100),
|
|
'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
|
}
|
|
|
|
# Prediction cache for dashboard display
|
|
self.prediction_cache = {}
|
|
self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
|
|
|
# Training control
|
|
self.training_enabled = True
|
|
self.inference_enabled = True
|
|
self.training_lock = threading.Lock()
|
|
|
|
# Real-time processing
|
|
self.is_running = False
|
|
self.processing_thread = None
|
|
|
|
logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}")
|
|
|
|
def start_real_time_processing(self):
|
|
"""Start real-time CNN processing"""
|
|
if self.is_running:
|
|
logger.warning("Real-time processing already running")
|
|
return
|
|
|
|
self.is_running = True
|
|
self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True)
|
|
self.processing_thread.start()
|
|
|
|
logger.info("Started real-time CNN processing")
|
|
|
|
def stop_real_time_processing(self):
|
|
"""Stop real-time CNN processing"""
|
|
self.is_running = False
|
|
if self.processing_thread:
|
|
self.processing_thread.join(timeout=5)
|
|
|
|
logger.info("Stopped real-time CNN processing")
|
|
|
|
def _real_time_processing_loop(self):
|
|
"""Main real-time processing loop"""
|
|
last_prediction_time = {}
|
|
prediction_interval = 1.0 # Make prediction every 1 second
|
|
|
|
while self.is_running:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
for symbol in self.symbols:
|
|
# Check if it's time to make a prediction for this symbol
|
|
if (symbol not in last_prediction_time or
|
|
current_time - last_prediction_time[symbol] >= prediction_interval):
|
|
|
|
# Make prediction if inference is enabled
|
|
if self.inference_enabled:
|
|
self._make_prediction(symbol)
|
|
last_prediction_time[symbol] = current_time
|
|
|
|
# Update performance metrics
|
|
self._update_performance_metrics()
|
|
|
|
# Sleep briefly to prevent overwhelming the system
|
|
time.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in real-time processing loop: {e}")
|
|
time.sleep(1)
|
|
|
|
def _make_prediction(self, symbol: str):
|
|
"""Make a prediction for a symbol"""
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Get standardized input data
|
|
base_data = self.data_provider.get_base_data_input(symbol)
|
|
|
|
if base_data is None:
|
|
logger.debug(f"No base data available for {symbol}")
|
|
return
|
|
|
|
# Make prediction
|
|
model_output = self.cnn_adapter.predict(base_data)
|
|
|
|
# Record inference time
|
|
inference_time = time.time() - start_time
|
|
self.performance_metrics['inference_times'].append(inference_time)
|
|
|
|
# Update performance metrics
|
|
self.performance_metrics['total_predictions'] += 1
|
|
self.performance_metrics['last_inference_time'] = datetime.now()
|
|
self.performance_metrics['confidence_history'].append(model_output.confidence)
|
|
|
|
# Update action distribution
|
|
action = model_output.predictions['action']
|
|
self.performance_metrics['action_distribution'][action] += 1
|
|
|
|
# Cache prediction for dashboard
|
|
self.prediction_cache[symbol] = model_output
|
|
self.prediction_history[symbol].append(model_output)
|
|
|
|
# Store model output in data provider
|
|
self.data_provider.store_model_output(model_output)
|
|
|
|
logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error making prediction for {symbol}: {e}")
|
|
|
|
def add_training_sample(self, symbol: str, actual_action: str, reward: float):
|
|
"""Add a training sample and trigger training if enabled"""
|
|
try:
|
|
if not self.training_enabled:
|
|
return
|
|
|
|
# Get base data for the symbol
|
|
base_data = self.data_provider.get_base_data_input(symbol)
|
|
|
|
if base_data is None:
|
|
logger.debug(f"No base data available for training sample: {symbol}")
|
|
return
|
|
|
|
# Add training sample
|
|
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
|
|
|
# Update metrics
|
|
self.performance_metrics['total_training_samples'] += 1
|
|
|
|
# Train model periodically (every 10 samples)
|
|
if self.performance_metrics['total_training_samples'] % 10 == 0:
|
|
self._train_model()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding training sample: {e}")
|
|
|
|
def _train_model(self):
|
|
"""Train the CNN model"""
|
|
try:
|
|
with self.training_lock:
|
|
start_time = time.time()
|
|
|
|
# Train model
|
|
metrics = self.cnn_adapter.train(epochs=1)
|
|
|
|
# Record training time
|
|
training_time = time.time() - start_time
|
|
self.performance_metrics['training_times'].append(training_time)
|
|
|
|
# Update performance metrics
|
|
self.performance_metrics['last_training_time'] = datetime.now()
|
|
|
|
if 'loss' in metrics:
|
|
self.performance_metrics['training_loss_history'].append(metrics['loss'])
|
|
|
|
if 'accuracy' in metrics:
|
|
self.performance_metrics['accuracy_history'].append(metrics['accuracy'])
|
|
|
|
# Update model status
|
|
if metrics.get('accuracy', 0) > 0.5:
|
|
self.performance_metrics['model_status'] = 'TRAINED'
|
|
else:
|
|
self.performance_metrics['model_status'] = 'TRAINING'
|
|
|
|
logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN model: {e}")
|
|
|
|
def _update_performance_metrics(self):
|
|
"""Update performance metrics for dashboard display"""
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# Calculate predictions per second (last 60 seconds)
|
|
recent_inferences = [t for t in self.performance_metrics['inference_times']
|
|
if current_time - t <= 60]
|
|
self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0
|
|
|
|
# Calculate training per second (last 60 seconds)
|
|
recent_trainings = [t for t in self.performance_metrics['training_times']
|
|
if current_time - t <= 60]
|
|
self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating performance metrics: {e}")
|
|
|
|
def get_dashboard_metrics(self) -> Dict[str, Any]:
|
|
"""Get metrics for dashboard display"""
|
|
try:
|
|
# Calculate current loss
|
|
current_loss = (self.performance_metrics['training_loss_history'][-1]
|
|
if self.performance_metrics['training_loss_history'] else 0.0)
|
|
|
|
# Calculate current accuracy
|
|
current_accuracy = (self.performance_metrics['accuracy_history'][-1]
|
|
if self.performance_metrics['accuracy_history'] else 0.0)
|
|
|
|
# Calculate average confidence
|
|
avg_confidence = (np.mean(list(self.performance_metrics['confidence_history']))
|
|
if self.performance_metrics['confidence_history'] else 0.0)
|
|
|
|
# Get latest prediction
|
|
latest_prediction = None
|
|
latest_symbol = None
|
|
for symbol, prediction in self.prediction_cache.items():
|
|
if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp:
|
|
latest_prediction = prediction
|
|
latest_symbol = symbol
|
|
|
|
# Format timing information
|
|
last_inference_str = "None"
|
|
last_training_str = "None"
|
|
|
|
if self.performance_metrics['last_inference_time']:
|
|
last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S")
|
|
|
|
if self.performance_metrics['last_training_time']:
|
|
last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S")
|
|
|
|
return {
|
|
'model_name': 'CNN',
|
|
'model_type': 'cnn',
|
|
'parameters': '50.0M',
|
|
'status': self.performance_metrics['model_status'],
|
|
'current_loss': current_loss,
|
|
'accuracy': current_accuracy,
|
|
'confidence': avg_confidence,
|
|
'total_predictions': self.performance_metrics['total_predictions'],
|
|
'total_training_samples': self.performance_metrics['total_training_samples'],
|
|
'predictions_per_second': self.performance_metrics['predictions_per_second'],
|
|
'training_per_second': self.performance_metrics['training_per_second'],
|
|
'last_inference': last_inference_str,
|
|
'last_training': last_training_str,
|
|
'latest_prediction': {
|
|
'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD',
|
|
'confidence': latest_prediction.confidence if latest_prediction else 0.0,
|
|
'symbol': latest_symbol or 'ETH/USDT',
|
|
'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None"
|
|
},
|
|
'action_distribution': self.performance_metrics['action_distribution'].copy(),
|
|
'training_enabled': self.training_enabled,
|
|
'inference_enabled': self.inference_enabled
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting dashboard metrics: {e}")
|
|
return {
|
|
'model_name': 'CNN',
|
|
'model_type': 'cnn',
|
|
'parameters': '50.0M',
|
|
'status': 'ERROR',
|
|
'current_loss': 0.0,
|
|
'accuracy': 0.0,
|
|
'confidence': 0.0,
|
|
'error': str(e)
|
|
}
|
|
|
|
def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get predictions for chart overlay"""
|
|
try:
|
|
if symbol not in self.prediction_history:
|
|
return []
|
|
|
|
predictions = list(self.prediction_history[symbol])[-limit:]
|
|
|
|
chart_data = []
|
|
for prediction in predictions:
|
|
chart_data.append({
|
|
'timestamp': prediction.timestamp,
|
|
'action': prediction.predictions['action'],
|
|
'confidence': prediction.confidence,
|
|
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
|
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
|
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
|
})
|
|
|
|
return chart_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting predictions for chart: {e}")
|
|
return []
|
|
|
|
def set_training_enabled(self, enabled: bool):
|
|
"""Enable or disable training"""
|
|
self.training_enabled = enabled
|
|
logger.info(f"CNN training {'enabled' if enabled else 'disabled'}")
|
|
|
|
def set_inference_enabled(self, enabled: bool):
|
|
"""Enable or disable inference"""
|
|
self.inference_enabled = enabled
|
|
logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}")
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get model information for dashboard"""
|
|
return {
|
|
'name': 'Enhanced CNN',
|
|
'version': '1.0',
|
|
'parameters': '50.0M',
|
|
'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown',
|
|
'device': str(self.cnn_adapter.device),
|
|
'checkpoint_dir': self.cnn_adapter.checkpoint_dir,
|
|
'training_samples': len(self.cnn_adapter.training_data),
|
|
'max_training_samples': self.cnn_adapter.max_training_samples
|
|
} |