more models wireup
This commit is contained in:
@ -58,20 +58,23 @@ class TradingDecision:
|
||||
|
||||
class TradingOrchestrator:
|
||||
"""
|
||||
Main orchestrator that coordinates multiple AI models for trading decisions
|
||||
Enhanced Trading Orchestrator with full ML and COB integration
|
||||
Coordinates CNN, DQN, and COB models for advanced trading decisions
|
||||
Features real-time COB (Change of Bid) integration for market microstructure data
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider = None):
|
||||
"""Initialize the orchestrator with COB integration"""
|
||||
def __init__(self, data_provider: DataProvider = None, enhanced_rl_training: bool = True, model_registry: Dict = None):
|
||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.model_registry = get_model_registry()
|
||||
self.model_registry = model_registry or get_model_registry()
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
# Configuration
|
||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.5)
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 60)
|
||||
self.symbols = self.config.get('symbols', ['ETH/USDT']) # Default symbols to trade
|
||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.20)
|
||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.10)
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||||
|
||||
# Dynamic weights (will be adapted based on performance)
|
||||
self.model_weights = {} # {model_name: weight}
|
||||
@ -92,22 +95,85 @@ class TradingOrchestrator:
|
||||
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
||||
self.cob_feature_history: Dict[str, List] = {symbol: [] for symbol in self.symbols} # Rolling history for models
|
||||
|
||||
logger.info("TradingOrchestrator initialized with modular model system")
|
||||
# Enhanced ML Models
|
||||
self.rl_agent = None # DQN Agent
|
||||
self.cnn_model = None # CNN Model for pattern recognition
|
||||
self.extrema_trainer = None # Extrema/pivot trainer
|
||||
self.latest_cnn_features: Dict[str, Any] = {} # CNN hidden features
|
||||
self.latest_cnn_predictions: Dict[str, Any] = {} # CNN predictions
|
||||
|
||||
# Enhanced RL features
|
||||
self.sensitivity_learning_queue = [] # For outcome-based learning
|
||||
self.perfect_move_buffer = [] # Buffer for perfect move analysis
|
||||
self.position_status = {} # Current positions
|
||||
|
||||
# Real-time processing
|
||||
self.realtime_processing = False
|
||||
self.realtime_tasks = []
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||
logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
# Initialize COB integration
|
||||
# Initialize models and COB integration
|
||||
self._initialize_ml_models()
|
||||
self._initialize_cob_integration()
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
logger.info("Initializing ML models...")
|
||||
|
||||
# Initialize DQN Agent
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
|
||||
action_size = self.config.rl.get('action_space', 3)
|
||||
self.rl_agent = DQNAgent(state_size=state_size, action_size=action_size)
|
||||
logger.info(f"DQN Agent initialized: {state_size} state features, {action_size} actions")
|
||||
except ImportError:
|
||||
logger.warning("DQN Agent not available")
|
||||
self.rl_agent = None
|
||||
|
||||
# Initialize CNN Model
|
||||
try:
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
self.cnn_model = EnhancedCNN()
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
self.cnn_model = CNNModel()
|
||||
logger.info("Basic CNN model initialized")
|
||||
except ImportError:
|
||||
logger.warning("CNN model not available")
|
||||
self.cnn_model = None
|
||||
|
||||
# Initialize Extrema Trainer
|
||||
try:
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
)
|
||||
logger.info("Extrema trainer initialized")
|
||||
except ImportError:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
logger.info("ML models initialization completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
|
||||
def _initialize_cob_integration(self):
|
||||
"""Initialize real-time COB integration for market microstructure data"""
|
||||
try:
|
||||
if COB_INTEGRATION_AVAILABLE:
|
||||
# Initialize COB integration with our symbols
|
||||
self.cob_integration = COBIntegration(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
)
|
||||
self.cob_integration = COBIntegration(data_provider=self.data_provider, symbols=self.symbols )
|
||||
|
||||
# Register callbacks to receive real-time COB data
|
||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
||||
@ -116,9 +182,8 @@ class TradingOrchestrator:
|
||||
|
||||
logger.info("COB Integration initialized - real-time market microstructure data available")
|
||||
logger.info(f"COB symbols: {self.symbols}")
|
||||
|
||||
# Start COB integration in background
|
||||
asyncio.create_task(self._start_cob_integration())
|
||||
|
||||
# COB integration will be started manually via start_cob_integration()
|
||||
else:
|
||||
logger.warning("COB Integration not available - models will use basic price data only")
|
||||
|
||||
@ -1177,4 +1242,165 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating fallback prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# Enhanced Orchestrator Methods
|
||||
|
||||
async def start_cob_integration(self):
|
||||
"""Start COB integration manually"""
|
||||
try:
|
||||
if self.cob_integration:
|
||||
await self._start_cob_integration()
|
||||
logger.info("COB Integration started successfully")
|
||||
else:
|
||||
logger.warning("COB Integration not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting COB integration: {e}")
|
||||
|
||||
async def stop_cob_integration(self):
|
||||
"""Stop COB integration"""
|
||||
try:
|
||||
if self.cob_integration:
|
||||
await self.cob_integration.stop()
|
||||
logger.info("COB Integration stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping COB integration: {e}")
|
||||
|
||||
async def start_realtime_processing(self):
|
||||
"""Start real-time processing"""
|
||||
try:
|
||||
self.realtime_processing = True
|
||||
logger.info("Real-time processing started")
|
||||
|
||||
# Start background tasks for real-time processing
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._realtime_processing_loop(symbol))
|
||||
self.realtime_tasks.append(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
async def stop_realtime_processing(self):
|
||||
"""Stop real-time processing"""
|
||||
try:
|
||||
self.realtime_processing = False
|
||||
|
||||
# Cancel all background tasks
|
||||
for task in self.realtime_tasks:
|
||||
task.cancel()
|
||||
self.realtime_tasks = []
|
||||
|
||||
logger.info("Real-time processing stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
||||
|
||||
async def _realtime_processing_loop(self, symbol: str):
|
||||
"""Real-time processing loop for a symbol"""
|
||||
while self.realtime_processing:
|
||||
try:
|
||||
# Update CNN features
|
||||
await self._update_cnn_features(symbol)
|
||||
|
||||
# Update RL state
|
||||
await self._update_rl_state(symbol)
|
||||
|
||||
# Sleep between updates
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in real-time processing for {symbol}: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _update_cnn_features(self, symbol: str):
|
||||
"""Update CNN features for a symbol"""
|
||||
try:
|
||||
if self.cnn_model and hasattr(self.cnn_model, 'extract_features'):
|
||||
# Get current market data
|
||||
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
|
||||
if df is not None and not df.empty:
|
||||
# Generate CNN features
|
||||
features = self.cnn_model.extract_features(df)
|
||||
if features is not None:
|
||||
self.latest_cnn_features[symbol] = features
|
||||
|
||||
# Generate CNN predictions
|
||||
if hasattr(self.cnn_model, 'predict'):
|
||||
predictions = self.cnn_model.predict(df)
|
||||
if predictions is not None:
|
||||
self.latest_cnn_predictions[symbol] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating CNN features for {symbol}: {e}")
|
||||
|
||||
async def _update_rl_state(self, symbol: str):
|
||||
"""Update RL state for a symbol"""
|
||||
try:
|
||||
if self.rl_agent:
|
||||
# Build comprehensive RL state
|
||||
rl_state = self.build_comprehensive_rl_state(symbol)
|
||||
if rl_state and hasattr(self.rl_agent, 'remember'):
|
||||
# Store for training
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating RL state for {symbol}: {e}")
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, Any]:
|
||||
"""Make coordinated trading decisions for all symbols"""
|
||||
decisions = {}
|
||||
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
decision = await self.make_trading_decision(symbol)
|
||||
decisions[symbol] = decision
|
||||
|
||||
return decisions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making coordinated decisions: {e}")
|
||||
return {}
|
||||
|
||||
def get_position_status(self) -> Dict[str, Any]:
|
||||
"""Get current position status"""
|
||||
return self.position_status.copy()
|
||||
|
||||
def cleanup_all_models(self):
|
||||
"""Cleanup all models"""
|
||||
try:
|
||||
if hasattr(self.model_registry, 'cleanup_all_models'):
|
||||
self.model_registry.cleanup_all_models()
|
||||
else:
|
||||
logger.debug("Model registry cleanup not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up models: {e}")
|
||||
|
||||
def _get_cnn_hidden_features_for_rl_enhanced(self, symbol: str) -> Optional[List[float]]:
|
||||
"""Get CNN hidden features for RL (enhanced version)"""
|
||||
try:
|
||||
cnn_features = self.latest_cnn_features.get(symbol)
|
||||
if cnn_features is not None:
|
||||
if hasattr(cnn_features, 'tolist'):
|
||||
return cnn_features.tolist()[:1000] # First 1000 features
|
||||
elif isinstance(cnn_features, list):
|
||||
return cnn_features[:1000]
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting CNN hidden features: {e}")
|
||||
return None
|
||||
|
||||
def _get_pivot_analysis_features_for_rl_enhanced(self, symbol: str) -> Optional[List[float]]:
|
||||
"""Get pivot analysis features for RL (enhanced version)"""
|
||||
try:
|
||||
if self.extrema_trainer and hasattr(self.extrema_trainer, 'get_context_features_for_model'):
|
||||
pivot_features = self.extrema_trainer.get_context_features_for_model(symbol)
|
||||
if pivot_features is not None:
|
||||
if hasattr(pivot_features, 'tolist'):
|
||||
return pivot_features.tolist()[:300] # First 300 features
|
||||
elif isinstance(pivot_features, list):
|
||||
return pivot_features[:300]
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting pivot analysis features: {e}")
|
||||
return None
|
Reference in New Issue
Block a user