""" Audit Plotter Create PNG snapshots of model input data at inference time: - Subplot 1: 1s candlesticks for recent window - Subplot 2: COB bucket volumes and imbalance near current price Windows-safe, ASCII-only logging messages. """ from __future__ import annotations import os import math import logging from datetime import datetime from typing import List, Tuple try: import matplotlib # Use a non-interactive backend suitable for headless servers matplotlib.use("Agg") import matplotlib.pyplot as plt except Exception: matplotlib = None # type: ignore plt = None # type: ignore logger = logging.getLogger(__name__) def _extract_recent_ohlcv(base_data, max_bars: int = 120) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]: """Return recent 1s OHLCV arrays (time, open, high, low, close). Falls back to 1m if needed.""" series = base_data.ohlcv_1s if getattr(base_data, "ohlcv_1s", None) else [] if not series or len(series) < 5: series = base_data.ohlcv_1m if getattr(base_data, "ohlcv_1m", None) else [] series = series[-max_bars:] if series else [] times = [b.timestamp for b in series] opens = [float(b.open) for b in series] highs = [float(b.high) for b in series] lows = [float(b.low) for b in series] closes = [float(b.close) for b in series] return times, opens, highs, lows, closes def _extract_timeframe_ohlcv(base_data, timeframe_attr: str, max_bars: int = 60) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]: """Extract OHLCV data for a specific timeframe attribute.""" series = getattr(base_data, timeframe_attr, []) if hasattr(base_data, timeframe_attr) else [] series = series[-max_bars:] if series else [] if not series: return [], [], [], [], [] times = [b.timestamp for b in series] opens = [float(b.open) for b in series] highs = [float(b.high) for b in series] lows = [float(b.low) for b in series] closes = [float(b.close) for b in series] return times, opens, highs, lows, closes def _plot_candlesticks(ax, times, opens, highs, lows, closes, title): """Plot candlestick chart on given axis.""" if not times: ax.text(0.5, 0.5, f"No {title} data", ha="center", va="center", transform=ax.transAxes) ax.set_title(title) return x = list(range(len(times))) # Plot high-low wicks ax.vlines(x, lows, highs, color="#444444", linewidth=0.8) # Plot body as rectangles bodies = [closes[i] - opens[i] for i in range(len(opens))] bottoms = [min(opens[i], closes[i]) for i in range(len(opens))] colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))] heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))] ax.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center", edgecolor="#222222", linewidth=0.3) ax.set_title(title, fontsize=10) ax.grid(True, linestyle=":", linewidth=0.4, alpha=0.6) # Show recent price if closes: ax.text(0.02, 0.98, f"${closes[-1]:.2f}", transform=ax.transAxes, verticalalignment='top', fontsize=8, fontweight='bold') def _plot_data_summary(ax, base_data, symbol): """Plot data summary statistics.""" ax.axis('off') # Collect data statistics stats = [] # ETH timeframes for tf, attr in [("1s", "ohlcv_1s"), ("1m", "ohlcv_1m"), ("1h", "ohlcv_1h"), ("1d", "ohlcv_1d")]: data = getattr(base_data, attr, []) if hasattr(base_data, attr) else [] stats.append(f"ETH {tf}: {len(data)} bars") # BTC data btc_data = getattr(base_data, "btc_ohlcv_1s", []) if hasattr(base_data, "btc_ohlcv_1s") else [] stats.append(f"BTC 1s: {len(btc_data)} bars") # COB data cob = getattr(base_data, "cob_data", None) if cob: if hasattr(cob, "price_buckets") and cob.price_buckets: stats.append(f"COB buckets: {len(cob.price_buckets)}") elif hasattr(cob, "bids") and hasattr(cob, "asks"): bids = getattr(cob, "bids", []) asks = getattr(cob, "asks", []) stats.append(f"COB levels: {len(bids)}b/{len(asks)}a") else: stats.append("COB: No data") else: stats.append("COB: Missing") # Technical indicators tech_indicators = getattr(base_data, "technical_indicators", {}) if hasattr(base_data, "technical_indicators") else {} stats.append(f"Tech indicators: {len(tech_indicators)}") # Display stats y_pos = 0.9 ax.text(0.05, y_pos, "Data Summary:", fontweight='bold', transform=ax.transAxes) y_pos -= 0.12 for stat in stats: ax.text(0.05, y_pos, stat, fontsize=9, transform=ax.transAxes) y_pos -= 0.1 ax.set_title("Input Data Stats", fontsize=10) def _plot_cob_data(ax, prices, bid_v, ask_v, imb, current_price, symbol): """Plot COB data with bid/ask volumes and imbalance.""" if not prices: ax.text(0.5, 0.5, f"No COB data for {symbol}", ha="center", va="center") ax.set_title("COB Data - No Data Available") return # Normalize x as offsets around current price if available if current_price > 0: xvals = [p - current_price for p in prices] ax.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0, alpha=0.7) ax.set_xlabel("Price offset from current") else: xvals = prices ax.set_xlabel("Price") # Plot bid/ask volumes ax.plot(xvals, bid_v, label="Bid Volume", color="#2c7fb8", linewidth=1.5) ax.plot(xvals, ask_v, label="Ask Volume", color="#d95f0e", linewidth=1.5) # Secondary axis for imbalance ax2 = ax.twinx() ax2.plot(xvals, imb, label="Imbalance", color="#6a3d9a", linewidth=2, alpha=0.8) ax2.set_ylabel("Imbalance", color="#6a3d9a") ax2.tick_params(axis='y', labelcolor="#6a3d9a") ax.set_ylabel("Volume") ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.6) # Combined legend lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines + lines2, labels + labels2, loc="upper right") # Title with current price info price_info = f" (${current_price:.2f})" if current_price > 0 else "" ax.set_title(f"COB Price Buckets - {symbol}{price_info}", fontsize=11) def _extract_cob(base_data, max_buckets: int = 40): """Return sorted price buckets and metrics from COBData.""" cob = getattr(base_data, "cob_data", None) # Try to get price buckets from COBData object if cob is not None and hasattr(cob, "price_buckets") and cob.price_buckets: # Sort by price and clip prices = sorted(list(cob.price_buckets.keys()))[:max_buckets] bid_vol = [] ask_vol = [] imb = [] for p in prices: bucket = cob.price_buckets.get(p, {}) b = float(bucket.get("bid_volume", 0.0)) a = float(bucket.get("ask_volume", 0.0)) bid_vol.append(b) ask_vol.append(a) denom = (b + a) if (b + a) > 0 else 1.0 imb.append((b - a) / denom) return prices, bid_vol, ask_vol, imb # Fallback: try to extract from raw bids/asks if available if cob is not None: # Check if we have raw bids/asks data bids = getattr(cob, "bids", []) if hasattr(cob, "bids") else [] asks = getattr(cob, "asks", []) if hasattr(cob, "asks") else [] current_price = getattr(cob, "current_price", 0.0) if hasattr(cob, "current_price") else 0.0 if bids and asks and current_price > 0: # Create price buckets from raw data bucket_size = 1.0 if hasattr(cob, "bucket_size") and cob.bucket_size else 1.0 buckets = {} # Process bids for bid in bids[:50]: # Top 50 levels if isinstance(bid, dict): price = float(bid.get("price", 0)) size = float(bid.get("size", 0)) elif isinstance(bid, list) and len(bid) >= 2: price = float(bid[0]) size = float(bid[1]) else: continue if price > 0 and size > 0: bucket_price = round(price / bucket_size) * bucket_size if bucket_price not in buckets: buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0} buckets[bucket_price]["bid_volume"] += size * price # Process asks for ask in asks[:50]: # Top 50 levels if isinstance(ask, dict): price = float(ask.get("price", 0)) size = float(ask.get("size", 0)) elif isinstance(ask, list) and len(ask) >= 2: price = float(ask[0]) size = float(ask[1]) else: continue if price > 0 and size > 0: bucket_price = round(price / bucket_size) * bucket_size if bucket_price not in buckets: buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0} buckets[bucket_price]["ask_volume"] += size * price if buckets: # Sort by price and clip prices = sorted(list(buckets.keys()))[:max_buckets] bid_vol = [] ask_vol = [] imb = [] for p in prices: bucket = buckets.get(p, {}) b = float(bucket.get("bid_volume", 0.0)) a = float(bucket.get("ask_volume", 0.0)) bid_vol.append(b) ask_vol.append(a) denom = (b + a) if (b + a) > 0 else 1.0 imb.append((b - a) / denom) return prices, bid_vol, ask_vol, imb # No COB data available return [], [], [], [] def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str: """Save a comprehensive PNG snapshot of input data with all timeframes and COB data.""" if matplotlib is None or plt is None: logger.warning("matplotlib not available; skipping audit image") return "" try: # Debug: Log what data we have logger.info(f"Creating audit image for {model_name} - {symbol}") if hasattr(base_data, 'ohlcv_1s'): logger.info(f"ETH 1s data: {len(base_data.ohlcv_1s)} bars") if hasattr(base_data, 'ohlcv_1m'): logger.info(f"ETH 1m data: {len(base_data.ohlcv_1m)} bars") if hasattr(base_data, 'ohlcv_1h'): logger.info(f"ETH 1h data: {len(base_data.ohlcv_1h)} bars") if hasattr(base_data, 'ohlcv_1d'): logger.info(f"ETH 1d data: {len(base_data.ohlcv_1d)} bars") if hasattr(base_data, 'btc_ohlcv_1s'): logger.info(f"BTC 1s data: {len(base_data.btc_ohlcv_1s)} bars") if hasattr(base_data, 'cob_data') and base_data.cob_data: cob = base_data.cob_data logger.info(f"COB data available: current_price={getattr(cob, 'current_price', 'N/A')}") if hasattr(cob, 'price_buckets') and cob.price_buckets: logger.info(f"COB price buckets: {len(cob.price_buckets)} buckets") elif hasattr(cob, 'bids') and hasattr(cob, 'asks'): logger.info(f"COB raw data: {len(getattr(cob, 'bids', []))} bids, {len(getattr(cob, 'asks', []))} asks") else: logger.info("COB data exists but no price_buckets or bids/asks found") else: logger.warning("No COB data available for audit image") # Ensure output directory structure day_dir = datetime.utcnow().strftime("%Y%m%d") out_dir = os.path.join(out_root, day_dir) os.makedirs(out_dir, exist_ok=True) # File name: {ts}_{symbol}_{model}.png (ASCII-only) ts_str = datetime.utcnow().strftime("%H%M%S_%f") safe_symbol = symbol.replace("/", "-") fname = f"{ts_str}_{safe_symbol}_{model_name}.png" out_path = os.path.join(out_dir, fname) # Extract all timeframe data eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1s", 60) eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1m", 60) eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1h", 24) eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1d", 30) btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c = _extract_timeframe_ohlcv(base_data, "btc_ohlcv_1s", 60) # Extract COB data prices, bid_v, ask_v, imb = _extract_cob(base_data) current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0)) # Create comprehensive figure with multiple subplots fig = plt.figure(figsize=(16, 12), dpi=110) gs = fig.add_gridspec(3, 3, height_ratios=[2, 2, 1.5], width_ratios=[1, 1, 1]) # ETH 1s data (top left) ax1 = fig.add_subplot(gs[0, 0]) _plot_candlesticks(ax1, eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c, f"ETH 1s (last 60)") # ETH 1m data (top middle) ax2 = fig.add_subplot(gs[0, 1]) _plot_candlesticks(ax2, eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c, f"ETH 1m (last 60)") # ETH 1h data (top right) ax3 = fig.add_subplot(gs[0, 2]) _plot_candlesticks(ax3, eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c, f"ETH 1h (last 24)") # ETH 1d data (middle left) ax4 = fig.add_subplot(gs[1, 0]) _plot_candlesticks(ax4, eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c, f"ETH 1d (last 30)") # BTC 1s data (middle middle) ax5 = fig.add_subplot(gs[1, 1]) _plot_candlesticks(ax5, btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c, f"BTC 1s (last 60)") # Data summary (middle right) ax6 = fig.add_subplot(gs[1, 2]) _plot_data_summary(ax6, base_data, symbol) # COB data (bottom, spanning all columns) ax7 = fig.add_subplot(gs[2, :]) _plot_cob_data(ax7, prices, bid_v, ask_v, imb, current_price, symbol) # Add overall title with model and timestamp info fig.suptitle(f"{model_name} - {safe_symbol} - {datetime.utcnow().strftime('%H:%M:%S')}", fontsize=14, fontweight='bold') fig.tight_layout() fig.savefig(out_path, bbox_inches="tight") plt.close(fig) logger.info(f"Saved comprehensive audit image: {out_path}") return out_path except Exception as ex: logger.error(f"Failed to save audit image: {ex}") try: plt.close("all") except Exception: pass return ""