464 lines
20 KiB
Python
464 lines
20 KiB
Python
"""
|
|
Enhanced Trading Orchestrator
|
|
|
|
Central coordination hub for the multi-modal trading system that manages:
|
|
- Data subscription and management
|
|
- Model inference coordination
|
|
- Cross-model data feeding
|
|
- Training pipeline orchestration
|
|
- Decision making using Mixture of Experts
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
|
|
from core.data_provider import DataProvider
|
|
from core.trading_action import TradingAction
|
|
from utils.tensorboard_logger import TensorBoardLogger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ModelOutput:
|
|
"""Extensible model output format supporting all model types"""
|
|
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
|
model_name: str # Specific model identifier
|
|
symbol: str
|
|
timestamp: datetime
|
|
confidence: float
|
|
predictions: Dict[str, Any] # Model-specific predictions
|
|
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
|
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
|
|
|
@dataclass
|
|
class BaseDataInput:
|
|
"""Unified base data input for all models"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
ohlcv_data: Dict[str, Any] = field(default_factory=dict) # Multi-timeframe OHLCV
|
|
cob_data: Optional[Dict[str, Any]] = None # COB buckets for 1s timeframe
|
|
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
|
pivot_points: List[Any] = field(default_factory=list)
|
|
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
|
|
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
|
|
|
|
@dataclass
|
|
class COBData:
|
|
"""Cumulative Order Book data for price buckets"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
current_price: float
|
|
bucket_size: float # $1 for ETH, $10 for BTC
|
|
price_buckets: Dict[float, Dict[str, float]] = field(default_factory=dict) # price -> {bid_volume, ask_volume, etc.}
|
|
bid_ask_imbalance: Dict[float, float] = field(default_factory=dict) # price -> imbalance ratio
|
|
volume_weighted_prices: Dict[float, float] = field(default_factory=dict) # price -> VWAP within bucket
|
|
order_flow_metrics: Dict[str, float] = field(default_factory=dict) # Various order flow indicators
|
|
|
|
class EnhancedTradingOrchestrator:
|
|
"""
|
|
Enhanced Trading Orchestrator implementing the design specification
|
|
|
|
Coordinates data flow, model inference, and decision making for the multi-modal trading system.
|
|
"""
|
|
|
|
def __init__(self, data_provider: DataProvider, symbols: List[str], enhanced_rl_training: bool = False, model_registry: Dict = None):
|
|
"""Initialize the enhanced orchestrator"""
|
|
self.data_provider = data_provider
|
|
self.symbols = symbols
|
|
self.enhanced_rl_training = enhanced_rl_training
|
|
self.model_registry = model_registry or {}
|
|
|
|
# Data management
|
|
self.data_buffers = {symbol: {} for symbol in symbols}
|
|
self.last_update_times = {symbol: {} for symbol in symbols}
|
|
|
|
# Model output storage
|
|
self.model_outputs = {symbol: {} for symbol in symbols}
|
|
self.model_output_history = {symbol: {} for symbol in symbols}
|
|
|
|
# Training pipeline
|
|
self.training_data = {symbol: [] for symbol in symbols}
|
|
self.tensorboard_logger = TensorBoardLogger("runs", f"orchestrator_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
|
|
# COB integration
|
|
self.cob_data = {symbol: None for symbol in symbols}
|
|
|
|
# Performance tracking
|
|
self.performance_metrics = {
|
|
'inference_count': 0,
|
|
'successful_states': 0,
|
|
'total_episodes': 0
|
|
}
|
|
|
|
logger.info("Enhanced Trading Orchestrator initialized")
|
|
|
|
async def start_cob_integration(self):
|
|
"""Start COB data integration for real-time market microstructure"""
|
|
try:
|
|
# Subscribe to COB data updates
|
|
self.data_provider.subscribe_to_cob_data(self._on_cob_data_update)
|
|
logger.info("COB integration started")
|
|
except Exception as e:
|
|
logger.error(f"Error starting COB integration: {e}")
|
|
|
|
async def start_realtime_processing(self):
|
|
"""Start real-time data processing"""
|
|
try:
|
|
# Subscribe to tick data for real-time processing
|
|
for symbol in self.symbols:
|
|
self.data_provider.subscribe_to_ticks(
|
|
callback=self._on_tick_data,
|
|
symbols=[symbol],
|
|
subscriber_name=f"orchestrator_{symbol}"
|
|
)
|
|
|
|
logger.info("Real-time processing started")
|
|
except Exception as e:
|
|
logger.error(f"Error starting real-time processing: {e}")
|
|
|
|
def _on_cob_data_update(self, symbol: str, cob_data: dict):
|
|
"""Handle COB data updates"""
|
|
try:
|
|
# Process and store COB data
|
|
self.cob_data[symbol] = self._process_cob_data(symbol, cob_data)
|
|
logger.debug(f"COB data updated for {symbol}")
|
|
except Exception as e:
|
|
logger.error(f"Error processing COB data for {symbol}: {e}")
|
|
|
|
def _process_cob_data(self, symbol: str, cob_data: dict) -> COBData:
|
|
"""Process raw COB data into structured format"""
|
|
try:
|
|
# Determine bucket size based on symbol
|
|
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
|
|
|
# Extract current price
|
|
stats = cob_data.get('stats', {})
|
|
current_price = stats.get('mid_price', 0)
|
|
|
|
# Create COB data structure
|
|
cob = COBData(
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
current_price=current_price,
|
|
bucket_size=bucket_size
|
|
)
|
|
|
|
# Process order book data into price buckets
|
|
bids = cob_data.get('bids', [])
|
|
asks = cob_data.get('asks', [])
|
|
|
|
# Create price buckets around current price
|
|
bucket_count = 20 # ±20 buckets
|
|
for i in range(-bucket_count, bucket_count + 1):
|
|
bucket_price = current_price + (i * bucket_size)
|
|
cob.price_buckets[bucket_price] = {
|
|
'bid_volume': 0.0,
|
|
'ask_volume': 0.0
|
|
}
|
|
|
|
# Aggregate bid volumes into buckets
|
|
for price, volume in bids:
|
|
bucket_price = round(price / bucket_size) * bucket_size
|
|
if bucket_price in cob.price_buckets:
|
|
cob.price_buckets[bucket_price]['bid_volume'] += volume
|
|
|
|
# Aggregate ask volumes into buckets
|
|
for price, volume in asks:
|
|
bucket_price = round(price / bucket_size) * bucket_size
|
|
if bucket_price in cob.price_buckets:
|
|
cob.price_buckets[bucket_price]['ask_volume'] += volume
|
|
|
|
# Calculate bid/ask imbalances
|
|
for price, volumes in cob.price_buckets.items():
|
|
bid_vol = volumes['bid_volume']
|
|
ask_vol = volumes['ask_volume']
|
|
total_vol = bid_vol + ask_vol
|
|
if total_vol > 0:
|
|
cob.bid_ask_imbalance[price] = (bid_vol - ask_vol) / total_vol
|
|
else:
|
|
cob.bid_ask_imbalance[price] = 0.0
|
|
|
|
# Calculate volume-weighted prices
|
|
for price, volumes in cob.price_buckets.items():
|
|
bid_vol = volumes['bid_volume']
|
|
ask_vol = volumes['ask_volume']
|
|
total_vol = bid_vol + ask_vol
|
|
if total_vol > 0:
|
|
cob.volume_weighted_prices[price] = (
|
|
(price * bid_vol) + (price * ask_vol)
|
|
) / total_vol
|
|
else:
|
|
cob.volume_weighted_prices[price] = price
|
|
|
|
# Calculate order flow metrics
|
|
cob.order_flow_metrics = {
|
|
'total_bid_volume': sum(v['bid_volume'] for v in cob.price_buckets.values()),
|
|
'total_ask_volume': sum(v['ask_volume'] for v in cob.price_buckets.values()),
|
|
'bid_ask_ratio': 0.0 if cob.order_flow_metrics['total_ask_volume'] == 0 else
|
|
cob.order_flow_metrics['total_bid_volume'] / cob.order_flow_metrics['total_ask_volume']
|
|
}
|
|
|
|
return cob
|
|
except Exception as e:
|
|
logger.error(f"Error processing COB data for {symbol}: {e}")
|
|
return COBData(symbol=symbol, timestamp=datetime.now(), current_price=0, bucket_size=bucket_size)
|
|
|
|
def _on_tick_data(self, tick):
|
|
"""Handle incoming tick data"""
|
|
try:
|
|
# Update data buffers
|
|
symbol = tick.symbol
|
|
if symbol not in self.data_buffers:
|
|
self.data_buffers[symbol] = {}
|
|
|
|
# Store tick data
|
|
if 'ticks' not in self.data_buffers[symbol]:
|
|
self.data_buffers[symbol]['ticks'] = []
|
|
self.data_buffers[symbol]['ticks'].append(tick)
|
|
|
|
# Keep only last 1000 ticks
|
|
if len(self.data_buffers[symbol]['ticks']) > 1000:
|
|
self.data_buffers[symbol]['ticks'] = self.data_buffers[symbol]['ticks'][-1000:]
|
|
|
|
# Update last update time
|
|
self.last_update_times[symbol]['tick'] = datetime.now()
|
|
|
|
logger.debug(f"Tick data updated for {symbol}")
|
|
except Exception as e:
|
|
logger.error(f"Error processing tick data: {e}")
|
|
|
|
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
|
"""
|
|
Build comprehensive RL state with 13,400 features as specified
|
|
|
|
Returns:
|
|
np.ndarray: State vector with 13,400 features
|
|
"""
|
|
try:
|
|
# Initialize state vector
|
|
state_size = 13400
|
|
state = np.zeros(state_size, dtype=np.float32)
|
|
|
|
# Get latest data
|
|
ohlcv_data = self.data_provider.get_latest_candles(symbol, '1s', limit=100)
|
|
cob_data = self.cob_data.get(symbol)
|
|
|
|
# Feature index tracking
|
|
idx = 0
|
|
|
|
# 1. OHLCV features (4000 features)
|
|
if ohlcv_data is not None and not ohlcv_data.empty:
|
|
# Use last 100 1s candles (40 features each: O,H,L,C,V + 36 indicators)
|
|
for i in range(min(100, len(ohlcv_data))):
|
|
if idx + 40 <= state_size:
|
|
row = ohlcv_data.iloc[-(i+1)]
|
|
state[idx] = row.get('open', 0) / 100000 # Normalized
|
|
state[idx+1] = row.get('high', 0) / 100000
|
|
state[idx+2] = row.get('low', 0) / 100000
|
|
state[idx+3] = row.get('close', 0) / 100000
|
|
state[idx+4] = row.get('volume', 0) / 1000000
|
|
|
|
# Add technical indicators if available
|
|
indicator_idx = 5
|
|
for col in ['sma_10', 'sma_20', 'ema_12', 'ema_26', 'rsi_14',
|
|
'macd', 'bb_upper', 'bb_lower', 'atr', 'adx']:
|
|
if col in row and idx + indicator_idx < state_size:
|
|
state[idx + indicator_idx] = row[col] / 100000
|
|
indicator_idx += 1
|
|
|
|
idx += 40
|
|
|
|
# 2. COB features (8000 features)
|
|
if cob_data and idx + 8000 <= state_size:
|
|
# Use 200 price buckets (40 features each)
|
|
bucket_prices = sorted(cob_data.price_buckets.keys())
|
|
for i, price in enumerate(bucket_prices[:200]):
|
|
if idx + 40 <= state_size:
|
|
bucket = cob_data.price_buckets[price]
|
|
state[idx] = bucket.get('bid_volume', 0) / 1000000 # Normalized
|
|
state[idx+1] = bucket.get('ask_volume', 0) / 1000000
|
|
state[idx+2] = cob_data.bid_ask_imbalance.get(price, 0)
|
|
state[idx+3] = cob_data.volume_weighted_prices.get(price, price) / 100000
|
|
|
|
# Additional COB metrics
|
|
state[idx+4] = cob_data.order_flow_metrics.get('total_bid_volume', 0) / 10000000
|
|
state[idx+5] = cob_data.order_flow_metrics.get('total_ask_volume', 0) / 10000000
|
|
state[idx+6] = cob_data.order_flow_metrics.get('bid_ask_ratio', 0)
|
|
|
|
idx += 40
|
|
|
|
# 3. Technical indicator features (1000 features)
|
|
# Already included in OHLCV section above
|
|
|
|
# 4. Market microstructure features (400 features)
|
|
if cob_data and idx + 400 <= state_size:
|
|
# Add order flow metrics
|
|
metrics = list(cob_data.order_flow_metrics.values())
|
|
for i, metric in enumerate(metrics[:400]):
|
|
if idx + i < state_size:
|
|
state[idx + i] = metric
|
|
|
|
# Log state building success
|
|
self.performance_metrics['successful_states'] += 1
|
|
logger.debug(f"Comprehensive RL state built for {symbol}: {len(state)} features")
|
|
|
|
# Log to TensorBoard
|
|
self.tensorboard_logger.log_state_metrics(
|
|
symbol=symbol,
|
|
state_info={
|
|
'size': len(state),
|
|
'quality': 1.0,
|
|
'feature_counts': {
|
|
'total': len(state),
|
|
'non_zero': np.count_nonzero(state)
|
|
}
|
|
},
|
|
step=self.performance_metrics['successful_states']
|
|
)
|
|
|
|
return state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building comprehensive RL state for {symbol}: {e}")
|
|
return None
|
|
|
|
def calculate_enhanced_pivot_reward(self, trade_decision: Dict, market_data: Dict, trade_outcome: Dict) -> float:
|
|
"""
|
|
Calculate enhanced pivot-based reward
|
|
|
|
Args:
|
|
trade_decision: Trading decision with action and confidence
|
|
market_data: Market context data
|
|
trade_outcome: Actual trade results
|
|
|
|
Returns:
|
|
float: Enhanced reward value
|
|
"""
|
|
try:
|
|
# Base reward from PnL
|
|
pnl_reward = trade_outcome.get('net_pnl', 0) / 100 # Normalize
|
|
|
|
# Confidence weighting
|
|
confidence = trade_decision.get('confidence', 0.5)
|
|
confidence_reward = confidence * 0.2
|
|
|
|
# Volatility adjustment
|
|
volatility = market_data.get('volatility', 0.01)
|
|
volatility_reward = (1.0 - volatility * 10) * 0.1 # Prefer low volatility
|
|
|
|
# Order flow alignment
|
|
order_flow = market_data.get('order_flow_strength', 0)
|
|
order_flow_reward = order_flow * 0.2
|
|
|
|
# Pivot alignment bonus (if near pivot in favorable direction)
|
|
pivot_bonus = 0.0
|
|
if market_data.get('near_pivot', False):
|
|
action = trade_decision.get('action', '').upper()
|
|
pivot_type = market_data.get('pivot_type', '').upper()
|
|
|
|
# Bonus for buying near support or selling near resistance
|
|
if (action == 'BUY' and pivot_type == 'LOW') or \
|
|
(action == 'SELL' and pivot_type == 'HIGH'):
|
|
pivot_bonus = 0.5
|
|
|
|
# Calculate final reward
|
|
enhanced_reward = pnl_reward + confidence_reward + volatility_reward + order_flow_reward + pivot_bonus
|
|
|
|
# Log to TensorBoard
|
|
self.tensorboard_logger.log_scalars('Rewards/Components', {
|
|
'pnl_component': pnl_reward,
|
|
'confidence': confidence_reward,
|
|
'volatility': volatility_reward,
|
|
'order_flow': order_flow_reward,
|
|
'pivot_bonus': pivot_bonus
|
|
}, self.performance_metrics['total_episodes'])
|
|
|
|
self.tensorboard_logger.log_scalar('Rewards/Enhanced', enhanced_reward, self.performance_metrics['total_episodes'])
|
|
|
|
logger.debug(f"Enhanced reward calculated: {enhanced_reward}")
|
|
return enhanced_reward
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating enhanced pivot reward: {e}")
|
|
return 0.0
|
|
|
|
async def make_coordinated_decisions(self) -> Dict[str, TradingAction]:
|
|
"""
|
|
Make coordinated trading decisions using all available models
|
|
|
|
Returns:
|
|
Dict[str, TradingAction]: Trading actions for each symbol
|
|
"""
|
|
try:
|
|
decisions = {}
|
|
|
|
# For each symbol, coordinate model inference
|
|
for symbol in self.symbols:
|
|
# Build comprehensive state for RL model
|
|
rl_state = self.build_comprehensive_rl_state(symbol)
|
|
|
|
if rl_state is not None:
|
|
# Store state for training
|
|
self.performance_metrics['total_episodes'] += 1
|
|
|
|
# Create mock RL decision (in a real implementation, this would call the RL model)
|
|
action = 'BUY' if np.mean(rl_state[:100]) > 0.5 else 'SELL'
|
|
confidence = min(1.0, max(0.0, np.std(rl_state) * 10))
|
|
|
|
# Create trading action
|
|
decisions[symbol] = TradingAction(
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
action=action,
|
|
confidence=confidence,
|
|
source='rl_orchestrator'
|
|
)
|
|
|
|
logger.info(f"Coordinated decision for {symbol}: {action} (confidence: {confidence:.3f})")
|
|
else:
|
|
logger.warning(f"Failed to build state for {symbol}, skipping decision")
|
|
|
|
self.performance_metrics['inference_count'] += 1
|
|
return decisions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error making coordinated decisions: {e}")
|
|
return {}
|
|
|
|
def _get_symbol_correlation(self, symbol1: str, symbol2: str) -> float:
|
|
"""
|
|
Calculate correlation between two symbols
|
|
|
|
Args:
|
|
symbol1: First symbol
|
|
symbol2: Second symbol
|
|
|
|
Returns:
|
|
float: Correlation coefficient (-1 to 1)
|
|
"""
|
|
try:
|
|
# Get recent price data for both symbols
|
|
data1 = self.data_provider.get_latest_candles(symbol1, '1m', limit=50)
|
|
data2 = self.data_provider.get_latest_candles(symbol2, '1m', limit=50)
|
|
|
|
if data1 is None or data2 is None or data1.empty or data2.empty:
|
|
return 0.0
|
|
|
|
# Align data by timestamp
|
|
merged = data1[['close']].join(data2[['close']], lsuffix='_1', rsuffix='_2', how='inner')
|
|
|
|
if len(merged) < 10:
|
|
return 0.0
|
|
|
|
# Calculate correlation
|
|
correlation = merged['close_1'].corr(merged['close_2'])
|
|
return correlation if not np.isnan(correlation) else 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating symbol correlation: {e}")
|
|
return 0.0
|
|
``` |