gogo2/NN/neural_network_orchestrator.py
2025-03-31 14:22:33 +03:00

287 lines
11 KiB
Python

import logging
import threading
import time
from typing import Dict, Any, List, Optional, Callable, Tuple
import os
import numpy as np
import pandas as pd
from .trading_agent import TradingAgent
logger = logging.getLogger(__name__)
class NeuralNetworkOrchestrator:
"""Orchestrator for neural network models and trading operations.
This class coordinates between neural network models and trading agents,
ensuring that signals from the models are properly processed and trades
are executed according to the strategy.
"""
def __init__(self, model, data_interface, chart=None,
symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20,
num_features: int = 5,
output_size: int = 3,
models_dir: str = "NN/models/saved",
data_dir: str = "NN/data",
exchange_config: Dict[str, Any] = None):
"""Initialize the neural network orchestrator.
Args:
model: Neural network model instance
data_interface: Data interface for retrieving market data
chart: Real-time chart for visualization (optional)
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
window_size: Window size for model input
num_features: Number of features per datapoint
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
models_dir: Directory for saved models
data_dir: Directory for data storage
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
"""
self.model = model
self.data_interface = data_interface
self.chart = chart
self.symbols = symbols or ["BTC/USDT"]
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
self.window_size = window_size
self.num_features = num_features
self.output_size = output_size
self.models_dir = models_dir
self.data_dir = data_dir
# Initialize trading agent if configuration provided
self.trading_agent = None
if exchange_config:
self.init_trading_agent(exchange_config)
# Initialize inference state
self.is_running = False
self.inference_thread = None
self.stop_event = threading.Event()
self.last_inference_time = 0
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
logger.info(f"- Symbol: {self.symbols[0]}")
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
logger.info(f"- Window size: {window_size}")
logger.info(f"- Num features: {num_features}")
logger.info(f"- Output size: {output_size}")
logger.info(f"- Models dir: {models_dir}")
logger.info(f"- Data dir: {data_dir}")
logger.info(f"- Inference interval: {self.inference_interval} seconds")
def init_trading_agent(self, config: Dict[str, Any]):
"""Initialize the trading agent with the given configuration.
Args:
config: Configuration for the trading agent
"""
exchange_name = config.get("exchange", "binance")
api_key = config.get("api_key")
api_secret = config.get("api_secret")
test_mode = config.get("test_mode", True)
trade_symbols = config.get("trade_symbols", self.symbols)
position_size = config.get("position_size", 0.1)
max_trades_per_day = config.get("max_trades_per_day", 5)
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
self.trading_agent = TradingAgent(
exchange_name=exchange_name,
api_key=api_key,
api_secret=api_secret,
test_mode=test_mode,
trade_symbols=trade_symbols,
position_size=position_size,
max_trades_per_day=max_trades_per_day,
trade_cooldown_minutes=trade_cooldown_minutes
)
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
def start_inference(self):
"""Start the inference thread."""
if self.is_running:
logger.warning("Neural network inference is already running.")
return
self.is_running = True
self.stop_event.clear()
# Start inference thread
self.inference_thread = threading.Thread(target=self._inference_loop)
self.inference_thread.daemon = True
self.inference_thread.start()
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
# Start trading agent if available
if self.trading_agent:
self.trading_agent.start(signal_callback=self._on_trade_executed)
def stop_inference(self):
"""Stop the inference thread."""
if not self.is_running:
logger.warning("Neural network inference is not running.")
return
logger.info("Stopping neural network inference...")
self.is_running = False
self.stop_event.set()
if self.inference_thread and self.inference_thread.is_alive():
self.inference_thread.join(timeout=10)
logger.info("Neural network inference stopped.")
# Stop trading agent if available
if self.trading_agent:
self.trading_agent.stop()
def _inference_loop(self):
"""Main inference loop that processes data and generates signals."""
logger.info("Inference loop started.")
try:
while self.is_running and not self.stop_event.is_set():
current_time = time.time()
# Check if we should run inference
if current_time - self.last_inference_time >= self.inference_interval:
try:
# Run inference for all symbols
for symbol in self.symbols:
prediction = self._run_inference(symbol)
if prediction:
self._process_prediction(symbol, prediction)
self.last_inference_time = current_time
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
# Sleep for a short time to prevent CPU hogging
time.sleep(1)
except Exception as e:
logger.error(f"Error in inference loop: {str(e)}")
finally:
logger.info("Inference loop stopped.")
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
"""Run inference for a specific symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
tuple: (action probabilities, current price) or None if inference failed
"""
try:
# Get the model timeframe from environment
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
if model_timeframe not in self.timeframes:
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
model_timeframe = self.timeframes[0]
# Load candles for the model timeframe
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
candles = self.data_interface.get_historical_data(
symbol=symbol,
timeframe=model_timeframe,
n_candles=1000
)
if candles is None or len(candles) < self.window_size:
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
return None
# Prepare input data
X, timestamp = self.data_interface.prepare_model_input(
data=candles,
window_size=self.window_size,
symbol=symbol
)
if X is None:
logger.warning(f"Failed to prepare model input for {symbol}.")
return None
# Get current price
current_price = candles['close'].iloc[-1]
# Run model inference
action_probs, price_pred = self.model.predict(X)
return action_probs, current_price
except Exception as e:
logger.error(f"Error running inference for {symbol}: {str(e)}")
return None
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
"""Process a prediction and generate signals.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
prediction: Tuple of (action probabilities, current price)
"""
action_probs, current_price = prediction
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
best_action = np.argmax(action_probs)
best_prob = float(action_probs[best_action])
# Convert to action name
action_names = ["SELL", "HOLD", "BUY"]
action_name = action_names[best_action]
# Log the prediction
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
# Add signal to chart if available
if self.chart:
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
# Process signal with trading agent if available
if self.trading_agent:
self.trading_agent.process_signal(
symbol=symbol,
action=action_name,
confidence=best_prob,
timestamp=int(time.time())
)
def _on_trade_executed(self, trade_record: Dict[str, Any]):
"""Callback for when a trade is executed.
Args:
trade_record: Trade information
"""
if self.chart and trade_record:
# Add trade to chart
self.chart.add_trade(
action=trade_record['action'],
price=trade_record.get('price', 0),
timestamp=trade_record['timestamp'],
pnl=trade_record.get('pnl', 0)
)
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
def get_trading_agent_info(self) -> Dict[str, Any]:
"""Get information about the trading agent.
Returns:
dict: Trading agent information or None if no agent is available
"""
if self.trading_agent:
return {
'exchange_info': self.trading_agent.get_exchange_info(),
'positions': self.trading_agent.get_current_positions(),
'trades': len(self.trading_agent.get_trade_history())
}
return None