orchestrator uses BaseDataInput
This commit is contained in:
@ -161,6 +161,15 @@ class DataProvider:
|
|||||||
# Enhanced WebSocket integration
|
# Enhanced WebSocket integration
|
||||||
self.enhanced_cob_websocket: Optional[EnhancedCOBWebSocket] = None
|
self.enhanced_cob_websocket: Optional[EnhancedCOBWebSocket] = None
|
||||||
self.websocket_tasks = {}
|
self.websocket_tasks = {}
|
||||||
|
|
||||||
|
# COB collection state guard to prevent duplicate starts
|
||||||
|
self._cob_started: bool = False
|
||||||
|
|
||||||
|
# Ensure COB collection is started so BaseDataInput includes real order book data
|
||||||
|
try:
|
||||||
|
self.start_cob_collection()
|
||||||
|
except Exception as _cob_init_ex:
|
||||||
|
logger.error(f"Failed to start COB collection at init: {_cob_init_ex}")
|
||||||
self.is_streaming = False
|
self.is_streaming = False
|
||||||
self.data_lock = Lock()
|
self.data_lock = Lock()
|
||||||
|
|
||||||
@ -1133,7 +1142,7 @@ class DataProvider:
|
|||||||
recent_ticks = self.get_cob_raw_ticks(symbol, count=limit * 10) # Get more ticks than needed
|
recent_ticks = self.get_cob_raw_ticks(symbol, count=limit * 10) # Get more ticks than needed
|
||||||
|
|
||||||
if not recent_ticks:
|
if not recent_ticks:
|
||||||
logger.warning(f"No tick data available for {symbol}, cannot generate 1s candles")
|
logger.debug(f"No tick data available for {symbol}, cannot generate 1s candles")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Group ticks by second and create OHLCV candles
|
# Group ticks by second and create OHLCV candles
|
||||||
@ -1156,8 +1165,7 @@ class DataProvider:
|
|||||||
bid_vol = stats.get('bid_volume', 0) or 0
|
bid_vol = stats.get('bid_volume', 0) or 0
|
||||||
ask_vol = stats.get('ask_volume', 0) or 0
|
ask_vol = stats.get('ask_volume', 0) or 0
|
||||||
volume = float(bid_vol) + float(ask_vol)
|
volume = float(bid_vol) + float(ask_vol)
|
||||||
if volume == 0:
|
# Do not create synthetic volume; keep zero if not available
|
||||||
volume = 1.0 # Minimal placeholder to avoid zero-volume bars
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -1212,7 +1220,7 @@ class DataProvider:
|
|||||||
candles.append(current_candle)
|
candles.append(current_candle)
|
||||||
|
|
||||||
if not candles:
|
if not candles:
|
||||||
logger.warning(f"No valid candles generated for {symbol}")
|
logger.debug(f"No valid candles generated for {symbol}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert to DataFrame (timestamps remain UTC tz-aware)
|
# Convert to DataFrame (timestamps remain UTC tz-aware)
|
||||||
@ -1251,7 +1259,7 @@ class DataProvider:
|
|||||||
logger.info(f"Successfully generated 1s candles from WebSocket ticks for {symbol}")
|
logger.info(f"Successfully generated 1s candles from WebSocket ticks for {symbol}")
|
||||||
return generated_df
|
return generated_df
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Failed to generate 1s candles from ticks for {symbol}, trying Binance API")
|
logger.info(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API")
|
||||||
|
|
||||||
# Convert symbol format
|
# Convert symbol format
|
||||||
binance_symbol = symbol.replace('/', '').upper()
|
binance_symbol = symbol.replace('/', '').upper()
|
||||||
@ -2673,7 +2681,16 @@ class DataProvider:
|
|||||||
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
||||||
"""Get latest COB data from Enhanced WebSocket"""
|
"""Get latest COB data from Enhanced WebSocket"""
|
||||||
try:
|
try:
|
||||||
return self.cob_websocket_data.get(symbol)
|
# First try the websocket data cache
|
||||||
|
if symbol in self.cob_websocket_data and self.cob_websocket_data[symbol]:
|
||||||
|
return self.cob_websocket_data[symbol]
|
||||||
|
|
||||||
|
# Fallback to raw ticks
|
||||||
|
if symbol in self.cob_raw_ticks and len(self.cob_raw_ticks[symbol]) > 0:
|
||||||
|
return self.cob_raw_ticks[symbol][-1] # Get latest raw tick
|
||||||
|
|
||||||
|
# No COB data available
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting latest COB data for {symbol}: {e}")
|
logger.error(f"Error getting latest COB data for {symbol}: {e}")
|
||||||
return None
|
return None
|
||||||
@ -4210,18 +4227,23 @@ class DataProvider:
|
|||||||
Start enhanced COB data collection with WebSocket and raw tick aggregation
|
Start enhanced COB data collection with WebSocket and raw tick aggregation
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Guard against duplicate starts
|
||||||
|
if getattr(self, "_cob_started", False):
|
||||||
|
return
|
||||||
# Initialize COB WebSocket system
|
# Initialize COB WebSocket system
|
||||||
self._initialize_enhanced_cob_websocket()
|
self._initialize_enhanced_cob_websocket()
|
||||||
|
|
||||||
# Start aggregation system
|
# Start aggregation system
|
||||||
self._start_cob_tick_aggregation()
|
self._start_cob_tick_aggregation()
|
||||||
|
|
||||||
|
self._cob_started = True
|
||||||
logger.info("Enhanced COB data collection started with WebSocket and tick aggregation")
|
logger.info("Enhanced COB data collection started with WebSocket and tick aggregation")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting enhanced COB collection: {e}")
|
logger.error(f"Error starting enhanced COB collection: {e}")
|
||||||
# Fallback to REST-only collection
|
# Fallback to REST-only collection
|
||||||
self._start_rest_only_cob_collection()
|
self._start_rest_only_cob_collection()
|
||||||
|
self._cob_started = True
|
||||||
|
|
||||||
def _initialize_enhanced_cob_websocket(self):
|
def _initialize_enhanced_cob_websocket(self):
|
||||||
"""Initialize the enhanced COB WebSocket system"""
|
"""Initialize the enhanced COB WebSocket system"""
|
||||||
|
@ -2315,7 +2315,11 @@ class TradingOrchestrator:
|
|||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
|
|
||||||
# Get the standard model input data once for all models
|
# Get the standard model input data once for all models
|
||||||
base_data = self.data_provider.build_base_data_input(symbol)
|
# Prefer standardized input if available; fallback to legacy builder
|
||||||
|
if hasattr(self.data_provider, "get_base_data_input"):
|
||||||
|
base_data = self.data_provider.get_base_data_input(symbol)
|
||||||
|
else:
|
||||||
|
base_data = self.data_provider.build_base_data_input(symbol)
|
||||||
if not base_data:
|
if not base_data:
|
||||||
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
|
logger.warning(f"Cannot build BaseDataInput for predictions: {symbol}")
|
||||||
return predictions
|
return predictions
|
||||||
|
@ -6,7 +6,7 @@ The system identifies swing highs and swing lows, then uses these pivot points t
|
|||||||
higher-level trends recursively.
|
higher-level trends recursively.
|
||||||
|
|
||||||
Key Features:
|
Key Features:
|
||||||
- Recursive pivot point calculation (5 levels)
|
- Recursive pivot point calculation (5 levels). first level is 1m OHLCV data, second level uses the first level as "candles", third level uses the second level as "candles", etc.
|
||||||
- Swing high/low identification
|
- Swing high/low identification
|
||||||
- Trend direction and strength analysis
|
- Trend direction and strength analysis
|
||||||
- Integration with CNN model for pivot prediction
|
- Integration with CNN model for pivot prediction
|
||||||
|
@ -44,35 +44,253 @@ def _extract_recent_ohlcv(base_data, max_bars: int = 120) -> Tuple[List[datetime
|
|||||||
return times, opens, highs, lows, closes
|
return times, opens, highs, lows, closes
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_timeframe_ohlcv(base_data, timeframe_attr: str, max_bars: int = 60) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
|
||||||
|
"""Extract OHLCV data for a specific timeframe attribute."""
|
||||||
|
series = getattr(base_data, timeframe_attr, []) if hasattr(base_data, timeframe_attr) else []
|
||||||
|
series = series[-max_bars:] if series else []
|
||||||
|
|
||||||
|
if not series:
|
||||||
|
return [], [], [], [], []
|
||||||
|
|
||||||
|
times = [b.timestamp for b in series]
|
||||||
|
opens = [float(b.open) for b in series]
|
||||||
|
highs = [float(b.high) for b in series]
|
||||||
|
lows = [float(b.low) for b in series]
|
||||||
|
closes = [float(b.close) for b in series]
|
||||||
|
return times, opens, highs, lows, closes
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_candlesticks(ax, times, opens, highs, lows, closes, title):
|
||||||
|
"""Plot candlestick chart on given axis."""
|
||||||
|
if not times:
|
||||||
|
ax.text(0.5, 0.5, f"No {title} data", ha="center", va="center", transform=ax.transAxes)
|
||||||
|
ax.set_title(title)
|
||||||
|
return
|
||||||
|
|
||||||
|
x = list(range(len(times)))
|
||||||
|
# Plot high-low wicks
|
||||||
|
ax.vlines(x, lows, highs, color="#444444", linewidth=0.8)
|
||||||
|
|
||||||
|
# Plot body as rectangles
|
||||||
|
bodies = [closes[i] - opens[i] for i in range(len(opens))]
|
||||||
|
bottoms = [min(opens[i], closes[i]) for i in range(len(opens))]
|
||||||
|
colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))]
|
||||||
|
heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))]
|
||||||
|
|
||||||
|
ax.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center",
|
||||||
|
edgecolor="#222222", linewidth=0.3)
|
||||||
|
|
||||||
|
ax.set_title(title, fontsize=10)
|
||||||
|
ax.grid(True, linestyle=":", linewidth=0.4, alpha=0.6)
|
||||||
|
|
||||||
|
# Show recent price
|
||||||
|
if closes:
|
||||||
|
ax.text(0.02, 0.98, f"${closes[-1]:.2f}", transform=ax.transAxes,
|
||||||
|
verticalalignment='top', fontsize=8, fontweight='bold')
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_data_summary(ax, base_data, symbol):
|
||||||
|
"""Plot data summary statistics."""
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
# Collect data statistics
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
# ETH timeframes
|
||||||
|
for tf, attr in [("1s", "ohlcv_1s"), ("1m", "ohlcv_1m"), ("1h", "ohlcv_1h"), ("1d", "ohlcv_1d")]:
|
||||||
|
data = getattr(base_data, attr, []) if hasattr(base_data, attr) else []
|
||||||
|
stats.append(f"ETH {tf}: {len(data)} bars")
|
||||||
|
|
||||||
|
# BTC data
|
||||||
|
btc_data = getattr(base_data, "btc_ohlcv_1s", []) if hasattr(base_data, "btc_ohlcv_1s") else []
|
||||||
|
stats.append(f"BTC 1s: {len(btc_data)} bars")
|
||||||
|
|
||||||
|
# COB data
|
||||||
|
cob = getattr(base_data, "cob_data", None)
|
||||||
|
if cob:
|
||||||
|
if hasattr(cob, "price_buckets") and cob.price_buckets:
|
||||||
|
stats.append(f"COB buckets: {len(cob.price_buckets)}")
|
||||||
|
elif hasattr(cob, "bids") and hasattr(cob, "asks"):
|
||||||
|
bids = getattr(cob, "bids", [])
|
||||||
|
asks = getattr(cob, "asks", [])
|
||||||
|
stats.append(f"COB levels: {len(bids)}b/{len(asks)}a")
|
||||||
|
else:
|
||||||
|
stats.append("COB: No data")
|
||||||
|
else:
|
||||||
|
stats.append("COB: Missing")
|
||||||
|
|
||||||
|
# Technical indicators
|
||||||
|
tech_indicators = getattr(base_data, "technical_indicators", {}) if hasattr(base_data, "technical_indicators") else {}
|
||||||
|
stats.append(f"Tech indicators: {len(tech_indicators)}")
|
||||||
|
|
||||||
|
# Display stats
|
||||||
|
y_pos = 0.9
|
||||||
|
ax.text(0.05, y_pos, "Data Summary:", fontweight='bold', transform=ax.transAxes)
|
||||||
|
y_pos -= 0.12
|
||||||
|
|
||||||
|
for stat in stats:
|
||||||
|
ax.text(0.05, y_pos, stat, fontsize=9, transform=ax.transAxes)
|
||||||
|
y_pos -= 0.1
|
||||||
|
|
||||||
|
ax.set_title("Input Data Stats", fontsize=10)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_cob_data(ax, prices, bid_v, ask_v, imb, current_price, symbol):
|
||||||
|
"""Plot COB data with bid/ask volumes and imbalance."""
|
||||||
|
if not prices:
|
||||||
|
ax.text(0.5, 0.5, f"No COB data for {symbol}", ha="center", va="center")
|
||||||
|
ax.set_title("COB Data - No Data Available")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normalize x as offsets around current price if available
|
||||||
|
if current_price > 0:
|
||||||
|
xvals = [p - current_price for p in prices]
|
||||||
|
ax.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0, alpha=0.7)
|
||||||
|
ax.set_xlabel("Price offset from current")
|
||||||
|
else:
|
||||||
|
xvals = prices
|
||||||
|
ax.set_xlabel("Price")
|
||||||
|
|
||||||
|
# Plot bid/ask volumes
|
||||||
|
ax.plot(xvals, bid_v, label="Bid Volume", color="#2c7fb8", linewidth=1.5)
|
||||||
|
ax.plot(xvals, ask_v, label="Ask Volume", color="#d95f0e", linewidth=1.5)
|
||||||
|
|
||||||
|
# Secondary axis for imbalance
|
||||||
|
ax2 = ax.twinx()
|
||||||
|
ax2.plot(xvals, imb, label="Imbalance", color="#6a3d9a", linewidth=2, alpha=0.8)
|
||||||
|
ax2.set_ylabel("Imbalance", color="#6a3d9a")
|
||||||
|
ax2.tick_params(axis='y', labelcolor="#6a3d9a")
|
||||||
|
|
||||||
|
ax.set_ylabel("Volume")
|
||||||
|
ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
||||||
|
|
||||||
|
# Combined legend
|
||||||
|
lines, labels = ax.get_legend_handles_labels()
|
||||||
|
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||||
|
ax.legend(lines + lines2, labels + labels2, loc="upper right")
|
||||||
|
|
||||||
|
# Title with current price info
|
||||||
|
price_info = f" (${current_price:.2f})" if current_price > 0 else ""
|
||||||
|
ax.set_title(f"COB Price Buckets - {symbol}{price_info}", fontsize=11)
|
||||||
|
|
||||||
|
|
||||||
def _extract_cob(base_data, max_buckets: int = 40):
|
def _extract_cob(base_data, max_buckets: int = 40):
|
||||||
"""Return sorted price buckets and metrics from COBData."""
|
"""Return sorted price buckets and metrics from COBData."""
|
||||||
cob = getattr(base_data, "cob_data", None)
|
cob = getattr(base_data, "cob_data", None)
|
||||||
if cob is None or not getattr(cob, "price_buckets", None):
|
|
||||||
return [], [], [], []
|
# Try to get price buckets from COBData object
|
||||||
|
if cob is not None and hasattr(cob, "price_buckets") and cob.price_buckets:
|
||||||
# Sort by price and clip
|
# Sort by price and clip
|
||||||
prices = sorted(list(cob.price_buckets.keys()))[:max_buckets]
|
prices = sorted(list(cob.price_buckets.keys()))[:max_buckets]
|
||||||
bid_vol = []
|
bid_vol = []
|
||||||
ask_vol = []
|
ask_vol = []
|
||||||
imb = []
|
imb = []
|
||||||
for p in prices:
|
for p in prices:
|
||||||
bucket = cob.price_buckets.get(p, {})
|
bucket = cob.price_buckets.get(p, {})
|
||||||
b = float(bucket.get("bid_volume", 0.0))
|
b = float(bucket.get("bid_volume", 0.0))
|
||||||
a = float(bucket.get("ask_volume", 0.0))
|
a = float(bucket.get("ask_volume", 0.0))
|
||||||
bid_vol.append(b)
|
bid_vol.append(b)
|
||||||
ask_vol.append(a)
|
ask_vol.append(a)
|
||||||
denom = (b + a) if (b + a) > 0 else 1.0
|
denom = (b + a) if (b + a) > 0 else 1.0
|
||||||
imb.append((b - a) / denom)
|
imb.append((b - a) / denom)
|
||||||
return prices, bid_vol, ask_vol, imb
|
return prices, bid_vol, ask_vol, imb
|
||||||
|
|
||||||
|
# Fallback: try to extract from raw bids/asks if available
|
||||||
|
if cob is not None:
|
||||||
|
# Check if we have raw bids/asks data
|
||||||
|
bids = getattr(cob, "bids", []) if hasattr(cob, "bids") else []
|
||||||
|
asks = getattr(cob, "asks", []) if hasattr(cob, "asks") else []
|
||||||
|
current_price = getattr(cob, "current_price", 0.0) if hasattr(cob, "current_price") else 0.0
|
||||||
|
|
||||||
|
if bids and asks and current_price > 0:
|
||||||
|
# Create price buckets from raw data
|
||||||
|
bucket_size = 1.0 if hasattr(cob, "bucket_size") and cob.bucket_size else 1.0
|
||||||
|
buckets = {}
|
||||||
|
|
||||||
|
# Process bids
|
||||||
|
for bid in bids[:50]: # Top 50 levels
|
||||||
|
if isinstance(bid, dict):
|
||||||
|
price = float(bid.get("price", 0))
|
||||||
|
size = float(bid.get("size", 0))
|
||||||
|
elif isinstance(bid, list) and len(bid) >= 2:
|
||||||
|
price = float(bid[0])
|
||||||
|
size = float(bid[1])
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if price > 0 and size > 0:
|
||||||
|
bucket_price = round(price / bucket_size) * bucket_size
|
||||||
|
if bucket_price not in buckets:
|
||||||
|
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
|
||||||
|
buckets[bucket_price]["bid_volume"] += size * price
|
||||||
|
|
||||||
|
# Process asks
|
||||||
|
for ask in asks[:50]: # Top 50 levels
|
||||||
|
if isinstance(ask, dict):
|
||||||
|
price = float(ask.get("price", 0))
|
||||||
|
size = float(ask.get("size", 0))
|
||||||
|
elif isinstance(ask, list) and len(ask) >= 2:
|
||||||
|
price = float(ask[0])
|
||||||
|
size = float(ask[1])
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if price > 0 and size > 0:
|
||||||
|
bucket_price = round(price / bucket_size) * bucket_size
|
||||||
|
if bucket_price not in buckets:
|
||||||
|
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
|
||||||
|
buckets[bucket_price]["ask_volume"] += size * price
|
||||||
|
|
||||||
|
if buckets:
|
||||||
|
# Sort by price and clip
|
||||||
|
prices = sorted(list(buckets.keys()))[:max_buckets]
|
||||||
|
bid_vol = []
|
||||||
|
ask_vol = []
|
||||||
|
imb = []
|
||||||
|
for p in prices:
|
||||||
|
bucket = buckets.get(p, {})
|
||||||
|
b = float(bucket.get("bid_volume", 0.0))
|
||||||
|
a = float(bucket.get("ask_volume", 0.0))
|
||||||
|
bid_vol.append(b)
|
||||||
|
ask_vol.append(a)
|
||||||
|
denom = (b + a) if (b + a) > 0 else 1.0
|
||||||
|
imb.append((b - a) / denom)
|
||||||
|
return prices, bid_vol, ask_vol, imb
|
||||||
|
|
||||||
|
# No COB data available
|
||||||
|
return [], [], [], []
|
||||||
|
|
||||||
|
|
||||||
def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str:
|
def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str:
|
||||||
"""Save a PNG snapshot of input data. Returns path if saved, else empty string."""
|
"""Save a comprehensive PNG snapshot of input data with all timeframes and COB data."""
|
||||||
if matplotlib is None or plt is None:
|
if matplotlib is None or plt is None:
|
||||||
logger.warning("matplotlib not available; skipping audit image")
|
logger.warning("matplotlib not available; skipping audit image")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Debug: Log what data we have
|
||||||
|
logger.info(f"Creating audit image for {model_name} - {symbol}")
|
||||||
|
if hasattr(base_data, 'ohlcv_1s'):
|
||||||
|
logger.info(f"ETH 1s data: {len(base_data.ohlcv_1s)} bars")
|
||||||
|
if hasattr(base_data, 'ohlcv_1m'):
|
||||||
|
logger.info(f"ETH 1m data: {len(base_data.ohlcv_1m)} bars")
|
||||||
|
if hasattr(base_data, 'ohlcv_1h'):
|
||||||
|
logger.info(f"ETH 1h data: {len(base_data.ohlcv_1h)} bars")
|
||||||
|
if hasattr(base_data, 'ohlcv_1d'):
|
||||||
|
logger.info(f"ETH 1d data: {len(base_data.ohlcv_1d)} bars")
|
||||||
|
if hasattr(base_data, 'btc_ohlcv_1s'):
|
||||||
|
logger.info(f"BTC 1s data: {len(base_data.btc_ohlcv_1s)} bars")
|
||||||
|
if hasattr(base_data, 'cob_data') and base_data.cob_data:
|
||||||
|
cob = base_data.cob_data
|
||||||
|
logger.info(f"COB data available: current_price={getattr(cob, 'current_price', 'N/A')}")
|
||||||
|
if hasattr(cob, 'price_buckets') and cob.price_buckets:
|
||||||
|
logger.info(f"COB price buckets: {len(cob.price_buckets)} buckets")
|
||||||
|
elif hasattr(cob, 'bids') and hasattr(cob, 'asks'):
|
||||||
|
logger.info(f"COB raw data: {len(getattr(cob, 'bids', []))} bids, {len(getattr(cob, 'asks', []))} asks")
|
||||||
|
else:
|
||||||
|
logger.info("COB data exists but no price_buckets or bids/asks found")
|
||||||
|
else:
|
||||||
|
logger.warning("No COB data available for audit image")
|
||||||
# Ensure output directory structure
|
# Ensure output directory structure
|
||||||
day_dir = datetime.utcnow().strftime("%Y%m%d")
|
day_dir = datetime.utcnow().strftime("%Y%m%d")
|
||||||
out_dir = os.path.join(out_root, day_dir)
|
out_dir = os.path.join(out_root, day_dir)
|
||||||
@ -84,68 +302,58 @@ def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root
|
|||||||
fname = f"{ts_str}_{safe_symbol}_{model_name}.png"
|
fname = f"{ts_str}_{safe_symbol}_{model_name}.png"
|
||||||
out_path = os.path.join(out_dir, fname)
|
out_path = os.path.join(out_dir, fname)
|
||||||
|
|
||||||
# Extract data
|
# Extract all timeframe data
|
||||||
times, o, h, l, c = _extract_recent_ohlcv(base_data)
|
eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1s", 60)
|
||||||
|
eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1m", 60)
|
||||||
|
eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1h", 24)
|
||||||
|
eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1d", 30)
|
||||||
|
btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c = _extract_timeframe_ohlcv(base_data, "btc_ohlcv_1s", 60)
|
||||||
|
|
||||||
|
# Extract COB data
|
||||||
prices, bid_v, ask_v, imb = _extract_cob(base_data)
|
prices, bid_v, ask_v, imb = _extract_cob(base_data)
|
||||||
current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0))
|
current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0))
|
||||||
|
|
||||||
# Prepare figure
|
# Create comprehensive figure with multiple subplots
|
||||||
fig = plt.figure(figsize=(12, 7), dpi=110)
|
fig = plt.figure(figsize=(16, 12), dpi=110)
|
||||||
gs = fig.add_gridspec(2, 1, height_ratios=[3, 2])
|
gs = fig.add_gridspec(3, 3, height_ratios=[2, 2, 1.5], width_ratios=[1, 1, 1])
|
||||||
|
|
||||||
# Candlestick subplot
|
# ETH 1s data (top left)
|
||||||
ax1 = fig.add_subplot(gs[0, 0])
|
ax1 = fig.add_subplot(gs[0, 0])
|
||||||
if times:
|
_plot_candlesticks(ax1, eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c, f"ETH 1s (last 60)")
|
||||||
x = list(range(len(times)))
|
|
||||||
# Plot high-low wicks
|
|
||||||
ax1.vlines(x, l, h, color="#444444", linewidth=1)
|
|
||||||
# Plot body as rectangle via bar with bottom=min(open, close) and height=abs(diff)
|
|
||||||
bodies = [c[i] - o[i] for i in range(len(o))]
|
|
||||||
bottoms = [min(o[i], c[i]) for i in range(len(o))]
|
|
||||||
colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))]
|
|
||||||
heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))]
|
|
||||||
ax1.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center", edgecolor="#222222", linewidth=0.5)
|
|
||||||
# Labels
|
|
||||||
ax1.set_title(f"{safe_symbol} Candles (recent)")
|
|
||||||
ax1.set_ylabel("Price")
|
|
||||||
ax1.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
|
||||||
else:
|
|
||||||
ax1.text(0.5, 0.5, "No OHLCV data", ha="center", va="center")
|
|
||||||
|
|
||||||
# COB subplot
|
# ETH 1m data (top middle)
|
||||||
ax2 = fig.add_subplot(gs[1, 0])
|
ax2 = fig.add_subplot(gs[0, 1])
|
||||||
if prices:
|
_plot_candlesticks(ax2, eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c, f"ETH 1m (last 60)")
|
||||||
# Normalize x as offsets around current price if available
|
|
||||||
if current_price > 0:
|
|
||||||
xvals = [p - current_price for p in prices]
|
|
||||||
ax2.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0)
|
|
||||||
ax2.set_xlabel("Price offset")
|
|
||||||
else:
|
|
||||||
xvals = prices
|
|
||||||
ax2.set_xlabel("Price")
|
|
||||||
|
|
||||||
# Plot bid/ask volumes
|
# ETH 1h data (top right)
|
||||||
ax2.plot(xvals, bid_v, label="bid_vol", color="#2c7fb8")
|
ax3 = fig.add_subplot(gs[0, 2])
|
||||||
ax2.plot(xvals, ask_v, label="ask_vol", color="#d95f0e")
|
_plot_candlesticks(ax3, eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c, f"ETH 1h (last 24)")
|
||||||
# Secondary axis for imbalance
|
|
||||||
ax2b = ax2.twinx()
|
# ETH 1d data (middle left)
|
||||||
ax2b.plot(xvals, imb, label="imbalance", color="#6a3d9a", linewidth=1.2)
|
ax4 = fig.add_subplot(gs[1, 0])
|
||||||
ax2b.set_ylabel("Imbalance")
|
_plot_candlesticks(ax4, eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c, f"ETH 1d (last 30)")
|
||||||
ax2.set_ylabel("Volume")
|
|
||||||
ax2.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
# BTC 1s data (middle middle)
|
||||||
# Build combined legend
|
ax5 = fig.add_subplot(gs[1, 1])
|
||||||
lines, labels = ax2.get_legend_handles_labels()
|
_plot_candlesticks(ax5, btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c, f"BTC 1s (last 60)")
|
||||||
lines2, labels2 = ax2b.get_legend_handles_labels()
|
|
||||||
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
# Data summary (middle right)
|
||||||
ax2.set_title("COB Buckets (recent)")
|
ax6 = fig.add_subplot(gs[1, 2])
|
||||||
else:
|
_plot_data_summary(ax6, base_data, symbol)
|
||||||
ax2.text(0.5, 0.5, "No COB data", ha="center", va="center")
|
|
||||||
|
# COB data (bottom, spanning all columns)
|
||||||
|
ax7 = fig.add_subplot(gs[2, :])
|
||||||
|
_plot_cob_data(ax7, prices, bid_v, ask_v, imb, current_price, symbol)
|
||||||
|
|
||||||
|
# Add overall title with model and timestamp info
|
||||||
|
fig.suptitle(f"{model_name} - {safe_symbol} - {datetime.utcnow().strftime('%H:%M:%S')}",
|
||||||
|
fontsize=14, fontweight='bold')
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
fig.savefig(out_path, bbox_inches="tight")
|
fig.savefig(out_path, bbox_inches="tight")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
logger.info(f"Saved audit image: {out_path}")
|
logger.info(f"Saved comprehensive audit image: {out_path}")
|
||||||
return out_path
|
return out_path
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.error(f"Failed to save audit image: {ex}")
|
logger.error(f"Failed to save audit image: {ex}")
|
||||||
|
Reference in New Issue
Block a user