280 lines
11 KiB
Python
280 lines
11 KiB
Python
"""
|
|
Standardized Data Models for Multi-Modal Trading System
|
|
|
|
This module defines the standardized data structures used across all models:
|
|
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
|
|
- ModelOutput: Extensible output format supporting all model types
|
|
- COBData: Cumulative Order Book data structure
|
|
- Enhanced data structures for cross-model feeding and extensibility
|
|
"""
|
|
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
|
|
@dataclass
|
|
class OHLCVBar:
|
|
"""OHLCV bar data structure"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
open: float
|
|
high: float
|
|
low: float
|
|
close: float
|
|
volume: float
|
|
timeframe: str
|
|
indicators: Dict[str, float] = field(default_factory=dict)
|
|
|
|
@dataclass
|
|
class PivotPoint:
|
|
"""Pivot point data structure"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
price: float
|
|
type: str # 'high' or 'low'
|
|
level: int # Pivot level (1, 2, 3, etc.)
|
|
confidence: float = 1.0
|
|
|
|
@dataclass
|
|
class ModelOutput:
|
|
"""Extensible model output format supporting all model types"""
|
|
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
|
model_name: str # Specific model identifier
|
|
symbol: str
|
|
timestamp: datetime
|
|
confidence: float
|
|
predictions: Dict[str, Any] # Model-specific predictions
|
|
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
|
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
|
|
|
@dataclass
|
|
class COBData:
|
|
"""Cumulative Order Book data for price buckets"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
current_price: float
|
|
bucket_size: float # $1 for ETH, $10 for BTC
|
|
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
|
|
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
|
|
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
|
|
order_flow_metrics: Dict[str, float] # Various order flow indicators
|
|
|
|
# Moving averages of COB imbalance for ±5 buckets
|
|
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
|
|
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
|
|
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
|
|
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
|
|
|
|
@dataclass
|
|
class BaseDataInput:
|
|
"""
|
|
Unified base data input for all models
|
|
|
|
Standardized format ensures all models receive identical input structure:
|
|
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
|
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
|
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
|
"""
|
|
symbol: str # Primary symbol (ETH/USDT)
|
|
timestamp: datetime
|
|
|
|
# Multi-timeframe OHLCV data for primary symbol (ETH)
|
|
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
|
|
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
|
|
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
|
|
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
|
|
|
|
# Reference symbol (BTC) 1s data
|
|
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
|
|
|
|
# COB data for 1s timeframe (±20 buckets around current price)
|
|
cob_data: Optional[COBData] = None
|
|
|
|
# Technical indicators
|
|
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
|
|
|
# Pivot points from Williams Market Structure
|
|
pivot_points: List[PivotPoint] = field(default_factory=list)
|
|
|
|
# Last predictions from all models (for cross-model feeding)
|
|
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
|
|
|
|
# Market microstructure data
|
|
market_microstructure: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def get_feature_vector(self) -> np.ndarray:
|
|
"""
|
|
Convert BaseDataInput to standardized feature vector for models
|
|
|
|
Returns:
|
|
np.ndarray: FIXED SIZE standardized feature vector (7850 features)
|
|
"""
|
|
# FIXED FEATURE SIZE - this should NEVER change at runtime
|
|
FIXED_FEATURE_SIZE = 7850
|
|
features = []
|
|
|
|
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
|
|
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
|
# Ensure exactly 300 frames by padding or truncating
|
|
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
|
|
|
|
# Pad with zeros if not enough data
|
|
while len(ohlcv_frames) < 300:
|
|
# Create a dummy OHLCV bar with zeros
|
|
dummy_bar = OHLCVBar(
|
|
symbol="ETH/USDT",
|
|
timestamp=datetime.now(),
|
|
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
|
|
timeframe="1s"
|
|
)
|
|
ohlcv_frames.insert(0, dummy_bar)
|
|
|
|
# Extract features from exactly 300 frames
|
|
for bar in ohlcv_frames:
|
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
|
|
|
# BTC OHLCV features (300 frames x 5 features = 1500 features)
|
|
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
|
|
|
|
# Pad BTC data if needed
|
|
while len(btc_frames) < 300:
|
|
dummy_bar = OHLCVBar(
|
|
symbol="BTC/USDT",
|
|
timestamp=datetime.now(),
|
|
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
|
|
timeframe="1s"
|
|
)
|
|
btc_frames.insert(0, dummy_bar)
|
|
|
|
for bar in btc_frames:
|
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
|
|
|
# COB features (FIXED SIZE: 200 features)
|
|
cob_features = []
|
|
if self.cob_data:
|
|
# Price bucket features (up to 40 buckets x 4 metrics = 160 features)
|
|
price_keys = sorted(self.cob_data.price_buckets.keys())[:40] # Max 40 buckets
|
|
for price in price_keys:
|
|
bucket_data = self.cob_data.price_buckets[price]
|
|
cob_features.extend([
|
|
bucket_data.get('bid_volume', 0.0),
|
|
bucket_data.get('ask_volume', 0.0),
|
|
bucket_data.get('total_volume', 0.0),
|
|
bucket_data.get('imbalance', 0.0)
|
|
])
|
|
|
|
# Moving averages (up to 10 features)
|
|
ma_features = []
|
|
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance]:
|
|
for price in sorted(list(ma_dict.keys())[:5]): # Max 5 buckets per MA
|
|
ma_features.append(ma_dict[price])
|
|
if len(ma_features) >= 10:
|
|
break
|
|
if len(ma_features) >= 10:
|
|
break
|
|
cob_features.extend(ma_features)
|
|
|
|
# Pad COB features to exactly 200
|
|
cob_features.extend([0.0] * (200 - len(cob_features)))
|
|
features.extend(cob_features[:200]) # Ensure exactly 200 COB features
|
|
|
|
# Technical indicators (FIXED SIZE: 100 features)
|
|
indicator_values = list(self.technical_indicators.values())
|
|
features.extend(indicator_values[:100]) # Take first 100 indicators
|
|
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad to exactly 100
|
|
|
|
# Last predictions from other models (FIXED SIZE: 50 features)
|
|
prediction_features = []
|
|
for model_output in self.last_predictions.values():
|
|
prediction_features.extend([
|
|
model_output.confidence,
|
|
model_output.predictions.get('buy_probability', 0.0),
|
|
model_output.predictions.get('sell_probability', 0.0),
|
|
model_output.predictions.get('hold_probability', 0.0),
|
|
model_output.predictions.get('expected_reward', 0.0)
|
|
])
|
|
features.extend(prediction_features[:50]) # Take first 50 prediction features
|
|
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad to exactly 50
|
|
|
|
# CRITICAL: Ensure EXACTLY the fixed feature size
|
|
if len(features) > FIXED_FEATURE_SIZE:
|
|
features = features[:FIXED_FEATURE_SIZE] # Truncate if too long
|
|
elif len(features) < FIXED_FEATURE_SIZE:
|
|
features.extend([0.0] * (FIXED_FEATURE_SIZE - len(features))) # Pad if too short
|
|
|
|
assert len(features) == FIXED_FEATURE_SIZE, f"Feature vector size mismatch: {len(features)} != {FIXED_FEATURE_SIZE}"
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
def validate(self) -> bool:
|
|
"""
|
|
Validate that the BaseDataInput contains required data
|
|
|
|
Returns:
|
|
bool: True if valid, False otherwise
|
|
"""
|
|
# Check that we have required OHLCV data
|
|
if len(self.ohlcv_1s) < 100: # At least 100 frames
|
|
return False
|
|
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
|
|
return False
|
|
|
|
# Check that timestamps are reasonable
|
|
if not self.timestamp:
|
|
return False
|
|
|
|
# Check symbol format
|
|
if not self.symbol or '/' not in self.symbol:
|
|
return False
|
|
|
|
return True
|
|
|
|
@dataclass
|
|
class TradingAction:
|
|
"""Trading action output from models"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
action: str # 'BUY', 'SELL', 'HOLD'
|
|
confidence: float
|
|
source: str # 'rl', 'cnn', 'orchestrator'
|
|
price: Optional[float] = None
|
|
quantity: Optional[float] = None
|
|
reason: Optional[str] = None
|
|
|
|
def create_model_output(model_type: str, model_name: str, symbol: str,
|
|
action: str, confidence: float,
|
|
hidden_states: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
|
|
"""
|
|
Helper function to create standardized ModelOutput
|
|
|
|
Args:
|
|
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
|
|
model_name: Specific model identifier
|
|
symbol: Trading symbol
|
|
action: Trading action ('BUY', 'SELL', 'HOLD')
|
|
confidence: Confidence score (0.0 to 1.0)
|
|
hidden_states: Optional hidden states for cross-model feeding
|
|
metadata: Optional additional metadata
|
|
|
|
Returns:
|
|
ModelOutput: Standardized model output
|
|
"""
|
|
predictions = {
|
|
'action': action,
|
|
'buy_probability': confidence if action == 'BUY' else 0.0,
|
|
'sell_probability': confidence if action == 'SELL' else 0.0,
|
|
'hold_probability': confidence if action == 'HOLD' else 0.0,
|
|
}
|
|
|
|
return ModelOutput(
|
|
model_type=model_type,
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
confidence=confidence,
|
|
predictions=predictions,
|
|
hidden_states=hidden_states or {},
|
|
metadata=metadata or {}
|
|
) |