Files
gogo2/tests/cob/test_cob_data_stability.py
2025-07-23 23:33:36 +03:00

503 lines
22 KiB
Python

import asyncio
import logging
import time
from collections import deque
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import LogNorm
from core.data_provider import DataProvider, MarketTick
from core.config import get_config
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class COBStabilityTester:
def __init__(self, symbol='ETHUSDT', duration_seconds=10):
self.symbol = symbol
self.duration = timedelta(seconds=duration_seconds)
self.ticks = deque()
# Set granularity (buckets) based on symbol
if 'ETH' in symbol.upper():
self.price_granularity = 1.0 # 1 USD for ETH
elif 'BTC' in symbol.upper():
self.price_granularity = 10.0 # 10 USD for BTC
else:
self.price_granularity = 1.0 # Default 1 USD
logger.info(f"Using price granularity: ${self.price_granularity} for {symbol}")
# Initialize DataProvider the same way as clean_dashboard
logger.info("Initializing DataProvider like in clean_dashboard...")
self.data_provider = DataProvider() # Use default constructor like clean_dashboard
# Initialize COB data collection like clean_dashboard does
self.cob_data_received = 0
self.latest_cob_data = {}
# Store all COB snapshots for heatmap generation
self.cob_snapshots = deque()
self.price_data = [] # For price line chart
self.start_time = None
self.subscriber_id = None
self.last_log_time = None
def _tick_callback(self, tick: MarketTick):
"""Callback function to receive ticks from the DataProvider."""
if self.start_time is None:
self.start_time = datetime.now()
logger.info(f"Started collecting ticks at {self.start_time}")
# Store all ticks
self.ticks.append(tick)
def _cob_data_callback(self, symbol: str, cob_data: dict):
"""Callback function to receive COB data from the DataProvider."""
# Debug: Log first few callbacks to see what symbols we're getting
if self.cob_data_received < 5:
logger.info(f"DEBUG: Received COB data for symbol '{symbol}' (target: '{self.symbol}')")
# Filter to only our requested symbol - handle both formats (ETH/USDT and ETHUSDT)
normalized_symbol = symbol.replace('/', '')
normalized_target = self.symbol.replace('/', '')
if normalized_symbol != normalized_target:
if self.cob_data_received < 5:
logger.info(f"DEBUG: Skipping symbol '{symbol}' (normalized: '{normalized_symbol}' vs target: '{normalized_target}')")
return
self.cob_data_received += 1
self.latest_cob_data[symbol] = cob_data
# Store the complete COB snapshot for heatmap generation
if 'bids' in cob_data and 'asks' in cob_data:
# Debug: Log structure of first few COB snapshots
if len(self.cob_snapshots) < 3:
logger.info(f"DEBUG: COB data structure - bids: {len(cob_data['bids'])} items, asks: {len(cob_data['asks'])} items")
if cob_data['bids']:
logger.info(f"DEBUG: First bid: {cob_data['bids'][0]}")
if cob_data['asks']:
logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}")
# Use current time for timestamp consistency
current_time = datetime.now()
snapshot = {
'timestamp': current_time,
'bids': cob_data['bids'],
'asks': cob_data['asks'],
'stats': cob_data.get('stats', {})
}
self.cob_snapshots.append(snapshot)
# Log bucketed COB data every second
now = datetime.now()
if self.last_log_time is None or (now - self.last_log_time).total_seconds() >= 1.0:
self.last_log_time = now
self._log_bucketed_cob_data(cob_data)
# Convert COB data to tick-like format for analysis
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
mid_price = cob_data['stats']['mid_price']
if mid_price > 0:
# Filter out extreme price movements (±10% of recent average)
if len(self.price_data) > 5:
recent_prices = [p['price'] for p in self.price_data[-5:]]
avg_recent_price = sum(recent_prices) / len(recent_prices)
price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price
if price_deviation > 0.10: # More than 10% deviation
logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})")
return # Skip this data point
# Store price data for line chart with consistent timestamp
current_time = datetime.now()
self.price_data.append({
'timestamp': current_time,
'price': mid_price
})
# Create a synthetic tick from COB data with consistent timestamp
current_time = datetime.now()
synthetic_tick = MarketTick(
symbol=symbol,
timestamp=current_time,
price=mid_price,
volume=cob_data.get('stats', {}).get('total_volume', 0),
quantity=0, # Not available in COB data
side='unknown', # COB data doesn't have side info
trade_id=f"cob_{self.cob_data_received}",
is_buyer_maker=False,
raw_data=cob_data
)
self.ticks.append(synthetic_tick)
if self.cob_data_received % 10 == 0: # Log every 10th update
logger.info(f"COB update #{self.cob_data_received}: {symbol} @ ${mid_price:.2f}")
def _log_bucketed_cob_data(self, cob_data: dict):
"""Log bucketed COB data every second"""
try:
if 'bids' not in cob_data or 'asks' not in cob_data:
logger.info("COB-1s: No order book data available")
return
if 'stats' not in cob_data or 'mid_price' not in cob_data['stats']:
logger.info("COB-1s: No mid price available")
return
mid_price = cob_data['stats']['mid_price']
if mid_price <= 0:
return
# Bucket the order book data
bid_buckets = {}
ask_buckets = {}
# Process bids (top 10)
for bid in cob_data['bids'][:10]:
try:
if isinstance(bid, dict):
price = float(bid['price'])
size = float(bid['size'])
elif isinstance(bid, (list, tuple)) and len(bid) >= 2:
price = float(bid[0])
size = float(bid[1])
else:
continue
bucketed_price = round(price / self.price_granularity) * self.price_granularity
bid_buckets[bucketed_price] = bid_buckets.get(bucketed_price, 0) + size
except (ValueError, TypeError, IndexError):
continue
# Process asks (top 10)
for ask in cob_data['asks'][:10]:
try:
if isinstance(ask, dict):
price = float(ask['price'])
size = float(ask['size'])
elif isinstance(ask, (list, tuple)) and len(ask) >= 2:
price = float(ask[0])
size = float(ask[1])
else:
continue
bucketed_price = round(price / self.price_granularity) * self.price_granularity
ask_buckets[bucketed_price] = ask_buckets.get(bucketed_price, 0) + size
except (ValueError, TypeError, IndexError):
continue
# Format for log output
bid_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(bid_buckets.items(), reverse=True)])
ask_str = ", ".join([f"${p:.0f}:{s:.3f}" for p, s in sorted(ask_buckets.items())])
logger.info(f"COB-1s @ ${mid_price:.2f} | BIDS: {bid_str} | ASKS: {ask_str}")
except Exception as e:
logger.warning(f"Error logging bucketed COB data: {e}")
async def run_test(self):
"""Run the data collection and plotting test."""
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
# Initialize COB collection like clean_dashboard does
try:
logger.info("Starting COB collection in data provider...")
self.data_provider.start_cob_collection()
logger.info("Started COB collection in data provider")
# Subscribe to COB updates
logger.info("Subscribing to COB data updates...")
self.data_provider.subscribe_to_cob(self._cob_data_callback)
logger.info("Subscribed to COB data updates from data provider")
except Exception as e:
logger.error(f"Failed to start COB collection or subscribe: {e}")
# Subscribe to ticks as fallback
try:
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
logger.info("Subscribed to tick data as fallback")
except Exception as e:
logger.warning(f"Failed to subscribe to ticks: {e}")
# Start the data provider's real-time streaming
try:
await self.data_provider.start_real_time_streaming()
logger.info("Started real-time streaming")
except Exception as e:
logger.error(f"Failed to start real-time streaming: {e}")
# Collect data for the specified duration
self.start_time = datetime.now()
while datetime.now() - self.start_time < self.duration:
await asyncio.sleep(1)
logger.info(f"Collected {len(self.ticks)} ticks so far...")
# Stop streaming and unsubscribe
await self.data_provider.stop_real_time_streaming()
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
# Plot the results
if self.price_data and self.cob_snapshots:
self.create_price_heatmap_chart()
elif self.ticks:
self._create_simple_price_chart()
else:
logger.warning("No data was collected. Cannot generate plot.")
def create_price_heatmap_chart(self):
"""Create a visualization with price chart and order book scatter plot."""
if not self.price_data or not self.cob_snapshots:
logger.warning("Insufficient data to plot.")
return
logger.info(f"Creating price and order book chart...")
logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots")
# Prepare price data
price_df = pd.DataFrame(self.price_data)
price_df['timestamp'] = pd.to_datetime(price_df['timestamp'])
logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}")
logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}")
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2])
# Top plot: Price chart with order book levels
ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10)
# Plot order book levels as scatter points
bid_times, bid_prices, bid_sizes = [], [], []
ask_times, ask_prices, ask_sizes = [], [], []
# Calculate average price for filtering
avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price
price_lower = avg_price * 0.9 # -10%
price_upper = avg_price * 1.1 # +10%
logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})")
for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity
timestamp = pd.to_datetime(snapshot['timestamp'])
# Process bids (top 10)
for order in snapshot.get('bids', [])[:10]:
try:
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
bid_times.append(timestamp)
bid_prices.append(price)
bid_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Process asks (top 10)
for order in snapshot.get('asks', [])[:10]:
try:
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
ask_times.append(timestamp)
ask_prices.append(price)
ask_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Plot order book data as scatter with size indicating volume
if bid_times:
bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility
ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids')
logger.info(f"Plotted {len(bid_times)} bid levels")
if ask_times:
ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility
ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks')
logger.info(f"Plotted {len(ask_times)} ask levels")
ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s')
ax1.set_ylabel('Price (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Set proper time range (X-axis) - use actual data collection period
time_min = price_df['timestamp'].min()
time_max = price_df['timestamp'].max()
actual_duration = (time_max - time_min).total_seconds()
logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds")
ax1.set_xlim(time_min, time_max)
# Set tight price range (Y-axis) - use ±2% of price range for better visibility
price_min = price_df['price'].min()
price_max = price_df['price'].max()
price_center = (price_min + price_max) / 2
price_range = price_max - price_min
# If price range is very small, use a minimum range of $5
if price_range < 5:
price_range = 5
# Add 20% padding to the price range for better visualization
y_padding = price_range * 0.2
y_min = price_min - y_padding
y_max = price_max + y_padding
ax1.set_ylim(y_min, y_max)
logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})")
# Bottom plot: Order book depth over time (aggregated)
time_buckets = []
bid_depths = []
ask_depths = []
# Create time buckets (every few snapshots)
snapshots_list = list(self.cob_snapshots)
bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets
for i in range(0, len(snapshots_list), bucket_size):
bucket_snapshots = snapshots_list[i:i+bucket_size]
if not bucket_snapshots:
continue
# Use middle timestamp of bucket
mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2]
time_buckets.append(pd.to_datetime(mid_snapshot['timestamp']))
# Calculate average depths
total_bid_depth = 0
total_ask_depth = 0
snapshot_count = 0
for snapshot in bucket_snapshots:
bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('bids', [])[:10]])
ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('asks', [])[:10]])
total_bid_depth += bid_depth
total_ask_depth += ask_depth
snapshot_count += 1
if snapshot_count > 0:
bid_depths.append(total_bid_depth / snapshot_count)
ask_depths.append(total_ask_depth / snapshot_count)
else:
bid_depths.append(0)
ask_depths.append(0)
if time_buckets:
ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7)
ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7)
ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green')
ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red')
ax2.set_title('Order Book Depth Over Time')
ax2.set_xlabel('Time')
ax2.set_ylabel('Depth (Volume)')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Set same time range for bottom chart
ax2.set_xlim(time_min, time_max)
# Format time axes
fig.autofmt_xdate()
plt.tight_layout()
plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
logger.info(f"Price and order book chart saved to {plot_filename}")
plt.show()
def _create_simple_price_chart(self):
"""Create a simple price chart as fallback"""
logger.info("Creating simple price chart as fallback...")
prices = []
times = []
for tick in self.ticks:
if tick.price > 0:
prices.append(tick.price)
times.append(tick.timestamp)
if not prices:
logger.warning("No price data to plot")
return
fig, ax = plt.subplots(figsize=(15, 8))
ax.plot(pd.to_datetime(times), prices, 'cyan', linewidth=1)
ax.set_title(f'Price Chart - {self.symbol}')
ax.set_xlabel('Time')
ax.set_ylabel('Price (USDT)')
fig.autofmt_xdate()
plot_filename = f"cob_price_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename)
logger.info(f"Price chart saved to {plot_filename}")
plt.show()
async def main(symbol='ETHUSDT', duration_seconds=10):
"""Main function to run the COB test with configurable parameters.
Args:
symbol: Trading symbol (default: ETHUSDT)
duration_seconds: Test duration in seconds (default: 10)
"""
logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s")
tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds)
await tester.run_test()
if __name__ == "__main__":
import sys
# Parse command line arguments
symbol = 'ETHUSDT' # Default
duration = 10 # Default
if len(sys.argv) > 1:
symbol = sys.argv[1]
if len(sys.argv) > 2:
try:
duration = int(sys.argv[2])
except ValueError:
logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds")
logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s")
logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}")
try:
asyncio.run(main(symbol, duration))
except KeyboardInterrupt:
logger.info("Test interrupted by user.")