import asyncio import logging import time from collections import deque from datetime import datetime, timedelta import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib.colors import LogNorm from core.data_provider import DataProvider, MarketTick from core.config import get_config # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class COBStabilityTester: def __init__(self, symbol='ETHUSDT', duration_seconds=15): self.symbol = symbol self.duration = timedelta(seconds=duration_seconds) self.ticks = deque() # Set granularity (buckets) based on symbol if 'ETH' in symbol.upper(): self.price_granularity = 1.0 # 1 USD for ETH elif 'BTC' in symbol.upper(): self.price_granularity = 10.0 # 10 USD for BTC else: self.price_granularity = 1.0 # Default 1 USD logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}") # Initialize DataProvider the same way as clean_dashboard logger.info("Initializing DataProvider like in clean_dashboard...") self.data_provider = DataProvider() # Use default constructor like clean_dashboard # Initialize COB data collection like clean_dashboard does self.cob_data_received = 0 self.latest_cob_data = {} # Store all COB snapshots for heatmap generation self.cob_snapshots = deque() self.price_data = [] # For price line chart self.start_time = None self.subscriber_id = None self.last_log_time = None def _tick_callback(self, tick: MarketTick): """Callback function to receive ticks from the DataProvider.""" if self.start_time is None: self.start_time = datetime.now() logger.info(f"Started collecting ticks at {self.start_time}") # Store all ticks self.ticks.append(tick) def _cob_data_callback(self, symbol: str, cob_data: dict): """Callback function to receive COB data from the DataProvider.""" # Debug: Log first few callbacks to see what symbols we're getting if self.cob_data_received < 5: logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')") # Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT) normalized_symbol = symbol.replace('/', '') normalized_target = self.symbol.replace('/', '') if normalized_symbol != normalized_target: if self.cob_data_received < 5: logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')") return self.cob_data_received += 1 self.latest_cob_data[symbol] = cob_data # Store the complete COB snapshot for heatmap generation if 'bids' in cob_data and 'asks' in cob_data: # Debug: Log structure of first few COB snapshots if len(self.cob_snapshots) < 3: logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items") if cob_data['bids']: logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}") if cob_data['asks']: logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}") snapshot = { 'timestamp': cob_data.get('timestamp', datetime.now()), 'bids': cob_data['bids'], 'asks': cob_data['asks'], 'stats': cob_data.get('stats', {}) } self.cob_snapshots.append(snapshot) # Log bucketed COB data every second now = datetime.now() if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0: self.last_log_time = now self._log_bucketed_cob_data(cob_data) # Convert COB data to tick-like format for analysis if 'stats' in cob_data and 'mid_price' in cob_data['stats']: mid_price = cob_data['stats']['mid_price'] if mid_price > 0: # Store price data for line chart self.price_data.append({ 'timestamp': cob_data.get('timestamp', datetime.now()), 'price': mid_price }) # Create a synthetic tick from COB data synthetic_tick = MarketTick( symbol=symbol, timestamp=cob_data.get('timestamp', datetime.now()), price=mid_price, volume=cob_data.get('stats', {}).get('total_volume', 0), quantity=0, # Not available in COB data side='unknown', # COB data doesn't have side info trade_id=f"cob_{self.cob_data_received}", is_buyer_maker=False, raw_data=cob_data ) self.ticks.append(synthetic_tick) if self.cob_data_received % 10 == 0: # Log every 10th update logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}") def _log_bucketed_cob_data(self, cob_data: dict): """Log bucketed COB data every second""" try: if 'bids' not in cob_data or 'asks' not in cob_data: logger.info("COB-1s: No order book data available") return if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']: logger.info("COB-1s: No mid price available") return mid_price = cob_data['stats']['mid_price'] if mid_price <= 0: return # Bucket the order book data bid_buckets = {} ask_buckets = {} # Process bids (top 10) for bid in cob_data['bids'][:10]: try: if isinstance(bid, dict): price = float(bid['price']) size = float(bid['size']) elif isinstance(bid, (list, tuple)) and len(bid) >= 2: price = float(bid[0]) size = float(bid[1]) else: continue bucketed_price = round(price / self.price_granularity) * self.price_granularity bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size except (ValueError, TypeError, IndexError): continue # Process asks (top 10) for ask in cob_data['asks'][:10]: try: if isinstance(ask, dict): price = float(ask['price']) size = float(ask['size']) elif isinstance(ask, (list, tuple)) and len(ask) >= 2: price = float(ask[0]) size = float(ask[1]) else: continue bucketed_price = round(price / self.price_granularity) * self.price_granularity ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size except (ValueError, TypeError, IndexError): continue # Format for log output bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)]) ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())]) logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}") except Exception as e: logger.warning(f"Error logging bucketed COB data: {e}") async def run_test(self): """Run the data collection and plotting test.""" logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...") # Initialize COB collection like clean_dashboard does try: logger.info("Starting COB collection in data provider...") self.data_provider.start_cob_collection() logger.info("Started COB collection in data provider") # Subscribe to COB updates logger.info("Subscribing to COB data updates...") self.data_provider.subscribe_to_cob(self._cob_data_callback) logger.info("Subscribed to COB data updates from data provider") except Exception as e: logger.error(f"Failed to start COB collection or subscribe: {e}") # Subscribe to ticks as fallback try: self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol]) logger.info("Subscribed to tick data as fallback") except Exception as e: logger.warning(f"Failed to subscribe to ticks: {e}") # Start the data provider's real-time streaming try: await self.data_provider.start_real_time_streaming() logger.info("Started real-time streaming") except Exception as e: logger.error(f"Failed to start real-time streaming: {e}") # Collect data for the specified duration self.start_time = datetime.now() while datetime.now() - self.start_time < self.duration: await asyncio.sleep(1) logger.info(f"Collected {len(self.ticks)} ticks so far...") # Stop streaming and unsubscribe await self.data_provider.stop_real_time_streaming() self.data_provider.unsubscribe_from_ticks(self.subscriber_id) logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}") # Plot the results if self.price_data and self.cob_snapshots: self.create_price_heatmap_chart() elif self.ticks: self._create_simple_price_chart() else: logger.warning("No data was collected. Cannot generate plot.") def create_price_heatmap_chart(self): """Create a visualization with price chart and order book heatmap.""" if not self.price_data or not self.cob_snapshots: logger.warning("Insufficient data to plot.") return logger.info(f"Creating price and order book heatmap chart...") logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots") # Prepare price data with consistent timestamp handling price_df = pd.DataFrame(self.price_data) price_df['timestamp'] = pd.to_datetime(price_df['timestamp']) logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}") logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}") # Extract order book data for heatmap with consistent timestamp handling heatmap_data = [] for snapshot in self.cob_snapshots: timestamp = pd.to_datetime(snapshot['timestamp']) # Ensure datetime for side in ['bids', 'asks']: if side not in snapshot or not snapshot[side]: continue # Take top 50 levels for better visualization orders = snapshot[side][:50] for order in orders: try: # Handle both dict and list formats if isinstance(order, dict): price = float(order['price']) size = float(order['size']) elif isinstance(order, (list, tuple)) and len(order) >= 2: price = float(order[0]) size = float(order[1]) else: continue # Apply granularity bucketing bucketed_price = round(price / self.price_granularity) * self.price_granularity heatmap_data.append({ 'time': timestamp, 'price': bucketed_price, 'size': size, 'side': side }) except (ValueError, TypeError, IndexError) as e: continue if not heatmap_data: logger.warning("No valid heatmap data found, creating price chart only") self._create_simple_price_chart() return heatmap_df = pd.DataFrame(heatmap_data) logger.info(f"Heatmap data: {len(heatmap_df)} order book entries") logger.info(f"Heatmap time range: {heatmap_df['time'].min()} to {heatmap_df['time'].max()}") # Create plot with better time handling fig, ax = plt.subplots(figsize=(16, 10)) # Determine overall time range all_times = pd.concat([price_df['timestamp'], heatmap_df['time']]) time_min = all_times.min() time_max = all_times.max() # Create price range for heatmap price_min = min(price_df['price'].min(), heatmap_df['price'].min()) - self.price_granularity * 2 price_max = max(price_df['price'].max(), heatmap_df['price'].max()) + self.price_granularity * 2 logger.info(f"Chart time range: {time_min} to {time_max}") logger.info(f"Chart price range: ${price_min:.2f} to ${price_max:.2f}") # Create heatmap first (background) for side, cmap, alpha in zip(['bids', 'asks'], ['Greens', 'Reds'], [0.6, 0.6]): side_df = heatmap_df[heatmap_df['side'] == side] if not side_df.empty: # Create more granular bins time_bins = pd.date_range(time_min, time_max, periods=min(100, len(side_df) // 10 + 10)) price_bins = np.arange(price_min, price_max + self.price_granularity, self.price_granularity) try: # Convert to seconds for histogram time_seconds = (side_df['time'] - time_min).dt.total_seconds() time_range_seconds = (time_max - time_min).total_seconds() if time_range_seconds > 0: hist, xedges, yedges = np.histogram2d( time_seconds, side_df['price'], bins=[np.linspace(0, time_range_seconds, len(time_bins)), price_bins], weights=side_df['size'] ) # Convert back to datetime for plotting time_edges = pd.to_datetime(xedges, unit='s', origin=time_min) if hist.max() > 0: # Only plot if we have data pcm = ax.pcolormesh(time_edges, yedges, hist.T, cmap=cmap, alpha=alpha, shading='auto') logger.info(f"Plotted {side} heatmap: max value = {hist.max():.2f}") except Exception as e: logger.warning(f"Error creating {side} heatmap: {e}") # Plot price line on top ax.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', alpha=0.9, zorder=10) # Enhance plot appearance ax.set_title(f'Price Chart with Order Book Heatmap - {self.symbol}\n' f'Granularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s\n' f'Green=Bids, Red=Asks (darker = more volume)', fontsize=14) ax.set_xlabel('Time') ax.set_ylabel('Price (USDT)') ax.legend(loc='upper left') ax.grid(True, alpha=0.3) # Format time axis ax.set_xlim(time_min, time_max) fig.autofmt_xdate() plt.tight_layout() plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" plt.savefig(plot_filename, dpi=150, bbox_inches='tight') logger.info(f"Price and heatmap chart saved to {plot_filename}") plt.show() def _create_simple_price_chart(self): """Create a simple price chart as fallback""" logger.info("Creating simple price chart as fallback...") prices = [] times = [] for tick in self.ticks: if tick.price > 0: prices.append(tick.price) times.append(tick.timestamp) if not prices: logger.warning("No price data to plot") return fig, ax = plt.subplots(figsize=(15, 8)) ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1) ax.set_title(f'Price Chart - {self.symbol}') ax.set_xlabel('Time') ax.set_ylabel('Price (USDT)') fig.autofmt_xdate() plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png" plt.savefig(plot_filename) logger.info(f"Price chart saved to {plot_filename}") plt.show() async def main(symbol='ETHUSDT', duration_seconds=15): """Main function to run the COB test with configurable parameters. Args: symbol: Trading symbol (default: ETHUSDT) duration_seconds: Test duration in seconds (default: 15) """ logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s") tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds) await tester.run_test() if __name__ == "__main__": import sys # Parse command line arguments symbol = 'ETHUSDT' # Default duration = 15 # Default if len(sys.argv) > 1: symbol = sys.argv[1] if len(sys.argv) > 2: try: duration = int(sys.argv[2]) except ValueError: logger.warning(f"Invalid duration '{sys.argv[2]}', using default 15 seconds") logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s") logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}") try: asyncio.run(main(symbol, duration)) except KeyboardInterrupt: logger.info("Test interrupted by user.")