diff --git a/core/dashboard_cnn_integration.py b/core/dashboard_cnn_integration.py new file mode 100644 index 0000000..7919f59 --- /dev/null +++ b/core/dashboard_cnn_integration.py @@ -0,0 +1,365 @@ +""" +Dashboard CNN Integration + +This module integrates the EnhancedCNNAdapter with the dashboard system, +providing real-time training, predictions, and performance metrics display. +""" + +import logging +import time +import threading +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Tuple +from collections import deque +import numpy as np + +from .enhanced_cnn_adapter import EnhancedCNNAdapter +from .standardized_data_provider import StandardizedDataProvider +from .data_models import BaseDataInput, ModelOutput, create_model_output + +logger = logging.getLogger(__name__) + +class DashboardCNNIntegration: + """ + CNN integration for the dashboard system + + This class: + 1. Manages CNN model lifecycle in the dashboard + 2. Provides real-time training and inference + 3. Tracks performance metrics for dashboard display + 4. Handles model predictions for chart overlay + """ + + def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None): + """ + Initialize the dashboard CNN integration + + Args: + data_provider: Standardized data provider + symbols: List of symbols to process + """ + self.data_provider = data_provider + self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] + + # Initialize CNN adapter + self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") + + # Load best checkpoint if available + self.cnn_adapter.load_best_checkpoint() + + # Performance tracking + self.performance_metrics = { + 'total_predictions': 0, + 'total_training_samples': 0, + 'last_training_time': None, + 'last_inference_time': None, + 'training_loss_history': deque(maxlen=100), + 'accuracy_history': deque(maxlen=100), + 'inference_times': deque(maxlen=100), + 'training_times': deque(maxlen=100), + 'predictions_per_second': 0.0, + 'training_per_second': 0.0, + 'model_status': 'FRESH', + 'confidence_history': deque(maxlen=100), + 'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0} + } + + # Prediction cache for dashboard display + self.prediction_cache = {} + self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols} + + # Training control + self.training_enabled = True + self.inference_enabled = True + self.training_lock = threading.Lock() + + # Real-time processing + self.is_running = False + self.processing_thread = None + + logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}") + + def start_real_time_processing(self): + """Start real-time CNN processing""" + if self.is_running: + logger.warning("Real-time processing already running") + return + + self.is_running = True + self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True) + self.processing_thread.start() + + logger.info("Started real-time CNN processing") + + def stop_real_time_processing(self): + """Stop real-time CNN processing""" + self.is_running = False + if self.processing_thread: + self.processing_thread.join(timeout=5) + + logger.info("Stopped real-time CNN processing") + + def _real_time_processing_loop(self): + """Main real-time processing loop""" + last_prediction_time = {} + prediction_interval = 1.0 # Make prediction every 1 second + + while self.is_running: + try: + current_time = time.time() + + for symbol in self.symbols: + # Check if it's time to make a prediction for this symbol + if (symbol not in last_prediction_time or + current_time - last_prediction_time[symbol] >= prediction_interval): + + # Make prediction if inference is enabled + if self.inference_enabled: + self._make_prediction(symbol) + last_prediction_time[symbol] = current_time + + # Update performance metrics + self._update_performance_metrics() + + # Sleep briefly to prevent overwhelming the system + time.sleep(0.1) + + except Exception as e: + logger.error(f"Error in real-time processing loop: {e}") + time.sleep(1) + + def _make_prediction(self, symbol: str): + """Make a prediction for a symbol""" + try: + start_time = time.time() + + # Get standardized input data + base_data = self.data_provider.get_base_data_input(symbol) + + if base_data is None: + logger.debug(f"No base data available for {symbol}") + return + + # Make prediction + model_output = self.cnn_adapter.predict(base_data) + + # Record inference time + inference_time = time.time() - start_time + self.performance_metrics['inference_times'].append(inference_time) + + # Update performance metrics + self.performance_metrics['total_predictions'] += 1 + self.performance_metrics['last_inference_time'] = datetime.now() + self.performance_metrics['confidence_history'].append(model_output.confidence) + + # Update action distribution + action = model_output.predictions['action'] + self.performance_metrics['action_distribution'][action] += 1 + + # Cache prediction for dashboard + self.prediction_cache[symbol] = model_output + self.prediction_history[symbol].append(model_output) + + # Store model output in data provider + self.data_provider.store_model_output(model_output) + + logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})") + + except Exception as e: + logger.error(f"Error making prediction for {symbol}: {e}") + + def add_training_sample(self, symbol: str, actual_action: str, reward: float): + """Add a training sample and trigger training if enabled""" + try: + if not self.training_enabled: + return + + # Get base data for the symbol + base_data = self.data_provider.get_base_data_input(symbol) + + if base_data is None: + logger.debug(f"No base data available for training sample: {symbol}") + return + + # Add training sample + self.cnn_adapter.add_training_sample(base_data, actual_action, reward) + + # Update metrics + self.performance_metrics['total_training_samples'] += 1 + + # Train model periodically (every 10 samples) + if self.performance_metrics['total_training_samples'] % 10 == 0: + self._train_model() + + except Exception as e: + logger.error(f"Error adding training sample: {e}") + + def _train_model(self): + """Train the CNN model""" + try: + with self.training_lock: + start_time = time.time() + + # Train model + metrics = self.cnn_adapter.train(epochs=1) + + # Record training time + training_time = time.time() - start_time + self.performance_metrics['training_times'].append(training_time) + + # Update performance metrics + self.performance_metrics['last_training_time'] = datetime.now() + + if 'loss' in metrics: + self.performance_metrics['training_loss_history'].append(metrics['loss']) + + if 'accuracy' in metrics: + self.performance_metrics['accuracy_history'].append(metrics['accuracy']) + + # Update model status + if metrics.get('accuracy', 0) > 0.5: + self.performance_metrics['model_status'] = 'TRAINED' + else: + self.performance_metrics['model_status'] = 'TRAINING' + + logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}") + + except Exception as e: + logger.error(f"Error training CNN model: {e}") + + def _update_performance_metrics(self): + """Update performance metrics for dashboard display""" + try: + current_time = time.time() + + # Calculate predictions per second (last 60 seconds) + recent_inferences = [t for t in self.performance_metrics['inference_times'] + if current_time - t <= 60] + self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0 + + # Calculate training per second (last 60 seconds) + recent_trainings = [t for t in self.performance_metrics['training_times'] + if current_time - t <= 60] + self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0 + + except Exception as e: + logger.error(f"Error updating performance metrics: {e}") + + def get_dashboard_metrics(self) -> Dict[str, Any]: + """Get metrics for dashboard display""" + try: + # Calculate current loss + current_loss = (self.performance_metrics['training_loss_history'][-1] + if self.performance_metrics['training_loss_history'] else 0.0) + + # Calculate current accuracy + current_accuracy = (self.performance_metrics['accuracy_history'][-1] + if self.performance_metrics['accuracy_history'] else 0.0) + + # Calculate average confidence + avg_confidence = (np.mean(list(self.performance_metrics['confidence_history'])) + if self.performance_metrics['confidence_history'] else 0.0) + + # Get latest prediction + latest_prediction = None + latest_symbol = None + for symbol, prediction in self.prediction_cache.items(): + if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp: + latest_prediction = prediction + latest_symbol = symbol + + # Format timing information + last_inference_str = "None" + last_training_str = "None" + + if self.performance_metrics['last_inference_time']: + last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S") + + if self.performance_metrics['last_training_time']: + last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S") + + return { + 'model_name': 'CNN', + 'model_type': 'cnn', + 'parameters': '50.0M', + 'status': self.performance_metrics['model_status'], + 'current_loss': current_loss, + 'accuracy': current_accuracy, + 'confidence': avg_confidence, + 'total_predictions': self.performance_metrics['total_predictions'], + 'total_training_samples': self.performance_metrics['total_training_samples'], + 'predictions_per_second': self.performance_metrics['predictions_per_second'], + 'training_per_second': self.performance_metrics['training_per_second'], + 'last_inference': last_inference_str, + 'last_training': last_training_str, + 'latest_prediction': { + 'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD', + 'confidence': latest_prediction.confidence if latest_prediction else 0.0, + 'symbol': latest_symbol or 'ETH/USDT', + 'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None" + }, + 'action_distribution': self.performance_metrics['action_distribution'].copy(), + 'training_enabled': self.training_enabled, + 'inference_enabled': self.inference_enabled + } + + except Exception as e: + logger.error(f"Error getting dashboard metrics: {e}") + return { + 'model_name': 'CNN', + 'model_type': 'cnn', + 'parameters': '50.0M', + 'status': 'ERROR', + 'current_loss': 0.0, + 'accuracy': 0.0, + 'confidence': 0.0, + 'error': str(e) + } + + def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]: + """Get predictions for chart overlay""" + try: + if symbol not in self.prediction_history: + return [] + + predictions = list(self.prediction_history[symbol])[-limit:] + + chart_data = [] + for prediction in predictions: + chart_data.append({ + 'timestamp': prediction.timestamp, + 'action': prediction.predictions['action'], + 'confidence': prediction.confidence, + 'buy_probability': prediction.predictions.get('buy_probability', 0.0), + 'sell_probability': prediction.predictions.get('sell_probability', 0.0), + 'hold_probability': prediction.predictions.get('hold_probability', 0.0) + }) + + return chart_data + + except Exception as e: + logger.error(f"Error getting predictions for chart: {e}") + return [] + + def set_training_enabled(self, enabled: bool): + """Enable or disable training""" + self.training_enabled = enabled + logger.info(f"CNN training {'enabled' if enabled else 'disabled'}") + + def set_inference_enabled(self, enabled: bool): + """Enable or disable inference""" + self.inference_enabled = enabled + logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}") + + def get_model_info(self) -> Dict[str, Any]: + """Get model information for dashboard""" + return { + 'name': 'Enhanced CNN', + 'version': '1.0', + 'parameters': '50.0M', + 'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown', + 'device': str(self.cnn_adapter.device), + 'checkpoint_dir': self.cnn_adapter.checkpoint_dir, + 'training_samples': len(self.cnn_adapter.training_data), + 'max_training_samples': self.cnn_adapter.max_training_samples + } \ No newline at end of file diff --git a/core/enhanced_cnn_integration.py b/core/enhanced_cnn_integration.py new file mode 100644 index 0000000..78bef98 --- /dev/null +++ b/core/enhanced_cnn_integration.py @@ -0,0 +1,403 @@ +""" +Enhanced CNN Integration for Dashboard + +This module integrates the EnhancedCNNAdapter with the dashboard, providing real-time +training and inference capabilities. +""" + +import logging +import threading +import time +from datetime import datetime +from typing import Dict, List, Optional, Any, Union +import os + +from .enhanced_cnn_adapter import EnhancedCNNAdapter +from .standardized_data_provider import StandardizedDataProvider +from .data_models import BaseDataInput, ModelOutput, create_model_output + +logger = logging.getLogger(__name__) + +class EnhancedCNNIntegration: + """ + Integration of EnhancedCNNAdapter with the dashboard + + This class: + 1. Manages the EnhancedCNNAdapter lifecycle + 2. Provides real-time training and inference + 3. Collects and reports performance metrics + 4. Integrates with the dashboard's model visualization + """ + + def __init__(self, data_provider: StandardizedDataProvider, checkpoint_dir: str = "models/enhanced_cnn"): + """ + Initialize the EnhancedCNNIntegration + + Args: + data_provider: StandardizedDataProvider instance + checkpoint_dir: Directory to store checkpoints + """ + self.data_provider = data_provider + self.checkpoint_dir = checkpoint_dir + self.model_name = "enhanced_cnn_v1" + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # Initialize CNN adapter + self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir) + + # Load best checkpoint if available + self.cnn_adapter.load_best_checkpoint() + + # Performance tracking + self.inference_times = [] + self.training_times = [] + self.total_inferences = 0 + self.total_training_runs = 0 + self.last_inference_time = None + self.last_training_time = None + self.inference_rate = 0.0 + self.training_rate = 0.0 + self.daily_inferences = 0 + self.daily_training_runs = 0 + + # Training settings + self.training_enabled = True + self.inference_enabled = True + self.training_frequency = 10 # Train every N inferences + self.training_batch_size = 32 + self.training_epochs = 1 + + # Latest prediction + self.latest_prediction = None + self.latest_prediction_time = None + + # Training metrics + self.current_loss = 0.0 + self.initial_loss = None + self.best_loss = None + self.current_accuracy = 0.0 + self.improvement_percentage = 0.0 + + # Training thread + self.training_thread = None + self.training_active = False + self.stop_training = False + + logger.info(f"EnhancedCNNIntegration initialized with model: {self.model_name}") + + def start_continuous_training(self): + """Start continuous training in a background thread""" + if self.training_thread is not None and self.training_thread.is_alive(): + logger.info("Continuous training already running") + return + + self.stop_training = False + self.training_thread = threading.Thread(target=self._continuous_training_loop, daemon=True) + self.training_thread.start() + logger.info("Started continuous training thread") + + def stop_continuous_training(self): + """Stop continuous training""" + self.stop_training = True + logger.info("Stopping continuous training thread") + + def _continuous_training_loop(self): + """Continuous training loop""" + try: + self.training_active = True + logger.info("Starting continuous training loop") + + while not self.stop_training: + # Check if training is enabled + if not self.training_enabled: + time.sleep(5) + continue + + # Check if we have enough training samples + if len(self.cnn_adapter.training_data) < self.training_batch_size: + logger.debug(f"Not enough training samples: {len(self.cnn_adapter.training_data)}/{self.training_batch_size}") + time.sleep(5) + continue + + # Train model + start_time = time.time() + metrics = self.cnn_adapter.train(epochs=self.training_epochs) + training_time = time.time() - start_time + + # Update metrics + self.training_times.append(training_time) + if len(self.training_times) > 100: + self.training_times.pop(0) + + self.total_training_runs += 1 + self.daily_training_runs += 1 + self.last_training_time = datetime.now() + + # Calculate training rate + if self.training_times: + avg_training_time = sum(self.training_times) / len(self.training_times) + self.training_rate = 1.0 / avg_training_time if avg_training_time > 0 else 0.0 + + # Update loss and accuracy + self.current_loss = metrics.get('loss', 0.0) + self.current_accuracy = metrics.get('accuracy', 0.0) + + # Update initial loss if not set + if self.initial_loss is None: + self.initial_loss = self.current_loss + + # Update best loss + if self.best_loss is None or self.current_loss < self.best_loss: + self.best_loss = self.current_loss + + # Calculate improvement percentage + if self.initial_loss is not None and self.initial_loss > 0: + self.improvement_percentage = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100 + + logger.info(f"Training completed: loss={self.current_loss:.4f}, accuracy={self.current_accuracy:.4f}, samples={metrics.get('samples', 0)}") + + # Sleep before next training + time.sleep(10) + + except Exception as e: + logger.error(f"Error in continuous training loop: {e}") + finally: + self.training_active = False + + def predict(self, symbol: str) -> Optional[ModelOutput]: + """ + Make a prediction using the EnhancedCNN model + + Args: + symbol: Trading symbol + + Returns: + ModelOutput: Standardized model output + """ + try: + # Check if inference is enabled + if not self.inference_enabled: + return None + + # Get standardized input data + base_data = self.data_provider.get_base_data_input(symbol) + + if base_data is None: + logger.warning(f"Failed to get base data input for {symbol}") + return None + + # Make prediction + start_time = time.time() + model_output = self.cnn_adapter.predict(base_data) + inference_time = time.time() - start_time + + # Update metrics + self.inference_times.append(inference_time) + if len(self.inference_times) > 100: + self.inference_times.pop(0) + + self.total_inferences += 1 + self.daily_inferences += 1 + self.last_inference_time = datetime.now() + + # Calculate inference rate + if self.inference_times: + avg_inference_time = sum(self.inference_times) / len(self.inference_times) + self.inference_rate = 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0 + + # Store latest prediction + self.latest_prediction = model_output + self.latest_prediction_time = datetime.now() + + # Store model output in data provider + self.data_provider.store_model_output(model_output) + + # Add training sample if we have a price + current_price = self._get_current_price(symbol) + if current_price and current_price > 0: + # Simulate market feedback based on price movement + # In a real system, this would be replaced with actual market performance data + action = model_output.predictions['action'] + + # For demonstration, we'll use a simple heuristic: + # - If price is above 3000, BUY is good + # - If price is below 3000, SELL is good + # - Otherwise, HOLD is good + if current_price > 3000: + best_action = 'BUY' + elif current_price < 3000: + best_action = 'SELL' + else: + best_action = 'HOLD' + + # Calculate reward based on whether the action matched the best action + if action == best_action: + reward = 0.05 # Positive reward for correct action + else: + reward = -0.05 # Negative reward for incorrect action + + # Add training sample + self.cnn_adapter.add_training_sample(base_data, best_action, reward) + + logger.debug(f"Added training sample for {symbol}, action: {action}, best_action: {best_action}, reward: {reward:.4f}") + + return model_output + + except Exception as e: + logger.error(f"Error making prediction: {e}") + return None + + def _get_current_price(self, symbol: str) -> Optional[float]: + """Get current price for a symbol""" + try: + # Try to get price from data provider + if hasattr(self.data_provider, 'current_prices'): + binance_symbol = symbol.replace('/', '').upper() + if binance_symbol in self.data_provider.current_prices: + return self.data_provider.current_prices[binance_symbol] + + # Try to get price from latest OHLCV data + df = self.data_provider.get_historical_data(symbol, '1s', 1) + if df is not None and not df.empty: + return float(df.iloc[-1]['close']) + + return None + + except Exception as e: + logger.error(f"Error getting current price: {e}") + return None + + def get_model_state(self) -> Dict[str, Any]: + """ + Get model state for dashboard display + + Returns: + Dict[str, Any]: Model state + """ + try: + # Format prediction for display + prediction_info = "FRESH" + confidence = 0.0 + + if self.latest_prediction: + action = self.latest_prediction.predictions.get('action', 'UNKNOWN') + confidence = self.latest_prediction.confidence + + # Map action to display text + if action == 'BUY': + prediction_info = "BUY_SIGNAL" + elif action == 'SELL': + prediction_info = "SELL_SIGNAL" + elif action == 'HOLD': + prediction_info = "HOLD_SIGNAL" + else: + prediction_info = "PATTERN_ANALYSIS" + + # Format timing information + inference_timing = "None" + training_timing = "None" + + if self.last_inference_time: + inference_timing = self.last_inference_time.strftime('%H:%M:%S') + + if self.last_training_time: + training_timing = self.last_training_time.strftime('%H:%M:%S') + + # Calculate improvement percentage + improvement = 0.0 + if self.initial_loss is not None and self.initial_loss > 0 and self.current_loss > 0: + improvement = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100 + + return { + 'model_name': self.model_name, + 'model_type': 'cnn', + 'parameters': 50000000, # 50M parameters + 'status': 'ACTIVE' if self.inference_enabled else 'DISABLED', + 'checkpoint_loaded': True, # Assume checkpoint is loaded + 'last_prediction': prediction_info, + 'confidence': confidence * 100, # Convert to percentage + 'last_inference_time': inference_timing, + 'last_training_time': training_timing, + 'inference_rate': self.inference_rate, + 'training_rate': self.training_rate, + 'daily_inferences': self.daily_inferences, + 'daily_training_runs': self.daily_training_runs, + 'initial_loss': self.initial_loss, + 'current_loss': self.current_loss, + 'best_loss': self.best_loss, + 'current_accuracy': self.current_accuracy, + 'improvement_percentage': improvement, + 'training_active': self.training_active, + 'training_enabled': self.training_enabled, + 'inference_enabled': self.inference_enabled, + 'training_samples': len(self.cnn_adapter.training_data) + } + + except Exception as e: + logger.error(f"Error getting model state: {e}") + return { + 'model_name': self.model_name, + 'model_type': 'cnn', + 'parameters': 50000000, # 50M parameters + 'status': 'ERROR', + 'error': str(e) + } + + def get_pivot_prediction(self) -> Dict[str, Any]: + """ + Get pivot prediction for dashboard display + + Returns: + Dict[str, Any]: Pivot prediction + """ + try: + if not self.latest_prediction: + return { + 'next_pivot': 0.0, + 'pivot_type': 'UNKNOWN', + 'confidence': 0.0, + 'time_to_pivot': 0 + } + + # Extract pivot prediction from model output + extrema_pred = self.latest_prediction.predictions.get('extrema', [0, 0, 0]) + + # Determine pivot type (0=bottom, 1=top, 2=neither) + pivot_type_idx = extrema_pred.index(max(extrema_pred)) + pivot_types = ['BOTTOM', 'TOP', 'RANGE_CONTINUATION'] + pivot_type = pivot_types[pivot_type_idx] + + # Get current price + current_price = self._get_current_price('ETH/USDT') or 0.0 + + # Calculate next pivot price (simple heuristic for demonstration) + if pivot_type == 'BOTTOM': + next_pivot = current_price * 0.95 # 5% below current price + elif pivot_type == 'TOP': + next_pivot = current_price * 1.05 # 5% above current price + else: + next_pivot = current_price # Same as current price + + # Calculate confidence + confidence = max(extrema_pred) * 100 # Convert to percentage + + # Calculate time to pivot (simple heuristic for demonstration) + time_to_pivot = 5 # 5 minutes + + return { + 'next_pivot': next_pivot, + 'pivot_type': pivot_type, + 'confidence': confidence, + 'time_to_pivot': time_to_pivot + } + + except Exception as e: + logger.error(f"Error getting pivot prediction: {e}") + return { + 'next_pivot': 0.0, + 'pivot_type': 'ERROR', + 'confidence': 0.0, + 'time_to_pivot': 0 + } \ No newline at end of file diff --git a/test_cob_data_stability.py b/test_cob_data_stability.py new file mode 100644 index 0000000..aec4dfa --- /dev/null +++ b/test_cob_data_stability.py @@ -0,0 +1,123 @@ +import asyncio +import logging +import time +from collections import deque +from datetime import datetime, timedelta + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.colors import LogNorm + +from core.data_provider import DataProvider, MarketTick + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +class COBStabilityTester: + def __init__(self, symbol='ETH/USDT', duration_seconds=15): + self.symbol = symbol + self.duration = timedelta(seconds=duration_seconds) + self.ticks = deque() + self.data_provider = DataProvider(symbols=[self.symbol], timeframes=['1s']) + self.start_time = None + self.subscriber_id = None + + def _tick_callback(self, tick: MarketTick): + """Callback function to receive ticks from the DataProvider.""" + if self.start_time is None: + self.start_time = datetime.now() + logger.info(f"Started collecting ticks at {self.start_time}") + + # Store all ticks + self.ticks.append(tick) + + async def run_test(self): + """Run the data collection and plotting test.""" + logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...") + + # Subscribe to ticks + self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol]) + + # Start the data provider's real-time streaming + await self.data_provider.start_real_time_streaming() + + # Collect data for the specified duration + self.start_time = datetime.now() + while datetime.now() - self.start_time < self.duration: + await asyncio.sleep(1) + logger.info(f"Collected {len(self.ticks)} ticks so far...") + + # Stop streaming and unsubscribe + await self.data_provider.stop_real_time_streaming() + self.data_provider.unsubscribe_from_ticks(self.subscriber_id) + + logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}") + + # Plot the results + if self.ticks: + self.plot_spectrogram() + else: + logger.warning("No ticks were collected. Cannot generate plot.") + + def plot_spectrogram(self): + """Create a spectrogram-like plot of trade intensity.""" + if not self.ticks: + logger.warning("No ticks to plot.") + return + + df = pd.DataFrame([{ + 'timestamp': tick.timestamp, + 'price': tick.price, + 'volume': tick.volume, + 'side': 1 if tick.side == 'buy' else -1 + } for tick in self.ticks]) + + df['timestamp'] = pd.to_datetime(df['timestamp']) + df = df.set_index('timestamp') + + # Create the plot + fig, ax = plt.subplots(figsize=(15, 8)) + + # Define bins for the 2D histogram + time_bins = pd.date_range(df.index.min(), df.index.max(), periods=100) + price_bins = np.linspace(df['price'].min(), df['price'].max(), 100) + + # Create the 2D histogram + # x-axis: time, y-axis: price, weights: volume + h, xedges, yedges = np.histogram2d( + df.index.astype(np.int64) // 10**9, + df['price'], + bins=[time_bins.astype(np.int64) // 10**9, price_bins], + weights=df['volume'] + ) + + # Use a logarithmic color scale for better visibility of smaller trades + pcm = ax.pcolormesh(time_bins, price_bins, h.T, norm=LogNorm(vmin=1e-3, vmax=h.max()), cmap='inferno') + + fig.colorbar(pcm, ax=ax, label='Trade Volume (USDT)') + ax.set_title(f'Trade Intensity Spectrogram for {self.symbol}') + ax.set_xlabel('Time') + ax.set_ylabel('Price (USDT)') + + # Format the x-axis to show time properly + fig.autofmt_xdate() + + plot_filename = f"cob_stability_spectrogram_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" + plt.savefig(plot_filename) + logger.info(f"Plot saved to {plot_filename}") + plt.show() + + +async def main(): + tester = COBStabilityTester() + await tester.run_test() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Test interrupted by user.") \ No newline at end of file