integrate CNN, fix COB data
This commit is contained in:
123
test_cob_data_stability.py
Normal file
123
test_cob_data_stability.py
Normal file
@ -0,0 +1,123 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.colors import LogNorm
|
||||
|
||||
from core.data_provider import DataProvider, MarketTick
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class COBStabilityTester:
|
||||
def __init__(self, symbol='ETH/USDT', duration_seconds=15):
|
||||
self.symbol = symbol
|
||||
self.duration = timedelta(seconds=duration_seconds)
|
||||
self.ticks = deque()
|
||||
self.data_provider = DataProvider(symbols=[self.symbol], timeframes=['1s'])
|
||||
self.start_time = None
|
||||
self.subscriber_id = None
|
||||
|
||||
def _tick_callback(self, tick: MarketTick):
|
||||
"""Callback function to receive ticks from the DataProvider."""
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.now()
|
||||
logger.info(f"Started collecting ticks at {self.start_time}")
|
||||
|
||||
# Store all ticks
|
||||
self.ticks.append(tick)
|
||||
|
||||
async def run_test(self):
|
||||
"""Run the data collection and plotting test."""
|
||||
logger.info(f"Starting COB stability test for {self.symbol} for {self.duration.total_seconds()} seconds...")
|
||||
|
||||
# Subscribe to ticks
|
||||
self.subscriber_id = self.data_provider.subscribe_to_ticks(self._tick_callback, symbols=[self.symbol])
|
||||
|
||||
# Start the data provider's real-time streaming
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
|
||||
# Collect data for the specified duration
|
||||
self.start_time = datetime.now()
|
||||
while datetime.now() - self.start_time < self.duration:
|
||||
await asyncio.sleep(1)
|
||||
logger.info(f"Collected {len(self.ticks)} ticks so far...")
|
||||
|
||||
# Stop streaming and unsubscribe
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
self.data_provider.unsubscribe_from_ticks(self.subscriber_id)
|
||||
|
||||
logger.info(f"Finished collecting data. Total ticks: {len(self.ticks)}")
|
||||
|
||||
# Plot the results
|
||||
if self.ticks:
|
||||
self.plot_spectrogram()
|
||||
else:
|
||||
logger.warning("No ticks were collected. Cannot generate plot.")
|
||||
|
||||
def plot_spectrogram(self):
|
||||
"""Create a spectrogram-like plot of trade intensity."""
|
||||
if not self.ticks:
|
||||
logger.warning("No ticks to plot.")
|
||||
return
|
||||
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'side': 1 if tick.side == 'buy' else -1
|
||||
} for tick in self.ticks])
|
||||
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
df = df.set_index('timestamp')
|
||||
|
||||
# Create the plot
|
||||
fig, ax = plt.subplots(figsize=(15, 8))
|
||||
|
||||
# Define bins for the 2D histogram
|
||||
time_bins = pd.date_range(df.index.min(), df.index.max(), periods=100)
|
||||
price_bins = np.linspace(df['price'].min(), df['price'].max(), 100)
|
||||
|
||||
# Create the 2D histogram
|
||||
# x-axis: time, y-axis: price, weights: volume
|
||||
h, xedges, yedges = np.histogram2d(
|
||||
df.index.astype(np.int64) // 10**9,
|
||||
df['price'],
|
||||
bins=[time_bins.astype(np.int64) // 10**9, price_bins],
|
||||
weights=df['volume']
|
||||
)
|
||||
|
||||
# Use a logarithmic color scale for better visibility of smaller trades
|
||||
pcm = ax.pcolormesh(time_bins, price_bins, h.T, norm=LogNorm(vmin=1e-3, vmax=h.max()), cmap='inferno')
|
||||
|
||||
fig.colorbar(pcm, ax=ax, label='Trade Volume (USDT)')
|
||||
ax.set_title(f'Trade Intensity Spectrogram for {self.symbol}')
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Price (USDT)')
|
||||
|
||||
# Format the x-axis to show time properly
|
||||
fig.autofmt_xdate()
|
||||
|
||||
plot_filename = f"cob_stability_spectrogram_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
|
||||
plt.savefig(plot_filename)
|
||||
logger.info(f"Plot saved to {plot_filename}")
|
||||
plt.show()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = COBStabilityTester()
|
||||
await tester.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test interrupted by user.")
|
Reference in New Issue
Block a user