integrate CNN, fix COB data
This commit is contained in:
365
core/dashboard_cnn_integration.py
Normal file
365
core/dashboard_cnn_integration.py
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Dashboard CNN Integration
|
||||||
|
|
||||||
|
This module integrates the EnhancedCNNAdapter with the dashboard system,
|
||||||
|
providing real-time training, predictions, and performance metrics display.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
|
from collections import deque
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||||
|
from .standardized_data_provider import StandardizedDataProvider
|
||||||
|
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DashboardCNNIntegration:
|
||||||
|
"""
|
||||||
|
CNN integration for the dashboard system
|
||||||
|
|
||||||
|
This class:
|
||||||
|
1. Manages CNN model lifecycle in the dashboard
|
||||||
|
2. Provides real-time training and inference
|
||||||
|
3. Tracks performance metrics for dashboard display
|
||||||
|
4. Handles model predictions for chart overlay
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the dashboard CNN integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_provider: Standardized data provider
|
||||||
|
symbols: List of symbols to process
|
||||||
|
"""
|
||||||
|
self.data_provider = data_provider
|
||||||
|
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||||
|
|
||||||
|
# Initialize CNN adapter
|
||||||
|
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||||
|
|
||||||
|
# Load best checkpoint if available
|
||||||
|
self.cnn_adapter.load_best_checkpoint()
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.performance_metrics = {
|
||||||
|
'total_predictions': 0,
|
||||||
|
'total_training_samples': 0,
|
||||||
|
'last_training_time': None,
|
||||||
|
'last_inference_time': None,
|
||||||
|
'training_loss_history': deque(maxlen=100),
|
||||||
|
'accuracy_history': deque(maxlen=100),
|
||||||
|
'inference_times': deque(maxlen=100),
|
||||||
|
'training_times': deque(maxlen=100),
|
||||||
|
'predictions_per_second': 0.0,
|
||||||
|
'training_per_second': 0.0,
|
||||||
|
'model_status': 'FRESH',
|
||||||
|
'confidence_history': deque(maxlen=100),
|
||||||
|
'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prediction cache for dashboard display
|
||||||
|
self.prediction_cache = {}
|
||||||
|
self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||||
|
|
||||||
|
# Training control
|
||||||
|
self.training_enabled = True
|
||||||
|
self.inference_enabled = True
|
||||||
|
self.training_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Real-time processing
|
||||||
|
self.is_running = False
|
||||||
|
self.processing_thread = None
|
||||||
|
|
||||||
|
logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}")
|
||||||
|
|
||||||
|
def start_real_time_processing(self):
|
||||||
|
"""Start real-time CNN processing"""
|
||||||
|
if self.is_running:
|
||||||
|
logger.warning("Real-time processing already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True)
|
||||||
|
self.processing_thread.start()
|
||||||
|
|
||||||
|
logger.info("Started real-time CNN processing")
|
||||||
|
|
||||||
|
def stop_real_time_processing(self):
|
||||||
|
"""Stop real-time CNN processing"""
|
||||||
|
self.is_running = False
|
||||||
|
if self.processing_thread:
|
||||||
|
self.processing_thread.join(timeout=5)
|
||||||
|
|
||||||
|
logger.info("Stopped real-time CNN processing")
|
||||||
|
|
||||||
|
def _real_time_processing_loop(self):
|
||||||
|
"""Main real-time processing loop"""
|
||||||
|
last_prediction_time = {}
|
||||||
|
prediction_interval = 1.0 # Make prediction every 1 second
|
||||||
|
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
for symbol in self.symbols:
|
||||||
|
# Check if it's time to make a prediction for this symbol
|
||||||
|
if (symbol not in last_prediction_time or
|
||||||
|
current_time - last_prediction_time[symbol] >= prediction_interval):
|
||||||
|
|
||||||
|
# Make prediction if inference is enabled
|
||||||
|
if self.inference_enabled:
|
||||||
|
self._make_prediction(symbol)
|
||||||
|
last_prediction_time[symbol] = current_time
|
||||||
|
|
||||||
|
# Update performance metrics
|
||||||
|
self._update_performance_metrics()
|
||||||
|
|
||||||
|
# Sleep briefly to prevent overwhelming the system
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in real-time processing loop: {e}")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def _make_prediction(self, symbol: str):
|
||||||
|
"""Make a prediction for a symbol"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Get standardized input data
|
||||||
|
base_data = self.data_provider.get_base_data_input(symbol)
|
||||||
|
|
||||||
|
if base_data is None:
|
||||||
|
logger.debug(f"No base data available for {symbol}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Make prediction
|
||||||
|
model_output = self.cnn_adapter.predict(base_data)
|
||||||
|
|
||||||
|
# Record inference time
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
self.performance_metrics['inference_times'].append(inference_time)
|
||||||
|
|
||||||
|
# Update performance metrics
|
||||||
|
self.performance_metrics['total_predictions'] += 1
|
||||||
|
self.performance_metrics['last_inference_time'] = datetime.now()
|
||||||
|
self.performance_metrics['confidence_history'].append(model_output.confidence)
|
||||||
|
|
||||||
|
# Update action distribution
|
||||||
|
action = model_output.predictions['action']
|
||||||
|
self.performance_metrics['action_distribution'][action] += 1
|
||||||
|
|
||||||
|
# Cache prediction for dashboard
|
||||||
|
self.prediction_cache[symbol] = model_output
|
||||||
|
self.prediction_history[symbol].append(model_output)
|
||||||
|
|
||||||
|
# Store model output in data provider
|
||||||
|
self.data_provider.store_model_output(model_output)
|
||||||
|
|
||||||
|
logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making prediction for {symbol}: {e}")
|
||||||
|
|
||||||
|
def add_training_sample(self, symbol: str, actual_action: str, reward: float):
|
||||||
|
"""Add a training sample and trigger training if enabled"""
|
||||||
|
try:
|
||||||
|
if not self.training_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get base data for the symbol
|
||||||
|
base_data = self.data_provider.get_base_data_input(symbol)
|
||||||
|
|
||||||
|
if base_data is None:
|
||||||
|
logger.debug(f"No base data available for training sample: {symbol}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add training sample
|
||||||
|
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
self.performance_metrics['total_training_samples'] += 1
|
||||||
|
|
||||||
|
# Train model periodically (every 10 samples)
|
||||||
|
if self.performance_metrics['total_training_samples'] % 10 == 0:
|
||||||
|
self._train_model()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding training sample: {e}")
|
||||||
|
|
||||||
|
def _train_model(self):
|
||||||
|
"""Train the CNN model"""
|
||||||
|
try:
|
||||||
|
with self.training_lock:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
metrics = self.cnn_adapter.train(epochs=1)
|
||||||
|
|
||||||
|
# Record training time
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
self.performance_metrics['training_times'].append(training_time)
|
||||||
|
|
||||||
|
# Update performance metrics
|
||||||
|
self.performance_metrics['last_training_time'] = datetime.now()
|
||||||
|
|
||||||
|
if 'loss' in metrics:
|
||||||
|
self.performance_metrics['training_loss_history'].append(metrics['loss'])
|
||||||
|
|
||||||
|
if 'accuracy' in metrics:
|
||||||
|
self.performance_metrics['accuracy_history'].append(metrics['accuracy'])
|
||||||
|
|
||||||
|
# Update model status
|
||||||
|
if metrics.get('accuracy', 0) > 0.5:
|
||||||
|
self.performance_metrics['model_status'] = 'TRAINED'
|
||||||
|
else:
|
||||||
|
self.performance_metrics['model_status'] = 'TRAINING'
|
||||||
|
|
||||||
|
logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error training CNN model: {e}")
|
||||||
|
|
||||||
|
def _update_performance_metrics(self):
|
||||||
|
"""Update performance metrics for dashboard display"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Calculate predictions per second (last 60 seconds)
|
||||||
|
recent_inferences = [t for t in self.performance_metrics['inference_times']
|
||||||
|
if current_time - t <= 60]
|
||||||
|
self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0
|
||||||
|
|
||||||
|
# Calculate training per second (last 60 seconds)
|
||||||
|
recent_trainings = [t for t in self.performance_metrics['training_times']
|
||||||
|
if current_time - t <= 60]
|
||||||
|
self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating performance metrics: {e}")
|
||||||
|
|
||||||
|
def get_dashboard_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get metrics for dashboard display"""
|
||||||
|
try:
|
||||||
|
# Calculate current loss
|
||||||
|
current_loss = (self.performance_metrics['training_loss_history'][-1]
|
||||||
|
if self.performance_metrics['training_loss_history'] else 0.0)
|
||||||
|
|
||||||
|
# Calculate current accuracy
|
||||||
|
current_accuracy = (self.performance_metrics['accuracy_history'][-1]
|
||||||
|
if self.performance_metrics['accuracy_history'] else 0.0)
|
||||||
|
|
||||||
|
# Calculate average confidence
|
||||||
|
avg_confidence = (np.mean(list(self.performance_metrics['confidence_history']))
|
||||||
|
if self.performance_metrics['confidence_history'] else 0.0)
|
||||||
|
|
||||||
|
# Get latest prediction
|
||||||
|
latest_prediction = None
|
||||||
|
latest_symbol = None
|
||||||
|
for symbol, prediction in self.prediction_cache.items():
|
||||||
|
if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp:
|
||||||
|
latest_prediction = prediction
|
||||||
|
latest_symbol = symbol
|
||||||
|
|
||||||
|
# Format timing information
|
||||||
|
last_inference_str = "None"
|
||||||
|
last_training_str = "None"
|
||||||
|
|
||||||
|
if self.performance_metrics['last_inference_time']:
|
||||||
|
last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S")
|
||||||
|
|
||||||
|
if self.performance_metrics['last_training_time']:
|
||||||
|
last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model_name': 'CNN',
|
||||||
|
'model_type': 'cnn',
|
||||||
|
'parameters': '50.0M',
|
||||||
|
'status': self.performance_metrics['model_status'],
|
||||||
|
'current_loss': current_loss,
|
||||||
|
'accuracy': current_accuracy,
|
||||||
|
'confidence': avg_confidence,
|
||||||
|
'total_predictions': self.performance_metrics['total_predictions'],
|
||||||
|
'total_training_samples': self.performance_metrics['total_training_samples'],
|
||||||
|
'predictions_per_second': self.performance_metrics['predictions_per_second'],
|
||||||
|
'training_per_second': self.performance_metrics['training_per_second'],
|
||||||
|
'last_inference': last_inference_str,
|
||||||
|
'last_training': last_training_str,
|
||||||
|
'latest_prediction': {
|
||||||
|
'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD',
|
||||||
|
'confidence': latest_prediction.confidence if latest_prediction else 0.0,
|
||||||
|
'symbol': latest_symbol or 'ETH/USDT',
|
||||||
|
'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None"
|
||||||
|
},
|
||||||
|
'action_distribution': self.performance_metrics['action_distribution'].copy(),
|
||||||
|
'training_enabled': self.training_enabled,
|
||||||
|
'inference_enabled': self.inference_enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting dashboard metrics: {e}")
|
||||||
|
return {
|
||||||
|
'model_name': 'CNN',
|
||||||
|
'model_type': 'cnn',
|
||||||
|
'parameters': '50.0M',
|
||||||
|
'status': 'ERROR',
|
||||||
|
'current_loss': 0.0,
|
||||||
|
'accuracy': 0.0,
|
||||||
|
'confidence': 0.0,
|
||||||
|
'error': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""Get predictions for chart overlay"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.prediction_history:
|
||||||
|
return []
|
||||||
|
|
||||||
|
predictions = list(self.prediction_history[symbol])[-limit:]
|
||||||
|
|
||||||
|
chart_data = []
|
||||||
|
for prediction in predictions:
|
||||||
|
chart_data.append({
|
||||||
|
'timestamp': prediction.timestamp,
|
||||||
|
'action': prediction.predictions['action'],
|
||||||
|
'confidence': prediction.confidence,
|
||||||
|
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
||||||
|
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
||||||
|
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
return chart_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting predictions for chart: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def set_training_enabled(self, enabled: bool):
|
||||||
|
"""Enable or disable training"""
|
||||||
|
self.training_enabled = enabled
|
||||||
|
logger.info(f"CNN training {'enabled' if enabled else 'disabled'}")
|
||||||
|
|
||||||
|
def set_inference_enabled(self, enabled: bool):
|
||||||
|
"""Enable or disable inference"""
|
||||||
|
self.inference_enabled = enabled
|
||||||
|
logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}")
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get model information for dashboard"""
|
||||||
|
return {
|
||||||
|
'name': 'Enhanced CNN',
|
||||||
|
'version': '1.0',
|
||||||
|
'parameters': '50.0M',
|
||||||
|
'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown',
|
||||||
|
'device': str(self.cnn_adapter.device),
|
||||||
|
'checkpoint_dir': self.cnn_adapter.checkpoint_dir,
|
||||||
|
'training_samples': len(self.cnn_adapter.training_data),
|
||||||
|
'max_training_samples': self.cnn_adapter.max_training_samples
|
||||||
|
}
|
403
core/enhanced_cnn_integration.py
Normal file
403
core/enhanced_cnn_integration.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
"""
|
||||||
|
Enhanced CNN Integration for Dashboard
|
||||||
|
|
||||||
|
This module integrates the EnhancedCNNAdapter with the dashboard, providing real-time
|
||||||
|
training and inference capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Any, Union
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||||
|
from .standardized_data_provider import StandardizedDataProvider
|
||||||
|
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class EnhancedCNNIntegration:
|
||||||
|
"""
|
||||||
|
Integration of EnhancedCNNAdapter with the dashboard
|
||||||
|
|
||||||
|
This class:
|
||||||
|
1. Manages the EnhancedCNNAdapter lifecycle
|
||||||
|
2. Provides real-time training and inference
|
||||||
|
3. Collects and reports performance metrics
|
||||||
|
4. Integrates with the dashboard's model visualization
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_provider: StandardizedDataProvider, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||||
|
"""
|
||||||
|
Initialize the EnhancedCNNIntegration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_provider: StandardizedDataProvider instance
|
||||||
|
checkpoint_dir: Directory to store checkpoints
|
||||||
|
"""
|
||||||
|
self.data_provider = data_provider
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
self.model_name = "enhanced_cnn_v1"
|
||||||
|
|
||||||
|
# Create checkpoint directory if it doesn't exist
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize CNN adapter
|
||||||
|
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
|
||||||
|
|
||||||
|
# Load best checkpoint if available
|
||||||
|
self.cnn_adapter.load_best_checkpoint()
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.inference_times = []
|
||||||
|
self.training_times = []
|
||||||
|
self.total_inferences = 0
|
||||||
|
self.total_training_runs = 0
|
||||||
|
self.last_inference_time = None
|
||||||
|
self.last_training_time = None
|
||||||
|
self.inference_rate = 0.0
|
||||||
|
self.training_rate = 0.0
|
||||||
|
self.daily_inferences = 0
|
||||||
|
self.daily_training_runs = 0
|
||||||
|
|
||||||
|
# Training settings
|
||||||
|
self.training_enabled = True
|
||||||
|
self.inference_enabled = True
|
||||||
|
self.training_frequency = 10 # Train every N inferences
|
||||||
|
self.training_batch_size = 32
|
||||||
|
self.training_epochs = 1
|
||||||
|
|
||||||
|
# Latest prediction
|
||||||
|
self.latest_prediction = None
|
||||||
|
self.latest_prediction_time = None
|
||||||
|
|
||||||
|
# Training metrics
|
||||||
|
self.current_loss = 0.0
|
||||||
|
self.initial_loss = None
|
||||||
|
self.best_loss = None
|
||||||
|
self.current_accuracy = 0.0
|
||||||
|
self.improvement_percentage = 0.0
|
||||||
|
|
||||||
|
# Training thread
|
||||||
|
self.training_thread = None
|
||||||
|
self.training_active = False
|
||||||
|
self.stop_training = False
|
||||||
|
|
||||||
|
logger.info(f"EnhancedCNNIntegration initialized with model: {self.model_name}")
|
||||||
|
|
||||||
|
def start_continuous_training(self):
|
||||||
|
"""Start continuous training in a background thread"""
|
||||||
|
if self.training_thread is not None and self.training_thread.is_alive():
|
||||||
|
logger.info("Continuous training already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stop_training = False
|
||||||
|
self.training_thread = threading.Thread(target=self._continuous_training_loop, daemon=True)
|
||||||
|
self.training_thread.start()
|
||||||
|
logger.info("Started continuous training thread")
|
||||||
|
|
||||||
|
def stop_continuous_training(self):
|
||||||
|
"""Stop continuous training"""
|
||||||
|
self.stop_training = True
|
||||||
|
logger.info("Stopping continuous training thread")
|
||||||
|
|
||||||
|
def _continuous_training_loop(self):
|
||||||
|
"""Continuous training loop"""
|
||||||
|
try:
|
||||||
|
self.training_active = True
|
||||||
|
logger.info("Starting continuous training loop")
|
||||||
|
|
||||||
|
while not self.stop_training:
|
||||||
|
# Check if training is enabled
|
||||||
|
if not self.training_enabled:
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we have enough training samples
|
||||||
|
if len(self.cnn_adapter.training_data) < self.training_batch_size:
|
||||||
|
logger.debug(f"Not enough training samples: {len(self.cnn_adapter.training_data)}/{self.training_batch_size}")
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
start_time = time.time()
|
||||||
|
metrics = self.cnn_adapter.train(epochs=self.training_epochs)
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
self.training_times.append(training_time)
|
||||||
|
if len(self.training_times) > 100:
|
||||||
|
self.training_times.pop(0)
|
||||||
|
|
||||||
|
self.total_training_runs += 1
|
||||||
|
self.daily_training_runs += 1
|
||||||
|
self.last_training_time = datetime.now()
|
||||||
|
|
||||||
|
# Calculate training rate
|
||||||
|
if self.training_times:
|
||||||
|
avg_training_time = sum(self.training_times) / len(self.training_times)
|
||||||
|
self.training_rate = 1.0 / avg_training_time if avg_training_time > 0 else 0.0
|
||||||
|
|
||||||
|
# Update loss and accuracy
|
||||||
|
self.current_loss = metrics.get('loss', 0.0)
|
||||||
|
self.current_accuracy = metrics.get('accuracy', 0.0)
|
||||||
|
|
||||||
|
# Update initial loss if not set
|
||||||
|
if self.initial_loss is None:
|
||||||
|
self.initial_loss = self.current_loss
|
||||||
|
|
||||||
|
# Update best loss
|
||||||
|
if self.best_loss is None or self.current_loss < self.best_loss:
|
||||||
|
self.best_loss = self.current_loss
|
||||||
|
|
||||||
|
# Calculate improvement percentage
|
||||||
|
if self.initial_loss is not None and self.initial_loss > 0:
|
||||||
|
self.improvement_percentage = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
|
||||||
|
|
||||||
|
logger.info(f"Training completed: loss={self.current_loss:.4f}, accuracy={self.current_accuracy:.4f}, samples={metrics.get('samples', 0)}")
|
||||||
|
|
||||||
|
# Sleep before next training
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in continuous training loop: {e}")
|
||||||
|
finally:
|
||||||
|
self.training_active = False
|
||||||
|
|
||||||
|
def predict(self, symbol: str) -> Optional[ModelOutput]:
|
||||||
|
"""
|
||||||
|
Make a prediction using the EnhancedCNN model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelOutput: Standardized model output
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if inference is enabled
|
||||||
|
if not self.inference_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get standardized input data
|
||||||
|
base_data = self.data_provider.get_base_data_input(symbol)
|
||||||
|
|
||||||
|
if base_data is None:
|
||||||
|
logger.warning(f"Failed to get base data input for {symbol}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Make prediction
|
||||||
|
start_time = time.time()
|
||||||
|
model_output = self.cnn_adapter.predict(base_data)
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
self.inference_times.append(inference_time)
|
||||||
|
if len(self.inference_times) > 100:
|
||||||
|
self.inference_times.pop(0)
|
||||||
|
|
||||||
|
self.total_inferences += 1
|
||||||
|
self.daily_inferences += 1
|
||||||
|
self.last_inference_time = datetime.now()
|
||||||
|
|
||||||
|
# Calculate inference rate
|
||||||
|
if self.inference_times:
|
||||||
|
avg_inference_time = sum(self.inference_times) / len(self.inference_times)
|
||||||
|
self.inference_rate = 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0
|
||||||
|
|
||||||
|
# Store latest prediction
|
||||||
|
self.latest_prediction = model_output
|
||||||
|
self.latest_prediction_time = datetime.now()
|
||||||
|
|
||||||
|
# Store model output in data provider
|
||||||
|
self.data_provider.store_model_output(model_output)
|
||||||
|
|
||||||
|
# Add training sample if we have a price
|
||||||
|
current_price = self._get_current_price(symbol)
|
||||||
|
if current_price and current_price > 0:
|
||||||
|
# Simulate market feedback based on price movement
|
||||||
|
# In a real system, this would be replaced with actual market performance data
|
||||||
|
action = model_output.predictions['action']
|
||||||
|
|
||||||
|
# For demonstration, we'll use a simple heuristic:
|
||||||
|
# - If price is above 3000, BUY is good
|
||||||
|
# - If price is below 3000, SELL is good
|
||||||
|
# - Otherwise, HOLD is good
|
||||||
|
if current_price > 3000:
|
||||||
|
best_action = 'BUY'
|
||||||
|
elif current_price < 3000:
|
||||||
|
best_action = 'SELL'
|
||||||
|
else:
|
||||||
|
best_action = 'HOLD'
|
||||||
|
|
||||||
|
# Calculate reward based on whether the action matched the best action
|
||||||
|
if action == best_action:
|
||||||
|
reward = 0.05 # Positive reward for correct action
|
||||||
|
else:
|
||||||
|
reward = -0.05 # Negative reward for incorrect action
|
||||||
|
|
||||||
|
# Add training sample
|
||||||
|
self.cnn_adapter.add_training_sample(base_data, best_action, reward)
|
||||||
|
|
||||||
|
logger.debug(f"Added training sample for {symbol}, action: {action}, best_action: {best_action}, reward: {reward:.4f}")
|
||||||
|
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making prediction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""Get current price for a symbol"""
|
||||||
|
try:
|
||||||
|
# Try to get price from data provider
|
||||||
|
if hasattr(self.data_provider, 'current_prices'):
|
||||||
|
binance_symbol = symbol.replace('/', '').upper()
|
||||||
|
if binance_symbol in self.data_provider.current_prices:
|
||||||
|
return self.data_provider.current_prices[binance_symbol]
|
||||||
|
|
||||||
|
# Try to get price from latest OHLCV data
|
||||||
|
df = self.data_provider.get_historical_data(symbol, '1s', 1)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
return float(df.iloc[-1]['close'])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting current price: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_model_state(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get model state for dashboard display
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Model state
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format prediction for display
|
||||||
|
prediction_info = "FRESH"
|
||||||
|
confidence = 0.0
|
||||||
|
|
||||||
|
if self.latest_prediction:
|
||||||
|
action = self.latest_prediction.predictions.get('action', 'UNKNOWN')
|
||||||
|
confidence = self.latest_prediction.confidence
|
||||||
|
|
||||||
|
# Map action to display text
|
||||||
|
if action == 'BUY':
|
||||||
|
prediction_info = "BUY_SIGNAL"
|
||||||
|
elif action == 'SELL':
|
||||||
|
prediction_info = "SELL_SIGNAL"
|
||||||
|
elif action == 'HOLD':
|
||||||
|
prediction_info = "HOLD_SIGNAL"
|
||||||
|
else:
|
||||||
|
prediction_info = "PATTERN_ANALYSIS"
|
||||||
|
|
||||||
|
# Format timing information
|
||||||
|
inference_timing = "None"
|
||||||
|
training_timing = "None"
|
||||||
|
|
||||||
|
if self.last_inference_time:
|
||||||
|
inference_timing = self.last_inference_time.strftime('%H:%M:%S')
|
||||||
|
|
||||||
|
if self.last_training_time:
|
||||||
|
training_timing = self.last_training_time.strftime('%H:%M:%S')
|
||||||
|
|
||||||
|
# Calculate improvement percentage
|
||||||
|
improvement = 0.0
|
||||||
|
if self.initial_loss is not None and self.initial_loss > 0 and self.current_loss > 0:
|
||||||
|
improvement = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model_name': self.model_name,
|
||||||
|
'model_type': 'cnn',
|
||||||
|
'parameters': 50000000, # 50M parameters
|
||||||
|
'status': 'ACTIVE' if self.inference_enabled else 'DISABLED',
|
||||||
|
'checkpoint_loaded': True, # Assume checkpoint is loaded
|
||||||
|
'last_prediction': prediction_info,
|
||||||
|
'confidence': confidence * 100, # Convert to percentage
|
||||||
|
'last_inference_time': inference_timing,
|
||||||
|
'last_training_time': training_timing,
|
||||||
|
'inference_rate': self.inference_rate,
|
||||||
|
'training_rate': self.training_rate,
|
||||||
|
'daily_inferences': self.daily_inferences,
|
||||||
|
'daily_training_runs': self.daily_training_runs,
|
||||||
|
'initial_loss': self.initial_loss,
|
||||||
|
'current_loss': self.current_loss,
|
||||||
|
'best_loss': self.best_loss,
|
||||||
|
'current_accuracy': self.current_accuracy,
|
||||||
|
'improvement_percentage': improvement,
|
||||||
|
'training_active': self.training_active,
|
||||||
|
'training_enabled': self.training_enabled,
|
||||||
|
'inference_enabled': self.inference_enabled,
|
||||||
|
'training_samples': len(self.cnn_adapter.training_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting model state: {e}")
|
||||||
|
return {
|
||||||
|
'model_name': self.model_name,
|
||||||
|
'model_type': 'cnn',
|
||||||
|
'parameters': 50000000, # 50M parameters
|
||||||
|
'status': 'ERROR',
|
||||||
|
'error': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pivot_prediction(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get pivot prediction for dashboard display
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Pivot prediction
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.latest_prediction:
|
||||||
|
return {
|
||||||
|
'next_pivot': 0.0,
|
||||||
|
'pivot_type': 'UNKNOWN',
|
||||||
|
'confidence': 0.0,
|
||||||
|
'time_to_pivot': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract pivot prediction from model output
|
||||||
|
extrema_pred = self.latest_prediction.predictions.get('extrema', [0, 0, 0])
|
||||||
|
|
||||||
|
# Determine pivot type (0=bottom, 1=top, 2=neither)
|
||||||
|
pivot_type_idx = extrema_pred.index(max(extrema_pred))
|
||||||
|
pivot_types = ['BOTTOM', 'TOP', 'RANGE_CONTINUATION']
|
||||||
|
pivot_type = pivot_types[pivot_type_idx]
|
||||||
|
|
||||||
|
# Get current price
|
||||||
|
current_price = self._get_current_price('ETH/USDT') or 0.0
|
||||||
|
|
||||||
|
# Calculate next pivot price (simple heuristic for demonstration)
|
||||||
|
if pivot_type == 'BOTTOM':
|
||||||
|
next_pivot = current_price * 0.95 # 5% below current price
|
||||||
|
elif pivot_type == 'TOP':
|
||||||
|
next_pivot = current_price * 1.05 # 5% above current price
|
||||||
|
else:
|
||||||
|
next_pivot = current_price # Same as current price
|
||||||
|
|
||||||
|
# Calculate confidence
|
||||||
|
confidence = max(extrema_pred) * 100 # Convert to percentage
|
||||||
|
|
||||||
|
# Calculate time to pivot (simple heuristic for demonstration)
|
||||||
|
time_to_pivot = 5 # 5 minutes
|
||||||
|
|
||||||
|
return {
|
||||||
|
'next_pivot': next_pivot,
|
||||||
|
'pivot_type': pivot_type,
|
||||||
|
'confidence': confidence,
|
||||||
|
'time_to_pivot': time_to_pivot
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting pivot prediction: {e}")
|
||||||
|
return {
|
||||||
|
'next_pivot': 0.0,
|
||||||
|
'pivot_type': 'ERROR',
|
||||||
|
'confidence': 0.0,
|
||||||
|
'time_to_pivot': 0
|
||||||
|
}
|
123
test_cob_data_stability.py
Normal file
123
test_cob_data_stability.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from matplotlib.colors import LogNorm
|
||||||
|
|
||||||
|
from core.data_provider import DataProvider, MarketTick
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class COBStabilityTester:
|
||||||
|
def __init__(self, symbol='ETH/USDT', duration_seconds=15):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.duration = timedelta(seconds=duration_seconds)
|
||||||
|
self.ticks = deque()
|
||||||
|
self.data_provider = DataProvider(symbols=[self.symbol], timeframes=['1s'])
|
||||||
|
self.start_time = None
|
||||||
|
self.subscriber_id = None
|
||||||
|
|
||||||
|
def _tick_callback(self, tick: MarketTick):
|
||||||
|
"""Callback function to receive ticks from the DataProvider."""
|
||||||
|
if self.start_time is None:
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
logger.info(f"Started collecting ticks at {self.start_time}")
|
||||||
|
|
||||||
|
# Store all ticks
|
||||||
|
self.ticks.append(tick)
|
||||||
|
|
||||||
|
async def run_test(self):
|
||||||
|
"""Run the data collection and plotting test."""
|
||||||
|
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
|
||||||
|
|
||||||
|
# Subscribe to ticks
|
||||||
|
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
|
||||||
|
|
||||||
|
# Start the data provider's real-time streaming
|
||||||
|
await self.data_provider.start_real_time_streaming()
|
||||||
|
|
||||||
|
# Collect data for the specified duration
|
||||||
|
self.start_time = datetime.now()
|
||||||
|
while datetime.now() - self.start_time < self.duration:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
logger.info(f"Collected {len(self.ticks)} ticks so far...")
|
||||||
|
|
||||||
|
# Stop streaming and unsubscribe
|
||||||
|
await self.data_provider.stop_real_time_streaming()
|
||||||
|
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
|
||||||
|
|
||||||
|
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
|
||||||
|
|
||||||
|
# Plot the results
|
||||||
|
if self.ticks:
|
||||||
|
self.plot_spectrogram()
|
||||||
|
else:
|
||||||
|
logger.warning("No ticks were collected. Cannot generate plot.")
|
||||||
|
|
||||||
|
def plot_spectrogram(self):
|
||||||
|
"""Create a spectrogram-like plot of trade intensity."""
|
||||||
|
if not self.ticks:
|
||||||
|
logger.warning("No ticks to plot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
df = pd.DataFrame([{
|
||||||
|
'timestamp': tick.timestamp,
|
||||||
|
'price': tick.price,
|
||||||
|
'volume': tick.volume,
|
||||||
|
'side': 1 if tick.side == 'buy' else -1
|
||||||
|
} for tick in self.ticks])
|
||||||
|
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||||
|
df = df.set_index('timestamp')
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
fig, ax = plt.subplots(figsize=(15, 8))
|
||||||
|
|
||||||
|
# Define bins for the 2D histogram
|
||||||
|
time_bins = pd.date_range(df.index.min(), df.index.max(), periods=100)
|
||||||
|
price_bins = np.linspace(df['price'].min(), df['price'].max(), 100)
|
||||||
|
|
||||||
|
# Create the 2D histogram
|
||||||
|
# x-axis: time, y-axis: price, weights: volume
|
||||||
|
h, xedges, yedges = np.histogram2d(
|
||||||
|
df.index.astype(np.int64) // 10**9,
|
||||||
|
df['price'],
|
||||||
|
bins=[time_bins.astype(np.int64) // 10**9, price_bins],
|
||||||
|
weights=df['volume']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a logarithmic color scale for better visibility of smaller trades
|
||||||
|
pcm = ax.pcolormesh(time_bins, price_bins, h.T, norm=LogNorm(vmin=1e-3, vmax=h.max()), cmap='inferno')
|
||||||
|
|
||||||
|
fig.colorbar(pcm, ax=ax, label='Trade Volume (USDT)')
|
||||||
|
ax.set_title(f'Trade Intensity Spectrogram for {self.symbol}')
|
||||||
|
ax.set_xlabel('Time')
|
||||||
|
ax.set_ylabel('Price (USDT)')
|
||||||
|
|
||||||
|
# Format the x-axis to show time properly
|
||||||
|
fig.autofmt_xdate()
|
||||||
|
|
||||||
|
plot_filename = f"cob_stability_spectrogram_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
|
||||||
|
plt.savefig(plot_filename)
|
||||||
|
logger.info(f"Plot saved to {plot_filename}")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
tester = COBStabilityTester()
|
||||||
|
await tester.run_test()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Test interrupted by user.")
|
Reference in New Issue
Block a user