""" Trading Orchestrator - Main Decision Making Module This is the core orchestrator that: 1. Coordinates CNN and RL modules via model registry 2. Combines their outputs with confidence weighting 3. Makes final trading decisions (BUY/SELL/HOLD) 4. Manages the learning loop between components 5. Ensures memory efficiency (8GB constraint) """ import asyncio import logging import time import numpy as np from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from .config import get_config from .data_provider import DataProvider from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface logger = logging.getLogger(__name__) @dataclass class Prediction: """Represents a prediction from a model""" action: str # 'BUY', 'SELL', 'HOLD' confidence: float # 0.0 to 1.0 probabilities: Dict[str, float] # Probabilities for each action timeframe: str # Timeframe this prediction is for timestamp: datetime model_name: str # Name of the model that made this prediction metadata: Dict[str, Any] = None # Additional model-specific data @dataclass class TradingDecision: """Final trading decision from the orchestrator""" action: str # 'BUY', 'SELL', 'HOLD' confidence: float # Combined confidence symbol: str price: float timestamp: datetime reasoning: Dict[str, Any] # Why this decision was made memory_usage: Dict[str, int] # Memory usage of models class TradingOrchestrator: """ Main orchestrator that coordinates multiple AI models for trading decisions """ def __init__(self, data_provider: DataProvider = None): """Initialize the orchestrator""" self.config = get_config() self.data_provider = data_provider or DataProvider() self.model_registry = get_model_registry() # Configuration self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.5) self.decision_frequency = self.config.orchestrator.get('decision_frequency', 60) # Dynamic weights (will be adapted based on performance) self.model_weights = {} # {model_name: weight} self._initialize_default_weights() # State tracking self.last_decision_time = {} # {symbol: datetime} self.recent_decisions = {} # {symbol: List[TradingDecision]} self.model_performance = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}} # Decision callbacks self.decision_callbacks = [] logger.info("TradingOrchestrator initialized with modular model system") logger.info(f"Confidence threshold: {self.confidence_threshold}") logger.info(f"Decision frequency: {self.decision_frequency}s") def _initialize_default_weights(self): """Initialize default model weights from config""" self.model_weights = { 'CNN': self.config.orchestrator.get('cnn_weight', 0.7), 'RL': self.config.orchestrator.get('rl_weight', 0.3) } def register_model(self, model: ModelInterface, weight: float = None) -> bool: """Register a new model with the orchestrator""" try: # Register with model registry if not self.model_registry.register_model(model): return False # Set weight if weight is not None: self.model_weights[model.name] = weight elif model.name not in self.model_weights: self.model_weights[model.name] = 0.1 # Default low weight for new models # Initialize performance tracking if model.name not in self.model_performance: self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0} logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}") self._normalize_weights() return True except Exception as e: logger.error(f"Error registering model {model.name}: {e}") return False def unregister_model(self, model_name: str) -> bool: """Unregister a model""" try: if self.model_registry.unregister_model(model_name): if model_name in self.model_weights: del self.model_weights[model_name] if model_name in self.model_performance: del self.model_performance[model_name] self._normalize_weights() logger.info(f"Unregistered {model_name} model") return True return False except Exception as e: logger.error(f"Error unregistering model {model_name}: {e}") return False def _normalize_weights(self): """Normalize model weights to sum to 1.0""" total_weight = sum(self.model_weights.values()) if total_weight > 0: for model_name in self.model_weights: self.model_weights[model_name] /= total_weight def add_decision_callback(self, callback): """Add a callback function to be called when decisions are made""" self.decision_callbacks.append(callback) async def make_trading_decision(self, symbol: str) -> Optional[TradingDecision]: """ Make a trading decision for a symbol by combining all registered model outputs """ try: current_time = datetime.now() # Check if enough time has passed since last decision if symbol in self.last_decision_time: time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds() if time_since_last < self.decision_frequency: return None # Get current market data current_price = self.data_provider.get_current_price(symbol) if current_price is None: logger.warning(f"No current price available for {symbol}") return None # Get predictions from all registered models predictions = await self._get_all_predictions(symbol) if not predictions: logger.warning(f"No predictions available for {symbol}") return None # Combine predictions decision = self._combine_predictions( symbol=symbol, price=current_price, predictions=predictions, timestamp=current_time ) # Update state self.last_decision_time[symbol] = current_time if symbol not in self.recent_decisions: self.recent_decisions[symbol] = [] self.recent_decisions[symbol].append(decision) # Keep only recent decisions (last 100) if len(self.recent_decisions[symbol]) > 100: self.recent_decisions[symbol] = self.recent_decisions[symbol][-100:] # Call decision callbacks for callback in self.decision_callbacks: try: await callback(decision) except Exception as e: logger.error(f"Error in decision callback: {e}") # Clean up memory periodically if len(self.recent_decisions[symbol]) % 50 == 0: self.model_registry.cleanup_all_models() return decision except Exception as e: logger.error(f"Error making trading decision for {symbol}: {e}") return None async def _get_all_predictions(self, symbol: str) -> List[Prediction]: """Get predictions from all registered models""" predictions = [] for model_name, model in self.model_registry.models.items(): try: if isinstance(model, CNNModelInterface): # Get CNN predictions for each timeframe cnn_predictions = await self._get_cnn_predictions(model, symbol) predictions.extend(cnn_predictions) elif isinstance(model, RLAgentInterface): # Get RL prediction rl_prediction = await self._get_rl_prediction(model, symbol) if rl_prediction: predictions.append(rl_prediction) else: # Generic model interface generic_prediction = await self._get_generic_prediction(model, symbol) if generic_prediction: predictions.append(generic_prediction) except Exception as e: logger.error(f"Error getting prediction from {model_name}: {e}") continue return predictions async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]: """Get predictions from CNN model for all timeframes""" predictions = [] try: for timeframe in self.config.timeframes: # Get feature matrix for this timeframe feature_matrix = self.data_provider.get_feature_matrix( symbol=symbol, timeframes=[timeframe], window_size=model.window_size ) if feature_matrix is not None: # Get CNN prediction try: action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe) except AttributeError: # Fallback to generic predict method action_probs, confidence = model.predict(feature_matrix) if action_probs is not None: # Convert to prediction object action_names = ['SELL', 'HOLD', 'BUY'] best_action_idx = np.argmax(action_probs) best_action = action_names[best_action_idx] prediction = Prediction( action=best_action, confidence=float(confidence) if confidence is not None else float(action_probs[best_action_idx]), probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)}, timeframe=timeframe, timestamp=datetime.now(), model_name=model.name, metadata={'timeframe_specific': True} ) predictions.append(prediction) except Exception as e: logger.error(f"Error getting CNN predictions: {e}") return predictions async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]: """Get prediction from RL agent""" try: # Get current state for RL agent state = self._get_rl_state(symbol) if state is None: return None # Get RL agent's action and confidence action_idx, confidence = model.act_with_confidence(state) action_names = ['SELL', 'HOLD', 'BUY'] action = action_names[action_idx] # Create prediction object prediction = Prediction( action=action, confidence=float(confidence), probabilities={action: float(confidence), 'HOLD': 1.0 - float(confidence)}, timeframe='mixed', # RL uses mixed timeframes timestamp=datetime.now(), model_name=model.name, metadata={'state_size': len(state)} ) return prediction except Exception as e: logger.error(f"Error getting RL prediction: {e}") return None async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]: """Get prediction from generic model""" try: # Get feature matrix for the model feature_matrix = self.data_provider.get_feature_matrix( symbol=symbol, timeframes=self.config.timeframes[:3], # Use first 3 timeframes window_size=20 ) if feature_matrix is not None: action_probs, confidence = model.predict(feature_matrix) if action_probs is not None: action_names = ['SELL', 'HOLD', 'BUY'] best_action_idx = np.argmax(action_probs) best_action = action_names[best_action_idx] prediction = Prediction( action=best_action, confidence=float(confidence), probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)}, timeframe='mixed', timestamp=datetime.now(), model_name=model.name, metadata={'generic_model': True} ) return prediction return None except Exception as e: logger.error(f"Error getting generic prediction: {e}") return None def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]: """Get current state for RL agent""" try: # Get feature matrix for all timeframes feature_matrix = self.data_provider.get_feature_matrix( symbol=symbol, timeframes=self.config.timeframes, window_size=self.config.rl.get('window_size', 20) ) if feature_matrix is not None: # Flatten the feature matrix for RL agent # Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,) state = feature_matrix.flatten() # Add additional state information (position, balance, etc.) # This would come from a portfolio manager in a real implementation additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl] return np.concatenate([state, additional_state]) return None except Exception as e: logger.error(f"Error creating RL state for {symbol}: {e}") return None def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction], timestamp: datetime) -> TradingDecision: """Combine all predictions into a final decision""" try: reasoning = { 'predictions': len(predictions), 'weights': self.model_weights.copy(), 'models_used': [pred.model_name for pred in predictions] } # Initialize action scores action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0} total_weight = 0.0 # Process all predictions for pred in predictions: # Get model weight model_weight = self.model_weights.get(pred.model_name, 0.1) # Weight by confidence and timeframe importance timeframe_weight = self._get_timeframe_weight(pred.timeframe) weighted_confidence = pred.confidence * timeframe_weight * model_weight action_scores[pred.action] += weighted_confidence total_weight += weighted_confidence # Normalize scores if total_weight > 0: for action in action_scores: action_scores[action] /= total_weight # Choose best action best_action = max(action_scores, key=action_scores.get) best_confidence = action_scores[best_action] # Apply confidence threshold if best_confidence < self.confidence_threshold: best_action = 'HOLD' reasoning['threshold_applied'] = True # Get memory usage stats memory_usage = self.model_registry.get_memory_stats() # Create final decision decision = TradingDecision( action=best_action, confidence=best_confidence, symbol=symbol, price=price, timestamp=timestamp, reasoning=reasoning, memory_usage=memory_usage['models'] ) logger.info(f"Decision for {symbol}: {best_action} (confidence: {best_confidence:.3f})") logger.debug(f"Memory usage: {memory_usage['total_used_mb']:.1f}MB / {memory_usage['total_limit_mb']:.1f}MB") return decision except Exception as e: logger.error(f"Error combining predictions for {symbol}: {e}") # Return safe default return TradingDecision( action='HOLD', confidence=0.0, symbol=symbol, price=price, timestamp=timestamp, reasoning={'error': str(e)}, memory_usage={} ) def _get_timeframe_weight(self, timeframe: str) -> float: """Get importance weight for a timeframe""" # Higher timeframes get more weight in decision making weights = { '1m': 0.1, '5m': 0.2, '15m': 0.3, '30m': 0.4, '1h': 0.6, '4h': 0.8, '1d': 1.0 } return weights.get(timeframe, 0.5) def update_model_performance(self, model_name: str, was_correct: bool): """Update performance tracking for a model""" if model_name in self.model_performance: self.model_performance[model_name]['total'] += 1 if was_correct: self.model_performance[model_name]['correct'] += 1 # Update accuracy total = self.model_performance[model_name]['total'] correct = self.model_performance[model_name]['correct'] self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0 def adapt_weights(self): """Dynamically adapt model weights based on performance""" try: for model_name, performance in self.model_performance.items(): if performance['total'] > 0: # Adjust weight based on relative performance accuracy = performance['correct'] / performance['total'] self.model_weights[model_name] = accuracy logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}") except Exception as e: logger.error(f"Error adapting weights: {e}") def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]: """Get recent decisions for a symbol""" if symbol in self.recent_decisions: return self.recent_decisions[symbol][-limit:] return [] def get_performance_metrics(self) -> Dict[str, Any]: """Get performance metrics for the orchestrator""" return { 'model_performance': self.model_performance.copy(), 'weights': self.model_weights.copy(), 'configuration': { 'confidence_threshold': self.confidence_threshold, 'decision_frequency': self.decision_frequency }, 'recent_activity': { symbol: len(decisions) for symbol, decisions in self.recent_decisions.items() } } async def start_continuous_trading(self, symbols: List[str] = None): """Start continuous trading decisions for specified symbols""" if symbols is None: symbols = self.config.symbols logger.info(f"Starting continuous trading for symbols: {symbols}") while True: try: # Make decisions for all symbols for symbol in symbols: decision = await self.make_trading_decision(symbol) if decision and decision.action != 'HOLD': logger.info(f"Trading decision: {decision.action} {symbol} at {decision.price}") # Wait before next decision cycle await asyncio.sleep(self.decision_frequency) except Exception as e: logger.error(f"Error in continuous trading loop: {e}") await asyncio.sleep(10) # Wait before retrying