761 lines
32 KiB
Python
761 lines
32 KiB
Python
"""
|
|
Standardized Data Models for Multi-Modal Trading System
|
|
|
|
This module defines the standardized data structures used across all models:
|
|
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
|
|
- ModelOutput: Extensible output format supporting all model types
|
|
- COBData: Cumulative Order Book data structure
|
|
- Enhanced data structures for cross-model feeding and extensibility
|
|
"""
|
|
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
|
|
@dataclass
|
|
class OHLCVBar:
|
|
"""
|
|
Enhanced OHLCV bar data structure with technical analysis features
|
|
|
|
Includes candle pattern recognition, relative sizing, body/wick analysis,
|
|
and Williams pivot points metadata for improved model feature engineering.
|
|
"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
open: float
|
|
high: float
|
|
low: float
|
|
close: float
|
|
volume: float
|
|
timeframe: str
|
|
indicators: Dict[str, float] = field(default_factory=dict)
|
|
|
|
# Pivot points metadata
|
|
pivot_distance_to_support: Optional[float] = None
|
|
pivot_distance_to_resistance: Optional[float] = None
|
|
pivot_level_context: Optional[Dict[str, Any]] = field(default=None)
|
|
near_pivot_support: bool = False
|
|
near_pivot_resistance: bool = False
|
|
|
|
# Candle characteristics (computed on-demand or cached)
|
|
_body_size: Optional[float] = field(default=None, repr=False)
|
|
_upper_wick: Optional[float] = field(default=None, repr=False)
|
|
_lower_wick: Optional[float] = field(default=None, repr=False)
|
|
_total_range: Optional[float] = field(default=None, repr=False)
|
|
_is_bullish: Optional[bool] = field(default=None, repr=False)
|
|
|
|
@property
|
|
def body_size(self) -> float:
|
|
"""Absolute size of candle body"""
|
|
if self._body_size is None:
|
|
self._body_size = abs(self.close - self.open)
|
|
return self._body_size
|
|
|
|
@property
|
|
def upper_wick(self) -> float:
|
|
"""Size of upper wick/shadow"""
|
|
if self._upper_wick is None:
|
|
self._upper_wick = self.high - max(self.open, self.close)
|
|
return self._upper_wick
|
|
|
|
@property
|
|
def lower_wick(self) -> float:
|
|
"""Size of lower wick/shadow"""
|
|
if self._lower_wick is None:
|
|
self._lower_wick = min(self.open, self.close) - self.low
|
|
return self._lower_wick
|
|
|
|
@property
|
|
def total_range(self) -> float:
|
|
"""Total high-low range"""
|
|
if self._total_range is None:
|
|
self._total_range = self.high - self.low
|
|
return self._total_range
|
|
|
|
@property
|
|
def is_bullish(self) -> bool:
|
|
"""True if close > open (hollow/green candle)"""
|
|
if self._is_bullish is None:
|
|
self._is_bullish = self.close > self.open
|
|
return self._is_bullish
|
|
|
|
@property
|
|
def is_bearish(self) -> bool:
|
|
"""True if close < open (solid/red candle)"""
|
|
return not self.is_bullish and self.close != self.open
|
|
|
|
@property
|
|
def is_doji(self) -> bool:
|
|
"""True if open ≈ close (doji pattern)"""
|
|
return self.body_size < (self.total_range * 0.1) if self.total_range > 0 else True
|
|
|
|
def get_body_to_range_ratio(self) -> float:
|
|
"""Body size as percentage of total range (0.0 to 1.0)"""
|
|
return self.body_size / self.total_range if self.total_range > 0 else 0.0
|
|
|
|
def get_upper_wick_ratio(self) -> float:
|
|
"""Upper wick as percentage of total range (0.0 to 1.0)"""
|
|
return self.upper_wick / self.total_range if self.total_range > 0 else 0.0
|
|
|
|
def get_lower_wick_ratio(self) -> float:
|
|
"""Lower wick as percentage of total range (0.0 to 1.0)"""
|
|
return self.lower_wick / self.total_range if self.total_range > 0 else 0.0
|
|
|
|
def get_relative_size(self, reference_bars: List['OHLCVBar'], method: str = 'avg') -> float:
|
|
"""
|
|
Get relative size compared to reference bars
|
|
|
|
Args:
|
|
reference_bars: List of previous bars for comparison
|
|
method: 'avg' (average), 'max' (maximum), or 'median'
|
|
|
|
Returns:
|
|
Ratio of current range to reference (1.0 = same size, >1.0 = larger, <1.0 = smaller)
|
|
"""
|
|
if not reference_bars:
|
|
return 1.0
|
|
|
|
reference_ranges = [bar.total_range for bar in reference_bars if bar.total_range > 0]
|
|
if not reference_ranges:
|
|
return 1.0
|
|
|
|
if method == 'avg':
|
|
reference_value = np.mean(reference_ranges)
|
|
elif method == 'max':
|
|
reference_value = np.max(reference_ranges)
|
|
elif method == 'median':
|
|
reference_value = np.median(reference_ranges)
|
|
else:
|
|
reference_value = np.mean(reference_ranges)
|
|
|
|
return self.total_range / reference_value if reference_value > 0 else 1.0
|
|
|
|
def get_candle_pattern(self) -> str:
|
|
"""
|
|
Identify basic candle pattern
|
|
|
|
Returns:
|
|
Pattern name: 'doji', 'hammer', 'shooting_star', 'spinning_top',
|
|
'marubozu_bullish', 'marubozu_bearish', 'standard'
|
|
"""
|
|
if self.total_range == 0:
|
|
return 'doji'
|
|
|
|
body_ratio = self.get_body_to_range_ratio()
|
|
upper_ratio = self.get_upper_wick_ratio()
|
|
lower_ratio = self.get_lower_wick_ratio()
|
|
|
|
# Doji: very small body
|
|
if body_ratio < 0.1:
|
|
return 'doji'
|
|
|
|
# Marubozu: very small wicks (>90% body)
|
|
if body_ratio > 0.9:
|
|
return 'marubozu_bullish' if self.is_bullish else 'marubozu_bearish'
|
|
|
|
# Hammer: small body at top, long lower wick
|
|
if body_ratio < 0.3 and lower_ratio > 0.6 and upper_ratio < 0.1:
|
|
return 'hammer'
|
|
|
|
# Shooting star: small body at bottom, long upper wick
|
|
if body_ratio < 0.3 and upper_ratio > 0.6 and lower_ratio < 0.1:
|
|
return 'shooting_star'
|
|
|
|
# Spinning top: small body, both wicks present
|
|
if body_ratio < 0.3 and (upper_ratio + lower_ratio) > 0.6:
|
|
return 'spinning_top'
|
|
|
|
return 'standard'
|
|
|
|
def get_ta_features(self, reference_bars: Optional[List['OHLCVBar']] = None) -> Dict[str, float]:
|
|
"""
|
|
Get all technical analysis features as a dictionary
|
|
|
|
Args:
|
|
reference_bars: Optional list of previous bars for relative sizing
|
|
|
|
Returns:
|
|
Dictionary of TA features suitable for model input
|
|
"""
|
|
features = {
|
|
# Basic candle properties
|
|
'is_bullish': 1.0 if self.is_bullish else 0.0,
|
|
'is_bearish': 1.0 if self.is_bearish else 0.0,
|
|
'is_doji': 1.0 if self.is_doji else 0.0,
|
|
|
|
# Size ratios
|
|
'body_to_range_ratio': self.get_body_to_range_ratio(),
|
|
'upper_wick_ratio': self.get_upper_wick_ratio(),
|
|
'lower_wick_ratio': self.get_lower_wick_ratio(),
|
|
|
|
# Absolute sizes (normalized by close price)
|
|
'body_size_pct': self.body_size / self.close if self.close > 0 else 0.0,
|
|
'upper_wick_pct': self.upper_wick / self.close if self.close > 0 else 0.0,
|
|
'lower_wick_pct': self.lower_wick / self.close if self.close > 0 else 0.0,
|
|
'total_range_pct': self.total_range / self.close if self.close > 0 else 0.0,
|
|
|
|
# Volume relative to price movement
|
|
'volume_per_range': self.volume / self.total_range if self.total_range > 0 else 0.0,
|
|
}
|
|
|
|
# Add relative sizing if reference bars provided
|
|
if reference_bars:
|
|
features['relative_size_avg'] = self.get_relative_size(reference_bars, 'avg')
|
|
features['relative_size_max'] = self.get_relative_size(reference_bars, 'max')
|
|
features['relative_size_median'] = self.get_relative_size(reference_bars, 'median')
|
|
|
|
# Add pattern encoding (one-hot style)
|
|
pattern = self.get_candle_pattern()
|
|
pattern_types = ['doji', 'hammer', 'shooting_star', 'spinning_top',
|
|
'marubozu_bullish', 'marubozu_bearish', 'standard']
|
|
for p in pattern_types:
|
|
features[f'pattern_{p}'] = 1.0 if pattern == p else 0.0
|
|
|
|
return features
|
|
|
|
@dataclass
|
|
class PivotPoint:
|
|
"""Pivot point data structure"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
price: float
|
|
type: str # 'high' or 'low'
|
|
level: int # Pivot level (1, 2, 3, etc.)
|
|
confidence: float = 1.0
|
|
|
|
@dataclass
|
|
class ModelOutput:
|
|
"""Extensible model output format supporting all model types"""
|
|
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
|
model_name: str # Specific model identifier
|
|
symbol: str
|
|
timestamp: datetime
|
|
confidence: float
|
|
predictions: Dict[str, Any] # Model-specific predictions
|
|
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
|
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
|
|
|
@dataclass
|
|
class COBData:
|
|
"""Cumulative Order Book data for price buckets"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
current_price: float
|
|
bucket_size: float # $1 for ETH, $10 for BTC
|
|
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
|
|
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
|
|
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
|
|
order_flow_metrics: Dict[str, float] # Various order flow indicators
|
|
|
|
# Moving averages of COB imbalance for ±5 buckets
|
|
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
|
|
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
|
|
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
|
|
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
|
|
|
|
@dataclass
|
|
class NormalizationBounds:
|
|
"""Normalization boundaries for price and volume data"""
|
|
price_min: float
|
|
price_max: float
|
|
volume_min: float
|
|
volume_max: float
|
|
symbol: str
|
|
timeframe: str = 'all' # 'all' means across all timeframes
|
|
|
|
def normalize_price(self, price: float) -> float:
|
|
"""Normalize price to 0-1 range"""
|
|
if self.price_max == self.price_min:
|
|
return 0.5
|
|
return (price - self.price_min) / (self.price_max - self.price_min)
|
|
|
|
def denormalize_price(self, normalized: float) -> float:
|
|
"""Denormalize price from 0-1 range back to original"""
|
|
return normalized * (self.price_max - self.price_min) + self.price_min
|
|
|
|
def normalize_volume(self, volume: float) -> float:
|
|
"""Normalize volume to 0-1 range"""
|
|
if self.volume_max == self.volume_min:
|
|
return 0.5
|
|
return (volume - self.volume_min) / (self.volume_max - self.volume_min)
|
|
|
|
def denormalize_volume(self, normalized: float) -> float:
|
|
"""Denormalize volume from 0-1 range back to original"""
|
|
return normalized * (self.volume_max - self.volume_min) + self.volume_min
|
|
|
|
def get_price_range(self) -> float:
|
|
"""Get price range"""
|
|
return self.price_max - self.price_min
|
|
|
|
def get_volume_range(self) -> float:
|
|
"""Get volume range"""
|
|
return self.volume_max - self.volume_min
|
|
|
|
@dataclass
|
|
class BaseDataInput:
|
|
"""
|
|
Unified base data input for all models
|
|
|
|
Standardized format ensures all models receive identical input structure:
|
|
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
|
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
|
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
|
- All OHLCV data is normalized to 0-1 range based on daily (longest timeframe) min/max
|
|
"""
|
|
symbol: str # Primary symbol (ETH/USDT)
|
|
timestamp: datetime
|
|
|
|
# Multi-timeframe OHLCV data for primary symbol (ETH)
|
|
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
|
|
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
|
|
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
|
|
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
|
|
|
|
# Reference symbol (BTC) 1s data
|
|
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
|
|
|
|
# COB data for 1s timeframe (±20 buckets around current price)
|
|
cob_data: Optional[COBData] = None
|
|
# COB heatmap (time-series of bucket metrics at 1s resolution)
|
|
# Each row corresponds to one second, columns to price buckets
|
|
cob_heatmap_times: List[datetime] = field(default_factory=list)
|
|
cob_heatmap_prices: List[float] = field(default_factory=list)
|
|
cob_heatmap_values: List[List[float]] = field(default_factory=list) # typically imbalance per bucket
|
|
|
|
# Technical indicators
|
|
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
|
|
|
# Pivot points from Williams Market Structure
|
|
pivot_points: List[PivotPoint] = field(default_factory=list)
|
|
|
|
# Last predictions from all models (for cross-model feeding)
|
|
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
|
|
|
|
# Market microstructure data
|
|
market_microstructure: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Position and trading state information
|
|
position_info: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Normalization boundaries (computed on-demand, cached)
|
|
_normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False)
|
|
_btc_normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False)
|
|
|
|
def _compute_normalization_bounds(self) -> NormalizationBounds:
|
|
"""
|
|
Compute normalization bounds from daily (longest timeframe) data
|
|
|
|
Uses daily data as it has the widest price range, ensuring all shorter
|
|
timeframes are normalized within 0-1 range.
|
|
|
|
Returns:
|
|
NormalizationBounds: Min/max for price and volume
|
|
"""
|
|
if self._normalization_bounds is not None:
|
|
return self._normalization_bounds
|
|
|
|
# Collect all OHLCV data, prioritizing daily for widest range
|
|
all_prices = []
|
|
all_volumes = []
|
|
|
|
# Use daily data first (widest range)
|
|
for bar in self.ohlcv_1d:
|
|
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
|
|
all_volumes.append(bar.volume)
|
|
|
|
# Add other timeframes to ensure coverage
|
|
for ohlcv_list in [self.ohlcv_1h, self.ohlcv_1m, self.ohlcv_1s]:
|
|
for bar in ohlcv_list:
|
|
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
|
|
all_volumes.append(bar.volume)
|
|
|
|
# Compute bounds
|
|
if all_prices and all_volumes:
|
|
price_min = min(all_prices)
|
|
price_max = max(all_prices)
|
|
volume_min = min(all_volumes)
|
|
volume_max = max(all_volumes)
|
|
else:
|
|
# Fallback if no data
|
|
price_min = price_max = 0.0
|
|
volume_min = volume_max = 0.0
|
|
|
|
self._normalization_bounds = NormalizationBounds(
|
|
price_min=price_min,
|
|
price_max=price_max,
|
|
volume_min=volume_min,
|
|
volume_max=volume_max,
|
|
symbol=self.symbol,
|
|
timeframe='all'
|
|
)
|
|
|
|
return self._normalization_bounds
|
|
|
|
def _compute_btc_normalization_bounds(self) -> NormalizationBounds:
|
|
"""
|
|
Compute normalization bounds for BTC data
|
|
|
|
Returns:
|
|
NormalizationBounds: Min/max for BTC price and volume
|
|
"""
|
|
if self._btc_normalization_bounds is not None:
|
|
return self._btc_normalization_bounds
|
|
|
|
all_prices = []
|
|
all_volumes = []
|
|
|
|
for bar in self.btc_ohlcv_1s:
|
|
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
|
|
all_volumes.append(bar.volume)
|
|
|
|
if all_prices and all_volumes:
|
|
price_min = min(all_prices)
|
|
price_max = max(all_prices)
|
|
volume_min = min(all_volumes)
|
|
volume_max = max(all_volumes)
|
|
else:
|
|
price_min = price_max = 0.0
|
|
volume_min = volume_max = 0.0
|
|
|
|
self._btc_normalization_bounds = NormalizationBounds(
|
|
price_min=price_min,
|
|
price_max=price_max,
|
|
volume_min=volume_min,
|
|
volume_max=volume_max,
|
|
symbol='BTC/USDT',
|
|
timeframe='1s'
|
|
)
|
|
|
|
return self._btc_normalization_bounds
|
|
|
|
def get_normalization_bounds(self) -> NormalizationBounds:
|
|
"""Get normalization bounds for primary symbol (cached)"""
|
|
return self._compute_normalization_bounds()
|
|
|
|
def get_btc_normalization_bounds(self) -> NormalizationBounds:
|
|
"""Get normalization bounds for BTC (cached)"""
|
|
return self._compute_btc_normalization_bounds()
|
|
|
|
def get_feature_vector(self, include_candle_ta: bool = True, normalize: bool = True) -> np.ndarray:
|
|
"""
|
|
Convert BaseDataInput to standardized feature vector for models
|
|
|
|
Args:
|
|
include_candle_ta: If True, include enhanced candle TA features (default: True)
|
|
normalize: If True, normalize OHLCV data to 0-1 range (default: True)
|
|
|
|
Returns:
|
|
np.ndarray: FIXED SIZE standardized feature vector (7870 or 22880 features)
|
|
|
|
Note:
|
|
- Full TA features are enabled by default for better model performance
|
|
- Normalization uses daily (longest timeframe) min/max for primary symbol
|
|
- BTC data is normalized independently using its own min/max
|
|
- Normalization bounds are cached and accessible via get_normalization_bounds()
|
|
- Includes pivot points metadata (10 features) for market structure context
|
|
"""
|
|
# FIXED FEATURE SIZE - this should NEVER change at runtime
|
|
# Standard: 7870 features (7850 + 10 pivot + 10 more indicators)
|
|
# With candle TA: 22880 features (22850 + 10 pivot + 10 more indicators)
|
|
FIXED_FEATURE_SIZE = 22880 if include_candle_ta else 7870
|
|
features = []
|
|
|
|
# Get normalization bounds (cached)
|
|
if normalize:
|
|
norm_bounds = self._compute_normalization_bounds()
|
|
|
|
# OHLCV features for ETH (up to 300 frames x 4 timeframes x 5 or 15 features)
|
|
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
|
# Use actual data only, up to 300 frames
|
|
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
|
|
|
|
# Extract features from actual frames
|
|
for i, bar in enumerate(ohlcv_frames):
|
|
# Basic OHLCV (5 features) - normalized to 0-1 range
|
|
if normalize:
|
|
features.extend([
|
|
norm_bounds.normalize_price(bar.open),
|
|
norm_bounds.normalize_price(bar.high),
|
|
norm_bounds.normalize_price(bar.low),
|
|
norm_bounds.normalize_price(bar.close),
|
|
norm_bounds.normalize_volume(bar.volume)
|
|
])
|
|
else:
|
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
|
|
|
# Enhanced candle TA features (10 additional features per bar)
|
|
if include_candle_ta:
|
|
# Get reference bars for relative sizing (last 10 bars)
|
|
ref_start = max(0, i - 10)
|
|
reference_bars = ohlcv_frames[ref_start:i] if i > 0 else []
|
|
|
|
ta_features = bar.get_ta_features(reference_bars)
|
|
# Extract key features in fixed order
|
|
features.extend([
|
|
ta_features.get('is_bullish', 0.0),
|
|
ta_features.get('body_to_range_ratio', 0.0),
|
|
ta_features.get('upper_wick_ratio', 0.0),
|
|
ta_features.get('lower_wick_ratio', 0.0),
|
|
ta_features.get('body_size_pct', 0.0),
|
|
ta_features.get('total_range_pct', 0.0),
|
|
ta_features.get('relative_size_avg', 1.0),
|
|
ta_features.get('pattern_doji', 0.0),
|
|
ta_features.get('pattern_hammer', 0.0),
|
|
ta_features.get('pattern_shooting_star', 0.0),
|
|
])
|
|
|
|
# Pad with zeros only if we have some data but less than 300 frames
|
|
frames_needed = 300 - len(ohlcv_frames)
|
|
if frames_needed > 0:
|
|
features_per_frame = 15 if include_candle_ta else 5
|
|
features.extend([0.0] * (frames_needed * features_per_frame))
|
|
|
|
# BTC OHLCV features (up to 300 frames x 5 or 15 features)
|
|
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
|
|
|
|
# Get BTC normalization bounds (cached, independent from primary symbol)
|
|
if normalize:
|
|
btc_norm_bounds = self._compute_btc_normalization_bounds()
|
|
|
|
# Extract features from actual BTC frames
|
|
for i, bar in enumerate(btc_frames):
|
|
# Basic OHLCV (5 features) - normalized to 0-1 range
|
|
if normalize:
|
|
features.extend([
|
|
btc_norm_bounds.normalize_price(bar.open),
|
|
btc_norm_bounds.normalize_price(bar.high),
|
|
btc_norm_bounds.normalize_price(bar.low),
|
|
btc_norm_bounds.normalize_price(bar.close),
|
|
btc_norm_bounds.normalize_volume(bar.volume)
|
|
])
|
|
else:
|
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
|
|
|
# Enhanced candle TA features (10 additional features per bar)
|
|
if include_candle_ta:
|
|
ref_start = max(0, i - 10)
|
|
reference_bars = btc_frames[ref_start:i] if i > 0 else []
|
|
|
|
ta_features = bar.get_ta_features(reference_bars)
|
|
features.extend([
|
|
ta_features.get('is_bullish', 0.0),
|
|
ta_features.get('body_to_range_ratio', 0.0),
|
|
ta_features.get('upper_wick_ratio', 0.0),
|
|
ta_features.get('lower_wick_ratio', 0.0),
|
|
ta_features.get('body_size_pct', 0.0),
|
|
ta_features.get('total_range_pct', 0.0),
|
|
ta_features.get('relative_size_avg', 1.0),
|
|
ta_features.get('pattern_doji', 0.0),
|
|
ta_features.get('pattern_hammer', 0.0),
|
|
ta_features.get('pattern_shooting_star', 0.0),
|
|
])
|
|
|
|
# Pad with zeros only if we have some data but less than 300 frames
|
|
btc_frames_needed = 300 - len(btc_frames)
|
|
if btc_frames_needed > 0:
|
|
features_per_frame = 15 if include_candle_ta else 5
|
|
features.extend([0.0] * (btc_frames_needed * features_per_frame))
|
|
|
|
# COB features (FIXED SIZE: 200 features)
|
|
cob_features = []
|
|
if self.cob_data:
|
|
# Price bucket features (up to 40 buckets x 4 metrics = 160 features)
|
|
price_keys = sorted(self.cob_data.price_buckets.keys())[:40] # Max 40 buckets
|
|
for price in price_keys:
|
|
bucket_data = self.cob_data.price_buckets[price]
|
|
cob_features.extend([
|
|
bucket_data.get('bid_volume', 0.0),
|
|
bucket_data.get('ask_volume', 0.0),
|
|
bucket_data.get('total_volume', 0.0),
|
|
bucket_data.get('imbalance', 0.0)
|
|
])
|
|
|
|
# Moving averages (up to 10 features)
|
|
ma_features = []
|
|
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance]:
|
|
for price in sorted(list(ma_dict.keys())[:5]): # Max 5 buckets per MA
|
|
ma_features.append(ma_dict[price])
|
|
if len(ma_features) >= 10:
|
|
break
|
|
if len(ma_features) >= 10:
|
|
break
|
|
cob_features.extend(ma_features)
|
|
|
|
# Add REAL aggregated COB heatmap features to fill remaining COB slots (no synthetic data)
|
|
# We compute per-bucket means over the most recent window (up to 300s) and a few global stats
|
|
try:
|
|
if self.cob_heatmap_values and self.cob_heatmap_prices:
|
|
z = np.array(self.cob_heatmap_values, dtype=float)
|
|
if z.ndim == 2 and z.size > 0:
|
|
# Use up to the last 300 seconds (or whatever is available)
|
|
window_rows = z[-300:] if z.shape[0] >= 300 else z
|
|
# Replace NaNs with 0.0 to respect the no-synthetic rule but avoid NaN propagation
|
|
window_rows = np.nan_to_num(window_rows, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# Per-bucket mean imbalance/liquidity across time
|
|
per_bucket_mean = window_rows.mean(axis=0).tolist()
|
|
space_left = 200 - len(cob_features)
|
|
if space_left > 0 and len(per_bucket_mean) > 0:
|
|
cob_features.extend(per_bucket_mean[:space_left])
|
|
|
|
# If there is still space, add compact global stats over the window
|
|
space_left = 200 - len(cob_features)
|
|
if space_left > 0:
|
|
flat = window_rows.reshape(-1)
|
|
if flat.size > 0:
|
|
global_mean = float(np.mean(flat))
|
|
global_std = float(np.std(flat))
|
|
global_max = float(np.max(flat))
|
|
global_min = float(np.min(flat))
|
|
global_stats = [global_mean, global_std, global_max, global_min]
|
|
cob_features.extend(global_stats[:space_left])
|
|
except Exception:
|
|
# On any error, skip heatmap-derived features (remaining space will be zero-padded below)
|
|
pass
|
|
|
|
# Pad COB features to exactly 200
|
|
cob_features.extend([0.0] * (200 - len(cob_features)))
|
|
features.extend(cob_features[:200]) # Ensure exactly 200 COB features
|
|
|
|
# Technical indicators (FIXED SIZE: 110 features - expanded to accommodate more indicators)
|
|
indicator_values = list(self.technical_indicators.values())
|
|
features.extend(indicator_values[:110]) # Take first 110 indicators
|
|
features.extend([0.0] * max(0, 110 - len(indicator_values))) # Pad to exactly 110
|
|
|
|
# Pivot points metadata (FIXED SIZE: 10 features)
|
|
# Extract pivot context from most recent OHLCV bars
|
|
pivot_features = []
|
|
if self.ohlcv_1m and len(self.ohlcv_1m) > 0:
|
|
latest_bar = self.ohlcv_1m[-1]
|
|
pivot_features.extend([
|
|
latest_bar.pivot_distance_to_support if latest_bar.pivot_distance_to_support is not None else 0.0,
|
|
latest_bar.pivot_distance_to_resistance if latest_bar.pivot_distance_to_resistance is not None else 0.0,
|
|
1.0 if latest_bar.near_pivot_support else 0.0,
|
|
1.0 if latest_bar.near_pivot_resistance else 0.0,
|
|
])
|
|
# Add pivot level context if available
|
|
if latest_bar.pivot_level_context:
|
|
ctx = latest_bar.pivot_level_context
|
|
pivot_features.extend([
|
|
ctx.get('trend_strength', 0.0),
|
|
ctx.get('support_count', 0.0),
|
|
ctx.get('resistance_count', 0.0),
|
|
ctx.get('price_position_in_range', 0.5), # 0=at support, 1=at resistance
|
|
ctx.get('distance_to_nearest_level', 0.0),
|
|
ctx.get('level_strength', 0.0),
|
|
])
|
|
else:
|
|
pivot_features.extend([0.0] * 6)
|
|
else:
|
|
pivot_features = [0.0] * 10
|
|
|
|
# Ensure exactly 10 pivot features
|
|
pivot_features = pivot_features[:10]
|
|
pivot_features.extend([0.0] * (10 - len(pivot_features)))
|
|
features.extend(pivot_features)
|
|
|
|
# Last predictions from other models (FIXED SIZE: 45 features)
|
|
prediction_features = []
|
|
for model_output in self.last_predictions.values():
|
|
prediction_features.extend([
|
|
model_output.confidence,
|
|
model_output.predictions.get('buy_probability', 0.0),
|
|
model_output.predictions.get('sell_probability', 0.0),
|
|
model_output.predictions.get('hold_probability', 0.0),
|
|
model_output.predictions.get('expected_reward', 0.0)
|
|
])
|
|
features.extend(prediction_features[:45]) # Take first 45 prediction features
|
|
features.extend([0.0] * max(0, 45 - len(prediction_features))) # Pad to exactly 45
|
|
|
|
# Position and trading state information (FIXED SIZE: 5 features)
|
|
position_features = [
|
|
1.0 if self.position_info.get('has_position', False) else 0.0,
|
|
self.position_info.get('position_pnl', 0.0),
|
|
self.position_info.get('position_size', 0.0),
|
|
self.position_info.get('entry_price', 0.0),
|
|
self.position_info.get('time_in_position_minutes', 0.0)
|
|
]
|
|
features.extend(position_features) # Exactly 5 position features
|
|
|
|
# CRITICAL: Ensure EXACTLY the fixed feature size
|
|
if len(features) > FIXED_FEATURE_SIZE:
|
|
features = features[:FIXED_FEATURE_SIZE] # Truncate if too long
|
|
elif len(features) < FIXED_FEATURE_SIZE:
|
|
features.extend([0.0] * (FIXED_FEATURE_SIZE - len(features))) # Pad if too short
|
|
|
|
assert len(features) == FIXED_FEATURE_SIZE, f"Feature vector size mismatch: {len(features)} != {FIXED_FEATURE_SIZE}"
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
def validate(self) -> bool:
|
|
"""
|
|
Validate that the BaseDataInput contains required data
|
|
|
|
Returns:
|
|
bool: True if valid, False otherwise
|
|
"""
|
|
# Check that we have required OHLCV data
|
|
if len(self.ohlcv_1s) < 100: # At least 100 frames
|
|
return False
|
|
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
|
|
return False
|
|
|
|
# Check that timestamps are reasonable
|
|
if not self.timestamp:
|
|
return False
|
|
|
|
# Check symbol format
|
|
if not self.symbol or '/' not in self.symbol:
|
|
return False
|
|
|
|
return True
|
|
|
|
@dataclass
|
|
class TradingAction:
|
|
"""Trading action output from models"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
action: str # 'BUY', 'SELL', 'HOLD'
|
|
confidence: float
|
|
source: str # 'rl', 'cnn', 'orchestrator'
|
|
price: Optional[float] = None
|
|
quantity: Optional[float] = None
|
|
reason: Optional[str] = None
|
|
|
|
def create_model_output(model_type: str, model_name: str, symbol: str,
|
|
action: str, confidence: float,
|
|
hidden_states: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
|
|
"""
|
|
Helper function to create standardized ModelOutput
|
|
|
|
Args:
|
|
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
|
|
model_name: Specific model identifier
|
|
symbol: Trading symbol
|
|
action: Trading action ('BUY', 'SELL', 'HOLD')
|
|
confidence: Confidence score (0.0 to 1.0)
|
|
hidden_states: Optional hidden states for cross-model feeding
|
|
metadata: Optional additional metadata
|
|
|
|
Returns:
|
|
ModelOutput: Standardized model output
|
|
"""
|
|
predictions = {
|
|
'action': action,
|
|
'buy_probability': confidence if action == 'BUY' else 0.0,
|
|
'sell_probability': confidence if action == 'SELL' else 0.0,
|
|
'hold_probability': confidence if action == 'HOLD' else 0.0,
|
|
}
|
|
|
|
return ModelOutput(
|
|
model_type=model_type,
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
confidence=confidence,
|
|
predictions=predictions,
|
|
hidden_states=hidden_states or {},
|
|
metadata=metadata or {}
|
|
) |