wip cnn training and cob

This commit is contained in:
Dobromir Popov
2025-07-23 23:33:36 +03:00
parent 8677c4c01c
commit 5437495003
4 changed files with 599 additions and 210 deletions

View File

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class COBStabilityTester:
def __init__(self, symbol='ETHUSDT', duration_seconds=15):
def __init__(self, symbol='ETHUSDT', duration_seconds=10):
self.symbol = symbol
self.duration = timedelta(seconds=duration_seconds)
self.ticks = deque()
@ -85,8 +85,10 @@ class COBStabilityTester:
if cob_data['asks']:
logger.info(f"DEBUG: First ask: {cob_data['asks'][0]}")
# Use current time for timestamp consistency
current_time = datetime.now()
snapshot = {
'timestamp': cob_data.get('timestamp', datetime.now()),
'timestamp': current_time,
'bids': cob_data['bids'],
'asks': cob_data['asks'],
'stats': cob_data.get('stats', {})
@ -103,16 +105,28 @@ class COBStabilityTester:
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
mid_price = cob_data['stats']['mid_price']
if mid_price > 0:
# Store price data for line chart
# Filter out extreme price movements (±10% of recent average)
if len(self.price_data) > 5:
recent_prices = [p['price'] for p in self.price_data[-5:]]
avg_recent_price = sum(recent_prices) / len(recent_prices)
price_deviation = abs(mid_price - avg_recent_price) / avg_recent_price
if price_deviation > 0.10: # More than 10% deviation
logger.warning(f"Filtering out extreme price: ${mid_price:.2f} (deviation: {price_deviation:.1%} from avg ${avg_recent_price:.2f})")
return # Skip this data point
# Store price data for line chart with consistent timestamp
current_time = datetime.now()
self.price_data.append({
'timestamp': cob_data.get('timestamp', datetime.now()),
'timestamp': current_time,
'price': mid_price
})
# Create a synthetic tick from COB data
# Create a synthetic tick from COB data with consistent timestamp
current_time = datetime.now()
synthetic_tick = MarketTick(
symbol=symbol,
timestamp=cob_data.get('timestamp', datetime.now()),
timestamp=current_time,
price=mid_price,
volume=cob_data.get('stats', {}).get('total_volume', 0),
quantity=0, # Not available in COB data
@ -240,132 +254,187 @@ class COBStabilityTester:
logger.warning("No data was collected. Cannot generate plot.")
def create_price_heatmap_chart(self):
"""Create a visualization with price chart and order book heatmap."""
"""Create a visualization with price chart and order book scatter plot."""
if not self.price_data or not self.cob_snapshots:
logger.warning("Insufficient data to plot.")
return
logger.info(f"Creating price and order book heatmap chart...")
logger.info(f"Creating price and order book chart...")
logger.info(f"Data summary: {len(self.price_data)} price points, {len(self.cob_snapshots)} COB snapshots")
# Prepare price data with consistent timestamp handling
# Prepare price data
price_df = pd.DataFrame(self.price_data)
price_df['timestamp'] = pd.to_datetime(price_df['timestamp'])
logger.info(f"Price data time range: {price_df['timestamp'].min()} to {price_df['timestamp'].max()}")
logger.info(f"Price range: ${price_df['price'].min():.2f} to ${price_df['price'].max():.2f}")
# Extract order book data for heatmap with consistent timestamp handling
heatmap_data = []
for snapshot in self.cob_snapshots:
timestamp = pd.to_datetime(snapshot['timestamp']) # Ensure datetime
for side in ['bids', 'asks']:
if side not in snapshot or not snapshot[side]:
continue
# Take top 50 levels for better visualization
orders = snapshot[side][:50]
for order in orders:
try:
# Handle both dict and list formats
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Apply granularity bucketing
bucketed_price = round(price / self.price_granularity) * self.price_granularity
heatmap_data.append({
'time': timestamp,
'price': bucketed_price,
'size': size,
'side': side
})
except (ValueError, TypeError, IndexError) as e:
continue
if not heatmap_data:
logger.warning("No valid heatmap data found, creating price chart only")
self._create_simple_price_chart()
return
heatmap_df = pd.DataFrame(heatmap_data)
logger.info(f"Heatmap data: {len(heatmap_df)} order book entries")
logger.info(f"Heatmap time range: {heatmap_df['time'].min()} to {heatmap_df['time'].max()}")
# Create plot with better time handling
fig, ax = plt.subplots(figsize=(16, 10))
# Determine overall time range
all_times = pd.concat([price_df['timestamp'], heatmap_df['time']])
time_min = all_times.min()
time_max = all_times.max()
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 12), height_ratios=[3, 2])
# Create price range for heatmap
price_min = min(price_df['price'].min(), heatmap_df['price'].min()) - self.price_granularity * 2
price_max = max(price_df['price'].max(), heatmap_df['price'].max()) + self.price_granularity * 2
# Top plot: Price chart with order book levels
ax1.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2, label='Mid Price', zorder=10)
logger.info(f"Chart time range: {time_min} to {time_max}")
logger.info(f"Chart price range: ${price_min:.2f} to ${price_max:.2f}")
# Create heatmap first (background)
for side, cmap, alpha in zip(['bids', 'asks'], ['Greens', 'Reds'], [0.6, 0.6]):
side_df = heatmap_df[heatmap_df['side'] == side]
if not side_df.empty:
# Create more granular bins
time_bins = pd.date_range(time_min, time_max, periods=min(100, len(side_df) // 10 + 10))
price_bins = np.arange(price_min, price_max + self.price_granularity, self.price_granularity)
# Plot order book levels as scatter points
bid_times, bid_prices, bid_sizes = [], [], []
ask_times, ask_prices, ask_sizes = [], [], []
# Calculate average price for filtering
avg_price = price_df['price'].mean() if not price_df.empty else 3500 # Fallback price
price_lower = avg_price * 0.9 # -10%
price_upper = avg_price * 1.1 # +10%
logger.info(f"Filtering order book data to price range: ${price_lower:.2f} - ${price_upper:.2f} (±10% of ${avg_price:.2f})")
for snapshot in list(self.cob_snapshots)[-50:]: # Use last 50 snapshots for clarity
timestamp = pd.to_datetime(snapshot['timestamp'])
# Process bids (top 10)
for order in snapshot.get('bids', [])[:10]:
try:
# Convert to seconds for histogram
time_seconds = (side_df['time'] - time_min).dt.total_seconds()
time_range_seconds = (time_max - time_min).total_seconds()
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
if time_range_seconds > 0:
hist, xedges, yedges = np.histogram2d(
time_seconds,
side_df['price'],
bins=[np.linspace(0, time_range_seconds, len(time_bins)), price_bins],
weights=side_df['size']
)
# Convert back to datetime for plotting
time_edges = pd.to_datetime(xedges, unit='s', origin=time_min)
if hist.max() > 0: # Only plot if we have data
pcm = ax.pcolormesh(time_edges, yedges, hist.T,
cmap=cmap, alpha=alpha, shading='auto')
logger.info(f"Plotted {side} heatmap: max value = {hist.max():.2f}")
except Exception as e:
logger.warning(f"Error creating {side} heatmap: {e}")
# Plot price line on top
ax.plot(price_df['timestamp'], price_df['price'], 'yellow', linewidth=2,
label='Mid Price', alpha=0.9, zorder=10)
# Enhance plot appearance
ax.set_title(f'Price Chart with Order Book Heatmap - {self.symbol}\n'
f'Granularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s\n'
f'Green=Bids, Red=Asks (darker = more volume)', fontsize=14)
ax.set_xlabel('Time')
ax.set_ylabel('Price (USDT)')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
bid_times.append(timestamp)
bid_prices.append(price)
bid_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Process asks (top 10)
for order in snapshot.get('asks', [])[:10]:
try:
if isinstance(order, dict):
price = float(order['price'])
size = float(order['size'])
elif isinstance(order, (list, tuple)) and len(order) >= 2:
price = float(order[0])
size = float(order[1])
else:
continue
# Filter out prices outside ±10% range
if price < price_lower or price > price_upper:
continue
ask_times.append(timestamp)
ask_prices.append(price)
ask_sizes.append(size)
except (ValueError, TypeError, IndexError):
continue
# Format time axis
ax.set_xlim(time_min, time_max)
# Plot order book data as scatter with size indicating volume
if bid_times:
bid_sizes_normalized = np.array(bid_sizes) * 3 # Scale for visibility
ax1.scatter(bid_times, bid_prices, s=bid_sizes_normalized, c='green', alpha=0.3, label='Bids')
logger.info(f"Plotted {len(bid_times)} bid levels")
if ask_times:
ask_sizes_normalized = np.array(ask_sizes) * 3 # Scale for visibility
ax1.scatter(ask_times, ask_prices, s=ask_sizes_normalized, c='red', alpha=0.3, label='Asks')
logger.info(f"Plotted {len(ask_times)} ask levels")
ax1.set_title(f'Real-time Price and Order Book - {self.symbol}\nGranularity: ${self.price_granularity} | Duration: {self.duration.total_seconds()}s')
ax1.set_ylabel('Price (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Set proper time range (X-axis) - use actual data collection period
time_min = price_df['timestamp'].min()
time_max = price_df['timestamp'].max()
actual_duration = (time_max - time_min).total_seconds()
logger.info(f"Actual data collection duration: {actual_duration:.1f} seconds")
ax1.set_xlim(time_min, time_max)
# Set tight price range (Y-axis) - use ±2% of price range for better visibility
price_min = price_df['price'].min()
price_max = price_df['price'].max()
price_center = (price_min + price_max) / 2
price_range = price_max - price_min
# If price range is very small, use a minimum range of $5
if price_range < 5:
price_range = 5
# Add 20% padding to the price range for better visualization
y_padding = price_range * 0.2
y_min = price_min - y_padding
y_max = price_max + y_padding
ax1.set_ylim(y_min, y_max)
logger.info(f"Chart Y-axis range: ${y_min:.2f} - ${y_max:.2f} (center: ${price_center:.2f}, range: ${price_range:.2f})")
# Bottom plot: Order book depth over time (aggregated)
time_buckets = []
bid_depths = []
ask_depths = []
# Create time buckets (every few snapshots)
snapshots_list = list(self.cob_snapshots)
bucket_size = max(1, len(snapshots_list) // 20) # ~20 buckets
for i in range(0, len(snapshots_list), bucket_size):
bucket_snapshots = snapshots_list[i:i+bucket_size]
if not bucket_snapshots:
continue
# Use middle timestamp of bucket
mid_snapshot = bucket_snapshots[len(bucket_snapshots)//2]
time_buckets.append(pd.to_datetime(mid_snapshot['timestamp']))
# Calculate average depths
total_bid_depth = 0
total_ask_depth = 0
snapshot_count = 0
for snapshot in bucket_snapshots:
bid_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('bids', [])[:10]])
ask_depth = sum([float(order[1]) if isinstance(order, (list, tuple)) else float(order.get('size', 0))
for order in snapshot.get('asks', [])[:10]])
total_bid_depth += bid_depth
total_ask_depth += ask_depth
snapshot_count += 1
if snapshot_count > 0:
bid_depths.append(total_bid_depth / snapshot_count)
ask_depths.append(total_ask_depth / snapshot_count)
else:
bid_depths.append(0)
ask_depths.append(0)
if time_buckets:
ax2.plot(time_buckets, bid_depths, 'green', linewidth=2, label='Bid Depth', alpha=0.7)
ax2.plot(time_buckets, ask_depths, 'red', linewidth=2, label='Ask Depth', alpha=0.7)
ax2.fill_between(time_buckets, bid_depths, alpha=0.3, color='green')
ax2.fill_between(time_buckets, ask_depths, alpha=0.3, color='red')
ax2.set_title('Order Book Depth Over Time')
ax2.set_xlabel('Time')
ax2.set_ylabel('Depth (Volume)')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Set same time range for bottom chart
ax2.set_xlim(time_min, time_max)
# Format time axes
fig.autofmt_xdate()
plt.tight_layout()
plot_filename = f"price_heatmap_chart_{self.symbol.replace('/', '_')}_{datetime.now():%Y%m%d_%H%M%S}.png"
plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
logger.info(f"Price and heatmap chart saved to {plot_filename}")
logger.info(f"Price and order book chart saved to {plot_filename}")
plt.show()
def _create_simple_price_chart(self):
@ -397,12 +466,12 @@ class COBStabilityTester:
plt.show()
async def main(symbol='ETHUSDT', duration_seconds=15):
async def main(symbol='ETHUSDT', duration_seconds=10):
"""Main function to run the COB test with configurable parameters.
Args:
symbol: Trading symbol (default: ETHUSDT)
duration_seconds: Test duration in seconds (default: 15)
duration_seconds: Test duration in seconds (default: 10)
"""
logger.info(f"Starting COB test with symbol={symbol}, duration={duration_seconds}s")
tester = COBStabilityTester(symbol=symbol, duration_seconds=duration_seconds)
@ -414,7 +483,7 @@ if __name__ == "__main__":
# Parse command line arguments
symbol = 'ETHUSDT' # Default
duration = 15 # Default
duration = 10 # Default
if len(sys.argv) > 1:
symbol = sys.argv[1]
@ -422,7 +491,7 @@ if __name__ == "__main__":
try:
duration = int(sys.argv[2])
except ValueError:
logger.warning(f"Invalid duration '{sys.argv[2]}', using default 15 seconds")
logger.warning(f"Invalid duration '{sys.argv[2]}', using default 10 seconds")
logger.info(f"Configuration: Symbol={symbol}, Duration={duration}s")
logger.info(f"Granularity: {'1 USD for ETH' if 'ETH' in symbol.upper() else '10 USD for BTC' if 'BTC' in symbol.upper() else '1 USD default'}")