diff --git a/core/data_models.py b/core/data_models.py index 30af822..f21aafc 100644 --- a/core/data_models.py +++ b/core/data_models.py @@ -90,6 +90,11 @@ class BaseDataInput: # COB data for 1s timeframe (±20 buckets around current price) cob_data: Optional[COBData] = None + # COB heatmap (time-series of bucket metrics at 1s resolution) + # Each row corresponds to one second, columns to price buckets + cob_heatmap_times: List[datetime] = field(default_factory=list) + cob_heatmap_prices: List[float] = field(default_factory=list) + cob_heatmap_values: List[List[float]] = field(default_factory=list) # typically imbalance per bucket # Technical indicators technical_indicators: Dict[str, float] = field(default_factory=dict) diff --git a/core/data_provider.py b/core/data_provider.py index 5b81228..4ebeae7 100644 --- a/core/data_provider.py +++ b/core/data_provider.py @@ -4756,6 +4756,87 @@ class DataProvider: except Exception as e: logger.error(f"Error getting 1s aggregated COB data for {symbol}: {e}") return [] + + def get_cob_heatmap_matrix( + self, + symbol: str, + seconds: int = 300, + bucket_radius: int = 10, + metric: str = 'imbalance' + ) -> Tuple[List[datetime], List[float], List[List[float]]]: + """ + Build a 1s COB heatmap matrix for ±bucket_radius buckets around current price. + + Returns (times, prices, matrix) where matrix is shape [T x B]. + metric: 'imbalance' or 'liquidity' (uses bid_volume+ask_volume) + """ + try: + times: List[datetime] = [] + prices: List[float] = [] + values: List[List[float]] = [] + + latest = self.get_latest_cob_data(symbol) + if not latest or 'stats' not in latest: + return times, prices, values + + mid = float(latest['stats'].get('mid_price', 0) or 0) + if mid <= 0: + return times, prices, values + + bucket_size = 1.0 if 'ETH' in symbol else 10.0 + center = round(mid / bucket_size) * bucket_size + prices = [center + i * bucket_size for i in range(-bucket_radius, bucket_radius + 1)] + + with self.subscriber_lock: + cache_for_symbol = getattr(self, 'cob_data_cache', {}).get(symbol, []) + snapshots = list(cache_for_symbol[-seconds:]) if cache_for_symbol else [] + + for snap in snapshots: + ts_ms = snap.get('timestamp') + if isinstance(ts_ms, (int, float)): + times.append(datetime.fromtimestamp(ts_ms / 1000.0)) + else: + times.append(datetime.utcnow()) + + bids = snap.get('bids') or [] + asks = snap.get('asks') or [] + + bucket_map: Dict[float, Dict[str, float]] = {} + for level in bids[:200]: + try: + price, size = float(level[0]), float(level[1]) + bp = round(price / bucket_size) * bucket_size + if bp not in bucket_map: + bucket_map[bp] = {'bid': 0.0, 'ask': 0.0} + bucket_map[bp]['bid'] += size + except Exception: + continue + for level in asks[:200]: + try: + price, size = float(level[0]), float(level[1]) + bp = round(price / bucket_size) * bucket_size + if bp not in bucket_map: + bucket_map[bp] = {'bid': 0.0, 'ask': 0.0} + bucket_map[bp]['ask'] += size + except Exception: + continue + + row: List[float] = [] + for p in prices: + b = float(bucket_map.get(p, {}).get('bid', 0.0)) + a = float(bucket_map.get(p, {}).get('ask', 0.0)) + if metric == 'liquidity': + val = (b + a) + else: + denom = (b + a) + val = (b - a) / denom if denom > 0 else 0.0 + row.append(val) + values.append(row) + + return times, prices, values + except Exception as e: + logger.error(f"Error building COB heatmap matrix for {symbol}: {e}") + return [], [], [] def get_combined_ohlcv_cob_data(self, symbol: str, timeframe: str = '1s', count: int = 60) -> dict: """ diff --git a/core/standardized_data_provider.py b/core/standardized_data_provider.py index d53ae93..2ed8029 100644 --- a/core/standardized_data_provider.py +++ b/core/standardized_data_provider.py @@ -165,6 +165,20 @@ class StandardizedDataProvider(DataProvider): pivot_points=pivot_points, last_predictions=last_predictions ) + + # Attach COB heatmap (visual+model optional input), fixed scope defaults + try: + times, prices, matrix = self.get_cob_heatmap_matrix( + symbol=symbol, + seconds=300, + bucket_radius=10, + metric='imbalance' + ) + base_input.cob_heatmap_times = times + base_input.cob_heatmap_prices = prices + base_input.cob_heatmap_values = matrix + except Exception as _hm_ex: + logger.debug(f"COB heatmap not attached for {symbol}: {_hm_ex}") # Validate the input if not base_input.validate():