""" Enhanced Trading Orchestrator Central coordination hub for the multi-modal trading system that manages: - Data subscription and management - Model inference coordination - Cross-model data feeding - Training pipeline orchestration - Decision making using Mixture of Experts """ import asyncio import logging import numpy as np from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import dataclass, field from core.data_provider import DataProvider from core.trading_action import TradingAction from utils.tensorboard_logger import TensorBoardLogger logger = logging.getLogger(__name__) @dataclass class ModelOutput: """Extensible model output format supporting all model types""" model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator' model_name: str # Specific model identifier symbol: str timestamp: datetime confidence: float predictions: Dict[str, Any] # Model-specific predictions hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding metadata: Dict[str, Any] = field(default_factory=dict) # Additional info @dataclass class BaseDataInput: """Unified base data input for all models""" symbol: str timestamp: datetime ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe technical_indicators: Dict[str, float] = field(default_factory=dict) pivot_points: List[Any] = field(default_factory=list) last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc. @dataclass class COBData: """Cumulative Order Book data for price buckets""" symbol: str timestamp: datetime current_price: float bucket_size: float # $1 for ETH, $10 for BTC price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.} bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators class EnhancedTradingOrchestrator: """ Enhanced Trading Orchestrator implementing the design specification Coordinates data flow, model inference, and decision making for the multi-modal trading system. """ def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None): """Initialize the enhanced orchestrator""" self.data_provider = data_provider self.symbols = symbols self.enhanced_rl_training = enhanced_rl_training self.model_registry = model_registry or {} # Data management self.data_buffers = {symbol: {} for symbol in symbols} self.last_update_times = {symbol: {} for symbol in symbols} # Model output storage self.model_outputs = {symbol: {} for symbol in symbols} self.model_output_history = {symbol: {} for symbol in symbols} # Training pipeline self.training_data = {symbol: [] for symbol in symbols} self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}") # COB integration self.cob_data = {symbol: None for symbol in symbols} # Performance tracking self.performance_metrics = { 'inference_count': 0, 'successful_states': 0, 'total_episodes': 0 } logger.info("Enhanced Trading Orchestrator initialized") async def start_cob_integration(self): """Start COB data integration for real-time market microstructure""" try: # Subscribe to COB data updates self.data_provider.subscribe_to_cob_data(self._on_cob_data_update) logger.info("COB integration started") except Exception as e: logger.error(f"Error starting COB integration: {e}") async def start_realtime_processing(self): """Start real-time data processing""" try: # Subscribe to tick data for real-time processing for symbol in self.symbols: self.data_provider.subscribe_to_ticks( callback=self._on_tick_data, symbols=[symbol], subscriber_name=f"orchestrator_{symbol}" ) logger.info("Real-time processing started") except Exception as e: logger.error(f"Error starting real-time processing: {e}") def _on_cob_data_update(self, symbol: str, cob_data: dict): """Handle COB data updates""" try: # Process and store COB data self.cob_data[symbol] = self._process_cob_data(symbol, cob_data) logger.debug(f"COB data updated for {symbol}") except Exception as e: logger.error(f"Error processing COB data for {symbol}: {e}") def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData: """Process raw COB data into structured format""" try: # Determine bucket size based on symbol bucket_size = 1.0 if 'ETH' in symbol else 10.0 # Extract current price stats = cob_data.get('stats', {}) current_price = stats.get('mid_price', 0) # Create COB data structure cob = COBData( symbol=symbol, timestamp=datetime.now(), current_price=current_price, bucket_size=bucket_size ) # Process order book data into price buckets bids = cob_data.get('bids', []) asks = cob_data.get('asks', []) # Create price buckets around current price bucket_count = 20 # ±20 buckets for i in range(-bucket_count, bucket_count + 1): bucket_price = current_price + (i * bucket_size) cob.price_buckets[bucket_price] = { 'bid_volume': 0.0, 'ask_volume': 0.0 } # Aggregate bid volumes into buckets for price, volume in bids: bucket_price = round(price / bucket_size) * bucket_size if bucket_price in cob.price_buckets: cob.price_buckets[bucket_price]['bid_volume'] += volume # Aggregate ask volumes into buckets for price, volume in asks: bucket_price = round(price / bucket_size) * bucket_size if bucket_price in cob.price_buckets: cob.price_buckets[bucket_price]['ask_volume'] += volume # Calculate bid/ask imbalances for price, volumes in cob.price_buckets.items(): bid_vol = volumes['bid_volume'] ask_vol = volumes['ask_volume'] total_vol = bid_vol + ask_vol if total_vol > 0: cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol else: cob.bid_ask_imbalance[price] = 0.0 # Calculate volume-weighted prices for price, volumes in cob.price_buckets.items(): bid_vol = volumes['bid_volume'] ask_vol = volumes['ask_volume'] total_vol = bid_vol + ask_vol if total_vol > 0: cob.volume_weighted_prices[price] = ( (price * bid_vol) + (price * ask_vol) ) / total_vol else: cob.volume_weighted_prices[price] = price # Calculate order flow metrics cob.order_flow_metrics = { 'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()), 'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()), 'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume'] } return cob except Exception as e: logger.error(f"Error processing COB data for {symbol}: {e}") return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size) def _on_tick_data(self, tick): """Handle incoming tick data""" try: # Update data buffers symbol = tick.symbol if symbol not in self.data_buffers: self.data_buffers[symbol] = {} # Store tick data if 'ticks' not in self.data_buffers[symbol]: self.data_buffers[symbol]['ticks'] = [] self.data_buffers[symbol]['ticks'].append(tick) # Keep only last 1000 ticks if len(self.data_buffers[symbol]['ticks']) > 1000: self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:] # Update last update time self.last_update_times[symbol]['tick'] = datetime.now() logger.debug(f"Tick data updated for {symbol}") except Exception as e: logger.error(f"Error processing tick data: {e}") def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]: """ Build comprehensive RL state with 13,400 features as specified Returns: np.ndarray: State vector with 13,400 features """ try: # Initialize state vector state_size = 13400 state = np.zeros(state_size, dtype=np.float32) # Get latest data ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100) cob_data = self.cob_data.get(symbol) # Feature index tracking idx = 0 # 1. OHLCV features (4000 features) if ohlcv_data is not None and not ohlcv_data.empty: # Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators) for i in range(min(100, len(ohlcv_data))): if idx + 40 <= state_size: row = ohlcv_data.iloc[-(i+1)] state[idx] = row.get('open', 0) / 100000 # Normalized state[idx+1] = row.get('high', 0) / 100000 state[idx+2] = row.get('low', 0) / 100000 state[idx+3] = row.get('close', 0) / 100000 state[idx+4] = row.get('volume', 0) / 1000000 # Add technical indicators if available indicator_idx = 5 for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14', 'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']: if col in row and idx + indicator_idx < state_size: state[idx + indicator_idx] = row[col] / 100000 indicator_idx += 1 idx += 40 # 2. COB features (8000 features) if cob_data and idx + 8000 <= state_size: # Use 200 price buckets (40 features each) bucket_prices = sorted(cob_data.price_buckets.keys()) for i, price in enumerate(bucket_prices[:200]): if idx + 40 <= state_size: bucket = cob_data.price_buckets[price] state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized state[idx+1] = bucket.get('ask_volume', 0) / 1000000 state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0) state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000 # Additional COB metrics state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000 state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000 state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0) idx += 40 # 3. Technical indicator features (1000 features) # Already included in OHLCV section above # 4. Market microstructure features (400 features) if cob_data and idx + 400 <= state_size: # Add order flow metrics metrics = list(cob_data.order_flow_metrics.values()) for i, metric in enumerate(metrics[:400]): if idx + i < state_size: state[idx + i] = metric # Log state building success self.performance_metrics['successful_states'] += 1 logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features") # Log to TensorBoard self.tensorboard_logger.log_state_metrics( symbol=symbol, state_info={ 'size': len(state), 'quality': 1.0, 'feature_counts': { 'total': len(state), 'non_zero': np.count_nonzero(state) } }, step=self.performance_metrics['successful_states'] ) return state except Exception as e: logger.error(f"Error building comprehensive RL state for {symbol}: {e}") return None def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float: """ Calculate enhanced pivot-based reward Args: trade_decision: Trading decision with action and confidence market_data: Market context data trade_outcome: Actual trade results Returns: float: Enhanced reward value """ try: # Base reward from PnL pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize # Confidence weighting confidence = trade_decision.get('confidence', 0.5) confidence_reward = confidence * 0.2 # Volatility adjustment volatility = market_data.get('volatility', 0.01) volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility # Order flow alignment order_flow = market_data.get('order_flow_strength', 0) order_flow_reward = order_flow * 0.2 # Pivot alignment bonus (if near pivot in favorable direction) pivot_bonus = 0.0 if market_data.get('near_pivot', False): action = trade_decision.get('action', '').upper() pivot_type = market_data.get('pivot_type', '').upper() # Bonus for buying near support or selling near resistance if (action == 'BUY' and pivot_type == 'LOW') or \ (action == 'SELL' and pivot_type == 'HIGH'): pivot_bonus = 0.5 # Calculate final reward enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus # Log to TensorBoard self.tensorboard_logger.log_scalars('Rewards/Components', { 'pnl_component': pnl_reward, 'confidence': confidence_reward, 'volatility': volatility_reward, 'order_flow': order_flow_reward, 'pivot_bonus': pivot_bonus }, self.performance_metrics['total_episodes']) self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes']) logger.debug(f"Enhanced reward calculated: {enhanced_reward}") return enhanced_reward except Exception as e: logger.error(f"Error calculating enhanced pivot reward: {e}") return 0.0 async def make_coordinated_decisions(self) -> Dict[str, TradingAction]: """ Make coordinated trading decisions using all available models Returns: Dict[str, TradingAction]: Trading actions for each symbol """ try: decisions = {} # For each symbol, coordinate model inference for symbol in self.symbols: # Build comprehensive state for RL model rl_state = self.build_comprehensive_rl_state(symbol) if rl_state is not None: # Store state for training self.performance_metrics['total_episodes'] += 1 # Create mock RL decision (in a real implementation, this would call the RL model) action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL' confidence = min(1.0, max(0.0, np.std(rl_state) * 10)) # Create trading action decisions[symbol] = TradingAction( symbol=symbol, timestamp=datetime.now(), action=action, confidence=confidence, source='rl_orchestrator' ) logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})") else: logger.warning(f"Failed to build state for {symbol}, skipping decision") self.performance_metrics['inference_count'] += 1 return decisions except Exception as e: logger.error(f"Error making coordinated decisions: {e}") return {} def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float: """ Calculate correlation between two symbols Args: symbol1: First symbol symbol2: Second symbol Returns: float: Correlation coefficient (-1 to 1) """ try: # Get recent price data for both symbols data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50) data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50) if data1 is None or data2 is None or data1.empty or data2.empty: return 0.0 # Align data by timestamp merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner') if len(merged) < 10: return 0.0 # Calculate correlation correlation = merged['close_1'].corr(merged['close_2']) return correlation if not np.isnan(correlation) else 0.0 except Exception as e: logger.error(f"Error calculating symbol correlation: {e}") return 0.0 ```