integrate CNN, fix COB data

This commit is contained in:
Dobromir Popov
2025-07-23 22:12:10 +03:00
parent 45a62443a0
commit 26e6ba2e1d
3 changed files with 891 additions and 0 deletions

View File

@ -0,0 +1,365 @@
"""
Dashboard CNN Integration
This module integrates the EnhancedCNNAdapter with the dashboard system,
providing real-time training, predictions, and performance metrics display.
"""
import logging
import time
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from collections import deque
import numpy as np
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .standardized_data_provider import StandardizedDataProvider
from .data_models import BaseDataInput, ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class DashboardCNNIntegration:
"""
CNN integration for the dashboard system
This class:
1. Manages CNN model lifecycle in the dashboard
2. Provides real-time training and inference
3. Tracks performance metrics for dashboard display
4. Handles model predictions for chart overlay
"""
def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None):
"""
Initialize the dashboard CNN integration
Args:
data_provider: Standardized data provider
symbols: List of symbols to process
"""
self.data_provider = data_provider
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
# Initialize CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
# Performance tracking
self.performance_metrics = {
'total_predictions': 0,
'total_training_samples': 0,
'last_training_time': None,
'last_inference_time': None,
'training_loss_history': deque(maxlen=100),
'accuracy_history': deque(maxlen=100),
'inference_times': deque(maxlen=100),
'training_times': deque(maxlen=100),
'predictions_per_second': 0.0,
'training_per_second': 0.0,
'model_status': 'FRESH',
'confidence_history': deque(maxlen=100),
'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0}
}
# Prediction cache for dashboard display
self.prediction_cache = {}
self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols}
# Training control
self.training_enabled = True
self.inference_enabled = True
self.training_lock = threading.Lock()
# Real-time processing
self.is_running = False
self.processing_thread = None
logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}")
def start_real_time_processing(self):
"""Start real-time CNN processing"""
if self.is_running:
logger.warning("Real-time processing already running")
return
self.is_running = True
self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started real-time CNN processing")
def stop_real_time_processing(self):
"""Stop real-time CNN processing"""
self.is_running = False
if self.processing_thread:
self.processing_thread.join(timeout=5)
logger.info("Stopped real-time CNN processing")
def _real_time_processing_loop(self):
"""Main real-time processing loop"""
last_prediction_time = {}
prediction_interval = 1.0 # Make prediction every 1 second
while self.is_running:
try:
current_time = time.time()
for symbol in self.symbols:
# Check if it's time to make a prediction for this symbol
if (symbol not in last_prediction_time or
current_time - last_prediction_time[symbol] >= prediction_interval):
# Make prediction if inference is enabled
if self.inference_enabled:
self._make_prediction(symbol)
last_prediction_time[symbol] = current_time
# Update performance metrics
self._update_performance_metrics()
# Sleep briefly to prevent overwhelming the system
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in real-time processing loop: {e}")
time.sleep(1)
def _make_prediction(self, symbol: str):
"""Make a prediction for a symbol"""
try:
start_time = time.time()
# Get standardized input data
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.debug(f"No base data available for {symbol}")
return
# Make prediction
model_output = self.cnn_adapter.predict(base_data)
# Record inference time
inference_time = time.time() - start_time
self.performance_metrics['inference_times'].append(inference_time)
# Update performance metrics
self.performance_metrics['total_predictions'] += 1
self.performance_metrics['last_inference_time'] = datetime.now()
self.performance_metrics['confidence_history'].append(model_output.confidence)
# Update action distribution
action = model_output.predictions['action']
self.performance_metrics['action_distribution'][action] += 1
# Cache prediction for dashboard
self.prediction_cache[symbol] = model_output
self.prediction_history[symbol].append(model_output)
# Store model output in data provider
self.data_provider.store_model_output(model_output)
logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})")
except Exception as e:
logger.error(f"Error making prediction for {symbol}: {e}")
def add_training_sample(self, symbol: str, actual_action: str, reward: float):
"""Add a training sample and trigger training if enabled"""
try:
if not self.training_enabled:
return
# Get base data for the symbol
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.debug(f"No base data available for training sample: {symbol}")
return
# Add training sample
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
# Update metrics
self.performance_metrics['total_training_samples'] += 1
# Train model periodically (every 10 samples)
if self.performance_metrics['total_training_samples'] % 10 == 0:
self._train_model()
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def _train_model(self):
"""Train the CNN model"""
try:
with self.training_lock:
start_time = time.time()
# Train model
metrics = self.cnn_adapter.train(epochs=1)
# Record training time
training_time = time.time() - start_time
self.performance_metrics['training_times'].append(training_time)
# Update performance metrics
self.performance_metrics['last_training_time'] = datetime.now()
if 'loss' in metrics:
self.performance_metrics['training_loss_history'].append(metrics['loss'])
if 'accuracy' in metrics:
self.performance_metrics['accuracy_history'].append(metrics['accuracy'])
# Update model status
if metrics.get('accuracy', 0) > 0.5:
self.performance_metrics['model_status'] = 'TRAINED'
else:
self.performance_metrics['model_status'] = 'TRAINING'
logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}")
except Exception as e:
logger.error(f"Error training CNN model: {e}")
def _update_performance_metrics(self):
"""Update performance metrics for dashboard display"""
try:
current_time = time.time()
# Calculate predictions per second (last 60 seconds)
recent_inferences = [t for t in self.performance_metrics['inference_times']
if current_time - t <= 60]
self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0
# Calculate training per second (last 60 seconds)
recent_trainings = [t for t in self.performance_metrics['training_times']
if current_time - t <= 60]
self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0
except Exception as e:
logger.error(f"Error updating performance metrics: {e}")
def get_dashboard_metrics(self) -> Dict[str, Any]:
"""Get metrics for dashboard display"""
try:
# Calculate current loss
current_loss = (self.performance_metrics['training_loss_history'][-1]
if self.performance_metrics['training_loss_history'] else 0.0)
# Calculate current accuracy
current_accuracy = (self.performance_metrics['accuracy_history'][-1]
if self.performance_metrics['accuracy_history'] else 0.0)
# Calculate average confidence
avg_confidence = (np.mean(list(self.performance_metrics['confidence_history']))
if self.performance_metrics['confidence_history'] else 0.0)
# Get latest prediction
latest_prediction = None
latest_symbol = None
for symbol, prediction in self.prediction_cache.items():
if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp:
latest_prediction = prediction
latest_symbol = symbol
# Format timing information
last_inference_str = "None"
last_training_str = "None"
if self.performance_metrics['last_inference_time']:
last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S")
if self.performance_metrics['last_training_time']:
last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S")
return {
'model_name': 'CNN',
'model_type': 'cnn',
'parameters': '50.0M',
'status': self.performance_metrics['model_status'],
'current_loss': current_loss,
'accuracy': current_accuracy,
'confidence': avg_confidence,
'total_predictions': self.performance_metrics['total_predictions'],
'total_training_samples': self.performance_metrics['total_training_samples'],
'predictions_per_second': self.performance_metrics['predictions_per_second'],
'training_per_second': self.performance_metrics['training_per_second'],
'last_inference': last_inference_str,
'last_training': last_training_str,
'latest_prediction': {
'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD',
'confidence': latest_prediction.confidence if latest_prediction else 0.0,
'symbol': latest_symbol or 'ETH/USDT',
'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None"
},
'action_distribution': self.performance_metrics['action_distribution'].copy(),
'training_enabled': self.training_enabled,
'inference_enabled': self.inference_enabled
}
except Exception as e:
logger.error(f"Error getting dashboard metrics: {e}")
return {
'model_name': 'CNN',
'model_type': 'cnn',
'parameters': '50.0M',
'status': 'ERROR',
'current_loss': 0.0,
'accuracy': 0.0,
'confidence': 0.0,
'error': str(e)
}
def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]:
"""Get predictions for chart overlay"""
try:
if symbol not in self.prediction_history:
return []
predictions = list(self.prediction_history[symbol])[-limit:]
chart_data = []
for prediction in predictions:
chart_data.append({
'timestamp': prediction.timestamp,
'action': prediction.predictions['action'],
'confidence': prediction.confidence,
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
})
return chart_data
except Exception as e:
logger.error(f"Error getting predictions for chart: {e}")
return []
def set_training_enabled(self, enabled: bool):
"""Enable or disable training"""
self.training_enabled = enabled
logger.info(f"CNN training {'enabled' if enabled else 'disabled'}")
def set_inference_enabled(self, enabled: bool):
"""Enable or disable inference"""
self.inference_enabled = enabled
logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}")
def get_model_info(self) -> Dict[str, Any]:
"""Get model information for dashboard"""
return {
'name': 'Enhanced CNN',
'version': '1.0',
'parameters': '50.0M',
'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown',
'device': str(self.cnn_adapter.device),
'checkpoint_dir': self.cnn_adapter.checkpoint_dir,
'training_samples': len(self.cnn_adapter.training_data),
'max_training_samples': self.cnn_adapter.max_training_samples
}

View File

@ -0,0 +1,403 @@
"""
Enhanced CNN Integration for Dashboard
This module integrates the EnhancedCNNAdapter with the dashboard, providing real-time
training and inference capabilities.
"""
import logging
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Union
import os
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .standardized_data_provider import StandardizedDataProvider
from .data_models import BaseDataInput, ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class EnhancedCNNIntegration:
"""
Integration of EnhancedCNNAdapter with the dashboard
This class:
1. Manages the EnhancedCNNAdapter lifecycle
2. Provides real-time training and inference
3. Collects and reports performance metrics
4. Integrates with the dashboard's model visualization
"""
def __init__(self, data_provider: StandardizedDataProvider, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the EnhancedCNNIntegration
Args:
data_provider: StandardizedDataProvider instance
checkpoint_dir: Directory to store checkpoints
"""
self.data_provider = data_provider
self.checkpoint_dir = checkpoint_dir
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
# Performance tracking
self.inference_times = []
self.training_times = []
self.total_inferences = 0
self.total_training_runs = 0
self.last_inference_time = None
self.last_training_time = None
self.inference_rate = 0.0
self.training_rate = 0.0
self.daily_inferences = 0
self.daily_training_runs = 0
# Training settings
self.training_enabled = True
self.inference_enabled = True
self.training_frequency = 10 # Train every N inferences
self.training_batch_size = 32
self.training_epochs = 1
# Latest prediction
self.latest_prediction = None
self.latest_prediction_time = None
# Training metrics
self.current_loss = 0.0
self.initial_loss = None
self.best_loss = None
self.current_accuracy = 0.0
self.improvement_percentage = 0.0
# Training thread
self.training_thread = None
self.training_active = False
self.stop_training = False
logger.info(f"EnhancedCNNIntegration initialized with model: {self.model_name}")
def start_continuous_training(self):
"""Start continuous training in a background thread"""
if self.training_thread is not None and self.training_thread.is_alive():
logger.info("Continuous training already running")
return
self.stop_training = False
self.training_thread = threading.Thread(target=self._continuous_training_loop, daemon=True)
self.training_thread.start()
logger.info("Started continuous training thread")
def stop_continuous_training(self):
"""Stop continuous training"""
self.stop_training = True
logger.info("Stopping continuous training thread")
def _continuous_training_loop(self):
"""Continuous training loop"""
try:
self.training_active = True
logger.info("Starting continuous training loop")
while not self.stop_training:
# Check if training is enabled
if not self.training_enabled:
time.sleep(5)
continue
# Check if we have enough training samples
if len(self.cnn_adapter.training_data) < self.training_batch_size:
logger.debug(f"Not enough training samples: {len(self.cnn_adapter.training_data)}/{self.training_batch_size}")
time.sleep(5)
continue
# Train model
start_time = time.time()
metrics = self.cnn_adapter.train(epochs=self.training_epochs)
training_time = time.time() - start_time
# Update metrics
self.training_times.append(training_time)
if len(self.training_times) > 100:
self.training_times.pop(0)
self.total_training_runs += 1
self.daily_training_runs += 1
self.last_training_time = datetime.now()
# Calculate training rate
if self.training_times:
avg_training_time = sum(self.training_times) / len(self.training_times)
self.training_rate = 1.0 / avg_training_time if avg_training_time > 0 else 0.0
# Update loss and accuracy
self.current_loss = metrics.get('loss', 0.0)
self.current_accuracy = metrics.get('accuracy', 0.0)
# Update initial loss if not set
if self.initial_loss is None:
self.initial_loss = self.current_loss
# Update best loss
if self.best_loss is None or self.current_loss < self.best_loss:
self.best_loss = self.current_loss
# Calculate improvement percentage
if self.initial_loss is not None and self.initial_loss > 0:
self.improvement_percentage = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
logger.info(f"Training completed: loss={self.current_loss:.4f}, accuracy={self.current_accuracy:.4f}, samples={metrics.get('samples', 0)}")
# Sleep before next training
time.sleep(10)
except Exception as e:
logger.error(f"Error in continuous training loop: {e}")
finally:
self.training_active = False
def predict(self, symbol: str) -> Optional[ModelOutput]:
"""
Make a prediction using the EnhancedCNN model
Args:
symbol: Trading symbol
Returns:
ModelOutput: Standardized model output
"""
try:
# Check if inference is enabled
if not self.inference_enabled:
return None
# Get standardized input data
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.warning(f"Failed to get base data input for {symbol}")
return None
# Make prediction
start_time = time.time()
model_output = self.cnn_adapter.predict(base_data)
inference_time = time.time() - start_time
# Update metrics
self.inference_times.append(inference_time)
if len(self.inference_times) > 100:
self.inference_times.pop(0)
self.total_inferences += 1
self.daily_inferences += 1
self.last_inference_time = datetime.now()
# Calculate inference rate
if self.inference_times:
avg_inference_time = sum(self.inference_times) / len(self.inference_times)
self.inference_rate = 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0
# Store latest prediction
self.latest_prediction = model_output
self.latest_prediction_time = datetime.now()
# Store model output in data provider
self.data_provider.store_model_output(model_output)
# Add training sample if we have a price
current_price = self._get_current_price(symbol)
if current_price and current_price > 0:
# Simulate market feedback based on price movement
# In a real system, this would be replaced with actual market performance data
action = model_output.predictions['action']
# For demonstration, we'll use a simple heuristic:
# - If price is above 3000, BUY is good
# - If price is below 3000, SELL is good
# - Otherwise, HOLD is good
if current_price > 3000:
best_action = 'BUY'
elif current_price < 3000:
best_action = 'SELL'
else:
best_action = 'HOLD'
# Calculate reward based on whether the action matched the best action
if action == best_action:
reward = 0.05 # Positive reward for correct action
else:
reward = -0.05 # Negative reward for incorrect action
# Add training sample
self.cnn_adapter.add_training_sample(base_data, best_action, reward)
logger.debug(f"Added training sample for {symbol}, action: {action}, best_action: {best_action}, reward: {reward:.4f}")
return model_output
except Exception as e:
logger.error(f"Error making prediction: {e}")
return None
def _get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol"""
try:
# Try to get price from data provider
if hasattr(self.data_provider, 'current_prices'):
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol in self.data_provider.current_prices:
return self.data_provider.current_prices[binance_symbol]
# Try to get price from latest OHLCV data
df = self.data_provider.get_historical_data(symbol, '1s', 1)
if df is not None and not df.empty:
return float(df.iloc[-1]['close'])
return None
except Exception as e:
logger.error(f"Error getting current price: {e}")
return None
def get_model_state(self) -> Dict[str, Any]:
"""
Get model state for dashboard display
Returns:
Dict[str, Any]: Model state
"""
try:
# Format prediction for display
prediction_info = "FRESH"
confidence = 0.0
if self.latest_prediction:
action = self.latest_prediction.predictions.get('action', 'UNKNOWN')
confidence = self.latest_prediction.confidence
# Map action to display text
if action == 'BUY':
prediction_info = "BUY_SIGNAL"
elif action == 'SELL':
prediction_info = "SELL_SIGNAL"
elif action == 'HOLD':
prediction_info = "HOLD_SIGNAL"
else:
prediction_info = "PATTERN_ANALYSIS"
# Format timing information
inference_timing = "None"
training_timing = "None"
if self.last_inference_time:
inference_timing = self.last_inference_time.strftime('%H:%M:%S')
if self.last_training_time:
training_timing = self.last_training_time.strftime('%H:%M:%S')
# Calculate improvement percentage
improvement = 0.0
if self.initial_loss is not None and self.initial_loss > 0 and self.current_loss > 0:
improvement = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
return {
'model_name': self.model_name,
'model_type': 'cnn',
'parameters': 50000000, # 50M parameters
'status': 'ACTIVE' if self.inference_enabled else 'DISABLED',
'checkpoint_loaded': True, # Assume checkpoint is loaded
'last_prediction': prediction_info,
'confidence': confidence * 100, # Convert to percentage
'last_inference_time': inference_timing,
'last_training_time': training_timing,
'inference_rate': self.inference_rate,
'training_rate': self.training_rate,
'daily_inferences': self.daily_inferences,
'daily_training_runs': self.daily_training_runs,
'initial_loss': self.initial_loss,
'current_loss': self.current_loss,
'best_loss': self.best_loss,
'current_accuracy': self.current_accuracy,
'improvement_percentage': improvement,
'training_active': self.training_active,
'training_enabled': self.training_enabled,
'inference_enabled': self.inference_enabled,
'training_samples': len(self.cnn_adapter.training_data)
}
except Exception as e:
logger.error(f"Error getting model state: {e}")
return {
'model_name': self.model_name,
'model_type': 'cnn',
'parameters': 50000000, # 50M parameters
'status': 'ERROR',
'error': str(e)
}
def get_pivot_prediction(self) -> Dict[str, Any]:
"""
Get pivot prediction for dashboard display
Returns:
Dict[str, Any]: Pivot prediction
"""
try:
if not self.latest_prediction:
return {
'next_pivot': 0.0,
'pivot_type': 'UNKNOWN',
'confidence': 0.0,
'time_to_pivot': 0
}
# Extract pivot prediction from model output
extrema_pred = self.latest_prediction.predictions.get('extrema', [0, 0, 0])
# Determine pivot type (0=bottom, 1=top, 2=neither)
pivot_type_idx = extrema_pred.index(max(extrema_pred))
pivot_types = ['BOTTOM', 'TOP', 'RANGE_CONTINUATION']
pivot_type = pivot_types[pivot_type_idx]
# Get current price
current_price = self._get_current_price('ETH/USDT') or 0.0
# Calculate next pivot price (simple heuristic for demonstration)
if pivot_type == 'BOTTOM':
next_pivot = current_price * 0.95 # 5% below current price
elif pivot_type == 'TOP':
next_pivot = current_price * 1.05 # 5% above current price
else:
next_pivot = current_price # Same as current price
# Calculate confidence
confidence = max(extrema_pred) * 100 # Convert to percentage
# Calculate time to pivot (simple heuristic for demonstration)
time_to_pivot = 5 # 5 minutes
return {
'next_pivot': next_pivot,
'pivot_type': pivot_type,
'confidence': confidence,
'time_to_pivot': time_to_pivot
}
except Exception as e:
logger.error(f"Error getting pivot prediction: {e}")
return {
'next_pivot': 0.0,
'pivot_type': 'ERROR',
'confidence': 0.0,
'time_to_pivot': 0
}

123
test_cob_data_stability.py Normal file
View File

@ -0,0 +1,123 @@
import asyncio
import logging
import time
from collections import deque
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import LogNorm
from core.data_provider import DataProvider, MarketTick
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class COBStabilityTester:
def __init__(self, symbol='ETH/USDT', duration_seconds=15):
self.symbol = symbol
self.duration = timedelta(seconds=duration_seconds)
self.ticks = deque()
self.data_provider = DataProvider(symbols=[self.symbol], timeframes=['1s'])
self.start_time = None
self.subscriber_id = None
def _tick_callback(self, tick: MarketTick):
"""Callback function to receive ticks from the DataProvider."""
if self.start_time is None:
self.start_time = datetime.now()
logger.info(f"Started collecting ticks at {self.start_time}")
# Store all ticks
self.ticks.append(tick)
async def run_test(self):
"""Run the data collection and plotting test."""
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
# Subscribe to ticks
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
# Start the data provider's real-time streaming
await self.data_provider.start_real_time_streaming()
# Collect data for the specified duration
self.start_time = datetime.now()
while datetime.now() - self.start_time < self.duration:
await asyncio.sleep(1)
logger.info(f"Collected {len(self.ticks)} ticks so far...")
# Stop streaming and unsubscribe
await self.data_provider.stop_real_time_streaming()
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
# Plot the results
if self.ticks:
self.plot_spectrogram()
else:
logger.warning("No ticks were collected. Cannot generate plot.")
def plot_spectrogram(self):
"""Create a spectrogram-like plot of trade intensity."""
if not self.ticks:
logger.warning("No ticks to plot.")
return
df = pd.DataFrame([{
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'side': 1 if tick.side == 'buy' else -1
} for tick in self.ticks])
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.set_index('timestamp')
# Create the plot
fig, ax = plt.subplots(figsize=(15, 8))
# Define bins for the 2D histogram
time_bins = pd.date_range(df.index.min(), df.index.max(), periods=100)
price_bins = np.linspace(df['price'].min(), df['price'].max(), 100)
# Create the 2D histogram
# x-axis: time, y-axis: price, weights: volume
h, xedges, yedges = np.histogram2d(
df.index.astype(np.int64) // 10**9,
df['price'],
bins=[time_bins.astype(np.int64) // 10**9, price_bins],
weights=df['volume']
)
# Use a logarithmic color scale for better visibility of smaller trades
pcm = ax.pcolormesh(time_bins, price_bins, h.T, norm=LogNorm(vmin=1e-3, vmax=h.max()), cmap='inferno')
fig.colorbar(pcm, ax=ax, label='Trade Volume (USDT)')
ax.set_title(f'Trade Intensity Spectrogram for {self.symbol}')
ax.set_xlabel('Time')
ax.set_ylabel('Price (USDT)')
# Format the x-axis to show time properly
fig.autofmt_xdate()
plot_filename = f"cob_stability_spectrogram_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename)
logger.info(f"Plot saved to {plot_filename}")
plt.show()
async def main():
tester = COBStabilityTester()
await tester.run_test()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Test interrupted by user.")