import logging import threading import time from typing import Dict, Any, List, Optional, Callable, Tuple import os import numpy as np import pandas as pd from .trading_agent import TradingAgent logger = logging.getLogger(__name__) class NeuralNetworkOrchestrator: """Orchestrator for neural network models and trading operations. This class coordinates between neural network models and trading agents, ensuring that signals from the models are properly processed and trades are executed according to the strategy. """ def __init__(self, model, data_interface, chart=None, symbols: List[str] = None, timeframes: List[str] = None, window_size: int = 20, num_features: int = 5, output_size: int = 3, models_dir: str = "NN/models/saved", data_dir: str = "NN/data", exchange_config: Dict[str, Any] = None): """Initialize the neural network orchestrator. Args: model: Neural network model instance data_interface: Data interface for retrieving market data chart: Real-time chart for visualization (optional) symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT']) timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h']) window_size: Window size for model input num_features: Number of features per datapoint output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL) models_dir: Directory for saved models data_dir: Directory for data storage exchange_config: Configuration for trading agent (exchange, API keys, etc.) """ self.model = model self.data_interface = data_interface self.chart = chart self.symbols = symbols or ["BTC/USDT"] self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"] self.window_size = window_size self.num_features = num_features self.output_size = output_size self.models_dir = models_dir self.data_dir = data_dir # Initialize trading agent if configuration provided self.trading_agent = None if exchange_config: self.init_trading_agent(exchange_config) # Initialize inference state self.is_running = False self.inference_thread = None self.stop_event = threading.Event() self.last_inference_time = 0 self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60")) logger.info(f"Initializing NeuralNetworkOrchestrator with:") logger.info(f"- Symbol: {self.symbols[0]}") logger.info(f"- Timeframes: {', '.join(self.timeframes)}") logger.info(f"- Window size: {window_size}") logger.info(f"- Num features: {num_features}") logger.info(f"- Output size: {output_size}") logger.info(f"- Models dir: {models_dir}") logger.info(f"- Data dir: {data_dir}") logger.info(f"- Inference interval: {self.inference_interval} seconds") def init_trading_agent(self, config: Dict[str, Any]): """Initialize the trading agent with the given configuration. Args: config: Configuration for the trading agent """ exchange_name = config.get("exchange", "binance") api_key = config.get("api_key") api_secret = config.get("api_secret") test_mode = config.get("test_mode", True) trade_symbols = config.get("trade_symbols", self.symbols) position_size = config.get("position_size", 0.1) max_trades_per_day = config.get("max_trades_per_day", 5) trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60) self.trading_agent = TradingAgent( exchange_name=exchange_name, api_key=api_key, api_secret=api_secret, test_mode=test_mode, trade_symbols=trade_symbols, position_size=position_size, max_trades_per_day=max_trades_per_day, trade_cooldown_minutes=trade_cooldown_minutes ) logger.info(f"Trading agent initialized for {exchange_name} exchange.") def start_inference(self): """Start the inference thread.""" if self.is_running: logger.warning("Neural network inference is already running.") return self.is_running = True self.stop_event.clear() # Start inference thread self.inference_thread = threading.Thread(target=self._inference_loop) self.inference_thread.daemon = True self.inference_thread.start() logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.") # Start trading agent if available if self.trading_agent: self.trading_agent.start(signal_callback=self._on_trade_executed) def stop_inference(self): """Stop the inference thread.""" if not self.is_running: logger.warning("Neural network inference is not running.") return logger.info("Stopping neural network inference...") self.is_running = False self.stop_event.set() if self.inference_thread and self.inference_thread.is_alive(): self.inference_thread.join(timeout=10) logger.info("Neural network inference stopped.") # Stop trading agent if available if self.trading_agent: self.trading_agent.stop() def _inference_loop(self): """Main inference loop that processes data and generates signals.""" logger.info("Inference loop started.") try: while self.is_running and not self.stop_event.is_set(): current_time = time.time() # Check if we should run inference if current_time - self.last_inference_time >= self.inference_interval: try: # Run inference for all symbols for symbol in self.symbols: prediction = self._run_inference(symbol) if prediction: self._process_prediction(symbol, prediction) self.last_inference_time = current_time except Exception as e: logger.error(f"Error during inference: {str(e)}") # Sleep for a short time to prevent CPU hogging time.sleep(1) except Exception as e: logger.error(f"Error in inference loop: {str(e)}") finally: logger.info("Inference loop stopped.") def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]: """Run inference for a specific symbol. Args: symbol: Trading symbol (e.g., 'BTC/USDT') Returns: tuple: (action probabilities, current price) or None if inference failed """ try: # Get the model timeframe from environment model_timeframe = os.environ.get("NN_TIMEFRAME", "1h") if model_timeframe not in self.timeframes: logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.") model_timeframe = self.timeframes[0] # Load candles for the model timeframe logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe") candles = self.data_interface.get_historical_data( symbol=symbol, timeframe=model_timeframe, n_candles=1000 ) if candles is None or len(candles) < self.window_size: logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.") return None # Prepare input data X, timestamp = self.data_interface.prepare_model_input( data=candles, window_size=self.window_size, symbol=symbol ) if X is None: logger.warning(f"Failed to prepare model input for {symbol}.") return None # Get current price current_price = candles['close'].iloc[-1] # Run model inference action_probs, price_pred = self.model.predict(X) return action_probs, current_price except Exception as e: logger.error(f"Error running inference for {symbol}: {str(e)}") return None def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]): """Process a prediction and generate signals. Args: symbol: Trading symbol (e.g., 'BTC/USDT') prediction: Tuple of (action probabilities, current price) """ action_probs, current_price = prediction # Get the best action (0=SELL, 1=HOLD, 2=BUY) best_action = np.argmax(action_probs) best_prob = float(action_probs[best_action]) # Convert to action name action_names = ["SELL", "HOLD", "BUY"] action_name = action_names[best_action] # Log the prediction logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}") # Add signal to chart if available if self.chart: self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time())) # Process signal with trading agent if available if self.trading_agent: self.trading_agent.process_signal( symbol=symbol, action=action_name, confidence=best_prob, timestamp=int(time.time()) ) def _on_trade_executed(self, trade_record: Dict[str, Any]): """Callback for when a trade is executed. Args: trade_record: Trade information """ if self.chart and trade_record: # Add trade to chart self.chart.add_trade( action=trade_record['action'], price=trade_record.get('price', 0), timestamp=trade_record['timestamp'], pnl=trade_record.get('pnl', 0) ) logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}") def get_trading_agent_info(self) -> Dict[str, Any]: """Get information about the trading agent. Returns: dict: Trading agent information or None if no agent is available """ if self.trading_agent: return { 'exchange_info': self.trading_agent.get_exchange_info(), 'positions': self.trading_agent.get_current_positions(), 'trades': len(self.trading_agent.get_trade_history()) } return None