""" Standardized Data Models for Multi-Modal Trading System This module defines the standardized data structures used across all models: - BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer) - ModelOutput: Extensible output format supporting all model types - COBData: Cumulative Order Book data structure - Enhanced data structures for cross-model feeding and extensibility """ import numpy as np from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import dataclass, field @dataclass class OHLCVBar: """ Enhanced OHLCV bar data structure with technical analysis features Includes candle pattern recognition, relative sizing, body/wick analysis, and Williams pivot points metadata for improved model feature engineering. """ symbol: str timestamp: datetime open: float high: float low: float close: float volume: float timeframe: str indicators: Dict[str, float] = field(default_factory=dict) # Pivot points metadata pivot_distance_to_support: Optional[float] = None pivot_distance_to_resistance: Optional[float] = None pivot_level_context: Optional[Dict[str, Any]] = field(default=None) near_pivot_support: bool = False near_pivot_resistance: bool = False # Candle characteristics (computed on-demand or cached) _body_size: Optional[float] = field(default=None, repr=False) _upper_wick: Optional[float] = field(default=None, repr=False) _lower_wick: Optional[float] = field(default=None, repr=False) _total_range: Optional[float] = field(default=None, repr=False) _is_bullish: Optional[bool] = field(default=None, repr=False) @property def body_size(self) -> float: """Absolute size of candle body""" if self._body_size is None: self._body_size = abs(self.close - self.open) return self._body_size @property def upper_wick(self) -> float: """Size of upper wick/shadow""" if self._upper_wick is None: self._upper_wick = self.high - max(self.open, self.close) return self._upper_wick @property def lower_wick(self) -> float: """Size of lower wick/shadow""" if self._lower_wick is None: self._lower_wick = min(self.open, self.close) - self.low return self._lower_wick @property def total_range(self) -> float: """Total high-low range""" if self._total_range is None: self._total_range = self.high - self.low return self._total_range @property def is_bullish(self) -> bool: """True if close > open (hollow/green candle)""" if self._is_bullish is None: self._is_bullish = self.close > self.open return self._is_bullish @property def is_bearish(self) -> bool: """True if close < open (solid/red candle)""" return not self.is_bullish and self.close != self.open @property def is_doji(self) -> bool: """True if open ≈ close (doji pattern)""" return self.body_size < (self.total_range * 0.1) if self.total_range > 0 else True def get_body_to_range_ratio(self) -> float: """Body size as percentage of total range (0.0 to 1.0)""" return self.body_size / self.total_range if self.total_range > 0 else 0.0 def get_upper_wick_ratio(self) -> float: """Upper wick as percentage of total range (0.0 to 1.0)""" return self.upper_wick / self.total_range if self.total_range > 0 else 0.0 def get_lower_wick_ratio(self) -> float: """Lower wick as percentage of total range (0.0 to 1.0)""" return self.lower_wick / self.total_range if self.total_range > 0 else 0.0 def get_relative_size(self, reference_bars: List['OHLCVBar'], method: str = 'avg') -> float: """ Get relative size compared to reference bars Args: reference_bars: List of previous bars for comparison method: 'avg' (average), 'max' (maximum), or 'median' Returns: Ratio of current range to reference (1.0 = same size, >1.0 = larger, <1.0 = smaller) """ if not reference_bars: return 1.0 reference_ranges = [bar.total_range for bar in reference_bars if bar.total_range > 0] if not reference_ranges: return 1.0 if method == 'avg': reference_value = np.mean(reference_ranges) elif method == 'max': reference_value = np.max(reference_ranges) elif method == 'median': reference_value = np.median(reference_ranges) else: reference_value = np.mean(reference_ranges) return self.total_range / reference_value if reference_value > 0 else 1.0 def get_candle_pattern(self) -> str: """ Identify basic candle pattern Returns: Pattern name: 'doji', 'hammer', 'shooting_star', 'spinning_top', 'marubozu_bullish', 'marubozu_bearish', 'standard' """ if self.total_range == 0: return 'doji' body_ratio = self.get_body_to_range_ratio() upper_ratio = self.get_upper_wick_ratio() lower_ratio = self.get_lower_wick_ratio() # Doji: very small body if body_ratio < 0.1: return 'doji' # Marubozu: very small wicks (>90% body) if body_ratio > 0.9: return 'marubozu_bullish' if self.is_bullish else 'marubozu_bearish' # Hammer: small body at top, long lower wick if body_ratio < 0.3 and lower_ratio > 0.6 and upper_ratio < 0.1: return 'hammer' # Shooting star: small body at bottom, long upper wick if body_ratio < 0.3 and upper_ratio > 0.6 and lower_ratio < 0.1: return 'shooting_star' # Spinning top: small body, both wicks present if body_ratio < 0.3 and (upper_ratio + lower_ratio) > 0.6: return 'spinning_top' return 'standard' def get_ta_features(self, reference_bars: Optional[List['OHLCVBar']] = None) -> Dict[str, float]: """ Get all technical analysis features as a dictionary Args: reference_bars: Optional list of previous bars for relative sizing Returns: Dictionary of TA features suitable for model input """ features = { # Basic candle properties 'is_bullish': 1.0 if self.is_bullish else 0.0, 'is_bearish': 1.0 if self.is_bearish else 0.0, 'is_doji': 1.0 if self.is_doji else 0.0, # Size ratios 'body_to_range_ratio': self.get_body_to_range_ratio(), 'upper_wick_ratio': self.get_upper_wick_ratio(), 'lower_wick_ratio': self.get_lower_wick_ratio(), # Absolute sizes (normalized by close price) 'body_size_pct': self.body_size / self.close if self.close > 0 else 0.0, 'upper_wick_pct': self.upper_wick / self.close if self.close > 0 else 0.0, 'lower_wick_pct': self.lower_wick / self.close if self.close > 0 else 0.0, 'total_range_pct': self.total_range / self.close if self.close > 0 else 0.0, # Volume relative to price movement 'volume_per_range': self.volume / self.total_range if self.total_range > 0 else 0.0, } # Add relative sizing if reference bars provided if reference_bars: features['relative_size_avg'] = self.get_relative_size(reference_bars, 'avg') features['relative_size_max'] = self.get_relative_size(reference_bars, 'max') features['relative_size_median'] = self.get_relative_size(reference_bars, 'median') # Add pattern encoding (one-hot style) pattern = self.get_candle_pattern() pattern_types = ['doji', 'hammer', 'shooting_star', 'spinning_top', 'marubozu_bullish', 'marubozu_bearish', 'standard'] for p in pattern_types: features[f'pattern_{p}'] = 1.0 if pattern == p else 0.0 return features @dataclass class PivotPoint: """Pivot point data structure""" symbol: str timestamp: datetime price: float type: str # 'high' or 'low' level: int # Pivot level (1, 2, 3, etc.) confidence: float = 1.0 @dataclass class ModelOutput: """Extensible model output format supporting all model types""" model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator' model_name: str # Specific model identifier symbol: str timestamp: datetime confidence: float predictions: Dict[str, Any] # Model-specific predictions hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding metadata: Dict[str, Any] = field(default_factory=dict) # Additional info @dataclass class COBData: """Cumulative Order Book data for price buckets""" symbol: str timestamp: datetime current_price: float bucket_size: float # $1 for ETH, $10 for BTC price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.} bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket order_flow_metrics: Dict[str, float] # Various order flow indicators # Moving averages of COB imbalance for ±5 buckets ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA @dataclass class NormalizationBounds: """Normalization boundaries for price and volume data""" price_min: float price_max: float volume_min: float volume_max: float symbol: str timeframe: str = 'all' # 'all' means across all timeframes def normalize_price(self, price: float) -> float: """Normalize price to 0-1 range""" if self.price_max == self.price_min: return 0.5 return (price - self.price_min) / (self.price_max - self.price_min) def denormalize_price(self, normalized: float) -> float: """Denormalize price from 0-1 range back to original""" return normalized * (self.price_max - self.price_min) + self.price_min def normalize_volume(self, volume: float) -> float: """Normalize volume to 0-1 range""" if self.volume_max == self.volume_min: return 0.5 return (volume - self.volume_min) / (self.volume_max - self.volume_min) def denormalize_volume(self, normalized: float) -> float: """Denormalize volume from 0-1 range back to original""" return normalized * (self.volume_max - self.volume_min) + self.volume_min def get_price_range(self) -> float: """Get price range""" return self.price_max - self.price_min def get_volume_range(self) -> float: """Get volume range""" return self.volume_max - self.volume_min @dataclass class BaseDataInput: """ Unified base data input for all models Standardized format ensures all models receive identical input structure: - OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC - COB: ±20 buckets of COB amounts in USD for each 1s OHLCV - MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets - All OHLCV data is normalized to 0-1 range based on daily (longest timeframe) min/max """ symbol: str # Primary symbol (ETH/USDT) timestamp: datetime # Multi-timeframe OHLCV data for primary symbol (ETH) ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data # Reference symbol (BTC) 1s data btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data # COB data for 1s timeframe (±20 buckets around current price) cob_data: Optional[COBData] = None # COB heatmap (time-series of bucket metrics at 1s resolution) # Each row corresponds to one second, columns to price buckets cob_heatmap_times: List[datetime] = field(default_factory=list) cob_heatmap_prices: List[float] = field(default_factory=list) cob_heatmap_values: List[List[float]] = field(default_factory=list) # typically imbalance per bucket # Technical indicators technical_indicators: Dict[str, float] = field(default_factory=dict) # Pivot points from Williams Market Structure pivot_points: List[PivotPoint] = field(default_factory=list) # Last predictions from all models (for cross-model feeding) last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # Market microstructure data market_microstructure: Dict[str, Any] = field(default_factory=dict) # Position and trading state information position_info: Dict[str, Any] = field(default_factory=dict) # Normalization boundaries (computed on-demand, cached) _normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False) _btc_normalization_bounds: Optional[NormalizationBounds] = field(default=None, repr=False) def _compute_normalization_bounds(self) -> NormalizationBounds: """ Compute normalization bounds from daily (longest timeframe) data Uses daily data as it has the widest price range, ensuring all shorter timeframes are normalized within 0-1 range. Returns: NormalizationBounds: Min/max for price and volume """ if self._normalization_bounds is not None: return self._normalization_bounds # Collect all OHLCV data, prioritizing daily for widest range all_prices = [] all_volumes = [] # Use daily data first (widest range) for bar in self.ohlcv_1d: all_prices.extend([bar.open, bar.high, bar.low, bar.close]) all_volumes.append(bar.volume) # Add other timeframes to ensure coverage for ohlcv_list in [self.ohlcv_1h, self.ohlcv_1m, self.ohlcv_1s]: for bar in ohlcv_list: all_prices.extend([bar.open, bar.high, bar.low, bar.close]) all_volumes.append(bar.volume) # Compute bounds if all_prices and all_volumes: price_min = min(all_prices) price_max = max(all_prices) volume_min = min(all_volumes) volume_max = max(all_volumes) else: # Fallback if no data price_min = price_max = 0.0 volume_min = volume_max = 0.0 self._normalization_bounds = NormalizationBounds( price_min=price_min, price_max=price_max, volume_min=volume_min, volume_max=volume_max, symbol=self.symbol, timeframe='all' ) return self._normalization_bounds def _compute_btc_normalization_bounds(self) -> NormalizationBounds: """ Compute normalization bounds for BTC data Returns: NormalizationBounds: Min/max for BTC price and volume """ if self._btc_normalization_bounds is not None: return self._btc_normalization_bounds all_prices = [] all_volumes = [] for bar in self.btc_ohlcv_1s: all_prices.extend([bar.open, bar.high, bar.low, bar.close]) all_volumes.append(bar.volume) if all_prices and all_volumes: price_min = min(all_prices) price_max = max(all_prices) volume_min = min(all_volumes) volume_max = max(all_volumes) else: price_min = price_max = 0.0 volume_min = volume_max = 0.0 self._btc_normalization_bounds = NormalizationBounds( price_min=price_min, price_max=price_max, volume_min=volume_min, volume_max=volume_max, symbol='BTC/USDT', timeframe='1s' ) return self._btc_normalization_bounds def get_normalization_bounds(self) -> NormalizationBounds: """Get normalization bounds for primary symbol (cached)""" return self._compute_normalization_bounds() def get_btc_normalization_bounds(self) -> NormalizationBounds: """Get normalization bounds for BTC (cached)""" return self._compute_btc_normalization_bounds() def get_feature_vector(self, include_candle_ta: bool = True, normalize: bool = True) -> np.ndarray: """ Convert BaseDataInput to standardized feature vector for models Args: include_candle_ta: If True, include enhanced candle TA features (default: True) normalize: If True, normalize OHLCV data to 0-1 range (default: True) Returns: np.ndarray: FIXED SIZE standardized feature vector (7870 or 22880 features) Note: - Full TA features are enabled by default for better model performance - Normalization uses daily (longest timeframe) min/max for primary symbol - BTC data is normalized independently using its own min/max - Normalization bounds are cached and accessible via get_normalization_bounds() - Includes pivot points metadata (10 features) for market structure context """ # FIXED FEATURE SIZE - this should NEVER change at runtime # Standard: 7870 features (7850 + 10 pivot + 10 more indicators) # With candle TA: 22880 features (22850 + 10 pivot + 10 more indicators) FIXED_FEATURE_SIZE = 22880 if include_candle_ta else 7870 features = [] # Get normalization bounds (cached) if normalize: norm_bounds = self._compute_normalization_bounds() # OHLCV features for ETH (up to 300 frames x 4 timeframes x 5 or 15 features) for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]: # Use actual data only, up to 300 frames ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list # Extract features from actual frames for i, bar in enumerate(ohlcv_frames): # Basic OHLCV (5 features) - normalized to 0-1 range if normalize: features.extend([ norm_bounds.normalize_price(bar.open), norm_bounds.normalize_price(bar.high), norm_bounds.normalize_price(bar.low), norm_bounds.normalize_price(bar.close), norm_bounds.normalize_volume(bar.volume) ]) else: features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume]) # Enhanced candle TA features (10 additional features per bar) if include_candle_ta: # Get reference bars for relative sizing (last 10 bars) ref_start = max(0, i - 10) reference_bars = ohlcv_frames[ref_start:i] if i > 0 else [] ta_features = bar.get_ta_features(reference_bars) # Extract key features in fixed order features.extend([ ta_features.get('is_bullish', 0.0), ta_features.get('body_to_range_ratio', 0.0), ta_features.get('upper_wick_ratio', 0.0), ta_features.get('lower_wick_ratio', 0.0), ta_features.get('body_size_pct', 0.0), ta_features.get('total_range_pct', 0.0), ta_features.get('relative_size_avg', 1.0), ta_features.get('pattern_doji', 0.0), ta_features.get('pattern_hammer', 0.0), ta_features.get('pattern_shooting_star', 0.0), ]) # Pad with zeros only if we have some data but less than 300 frames frames_needed = 300 - len(ohlcv_frames) if frames_needed > 0: features_per_frame = 15 if include_candle_ta else 5 features.extend([0.0] * (frames_needed * features_per_frame)) # BTC OHLCV features (up to 300 frames x 5 or 15 features) btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s # Get BTC normalization bounds (cached, independent from primary symbol) if normalize: btc_norm_bounds = self._compute_btc_normalization_bounds() # Extract features from actual BTC frames for i, bar in enumerate(btc_frames): # Basic OHLCV (5 features) - normalized to 0-1 range if normalize: features.extend([ btc_norm_bounds.normalize_price(bar.open), btc_norm_bounds.normalize_price(bar.high), btc_norm_bounds.normalize_price(bar.low), btc_norm_bounds.normalize_price(bar.close), btc_norm_bounds.normalize_volume(bar.volume) ]) else: features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume]) # Enhanced candle TA features (10 additional features per bar) if include_candle_ta: ref_start = max(0, i - 10) reference_bars = btc_frames[ref_start:i] if i > 0 else [] ta_features = bar.get_ta_features(reference_bars) features.extend([ ta_features.get('is_bullish', 0.0), ta_features.get('body_to_range_ratio', 0.0), ta_features.get('upper_wick_ratio', 0.0), ta_features.get('lower_wick_ratio', 0.0), ta_features.get('body_size_pct', 0.0), ta_features.get('total_range_pct', 0.0), ta_features.get('relative_size_avg', 1.0), ta_features.get('pattern_doji', 0.0), ta_features.get('pattern_hammer', 0.0), ta_features.get('pattern_shooting_star', 0.0), ]) # Pad with zeros only if we have some data but less than 300 frames btc_frames_needed = 300 - len(btc_frames) if btc_frames_needed > 0: features_per_frame = 15 if include_candle_ta else 5 features.extend([0.0] * (btc_frames_needed * features_per_frame)) # COB features (FIXED SIZE: 200 features) cob_features = [] if self.cob_data: # Price bucket features (up to 40 buckets x 4 metrics = 160 features) price_keys = sorted(self.cob_data.price_buckets.keys())[:40] # Max 40 buckets for price in price_keys: bucket_data = self.cob_data.price_buckets[price] cob_features.extend([ bucket_data.get('bid_volume', 0.0), bucket_data.get('ask_volume', 0.0), bucket_data.get('total_volume', 0.0), bucket_data.get('imbalance', 0.0) ]) # Moving averages (up to 10 features) ma_features = [] for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance]: for price in sorted(list(ma_dict.keys())[:5]): # Max 5 buckets per MA ma_features.append(ma_dict[price]) if len(ma_features) >= 10: break if len(ma_features) >= 10: break cob_features.extend(ma_features) # Add REAL aggregated COB heatmap features to fill remaining COB slots (no synthetic data) # We compute per-bucket means over the most recent window (up to 300s) and a few global stats try: if self.cob_heatmap_values and self.cob_heatmap_prices: z = np.array(self.cob_heatmap_values, dtype=float) if z.ndim == 2 and z.size > 0: # Use up to the last 300 seconds (or whatever is available) window_rows = z[-300:] if z.shape[0] >= 300 else z # Replace NaNs with 0.0 to respect the no-synthetic rule but avoid NaN propagation window_rows = np.nan_to_num(window_rows, nan=0.0, posinf=0.0, neginf=0.0) # Per-bucket mean imbalance/liquidity across time per_bucket_mean = window_rows.mean(axis=0).tolist() space_left = 200 - len(cob_features) if space_left > 0 and len(per_bucket_mean) > 0: cob_features.extend(per_bucket_mean[:space_left]) # If there is still space, add compact global stats over the window space_left = 200 - len(cob_features) if space_left > 0: flat = window_rows.reshape(-1) if flat.size > 0: global_mean = float(np.mean(flat)) global_std = float(np.std(flat)) global_max = float(np.max(flat)) global_min = float(np.min(flat)) global_stats = [global_mean, global_std, global_max, global_min] cob_features.extend(global_stats[:space_left]) except Exception: # On any error, skip heatmap-derived features (remaining space will be zero-padded below) pass # Pad COB features to exactly 200 cob_features.extend([0.0] * (200 - len(cob_features))) features.extend(cob_features[:200]) # Ensure exactly 200 COB features # Technical indicators (FIXED SIZE: 110 features - expanded to accommodate more indicators) indicator_values = list(self.technical_indicators.values()) features.extend(indicator_values[:110]) # Take first 110 indicators features.extend([0.0] * max(0, 110 - len(indicator_values))) # Pad to exactly 110 # Pivot points metadata (FIXED SIZE: 10 features) # Extract pivot context from most recent OHLCV bars pivot_features = [] if self.ohlcv_1m and len(self.ohlcv_1m) > 0: latest_bar = self.ohlcv_1m[-1] pivot_features.extend([ latest_bar.pivot_distance_to_support if latest_bar.pivot_distance_to_support is not None else 0.0, latest_bar.pivot_distance_to_resistance if latest_bar.pivot_distance_to_resistance is not None else 0.0, 1.0 if latest_bar.near_pivot_support else 0.0, 1.0 if latest_bar.near_pivot_resistance else 0.0, ]) # Add pivot level context if available if latest_bar.pivot_level_context: ctx = latest_bar.pivot_level_context pivot_features.extend([ ctx.get('trend_strength', 0.0), ctx.get('support_count', 0.0), ctx.get('resistance_count', 0.0), ctx.get('price_position_in_range', 0.5), # 0=at support, 1=at resistance ctx.get('distance_to_nearest_level', 0.0), ctx.get('level_strength', 0.0), ]) else: pivot_features.extend([0.0] * 6) else: pivot_features = [0.0] * 10 # Ensure exactly 10 pivot features pivot_features = pivot_features[:10] pivot_features.extend([0.0] * (10 - len(pivot_features))) features.extend(pivot_features) # Last predictions from other models (FIXED SIZE: 45 features) prediction_features = [] for model_output in self.last_predictions.values(): prediction_features.extend([ model_output.confidence, model_output.predictions.get('buy_probability', 0.0), model_output.predictions.get('sell_probability', 0.0), model_output.predictions.get('hold_probability', 0.0), model_output.predictions.get('expected_reward', 0.0) ]) features.extend(prediction_features[:45]) # Take first 45 prediction features features.extend([0.0] * max(0, 45 - len(prediction_features))) # Pad to exactly 45 # Position and trading state information (FIXED SIZE: 5 features) position_features = [ 1.0 if self.position_info.get('has_position', False) else 0.0, self.position_info.get('position_pnl', 0.0), self.position_info.get('position_size', 0.0), self.position_info.get('entry_price', 0.0), self.position_info.get('time_in_position_minutes', 0.0) ] features.extend(position_features) # Exactly 5 position features # CRITICAL: Ensure EXACTLY the fixed feature size if len(features) > FIXED_FEATURE_SIZE: features = features[:FIXED_FEATURE_SIZE] # Truncate if too long elif len(features) < FIXED_FEATURE_SIZE: features.extend([0.0] * (FIXED_FEATURE_SIZE - len(features))) # Pad if too short assert len(features) == FIXED_FEATURE_SIZE, f"Feature vector size mismatch: {len(features)} != {FIXED_FEATURE_SIZE}" return np.array(features, dtype=np.float32) def validate(self) -> bool: """ Validate that the BaseDataInput contains required data Returns: bool: True if valid, False otherwise """ # Check that we have required OHLCV data if len(self.ohlcv_1s) < 100: # At least 100 frames return False if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data return False # Check that timestamps are reasonable if not self.timestamp: return False # Check symbol format if not self.symbol or '/' not in self.symbol: return False return True @dataclass class TradingAction: """Trading action output from models""" symbol: str timestamp: datetime action: str # 'BUY', 'SELL', 'HOLD' confidence: float source: str # 'rl', 'cnn', 'orchestrator' price: Optional[float] = None quantity: Optional[float] = None reason: Optional[str] = None def create_model_output(model_type: str, model_name: str, symbol: str, action: str, confidence: float, hidden_states: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None) -> ModelOutput: """ Helper function to create standardized ModelOutput Args: model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator') model_name: Specific model identifier symbol: Trading symbol action: Trading action ('BUY', 'SELL', 'HOLD') confidence: Confidence score (0.0 to 1.0) hidden_states: Optional hidden states for cross-model feeding metadata: Optional additional metadata Returns: ModelOutput: Standardized model output """ predictions = { 'action': action, 'buy_probability': confidence if action == 'BUY' else 0.0, 'sell_probability': confidence if action == 'SELL' else 0.0, 'hold_probability': confidence if action == 'HOLD' else 0.0, } return ModelOutput( model_type=model_type, model_name=model_name, symbol=symbol, timestamp=datetime.now(), confidence=confidence, predictions=predictions, hidden_states=hidden_states or {}, metadata=metadata or {} )