improved data structure
This commit is contained in:
803
docs/BASE_DATA_INPUT_SPECIFICATION.md
Normal file
803
docs/BASE_DATA_INPUT_SPECIFICATION.md
Normal file
@@ -0,0 +1,803 @@
|
||||
# BaseDataInput Specification
|
||||
|
||||
## Overview
|
||||
|
||||
`BaseDataInput` is the **unified, standardized data structure** used across all models in the trading system for both inference and training. It ensures consistency, extensibility, and proper feature engineering across CNN, RL, LSTM, Transformer, and Orchestrator models.
|
||||
|
||||
**Location:** `core/data_models.py`
|
||||
|
||||
---
|
||||
|
||||
## Design Principles
|
||||
|
||||
1. **Single Source of Truth**: All models receive identical input structure
|
||||
2. **Fixed Feature Size**: `get_feature_vector()` always returns exactly 7,850 features
|
||||
3. **Extensibility**: New features can be added without breaking existing models
|
||||
4. **No Synthetic Data**: All features must come from real market data or be zero-padded
|
||||
5. **Multi-Timeframe**: Supports multiple timeframes for comprehensive market analysis
|
||||
6. **Cross-Model Feeding**: Includes predictions from other models for ensemble approaches
|
||||
|
||||
---
|
||||
|
||||
## Data Structure
|
||||
|
||||
### Core Fields
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
symbol: str # Primary trading symbol (e.g., 'ETH/USDT')
|
||||
timestamp: datetime # Current timestamp
|
||||
```
|
||||
|
||||
### Multi-Timeframe OHLCV Data (Primary Symbol - ETH)
|
||||
|
||||
```python
|
||||
ohlcv_1s: List[OHLCVBar] # 300 frames of 1-second bars
|
||||
ohlcv_1m: List[OHLCVBar] # 300 frames of 1-minute bars
|
||||
ohlcv_1h: List[OHLCVBar] # 300 frames of 1-hour bars
|
||||
ohlcv_1d: List[OHLCVBar] # 300 frames of 1-day bars
|
||||
```
|
||||
|
||||
**OHLCVBar Structure:**
|
||||
```python
|
||||
@dataclass
|
||||
class OHLCVBar:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
timeframe: str
|
||||
indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Enhanced TA properties (computed on-demand)
|
||||
@property
|
||||
def body_size(self) -> float: ...
|
||||
@property
|
||||
def upper_wick(self) -> float: ...
|
||||
@property
|
||||
def lower_wick(self) -> float: ...
|
||||
@property
|
||||
def total_range(self) -> float: ...
|
||||
@property
|
||||
def is_bullish(self) -> bool: ...
|
||||
@property
|
||||
def is_bearish(self) -> bool: ...
|
||||
@property
|
||||
def is_doji(self) -> bool: ...
|
||||
|
||||
# Enhanced TA methods
|
||||
def get_body_to_range_ratio(self) -> float: ...
|
||||
def get_upper_wick_ratio(self) -> float: ...
|
||||
def get_lower_wick_ratio(self) -> float: ...
|
||||
def get_relative_size(self, reference_bars, method='avg') -> float: ...
|
||||
def get_candle_pattern(self) -> str: ...
|
||||
def get_ta_features(self, reference_bars=None) -> Dict[str, float]: ...
|
||||
```
|
||||
|
||||
**See**: `docs/CANDLE_TA_FEATURES_REFERENCE.md` for complete TA feature documentation
|
||||
|
||||
### Reference Symbol Data (BTC)
|
||||
|
||||
```python
|
||||
btc_ohlcv_1s: List[OHLCVBar] # 300 seconds of 1-second BTC bars
|
||||
```
|
||||
|
||||
Used for correlation analysis and market-wide context.
|
||||
|
||||
### Consolidated Order Book (COB) Data
|
||||
|
||||
```python
|
||||
cob_data: Optional[COBData] # Real-time order book snapshot
|
||||
```
|
||||
|
||||
**COBData Structure:**
|
||||
```python
|
||||
@dataclass
|
||||
class COBData:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
bucket_size: float # $1 for ETH, $10 for BTC
|
||||
price_buckets: Dict[float, Dict[str, float]] # ±20 buckets around current price
|
||||
bid_ask_imbalance: Dict[float, float] # Imbalance ratio per bucket
|
||||
volume_weighted_prices: Dict[float, float] # VWAP within each bucket
|
||||
order_flow_metrics: Dict[str, float] # Order flow indicators
|
||||
|
||||
# Moving averages of COB imbalance for ±5 buckets
|
||||
ma_1s_imbalance: Dict[float, float] # 1-second MA
|
||||
ma_5s_imbalance: Dict[float, float] # 5-second MA
|
||||
ma_15s_imbalance: Dict[float, float] # 15-second MA
|
||||
ma_60s_imbalance: Dict[float, float] # 60-second MA
|
||||
```
|
||||
|
||||
**Price Bucket Details:**
|
||||
Each bucket contains:
|
||||
- `bid_volume`: Total bid volume in USD
|
||||
- `ask_volume`: Total ask volume in USD
|
||||
- `total_volume`: Combined volume
|
||||
- `imbalance`: (bid_volume - ask_volume) / total_volume
|
||||
|
||||
### COB Heatmap (Time-Series)
|
||||
|
||||
```python
|
||||
cob_heatmap_times: List[datetime] # Timestamps for each snapshot
|
||||
cob_heatmap_prices: List[float] # Price levels tracked
|
||||
cob_heatmap_values: List[List[float]] # 2D array: time × price buckets
|
||||
```
|
||||
|
||||
Provides temporal evolution of order book liquidity and imbalance.
|
||||
|
||||
### Technical Indicators
|
||||
|
||||
```python
|
||||
technical_indicators: Dict[str, float] # Calculated indicators
|
||||
```
|
||||
|
||||
Common indicators include:
|
||||
- `sma_5`, `sma_20`, `sma_50`, `sma_200`: Simple moving averages
|
||||
- `ema_12`, `ema_26`: Exponential moving averages
|
||||
- `rsi`: Relative Strength Index
|
||||
- `macd`, `macd_signal`, `macd_hist`: MACD components
|
||||
- `bb_upper`, `bb_middle`, `bb_lower`: Bollinger Bands
|
||||
- `atr`: Average True Range
|
||||
- `volatility`: Historical volatility
|
||||
- `volume_ratio`: Current volume vs average
|
||||
- `price_change_5m`, `price_change_15m`, `price_change_1h`: Price changes
|
||||
|
||||
### Pivot Points
|
||||
|
||||
```python
|
||||
pivot_points: List[PivotPoint] # Williams Market Structure pivots
|
||||
```
|
||||
|
||||
**PivotPoint Structure:**
|
||||
```python
|
||||
@dataclass
|
||||
class PivotPoint:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
type: str # 'high' or 'low'
|
||||
level: int # Pivot level (1, 2, 3, etc.)
|
||||
confidence: float # Confidence score (0.0 to 1.0)
|
||||
```
|
||||
|
||||
### Cross-Model Predictions
|
||||
|
||||
```python
|
||||
last_predictions: Dict[str, ModelOutput] # Previous predictions from all models
|
||||
```
|
||||
|
||||
Enables ensemble approaches and cross-model feeding. Keys are model names (e.g., 'cnn_v1', 'rl_agent', 'transformer').
|
||||
|
||||
### Market Microstructure
|
||||
|
||||
```python
|
||||
market_microstructure: Dict[str, Any] # Additional market state data
|
||||
```
|
||||
|
||||
May include:
|
||||
- Spread metrics
|
||||
- Liquidity depth
|
||||
- Order arrival rates
|
||||
- Trade flow toxicity
|
||||
- Market impact estimates
|
||||
|
||||
### Position Information
|
||||
|
||||
```python
|
||||
position_info: Dict[str, Any] # Current trading position state
|
||||
```
|
||||
|
||||
Contains:
|
||||
- `has_position`: Boolean indicating if position is open
|
||||
- `position_pnl`: Current profit/loss
|
||||
- `position_size`: Size of position
|
||||
- `entry_price`: Entry price of position
|
||||
- `time_in_position_minutes`: Duration of position
|
||||
|
||||
---
|
||||
|
||||
## Feature Vector Conversion
|
||||
|
||||
The `get_feature_vector()` method converts the rich `BaseDataInput` structure into a **fixed-size numpy array** suitable for neural network input.
|
||||
|
||||
**Key Features:**
|
||||
- **Automatic Normalization**: All OHLCV data normalized to 0-1 range by default
|
||||
- **Independent Normalization**: Primary symbol and BTC normalized separately
|
||||
- **Daily Range**: Uses daily (longest timeframe) min/max for widest coverage
|
||||
- **Cached Bounds**: Normalization boundaries cached for performance and denormalization
|
||||
- **Fixed Size**: 7,850 features (standard) or 22,850 features (with candle TA)
|
||||
|
||||
### Feature Vector Breakdown
|
||||
|
||||
| Component | Features | Description |
|
||||
|-----------|----------|-------------|
|
||||
| **OHLCV ETH (4 timeframes)** | 6,000 | 300 frames × 4 timeframes × 5 values (OHLCV) |
|
||||
| **OHLCV BTC (1s)** | 1,500 | 300 frames × 5 values (OHLCV) |
|
||||
| **COB Features** | 200 | Price buckets + MAs + heatmap aggregates |
|
||||
| **Technical Indicators** | 100 | Calculated indicators |
|
||||
| **Last Predictions** | 45 | Cross-model predictions (9 models × 5 features) |
|
||||
| **Position Info** | 5 | Position state |
|
||||
| **TOTAL** | **7,850** | Fixed size |
|
||||
|
||||
### Normalization
|
||||
|
||||
#### NormalizationBounds Class
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class NormalizationBounds:
|
||||
"""Normalization boundaries for price and volume data"""
|
||||
price_min: float
|
||||
price_max: float
|
||||
volume_min: float
|
||||
volume_max: float
|
||||
symbol: str
|
||||
timeframe: str = 'all'
|
||||
|
||||
def normalize_price(self, price: float) -> float:
|
||||
"""Normalize price to 0-1 range"""
|
||||
return (price - self.price_min) / (self.price_max - self.price_min)
|
||||
|
||||
def denormalize_price(self, normalized: float) -> float:
|
||||
"""Denormalize price from 0-1 range back to original"""
|
||||
return normalized * (self.price_max - self.price_min) + self.price_min
|
||||
|
||||
def normalize_volume(self, volume: float) -> float:
|
||||
"""Normalize volume to 0-1 range"""
|
||||
return (volume - self.volume_min) / (self.volume_max - self.volume_min)
|
||||
|
||||
def denormalize_volume(self, normalized: float) -> float:
|
||||
"""Denormalize volume from 0-1 range back to original"""
|
||||
return normalized * (self.volume_max - self.volume_min) + self.volume_min
|
||||
```
|
||||
|
||||
#### How Normalization Works
|
||||
|
||||
1. **Primary Symbol (ETH)**: Uses daily (1d) timeframe data to compute min/max
|
||||
- Ensures all shorter timeframes (1s, 1m, 1h) fit within 0-1 range
|
||||
- Daily has widest price range, so all intraday prices normalize properly
|
||||
|
||||
2. **Reference Symbol (BTC)**: Uses its own 1s data to compute independent min/max
|
||||
- BTC and ETH have different price scales
|
||||
- Independent normalization ensures both are in 0-1 range
|
||||
|
||||
3. **Caching**: Bounds computed once and cached for performance
|
||||
- Access via `get_normalization_bounds()` and `get_btc_normalization_bounds()`
|
||||
- Useful for denormalizing model predictions back to actual prices
|
||||
|
||||
#### Usage Examples
|
||||
|
||||
```python
|
||||
# Get feature vector with normalization (default)
|
||||
features = base_data.get_feature_vector(normalize=True)
|
||||
# All OHLCV values are now in 0-1 range
|
||||
|
||||
# Get raw features without normalization
|
||||
features_raw = base_data.get_feature_vector(normalize=False)
|
||||
# OHLCV values are in original price/volume units
|
||||
|
||||
# Access normalization bounds for denormalization
|
||||
bounds = base_data.get_normalization_bounds()
|
||||
print(f"Price range: {bounds.price_min:.2f} - {bounds.price_max:.2f}")
|
||||
|
||||
# Denormalize a model prediction
|
||||
predicted_normalized = 0.75 # Model output
|
||||
predicted_price = bounds.denormalize_price(predicted_normalized)
|
||||
print(f"Predicted price: ${predicted_price:.2f}")
|
||||
|
||||
# BTC bounds (independent)
|
||||
btc_bounds = base_data.get_btc_normalization_bounds()
|
||||
print(f"BTC range: {btc_bounds.price_min:.2f} - {btc_bounds.price_max:.2f}")
|
||||
```
|
||||
|
||||
### Feature Vector Implementation
|
||||
|
||||
```python
|
||||
def get_feature_vector(self, include_candle_ta: bool = False, normalize: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Convert BaseDataInput to standardized feature vector for models
|
||||
|
||||
Args:
|
||||
include_candle_ta: If True, include enhanced candle TA features
|
||||
normalize: If True, normalize OHLCV to 0-1 range (default: True)
|
||||
|
||||
Returns:
|
||||
np.ndarray: FIXED SIZE standardized feature vector (7850 or 22850 features)
|
||||
"""
|
||||
FIXED_FEATURE_SIZE = 22850 if include_candle_ta else 7850
|
||||
features = []
|
||||
|
||||
# Get normalization bounds (cached)
|
||||
if normalize:
|
||||
norm_bounds = self._compute_normalization_bounds()
|
||||
btc_norm_bounds = self._compute_btc_normalization_bounds()
|
||||
|
||||
# 1. OHLCV features for ETH (6000 features, normalized to 0-1)
|
||||
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
||||
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
|
||||
for bar in ohlcv_frames:
|
||||
if normalize:
|
||||
features.extend([
|
||||
norm_bounds.normalize_price(bar.open),
|
||||
norm_bounds.normalize_price(bar.high),
|
||||
norm_bounds.normalize_price(bar.low),
|
||||
norm_bounds.normalize_price(bar.close),
|
||||
norm_bounds.normalize_volume(bar.volume)
|
||||
])
|
||||
else:
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
frames_needed = 300 - len(ohlcv_frames)
|
||||
if frames_needed > 0:
|
||||
features.extend([0.0] * (frames_needed * 5))
|
||||
|
||||
# 2. BTC OHLCV features (1500 features, normalized independently)
|
||||
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
|
||||
for bar in btc_frames:
|
||||
if normalize:
|
||||
features.extend([
|
||||
btc_norm_bounds.normalize_price(bar.open),
|
||||
btc_norm_bounds.normalize_price(bar.high),
|
||||
btc_norm_bounds.normalize_price(bar.low),
|
||||
btc_norm_bounds.normalize_price(bar.close),
|
||||
btc_norm_bounds.normalize_volume(bar.volume)
|
||||
])
|
||||
else:
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
btc_frames_needed = 300 - len(btc_frames)
|
||||
if btc_frames_needed > 0:
|
||||
features.extend([0.0] * (btc_frames_needed * 5))
|
||||
|
||||
# 3. COB features (200 features)
|
||||
cob_features = []
|
||||
if self.cob_data:
|
||||
# Price bucket features (up to 160 features: 40 buckets × 4 metrics)
|
||||
price_keys = sorted(self.cob_data.price_buckets.keys())[:40]
|
||||
for price in price_keys:
|
||||
bucket_data = self.cob_data.price_buckets[price]
|
||||
cob_features.extend([
|
||||
bucket_data.get('bid_volume', 0.0),
|
||||
bucket_data.get('ask_volume', 0.0),
|
||||
bucket_data.get('total_volume', 0.0),
|
||||
bucket_data.get('imbalance', 0.0)
|
||||
])
|
||||
|
||||
# Moving averages (up to 10 features)
|
||||
ma_features = []
|
||||
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance]:
|
||||
for price in sorted(list(ma_dict.keys())[:5]):
|
||||
ma_features.append(ma_dict[price])
|
||||
if len(ma_features) >= 10:
|
||||
break
|
||||
if len(ma_features) >= 10:
|
||||
break
|
||||
cob_features.extend(ma_features)
|
||||
|
||||
# Heatmap aggregates (remaining space)
|
||||
if self.cob_heatmap_values and self.cob_heatmap_prices:
|
||||
z = np.array(self.cob_heatmap_values, dtype=float)
|
||||
if z.ndim == 2 and z.size > 0:
|
||||
window_rows = z[-300:] if z.shape[0] >= 300 else z
|
||||
window_rows = np.nan_to_num(window_rows, nan=0.0)
|
||||
per_bucket_mean = window_rows.mean(axis=0).tolist()
|
||||
space_left = 200 - len(cob_features)
|
||||
if space_left > 0:
|
||||
cob_features.extend(per_bucket_mean[:space_left])
|
||||
|
||||
# Pad COB features to exactly 200
|
||||
cob_features.extend([0.0] * (200 - len(cob_features)))
|
||||
features.extend(cob_features[:200])
|
||||
|
||||
# 4. Technical indicators (100 features)
|
||||
indicator_values = list(self.technical_indicators.values())
|
||||
features.extend(indicator_values[:100])
|
||||
features.extend([0.0] * max(0, 100 - len(indicator_values)))
|
||||
|
||||
# 5. Last predictions (45 features)
|
||||
prediction_features = []
|
||||
for model_output in self.last_predictions.values():
|
||||
prediction_features.extend([
|
||||
model_output.confidence,
|
||||
model_output.predictions.get('buy_probability', 0.0),
|
||||
model_output.predictions.get('sell_probability', 0.0),
|
||||
model_output.predictions.get('hold_probability', 0.0),
|
||||
model_output.predictions.get('expected_reward', 0.0)
|
||||
])
|
||||
features.extend(prediction_features[:45])
|
||||
features.extend([0.0] * max(0, 45 - len(prediction_features)))
|
||||
|
||||
# 6. Position info (5 features)
|
||||
position_features = [
|
||||
1.0 if self.position_info.get('has_position', False) else 0.0,
|
||||
self.position_info.get('position_pnl', 0.0),
|
||||
self.position_info.get('position_size', 0.0),
|
||||
self.position_info.get('entry_price', 0.0),
|
||||
self.position_info.get('time_in_position_minutes', 0.0)
|
||||
]
|
||||
features.extend(position_features)
|
||||
|
||||
# Ensure exactly FIXED_FEATURE_SIZE
|
||||
if len(features) > FIXED_FEATURE_SIZE:
|
||||
features = features[:FIXED_FEATURE_SIZE]
|
||||
elif len(features) < FIXED_FEATURE_SIZE:
|
||||
features.extend([0.0] * (FIXED_FEATURE_SIZE - len(features)))
|
||||
|
||||
assert len(features) == FIXED_FEATURE_SIZE
|
||||
return np.array(features, dtype=np.float32)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Extensibility
|
||||
|
||||
### Adding New Features
|
||||
|
||||
The `BaseDataInput` structure is designed for extensibility. To add new features:
|
||||
|
||||
#### 1. Add New Field to BaseDataInput
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
# ... existing fields ...
|
||||
|
||||
# NEW: Add your new feature
|
||||
sentiment_data: Dict[str, float] = field(default_factory=dict)
|
||||
```
|
||||
|
||||
#### 2. Update get_feature_vector()
|
||||
|
||||
**Option A: Add to existing feature slots (if space available)**
|
||||
|
||||
```python
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
# ... existing code ...
|
||||
|
||||
# Add sentiment features to technical indicators section
|
||||
sentiment_features = [
|
||||
self.sentiment_data.get('twitter_sentiment', 0.0),
|
||||
self.sentiment_data.get('news_sentiment', 0.0),
|
||||
self.sentiment_data.get('fear_greed_index', 0.0)
|
||||
]
|
||||
indicator_values.extend(sentiment_features)
|
||||
# ... rest of code ...
|
||||
```
|
||||
|
||||
**Option B: Increase FIXED_FEATURE_SIZE (requires model retraining)**
|
||||
|
||||
```python
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
FIXED_FEATURE_SIZE = 7900 # Increased from 7850
|
||||
|
||||
# ... existing features (7850) ...
|
||||
|
||||
# NEW: Sentiment features (50 features)
|
||||
sentiment_features = []
|
||||
for key in sorted(self.sentiment_data.keys())[:50]:
|
||||
sentiment_features.append(self.sentiment_data[key])
|
||||
features.extend(sentiment_features[:50])
|
||||
features.extend([0.0] * max(0, 50 - len(sentiment_features)))
|
||||
|
||||
# ... ensure FIXED_FEATURE_SIZE ...
|
||||
```
|
||||
|
||||
#### 3. Update Data Provider
|
||||
|
||||
Ensure your data provider populates the new field:
|
||||
|
||||
```python
|
||||
def build_base_data_input(self, symbol: str) -> BaseDataInput:
|
||||
# ... existing code ...
|
||||
|
||||
# NEW: Add sentiment data
|
||||
sentiment_data = self._get_sentiment_data(symbol)
|
||||
|
||||
return BaseDataInput(
|
||||
# ... existing fields ...
|
||||
sentiment_data=sentiment_data
|
||||
)
|
||||
```
|
||||
|
||||
### Best Practices for Extension
|
||||
|
||||
1. **Maintain Fixed Size**: If adding features, either:
|
||||
- Use existing padding space
|
||||
- Increase `FIXED_FEATURE_SIZE` and retrain all models
|
||||
|
||||
2. **Zero Padding**: Always pad missing data with zeros, never synthetic data
|
||||
|
||||
3. **Validation**: Update `validate()` method if new fields are required
|
||||
|
||||
4. **Documentation**: Update this document with new feature descriptions
|
||||
|
||||
5. **Backward Compatibility**: Consider versioning if making breaking changes
|
||||
|
||||
---
|
||||
|
||||
## Current Usage Status
|
||||
|
||||
### Models Using BaseDataInput
|
||||
|
||||
✅ **StandardizedCNN** (`NN/models/standardized_cnn.py`)
|
||||
- Uses `get_feature_vector()` directly
|
||||
- Expected input: 7,834 features (close to 7,850)
|
||||
|
||||
✅ **Orchestrator** (`core/orchestrator.py`)
|
||||
- Builds BaseDataInput via `data_provider.build_base_data_input()`
|
||||
- Passes to all models
|
||||
|
||||
✅ **UnifiedTrainingManager** (`core/unified_training_manager_v2.py`)
|
||||
- Converts BaseDataInput to DQN state via `get_feature_vector()`
|
||||
|
||||
✅ **Dashboard** (`web/clean_dashboard.py`)
|
||||
- Creates BaseDataInput for CNN predictions
|
||||
- Uses `get_feature_vector()` for feature extraction
|
||||
|
||||
### Alternative Implementations Found
|
||||
|
||||
⚠️ **ModelInputData** (`core/unified_model_data_interface.py`)
|
||||
- **Status**: Legacy/alternative interface
|
||||
- **Usage**: Limited, primarily for model-specific preprocessing
|
||||
- **Recommendation**: Migrate to BaseDataInput for consistency
|
||||
|
||||
⚠️ **MockBaseDataInput** (`COBY/integration/orchestrator_adapter.py`)
|
||||
- **Status**: Temporary adapter for COBY integration
|
||||
- **Usage**: Provides BaseDataInput interface for COBY data
|
||||
- **Recommendation**: Replace with proper BaseDataInput construction
|
||||
|
||||
### Models NOT Using BaseDataInput
|
||||
|
||||
❌ **RealtimeRLCOBTrader** (`core/realtime_rl_cob_trader.py`)
|
||||
- Uses custom `_extract_features()` method
|
||||
- **Recommendation**: Migrate to BaseDataInput
|
||||
|
||||
❌ **Some legacy models** may use direct feature extraction
|
||||
- **Recommendation**: Audit and migrate to BaseDataInput
|
||||
|
||||
---
|
||||
|
||||
## Validation
|
||||
|
||||
The `validate()` method ensures data quality:
|
||||
|
||||
```python
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate that the BaseDataInput contains required data
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
# Check minimum OHLCV data
|
||||
if len(self.ohlcv_1s) < 100:
|
||||
return False
|
||||
if len(self.btc_ohlcv_1s) < 100:
|
||||
return False
|
||||
|
||||
# Check timestamp
|
||||
if not self.timestamp:
|
||||
return False
|
||||
|
||||
# Check symbol format
|
||||
if not self.symbol or '/' not in self.symbol:
|
||||
return False
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Related Classes
|
||||
|
||||
### ModelOutput
|
||||
|
||||
Output structure for model predictions:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
model_type: str # 'cnn', 'rl', 'lstm', 'transformer'
|
||||
model_name: str # Specific model identifier
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
confidence: float
|
||||
predictions: Dict[str, Any] # Model-specific predictions
|
||||
hidden_states: Optional[Dict[str, Any]] # For cross-model feeding
|
||||
metadata: Dict[str, Any] # Additional info
|
||||
```
|
||||
|
||||
### COBSnapshot
|
||||
|
||||
Raw consolidated order book data (transformed into COBData):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class COBSnapshot:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
consolidated_bids: List[ConsolidatedOrderBookLevel]
|
||||
consolidated_asks: List[ConsolidatedOrderBookLevel]
|
||||
exchanges_active: List[str]
|
||||
volume_weighted_mid: float
|
||||
total_bid_liquidity: float
|
||||
total_ask_liquidity: float
|
||||
spread_bps: float
|
||||
liquidity_imbalance: float
|
||||
price_buckets: Dict[str, Dict[str, float]]
|
||||
```
|
||||
|
||||
### PredictionSnapshot
|
||||
|
||||
Stores predictions with inputs for future training:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PredictionSnapshot:
|
||||
prediction_id: str
|
||||
symbol: str
|
||||
prediction_time: datetime
|
||||
target_horizon_minutes: int
|
||||
target_time: datetime
|
||||
current_price: float
|
||||
predicted_min_price: float
|
||||
predicted_max_price: float
|
||||
confidence: float
|
||||
model_inputs: Dict[str, Any] # Includes BaseDataInput features
|
||||
market_state: Dict[str, Any]
|
||||
technical_indicators: Dict[str, Any]
|
||||
pivot_analysis: Dict[str, Any]
|
||||
actual_min_price: Optional[float]
|
||||
actual_max_price: Optional[float]
|
||||
outcome_known: bool
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### For Models Not Using BaseDataInput
|
||||
|
||||
1. **Identify current input method**
|
||||
```python
|
||||
# OLD
|
||||
features = self._extract_features(symbol, data)
|
||||
```
|
||||
|
||||
2. **Update to use BaseDataInput**
|
||||
```python
|
||||
# NEW
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
if base_data and base_data.validate():
|
||||
features = base_data.get_feature_vector()
|
||||
```
|
||||
|
||||
3. **Update model interface**
|
||||
```python
|
||||
# OLD
|
||||
def predict(self, features: np.ndarray) -> Dict:
|
||||
|
||||
# NEW
|
||||
def predict(self, base_input: BaseDataInput) -> ModelOutput:
|
||||
features = base_input.get_feature_vector()
|
||||
# ... prediction logic ...
|
||||
```
|
||||
|
||||
4. **Test thoroughly**
|
||||
- Verify feature vector size matches expectations
|
||||
- Check for NaN or infinite values
|
||||
- Validate predictions are reasonable
|
||||
|
||||
---
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
|
||||
- **BaseDataInput object**: ~2-5 MB per instance
|
||||
- **Feature vector**: 7,850 × 4 bytes = 31.4 KB
|
||||
- **Recommendation**: Cache BaseDataInput for 1-2 seconds, regenerate feature vectors as needed
|
||||
|
||||
### Computation Time
|
||||
|
||||
- **Building BaseDataInput**: ~5-10 ms
|
||||
- **get_feature_vector()**: ~1-2 ms
|
||||
- **Total overhead**: Negligible for real-time trading
|
||||
|
||||
### Optimization Tips
|
||||
|
||||
1. **Reuse OHLCV data**: Cache OHLCV bars across multiple BaseDataInput instances
|
||||
2. **Lazy evaluation**: Only compute features when `get_feature_vector()` is called
|
||||
3. **Batch processing**: Process multiple symbols in parallel
|
||||
4. **Avoid deep copies**: Use references where possible
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
|
||||
```python
|
||||
def test_base_data_input_feature_vector():
|
||||
"""Test that feature vector has correct size"""
|
||||
base_data = create_test_base_data_input()
|
||||
features = base_data.get_feature_vector()
|
||||
|
||||
assert len(features) == 7850
|
||||
assert features.dtype == np.float32
|
||||
assert not np.isnan(features).any()
|
||||
assert not np.isinf(features).any()
|
||||
|
||||
def test_base_data_input_validation():
|
||||
"""Test validation logic"""
|
||||
base_data = create_test_base_data_input()
|
||||
assert base_data.validate() == True
|
||||
|
||||
# Test with insufficient data
|
||||
base_data.ohlcv_1s = []
|
||||
assert base_data.validate() == False
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
|
||||
```python
|
||||
def test_model_with_base_data_input():
|
||||
"""Test model prediction with BaseDataInput"""
|
||||
orchestrator = create_test_orchestrator()
|
||||
base_data = orchestrator.data_provider.build_base_data_input('ETH/USDT')
|
||||
|
||||
assert base_data is not None
|
||||
assert base_data.validate()
|
||||
|
||||
# Test CNN prediction
|
||||
cnn_output = orchestrator.cnn_model.predict_from_base_input(base_data)
|
||||
assert isinstance(cnn_output, ModelOutput)
|
||||
assert 0.0 <= cnn_output.confidence <= 1.0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Features
|
||||
|
||||
1. **Multi-Symbol Support**: Extend to support multiple correlated symbols
|
||||
2. **Alternative Data**: Add social sentiment, on-chain metrics, macro indicators
|
||||
3. **Feature Importance**: Track which features contribute most to predictions
|
||||
4. **Compression**: Implement feature compression for faster transmission
|
||||
5. **Versioning**: Add version field for backward compatibility
|
||||
|
||||
### Research Directions
|
||||
|
||||
1. **Adaptive Feature Selection**: Dynamically select relevant features per market regime
|
||||
2. **Hierarchical Features**: Group related features for better model interpretability
|
||||
3. **Temporal Attention**: Weight recent data more heavily than historical
|
||||
4. **Cross-Asset Features**: Include correlations with other asset classes
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
`BaseDataInput` is the cornerstone of the multi-modal trading system, providing:
|
||||
|
||||
- ✅ **Consistency**: All models use the same input format
|
||||
- ✅ **Extensibility**: Easy to add new features without breaking existing code
|
||||
- ✅ **Performance**: Fixed-size feature vectors enable efficient computation
|
||||
- ✅ **Quality**: Validation ensures data integrity
|
||||
- ✅ **Flexibility**: Supports multiple timeframes, order book data, and cross-model feeding
|
||||
|
||||
**All new models MUST use BaseDataInput** to ensure system-wide consistency and maintainability.
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- **Implementation**: `core/data_models.py`
|
||||
- **Data Provider**: `core/standardized_data_provider.py`
|
||||
- **Model Example**: `NN/models/standardized_cnn.py`
|
||||
- **Training**: `core/unified_training_manager_v2.py`
|
||||
- **FIFO Queue System**: `docs/fifo_queue_system.md`
|
||||
Reference in New Issue
Block a user