287 lines
11 KiB
Python
287 lines
11 KiB
Python
import logging
|
|
import threading
|
|
import time
|
|
from typing import Dict, Any, List, Optional, Callable, Tuple
|
|
import os
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from .trading_agent import TradingAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class NeuralNetworkOrchestrator:
|
|
"""Orchestrator for neural network models and trading operations.
|
|
|
|
This class coordinates between neural network models and trading agents,
|
|
ensuring that signals from the models are properly processed and trades
|
|
are executed according to the strategy.
|
|
"""
|
|
|
|
def __init__(self, model, data_interface, chart=None,
|
|
symbols: List[str] = None,
|
|
timeframes: List[str] = None,
|
|
window_size: int = 20,
|
|
num_features: int = 5,
|
|
output_size: int = 3,
|
|
models_dir: str = "NN/models/saved",
|
|
data_dir: str = "NN/data",
|
|
exchange_config: Dict[str, Any] = None):
|
|
"""Initialize the neural network orchestrator.
|
|
|
|
Args:
|
|
model: Neural network model instance
|
|
data_interface: Data interface for retrieving market data
|
|
chart: Real-time chart for visualization (optional)
|
|
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
|
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
|
|
window_size: Window size for model input
|
|
num_features: Number of features per datapoint
|
|
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
|
|
models_dir: Directory for saved models
|
|
data_dir: Directory for data storage
|
|
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
|
|
"""
|
|
self.model = model
|
|
self.data_interface = data_interface
|
|
self.chart = chart
|
|
|
|
self.symbols = symbols or ["BTC/USDT"]
|
|
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
|
|
self.window_size = window_size
|
|
self.num_features = num_features
|
|
self.output_size = output_size
|
|
self.models_dir = models_dir
|
|
self.data_dir = data_dir
|
|
|
|
# Initialize trading agent if configuration provided
|
|
self.trading_agent = None
|
|
if exchange_config:
|
|
self.init_trading_agent(exchange_config)
|
|
|
|
# Initialize inference state
|
|
self.is_running = False
|
|
self.inference_thread = None
|
|
self.stop_event = threading.Event()
|
|
self.last_inference_time = 0
|
|
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
|
|
|
|
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
|
|
logger.info(f"- Symbol: {self.symbols[0]}")
|
|
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
|
|
logger.info(f"- Window size: {window_size}")
|
|
logger.info(f"- Num features: {num_features}")
|
|
logger.info(f"- Output size: {output_size}")
|
|
logger.info(f"- Models dir: {models_dir}")
|
|
logger.info(f"- Data dir: {data_dir}")
|
|
logger.info(f"- Inference interval: {self.inference_interval} seconds")
|
|
|
|
def init_trading_agent(self, config: Dict[str, Any]):
|
|
"""Initialize the trading agent with the given configuration.
|
|
|
|
Args:
|
|
config: Configuration for the trading agent
|
|
"""
|
|
exchange_name = config.get("exchange", "binance")
|
|
api_key = config.get("api_key")
|
|
api_secret = config.get("api_secret")
|
|
test_mode = config.get("test_mode", True)
|
|
trade_symbols = config.get("trade_symbols", self.symbols)
|
|
position_size = config.get("position_size", 0.1)
|
|
max_trades_per_day = config.get("max_trades_per_day", 5)
|
|
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
|
|
|
|
self.trading_agent = TradingAgent(
|
|
exchange_name=exchange_name,
|
|
api_key=api_key,
|
|
api_secret=api_secret,
|
|
test_mode=test_mode,
|
|
trade_symbols=trade_symbols,
|
|
position_size=position_size,
|
|
max_trades_per_day=max_trades_per_day,
|
|
trade_cooldown_minutes=trade_cooldown_minutes
|
|
)
|
|
|
|
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
|
|
|
|
def start_inference(self):
|
|
"""Start the inference thread."""
|
|
if self.is_running:
|
|
logger.warning("Neural network inference is already running.")
|
|
return
|
|
|
|
self.is_running = True
|
|
self.stop_event.clear()
|
|
|
|
# Start inference thread
|
|
self.inference_thread = threading.Thread(target=self._inference_loop)
|
|
self.inference_thread.daemon = True
|
|
self.inference_thread.start()
|
|
|
|
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
|
|
|
|
# Start trading agent if available
|
|
if self.trading_agent:
|
|
self.trading_agent.start(signal_callback=self._on_trade_executed)
|
|
|
|
def stop_inference(self):
|
|
"""Stop the inference thread."""
|
|
if not self.is_running:
|
|
logger.warning("Neural network inference is not running.")
|
|
return
|
|
|
|
logger.info("Stopping neural network inference...")
|
|
self.is_running = False
|
|
self.stop_event.set()
|
|
|
|
if self.inference_thread and self.inference_thread.is_alive():
|
|
self.inference_thread.join(timeout=10)
|
|
|
|
logger.info("Neural network inference stopped.")
|
|
|
|
# Stop trading agent if available
|
|
if self.trading_agent:
|
|
self.trading_agent.stop()
|
|
|
|
def _inference_loop(self):
|
|
"""Main inference loop that processes data and generates signals."""
|
|
logger.info("Inference loop started.")
|
|
|
|
try:
|
|
while self.is_running and not self.stop_event.is_set():
|
|
current_time = time.time()
|
|
|
|
# Check if we should run inference
|
|
if current_time - self.last_inference_time >= self.inference_interval:
|
|
try:
|
|
# Run inference for all symbols
|
|
for symbol in self.symbols:
|
|
prediction = self._run_inference(symbol)
|
|
if prediction:
|
|
self._process_prediction(symbol, prediction)
|
|
|
|
self.last_inference_time = current_time
|
|
except Exception as e:
|
|
logger.error(f"Error during inference: {str(e)}")
|
|
|
|
# Sleep for a short time to prevent CPU hogging
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logger.error(f"Error in inference loop: {str(e)}")
|
|
finally:
|
|
logger.info("Inference loop stopped.")
|
|
|
|
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
|
|
"""Run inference for a specific symbol.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
|
|
|
Returns:
|
|
tuple: (action probabilities, current price) or None if inference failed
|
|
"""
|
|
try:
|
|
# Get the model timeframe from environment
|
|
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
|
|
if model_timeframe not in self.timeframes:
|
|
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
|
|
model_timeframe = self.timeframes[0]
|
|
|
|
# Load candles for the model timeframe
|
|
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
|
|
candles = self.data_interface.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=model_timeframe,
|
|
n_candles=1000
|
|
)
|
|
|
|
if candles is None or len(candles) < self.window_size:
|
|
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
|
|
return None
|
|
|
|
# Prepare input data
|
|
X, timestamp = self.data_interface.prepare_model_input(
|
|
data=candles,
|
|
window_size=self.window_size,
|
|
symbol=symbol
|
|
)
|
|
|
|
if X is None:
|
|
logger.warning(f"Failed to prepare model input for {symbol}.")
|
|
return None
|
|
|
|
# Get current price
|
|
current_price = candles['close'].iloc[-1]
|
|
|
|
# Run model inference
|
|
action_probs, price_pred = self.model.predict(X)
|
|
|
|
return action_probs, current_price
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running inference for {symbol}: {str(e)}")
|
|
return None
|
|
|
|
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
|
|
"""Process a prediction and generate signals.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
|
prediction: Tuple of (action probabilities, current price)
|
|
"""
|
|
action_probs, current_price = prediction
|
|
|
|
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
|
|
best_action = np.argmax(action_probs)
|
|
best_prob = float(action_probs[best_action])
|
|
|
|
# Convert to action name
|
|
action_names = ["SELL", "HOLD", "BUY"]
|
|
action_name = action_names[best_action]
|
|
|
|
# Log the prediction
|
|
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
|
|
|
|
# Add signal to chart if available
|
|
if self.chart:
|
|
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
|
|
|
|
# Process signal with trading agent if available
|
|
if self.trading_agent:
|
|
self.trading_agent.process_signal(
|
|
symbol=symbol,
|
|
action=action_name,
|
|
confidence=best_prob,
|
|
timestamp=int(time.time())
|
|
)
|
|
|
|
def _on_trade_executed(self, trade_record: Dict[str, Any]):
|
|
"""Callback for when a trade is executed.
|
|
|
|
Args:
|
|
trade_record: Trade information
|
|
"""
|
|
if self.chart and trade_record:
|
|
# Add trade to chart
|
|
self.chart.add_trade(
|
|
action=trade_record['action'],
|
|
price=trade_record.get('price', 0),
|
|
timestamp=trade_record['timestamp'],
|
|
pnl=trade_record.get('pnl', 0)
|
|
)
|
|
|
|
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
|
|
|
|
def get_trading_agent_info(self) -> Dict[str, Any]:
|
|
"""Get information about the trading agent.
|
|
|
|
Returns:
|
|
dict: Trading agent information or None if no agent is available
|
|
"""
|
|
if self.trading_agent:
|
|
return {
|
|
'exchange_info': self.trading_agent.get_exchange_info(),
|
|
'positions': self.trading_agent.get_current_positions(),
|
|
'trades': len(self.trading_agent.get_trade_history())
|
|
}
|
|
return None |