data input audit, cleanup
This commit is contained in:
158
utils/audit_plotter.py
Normal file
158
utils/audit_plotter.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""
|
||||
Audit Plotter
|
||||
|
||||
Create PNG snapshots of model input data at inference time:
|
||||
- Subplot 1: 1s candlesticks for recent window
|
||||
- Subplot 2: COB bucket volumes and imbalance near current price
|
||||
|
||||
Windows-safe, ASCII-only logging messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import math
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
# Use a non-interactive backend suitable for headless servers
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception:
|
||||
matplotlib = None # type: ignore
|
||||
plt = None # type: ignore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_recent_ohlcv(base_data, max_bars: int = 120) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
|
||||
"""Return recent 1s OHLCV arrays (time, open, high, low, close). Falls back to 1m if needed."""
|
||||
series = base_data.ohlcv_1s if getattr(base_data, "ohlcv_1s", None) else []
|
||||
if not series or len(series) < 5:
|
||||
series = base_data.ohlcv_1m if getattr(base_data, "ohlcv_1m", None) else []
|
||||
|
||||
series = series[-max_bars:] if series else []
|
||||
times = [b.timestamp for b in series]
|
||||
opens = [float(b.open) for b in series]
|
||||
highs = [float(b.high) for b in series]
|
||||
lows = [float(b.low) for b in series]
|
||||
closes = [float(b.close) for b in series]
|
||||
return times, opens, highs, lows, closes
|
||||
|
||||
|
||||
def _extract_cob(base_data, max_buckets: int = 40):
|
||||
"""Return sorted price buckets and metrics from COBData."""
|
||||
cob = getattr(base_data, "cob_data", None)
|
||||
if cob is None or not getattr(cob, "price_buckets", None):
|
||||
return [], [], [], []
|
||||
|
||||
# Sort by price and clip
|
||||
prices = sorted(list(cob.price_buckets.keys()))[:max_buckets]
|
||||
bid_vol = []
|
||||
ask_vol = []
|
||||
imb = []
|
||||
for p in prices:
|
||||
bucket = cob.price_buckets.get(p, {})
|
||||
b = float(bucket.get("bid_volume", 0.0))
|
||||
a = float(bucket.get("ask_volume", 0.0))
|
||||
bid_vol.append(b)
|
||||
ask_vol.append(a)
|
||||
denom = (b + a) if (b + a) > 0 else 1.0
|
||||
imb.append((b - a) / denom)
|
||||
return prices, bid_vol, ask_vol, imb
|
||||
|
||||
|
||||
def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str:
|
||||
"""Save a PNG snapshot of input data. Returns path if saved, else empty string."""
|
||||
if matplotlib is None or plt is None:
|
||||
logger.warning("matplotlib not available; skipping audit image")
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Ensure output directory structure
|
||||
day_dir = datetime.utcnow().strftime("%Y%m%d")
|
||||
out_dir = os.path.join(out_root, day_dir)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# File name: {ts}_{symbol}_{model}.png (ASCII-only)
|
||||
ts_str = datetime.utcnow().strftime("%H%M%S_%f")
|
||||
safe_symbol = symbol.replace("/", "-")
|
||||
fname = f"{ts_str}_{safe_symbol}_{model_name}.png"
|
||||
out_path = os.path.join(out_dir, fname)
|
||||
|
||||
# Extract data
|
||||
times, o, h, l, c = _extract_recent_ohlcv(base_data)
|
||||
prices, bid_v, ask_v, imb = _extract_cob(base_data)
|
||||
current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0))
|
||||
|
||||
# Prepare figure
|
||||
fig = plt.figure(figsize=(12, 7), dpi=110)
|
||||
gs = fig.add_gridspec(2, 1, height_ratios=[3, 2])
|
||||
|
||||
# Candlestick subplot
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
if times:
|
||||
x = list(range(len(times)))
|
||||
# Plot high-low wicks
|
||||
ax1.vlines(x, l, h, color="#444444", linewidth=1)
|
||||
# Plot body as rectangle via bar with bottom=min(open, close) and height=abs(diff)
|
||||
bodies = [c[i] - o[i] for i in range(len(o))]
|
||||
bottoms = [min(o[i], c[i]) for i in range(len(o))]
|
||||
colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))]
|
||||
heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))]
|
||||
ax1.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center", edgecolor="#222222", linewidth=0.5)
|
||||
# Labels
|
||||
ax1.set_title(f"{safe_symbol} Candles (recent)")
|
||||
ax1.set_ylabel("Price")
|
||||
ax1.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
||||
else:
|
||||
ax1.text(0.5, 0.5, "No OHLCV data", ha="center", va="center")
|
||||
|
||||
# COB subplot
|
||||
ax2 = fig.add_subplot(gs[1, 0])
|
||||
if prices:
|
||||
# Normalize x as offsets around current price if available
|
||||
if current_price > 0:
|
||||
xvals = [p - current_price for p in prices]
|
||||
ax2.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0)
|
||||
ax2.set_xlabel("Price offset")
|
||||
else:
|
||||
xvals = prices
|
||||
ax2.set_xlabel("Price")
|
||||
|
||||
# Plot bid/ask volumes
|
||||
ax2.plot(xvals, bid_v, label="bid_vol", color="#2c7fb8")
|
||||
ax2.plot(xvals, ask_v, label="ask_vol", color="#d95f0e")
|
||||
# Secondary axis for imbalance
|
||||
ax2b = ax2.twinx()
|
||||
ax2b.plot(xvals, imb, label="imbalance", color="#6a3d9a", linewidth=1.2)
|
||||
ax2b.set_ylabel("Imbalance")
|
||||
ax2.set_ylabel("Volume")
|
||||
ax2.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
||||
# Build combined legend
|
||||
lines, labels = ax2.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2b.get_legend_handles_labels()
|
||||
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
||||
ax2.set_title("COB Buckets (recent)")
|
||||
else:
|
||||
ax2.text(0.5, 0.5, "No COB data", ha="center", va="center")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
logger.info(f"Saved audit image: {out_path}")
|
||||
return out_path
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to save audit image: {ex}")
|
||||
try:
|
||||
plt.close("all")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
Reference in New Issue
Block a user