Files
gogo2/core/data_models.py
2025-10-31 00:44:08 +02:00

761 lines
32 KiB
Python

"""
Standardized Data Models for Multi-Modal Trading System
This module defines the standardized data structures used across all models:
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
- ModelOutput: Extensible output format supporting all model types
- COBData: Cumulative Order Book data structure
- Enhanced data structures for cross-model feeding and extensibility
"""
import numpy as np
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
@dataclass
class OHLCVBar:
"""
Enhanced OHLCV bar data structure with technical analysis features
Includes candle pattern recognition, relative sizing, body/wick analysis,
and Williams pivot points metadata for improved model feature engineering.
"""
symbol: str
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
timeframe: str
indicators: Dict[str, float] = field(default_factory=dict)
# Pivot points metadata
pivot_distance_to_support: Optional[float] = None
pivot_distance_to_resistance: Optional[float] = None
pivot_level_context: Optional[Dict[str, Any]] = field(default=None)
near_pivot_support: bool = False
near_pivot_resistance: bool = False
# Candle characteristics (computed on-demand or cached)
_body_size: Optional[float] = field(default=None, repr=False)
_upper_wick: Optional[float] = field(default=None, repr=False)
_lower_wick: Optional[float] = field(default=None, repr=False)
_total_range: Optional[float] = field(default=None, repr=False)
_is_bullish: Optional[bool] = field(default=None, repr=False)
@property
def body_size(self) -> float:
"""Absolute size of candle body"""
if self._body_size is None:
self._body_size = abs(self.close - self.open)
return self._body_size
@property
def upper_wick(self) -> float:
"""Size of upper wick/shadow"""
if self._upper_wick is None:
self._upper_wick = self.high - max(self.open, self.close)
return self._upper_wick
@property
def lower_wick(self) -> float:
"""Size of lower wick/shadow"""
if self._lower_wick is None:
self._lower_wick = min(self.open, self.close) - self.low
return self._lower_wick
@property
def total_range(self) -> float:
"""Total high-low range"""
if self._total_range is None:
self._total_range = self.high - self.low
return self._total_range
@property
def is_bullish(self) -> bool:
"""True if close > open (hollow/green candle)"""
if self._is_bullish is None:
self._is_bullish = self.close > self.open
return self._is_bullish
@property
def is_bearish(self) -> bool:
"""True if close < open (solid/red candle)"""
return not self.is_bullish and self.close != self.open
@property
def is_doji(self) -> bool:
"""True if open ≈ close (doji pattern)"""
return self.body_size < (self.total_range * 0.1) if self.total_range > 0 else True
def get_body_to_range_ratio(self) -> float:
"""Body size as percentage of total range (0.0 to 1.0)"""
return self.body_size / self.total_range if self.total_range > 0 else 0.0
def get_upper_wick_ratio(self) -> float:
"""Upper wick as percentage of total range (0.0 to 1.0)"""
return self.upper_wick / self.total_range if self.total_range > 0 else 0.0
def get_lower_wick_ratio(self) -> float:
"""Lower wick as percentage of total range (0.0 to 1.0)"""
return self.lower_wick / self.total_range if self.total_range > 0 else 0.0
def get_relative_size(self, reference_bars: List['OHLCVBar'], method: str = 'avg') -> float:
"""
Get relative size compared to reference bars
Args:
reference_bars: List of previous bars for comparison
method: 'avg' (average), 'max' (maximum), or 'median'
Returns:
Ratio of current range to reference (1.0 = same size, >1.0 = larger, <1.0 = smaller)
"""
if not reference_bars:
return 1.0
reference_ranges = [bar.total_range for bar in reference_bars if bar.total_range > 0]
if not reference_ranges:
return 1.0
if method == 'avg':
reference_value = np.mean(reference_ranges)
elif method == 'max':
reference_value = np.max(reference_ranges)
elif method == 'median':
reference_value = np.median(reference_ranges)
else:
reference_value = np.mean(reference_ranges)
return self.total_range / reference_value if reference_value > 0 else 1.0
def get_candle_pattern(self) -> str:
"""
Identify basic candle pattern
Returns:
Pattern name: 'doji', 'hammer', 'shooting_star', 'spinning_top',
'marubozu_bullish', 'marubozu_bearish', 'standard'
"""
if self.total_range == 0:
return 'doji'
body_ratio = self.get_body_to_range_ratio()
upper_ratio = self.get_upper_wick_ratio()
lower_ratio = self.get_lower_wick_ratio()
# Doji: very small body
if body_ratio < 0.1:
return 'doji'
# Marubozu: very small wicks (>90% body)
if body_ratio > 0.9:
return 'marubozu_bullish' if self.is_bullish else 'marubozu_bearish'
# Hammer: small body at top, long lower wick
if body_ratio < 0.3 and lower_ratio > 0.6 and upper_ratio < 0.1:
return 'hammer'
# Shooting star: small body at bottom, long upper wick
if body_ratio < 0.3 and upper_ratio > 0.6 and lower_ratio < 0.1:
return 'shooting_star'
# Spinning top: small body, both wicks present
if body_ratio < 0.3 and (upper_ratio + lower_ratio) > 0.6:
return 'spinning_top'
return 'standard'
def get_ta_features(self, reference_bars: Optional[List['OHLCVBar']] = None) -> Dict[str, float]:
"""
Get all technical analysis features as a dictionary
Args:
reference_bars: Optional list of previous bars for relative sizing
Returns:
Dictionary of TA features suitable for model input
"""
features = {
# Basic candle properties
'is_bullish': 1.0 if self.is_bullish else 0.0,
'is_bearish': 1.0 if self.is_bearish else 0.0,
'is_doji': 1.0 if self.is_doji else 0.0,
# Size ratios
'body_to_range_ratio': self.get_body_to_range_ratio(),
'upper_wick_ratio': self.get_upper_wick_ratio(),
'lower_wick_ratio': self.get_lower_wick_ratio(),
# Absolute sizes (normalized by close price)
'body_size_pct': self.body_size / self.close if self.close > 0 else 0.0,
'upper_wick_pct': self.upper_wick / self.close if self.close > 0 else 0.0,
'lower_wick_pct': self.lower_wick / self.close if self.close > 0 else 0.0,
'total_range_pct': self.total_range / self.close if self.close > 0 else 0.0,
# Volume relative to price movement
'volume_per_range': self.volume / self.total_range if self.total_range > 0 else 0.0,
}
# Add relative sizing if reference bars provided
if reference_bars:
features['relative_size_avg'] = self.get_relative_size(reference_bars, 'avg')
features['relative_size_max'] = self.get_relative_size(reference_bars, 'max')
features['relative_size_median'] = self.get_relative_size(reference_bars, 'median')
# Add pattern encoding (one-hot style)
pattern = self.get_candle_pattern()
pattern_types = ['doji', 'hammer', 'shooting_star', 'spinning_top',
'marubozu_bullish', 'marubozu_bearish', 'standard']
for p in pattern_types:
features[f'pattern_{p}'] = 1.0 if pattern == p else 0.0
return features
@dataclass
class PivotPoint:
"""Pivot point data structure"""
symbol: str
timestamp: datetime
price: float
type: str # 'high' or 'low'
level: int # Pivot level (1, 2, 3, etc.)
confidence: float = 1.0
@dataclass
class ModelOutput:
"""Extensible model output format supporting all model types"""
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
model_name: str # Specific model identifier
symbol: str
timestamp: datetime
confidence: float
predictions: Dict[str, Any] # Model-specific predictions
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
@dataclass
class COBData:
"""Cumulative Order Book data for price buckets"""
symbol: str
timestamp: datetime
current_price: float
bucket_size: float # $1 for ETH, $10 for BTC
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
order_flow_metrics: Dict[str, float] # Various order flow indicators
# Moving averages of COB imbalance for ±5 buckets
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
@dataclass
class NormalizationBounds:
"""Normalization boundaries for price and volume data"""
price_min: float
price_max: float
volume_min: float
volume_max: float
symbol: str
timeframe: str = 'all' # 'all' means across all timeframes
def normalize_price(self, price: float) -> float:
"""Normalize price to 0-1 range"""
if self.price_max == self.price_min:
return 0.5
return (price - self.price_min) / (self.price_max - self.price_min)
def denormalize_price(self, normalized: float) -> float:
"""Denormalize price from 0-1 range back to original"""
return normalized * (self.price_max - self.price_min) + self.price_min
def normalize_volume(self, volume: float) -> float:
"""Normalize volume to 0-1 range"""
if self.volume_max == self.volume_min:
return 0.5
return (volume - self.volume_min) / (self.volume_max - self.volume_min)
def denormalize_volume(self, normalized: float) -> float:
"""Denormalize volume from 0-1 range back to original"""
return normalized * (self.volume_max - self.volume_min) + self.volume_min
def get_price_range(self) -> float:
"""Get price range"""
return self.price_max - self.price_min
def get_volume_range(self) -> float:
"""Get volume range"""
return self.volume_max - self.volume_min
@dataclass
class BaseDataInput:
"""
Unified base data input for all models
Standardized format ensures all models receive identical input structure:
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
- All OHLCV data is normalized to 0-1 range based on daily (longest timeframe) min/max
"""
symbol: str # Primary symbol (ETH/USDT)
timestamp: datetime
# Multi-timeframe OHLCV data for primary symbol (ETH)
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
# Reference symbol (BTC) 1s data
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
# COB data for 1s timeframe (±20 buckets around current price)
cob_data: Optional[COBData] = None
# COB heatmap (time-series of bucket metrics at 1s resolution)
# Each row corresponds to one second, columns to price buckets
cob_heatmap_times: List[datetime] = field(default_factory=list)
cob_heatmap_prices: List[float] = field(default_factory=list)
cob_heatmap_values: List[List[float]] = field(default_factory=list) # typically imbalance per bucket
# Technical indicators
technical_indicators: Dict[str, float] = field(default_factory=dict)
# Pivot points from Williams Market Structure
pivot_points: List[PivotPoint] = field(default_factory=list)
# Last predictions from all models (for cross-model feeding)
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
# Market microstructure data
market_microstructure: Dict[str, Any] = field(default_factory=dict)
# Position and trading state information
position_info: Dict[str, Any] = field(default_factory=dict)
# Normalization boundaries (computed on-demand, cached)
_normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False)
_btc_normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False)
def _compute_normalization_bounds(self) -> NormalizationBounds:
"""
Compute normalization bounds from daily (longest timeframe) data
Uses daily data as it has the widest price range, ensuring all shorter
timeframes are normalized within 0-1 range.
Returns:
NormalizationBounds: Min/max for price and volume
"""
if self._normalization_bounds is not None:
return self._normalization_bounds
# Collect all OHLCV data, prioritizing daily for widest range
all_prices = []
all_volumes = []
# Use daily data first (widest range)
for bar in self.ohlcv_1d:
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
all_volumes.append(bar.volume)
# Add other timeframes to ensure coverage
for ohlcv_list in [self.ohlcv_1h, self.ohlcv_1m, self.ohlcv_1s]:
for bar in ohlcv_list:
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
all_volumes.append(bar.volume)
# Compute bounds
if all_prices and all_volumes:
price_min = min(all_prices)
price_max = max(all_prices)
volume_min = min(all_volumes)
volume_max = max(all_volumes)
else:
# Fallback if no data
price_min = price_max = 0.0
volume_min = volume_max = 0.0
self._normalization_bounds = NormalizationBounds(
price_min=price_min,
price_max=price_max,
volume_min=volume_min,
volume_max=volume_max,
symbol=self.symbol,
timeframe='all'
)
return self._normalization_bounds
def _compute_btc_normalization_bounds(self) -> NormalizationBounds:
"""
Compute normalization bounds for BTC data
Returns:
NormalizationBounds: Min/max for BTC price and volume
"""
if self._btc_normalization_bounds is not None:
return self._btc_normalization_bounds
all_prices = []
all_volumes = []
for bar in self.btc_ohlcv_1s:
all_prices.extend([bar.open, bar.high, bar.low, bar.close])
all_volumes.append(bar.volume)
if all_prices and all_volumes:
price_min = min(all_prices)
price_max = max(all_prices)
volume_min = min(all_volumes)
volume_max = max(all_volumes)
else:
price_min = price_max = 0.0
volume_min = volume_max = 0.0
self._btc_normalization_bounds = NormalizationBounds(
price_min=price_min,
price_max=price_max,
volume_min=volume_min,
volume_max=volume_max,
symbol='BTC/USDT',
timeframe='1s'
)
return self._btc_normalization_bounds
def get_normalization_bounds(self) -> NormalizationBounds:
"""Get normalization bounds for primary symbol (cached)"""
return self._compute_normalization_bounds()
def get_btc_normalization_bounds(self) -> NormalizationBounds:
"""Get normalization bounds for BTC (cached)"""
return self._compute_btc_normalization_bounds()
def get_feature_vector(self, include_candle_ta: bool = True, normalize: bool = True) -> np.ndarray:
"""
Convert BaseDataInput to standardized feature vector for models
Args:
include_candle_ta: If True, include enhanced candle TA features (default: True)
normalize: If True, normalize OHLCV data to 0-1 range (default: True)
Returns:
np.ndarray: FIXED SIZE standardized feature vector (7870 or 22880 features)
Note:
- Full TA features are enabled by default for better model performance
- Normalization uses daily (longest timeframe) min/max for primary symbol
- BTC data is normalized independently using its own min/max
- Normalization bounds are cached and accessible via get_normalization_bounds()
- Includes pivot points metadata (10 features) for market structure context
"""
# FIXED FEATURE SIZE - this should NEVER change at runtime
# Standard: 7870 features (7850 + 10 pivot + 10 more indicators)
# With candle TA: 22880 features (22850 + 10 pivot + 10 more indicators)
FIXED_FEATURE_SIZE = 22880 if include_candle_ta else 7870
features = []
# Get normalization bounds (cached)
if normalize:
norm_bounds = self._compute_normalization_bounds()
# OHLCV features for ETH (up to 300 frames x 4 timeframes x 5 or 15 features)
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
# Use actual data only, up to 300 frames
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
# Extract features from actual frames
for i, bar in enumerate(ohlcv_frames):
# Basic OHLCV (5 features) - normalized to 0-1 range
if normalize:
features.extend([
norm_bounds.normalize_price(bar.open),
norm_bounds.normalize_price(bar.high),
norm_bounds.normalize_price(bar.low),
norm_bounds.normalize_price(bar.close),
norm_bounds.normalize_volume(bar.volume)
])
else:
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# Enhanced candle TA features (10 additional features per bar)
if include_candle_ta:
# Get reference bars for relative sizing (last 10 bars)
ref_start = max(0, i - 10)
reference_bars = ohlcv_frames[ref_start:i] if i > 0 else []
ta_features = bar.get_ta_features(reference_bars)
# Extract key features in fixed order
features.extend([
ta_features.get('is_bullish', 0.0),
ta_features.get('body_to_range_ratio', 0.0),
ta_features.get('upper_wick_ratio', 0.0),
ta_features.get('lower_wick_ratio', 0.0),
ta_features.get('body_size_pct', 0.0),
ta_features.get('total_range_pct', 0.0),
ta_features.get('relative_size_avg', 1.0),
ta_features.get('pattern_doji', 0.0),
ta_features.get('pattern_hammer', 0.0),
ta_features.get('pattern_shooting_star', 0.0),
])
# Pad with zeros only if we have some data but less than 300 frames
frames_needed = 300 - len(ohlcv_frames)
if frames_needed > 0:
features_per_frame = 15 if include_candle_ta else 5
features.extend([0.0] * (frames_needed * features_per_frame))
# BTC OHLCV features (up to 300 frames x 5 or 15 features)
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
# Get BTC normalization bounds (cached, independent from primary symbol)
if normalize:
btc_norm_bounds = self._compute_btc_normalization_bounds()
# Extract features from actual BTC frames
for i, bar in enumerate(btc_frames):
# Basic OHLCV (5 features) - normalized to 0-1 range
if normalize:
features.extend([
btc_norm_bounds.normalize_price(bar.open),
btc_norm_bounds.normalize_price(bar.high),
btc_norm_bounds.normalize_price(bar.low),
btc_norm_bounds.normalize_price(bar.close),
btc_norm_bounds.normalize_volume(bar.volume)
])
else:
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# Enhanced candle TA features (10 additional features per bar)
if include_candle_ta:
ref_start = max(0, i - 10)
reference_bars = btc_frames[ref_start:i] if i > 0 else []
ta_features = bar.get_ta_features(reference_bars)
features.extend([
ta_features.get('is_bullish', 0.0),
ta_features.get('body_to_range_ratio', 0.0),
ta_features.get('upper_wick_ratio', 0.0),
ta_features.get('lower_wick_ratio', 0.0),
ta_features.get('body_size_pct', 0.0),
ta_features.get('total_range_pct', 0.0),
ta_features.get('relative_size_avg', 1.0),
ta_features.get('pattern_doji', 0.0),
ta_features.get('pattern_hammer', 0.0),
ta_features.get('pattern_shooting_star', 0.0),
])
# Pad with zeros only if we have some data but less than 300 frames
btc_frames_needed = 300 - len(btc_frames)
if btc_frames_needed > 0:
features_per_frame = 15 if include_candle_ta else 5
features.extend([0.0] * (btc_frames_needed * features_per_frame))
# COB features (FIXED SIZE: 200 features)
cob_features = []
if self.cob_data:
# Price bucket features (up to 40 buckets x 4 metrics = 160 features)
price_keys = sorted(self.cob_data.price_buckets.keys())[:40] # Max 40 buckets
for price in price_keys:
bucket_data = self.cob_data.price_buckets[price]
cob_features.extend([
bucket_data.get('bid_volume', 0.0),
bucket_data.get('ask_volume', 0.0),
bucket_data.get('total_volume', 0.0),
bucket_data.get('imbalance', 0.0)
])
# Moving averages (up to 10 features)
ma_features = []
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance]:
for price in sorted(list(ma_dict.keys())[:5]): # Max 5 buckets per MA
ma_features.append(ma_dict[price])
if len(ma_features) >= 10:
break
if len(ma_features) >= 10:
break
cob_features.extend(ma_features)
# Add REAL aggregated COB heatmap features to fill remaining COB slots (no synthetic data)
# We compute per-bucket means over the most recent window (up to 300s) and a few global stats
try:
if self.cob_heatmap_values and self.cob_heatmap_prices:
z = np.array(self.cob_heatmap_values, dtype=float)
if z.ndim == 2 and z.size > 0:
# Use up to the last 300 seconds (or whatever is available)
window_rows = z[-300:] if z.shape[0] >= 300 else z
# Replace NaNs with 0.0 to respect the no-synthetic rule but avoid NaN propagation
window_rows = np.nan_to_num(window_rows, nan=0.0, posinf=0.0, neginf=0.0)
# Per-bucket mean imbalance/liquidity across time
per_bucket_mean = window_rows.mean(axis=0).tolist()
space_left = 200 - len(cob_features)
if space_left > 0 and len(per_bucket_mean) > 0:
cob_features.extend(per_bucket_mean[:space_left])
# If there is still space, add compact global stats over the window
space_left = 200 - len(cob_features)
if space_left > 0:
flat = window_rows.reshape(-1)
if flat.size > 0:
global_mean = float(np.mean(flat))
global_std = float(np.std(flat))
global_max = float(np.max(flat))
global_min = float(np.min(flat))
global_stats = [global_mean, global_std, global_max, global_min]
cob_features.extend(global_stats[:space_left])
except Exception:
# On any error, skip heatmap-derived features (remaining space will be zero-padded below)
pass
# Pad COB features to exactly 200
cob_features.extend([0.0] * (200 - len(cob_features)))
features.extend(cob_features[:200]) # Ensure exactly 200 COB features
# Technical indicators (FIXED SIZE: 110 features - expanded to accommodate more indicators)
indicator_values = list(self.technical_indicators.values())
features.extend(indicator_values[:110]) # Take first 110 indicators
features.extend([0.0] * max(0, 110 - len(indicator_values))) # Pad to exactly 110
# Pivot points metadata (FIXED SIZE: 10 features)
# Extract pivot context from most recent OHLCV bars
pivot_features = []
if self.ohlcv_1m and len(self.ohlcv_1m) > 0:
latest_bar = self.ohlcv_1m[-1]
pivot_features.extend([
latest_bar.pivot_distance_to_support if latest_bar.pivot_distance_to_support is not None else 0.0,
latest_bar.pivot_distance_to_resistance if latest_bar.pivot_distance_to_resistance is not None else 0.0,
1.0 if latest_bar.near_pivot_support else 0.0,
1.0 if latest_bar.near_pivot_resistance else 0.0,
])
# Add pivot level context if available
if latest_bar.pivot_level_context:
ctx = latest_bar.pivot_level_context
pivot_features.extend([
ctx.get('trend_strength', 0.0),
ctx.get('support_count', 0.0),
ctx.get('resistance_count', 0.0),
ctx.get('price_position_in_range', 0.5), # 0=at support, 1=at resistance
ctx.get('distance_to_nearest_level', 0.0),
ctx.get('level_strength', 0.0),
])
else:
pivot_features.extend([0.0] * 6)
else:
pivot_features = [0.0] * 10
# Ensure exactly 10 pivot features
pivot_features = pivot_features[:10]
pivot_features.extend([0.0] * (10 - len(pivot_features)))
features.extend(pivot_features)
# Last predictions from other models (FIXED SIZE: 45 features)
prediction_features = []
for model_output in self.last_predictions.values():
prediction_features.extend([
model_output.confidence,
model_output.predictions.get('buy_probability', 0.0),
model_output.predictions.get('sell_probability', 0.0),
model_output.predictions.get('hold_probability', 0.0),
model_output.predictions.get('expected_reward', 0.0)
])
features.extend(prediction_features[:45]) # Take first 45 prediction features
features.extend([0.0] * max(0, 45 - len(prediction_features))) # Pad to exactly 45
# Position and trading state information (FIXED SIZE: 5 features)
position_features = [
1.0 if self.position_info.get('has_position', False) else 0.0,
self.position_info.get('position_pnl', 0.0),
self.position_info.get('position_size', 0.0),
self.position_info.get('entry_price', 0.0),
self.position_info.get('time_in_position_minutes', 0.0)
]
features.extend(position_features) # Exactly 5 position features
# CRITICAL: Ensure EXACTLY the fixed feature size
if len(features) > FIXED_FEATURE_SIZE:
features = features[:FIXED_FEATURE_SIZE] # Truncate if too long
elif len(features) < FIXED_FEATURE_SIZE:
features.extend([0.0] * (FIXED_FEATURE_SIZE - len(features))) # Pad if too short
assert len(features) == FIXED_FEATURE_SIZE, f"Feature vector size mismatch: {len(features)} != {FIXED_FEATURE_SIZE}"
return np.array(features, dtype=np.float32)
def validate(self) -> bool:
"""
Validate that the BaseDataInput contains required data
Returns:
bool: True if valid, False otherwise
"""
# Check that we have required OHLCV data
if len(self.ohlcv_1s) < 100: # At least 100 frames
return False
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
return False
# Check that timestamps are reasonable
if not self.timestamp:
return False
# Check symbol format
if not self.symbol or '/' not in self.symbol:
return False
return True
@dataclass
class TradingAction:
"""Trading action output from models"""
symbol: str
timestamp: datetime
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float
source: str # 'rl', 'cnn', 'orchestrator'
price: Optional[float] = None
quantity: Optional[float] = None
reason: Optional[str] = None
def create_model_output(model_type: str, model_name: str, symbol: str,
action: str, confidence: float,
hidden_states: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
"""
Helper function to create standardized ModelOutput
Args:
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
model_name: Specific model identifier
symbol: Trading symbol
action: Trading action ('BUY', 'SELL', 'HOLD')
confidence: Confidence score (0.0 to 1.0)
hidden_states: Optional hidden states for cross-model feeding
metadata: Optional additional metadata
Returns:
ModelOutput: Standardized model output
"""
predictions = {
'action': action,
'buy_probability': confidence if action == 'BUY' else 0.0,
'sell_probability': confidence if action == 'SELL' else 0.0,
'hold_probability': confidence if action == 'HOLD' else 0.0,
}
return ModelOutput(
model_type=model_type,
model_name=model_name,
symbol=symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states or {},
metadata=metadata or {}
)