showing trades on realtime chart - chart broken
This commit is contained in:
287
NN/neural_network_orchestrator.py
Normal file
287
NN/neural_network_orchestrator.py
Normal file
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable, Tuple
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .trading_agent import TradingAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NeuralNetworkOrchestrator:
|
||||
"""Orchestrator for neural network models and trading operations.
|
||||
|
||||
This class coordinates between neural network models and trading agents,
|
||||
ensuring that signals from the models are properly processed and trades
|
||||
are executed according to the strategy.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_interface, chart=None,
|
||||
symbols: List[str] = None,
|
||||
timeframes: List[str] = None,
|
||||
window_size: int = 20,
|
||||
num_features: int = 5,
|
||||
output_size: int = 3,
|
||||
models_dir: str = "NN/models/saved",
|
||||
data_dir: str = "NN/data",
|
||||
exchange_config: Dict[str, Any] = None):
|
||||
"""Initialize the neural network orchestrator.
|
||||
|
||||
Args:
|
||||
model: Neural network model instance
|
||||
data_interface: Data interface for retrieving market data
|
||||
chart: Real-time chart for visualization (optional)
|
||||
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
|
||||
window_size: Window size for model input
|
||||
num_features: Number of features per datapoint
|
||||
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
|
||||
models_dir: Directory for saved models
|
||||
data_dir: Directory for data storage
|
||||
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
|
||||
"""
|
||||
self.model = model
|
||||
self.data_interface = data_interface
|
||||
self.chart = chart
|
||||
|
||||
self.symbols = symbols or ["BTC/USDT"]
|
||||
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.output_size = output_size
|
||||
self.models_dir = models_dir
|
||||
self.data_dir = data_dir
|
||||
|
||||
# Initialize trading agent if configuration provided
|
||||
self.trading_agent = None
|
||||
if exchange_config:
|
||||
self.init_trading_agent(exchange_config)
|
||||
|
||||
# Initialize inference state
|
||||
self.is_running = False
|
||||
self.inference_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.last_inference_time = 0
|
||||
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
|
||||
|
||||
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
|
||||
logger.info(f"- Symbol: {self.symbols[0]}")
|
||||
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
|
||||
logger.info(f"- Window size: {window_size}")
|
||||
logger.info(f"- Num features: {num_features}")
|
||||
logger.info(f"- Output size: {output_size}")
|
||||
logger.info(f"- Models dir: {models_dir}")
|
||||
logger.info(f"- Data dir: {data_dir}")
|
||||
logger.info(f"- Inference interval: {self.inference_interval} seconds")
|
||||
|
||||
def init_trading_agent(self, config: Dict[str, Any]):
|
||||
"""Initialize the trading agent with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for the trading agent
|
||||
"""
|
||||
exchange_name = config.get("exchange", "binance")
|
||||
api_key = config.get("api_key")
|
||||
api_secret = config.get("api_secret")
|
||||
test_mode = config.get("test_mode", True)
|
||||
trade_symbols = config.get("trade_symbols", self.symbols)
|
||||
position_size = config.get("position_size", 0.1)
|
||||
max_trades_per_day = config.get("max_trades_per_day", 5)
|
||||
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
|
||||
|
||||
self.trading_agent = TradingAgent(
|
||||
exchange_name=exchange_name,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=test_mode,
|
||||
trade_symbols=trade_symbols,
|
||||
position_size=position_size,
|
||||
max_trades_per_day=max_trades_per_day,
|
||||
trade_cooldown_minutes=trade_cooldown_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
|
||||
|
||||
def start_inference(self):
|
||||
"""Start the inference thread."""
|
||||
if self.is_running:
|
||||
logger.warning("Neural network inference is already running.")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start inference thread
|
||||
self.inference_thread = threading.Thread(target=self._inference_loop)
|
||||
self.inference_thread.daemon = True
|
||||
self.inference_thread.start()
|
||||
|
||||
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
|
||||
|
||||
# Start trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.start(signal_callback=self._on_trade_executed)
|
||||
|
||||
def stop_inference(self):
|
||||
"""Stop the inference thread."""
|
||||
if not self.is_running:
|
||||
logger.warning("Neural network inference is not running.")
|
||||
return
|
||||
|
||||
logger.info("Stopping neural network inference...")
|
||||
self.is_running = False
|
||||
self.stop_event.set()
|
||||
|
||||
if self.inference_thread and self.inference_thread.is_alive():
|
||||
self.inference_thread.join(timeout=10)
|
||||
|
||||
logger.info("Neural network inference stopped.")
|
||||
|
||||
# Stop trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.stop()
|
||||
|
||||
def _inference_loop(self):
|
||||
"""Main inference loop that processes data and generates signals."""
|
||||
logger.info("Inference loop started.")
|
||||
|
||||
try:
|
||||
while self.is_running and not self.stop_event.is_set():
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we should run inference
|
||||
if current_time - self.last_inference_time >= self.inference_interval:
|
||||
try:
|
||||
# Run inference for all symbols
|
||||
for symbol in self.symbols:
|
||||
prediction = self._run_inference(symbol)
|
||||
if prediction:
|
||||
self._process_prediction(symbol, prediction)
|
||||
|
||||
self.last_inference_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {str(e)}")
|
||||
|
||||
# Sleep for a short time to prevent CPU hogging
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference loop: {str(e)}")
|
||||
finally:
|
||||
logger.info("Inference loop stopped.")
|
||||
|
||||
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
|
||||
"""Run inference for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
tuple: (action probabilities, current price) or None if inference failed
|
||||
"""
|
||||
try:
|
||||
# Get the model timeframe from environment
|
||||
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
|
||||
if model_timeframe not in self.timeframes:
|
||||
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
|
||||
model_timeframe = self.timeframes[0]
|
||||
|
||||
# Load candles for the model timeframe
|
||||
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
|
||||
candles = self.data_interface.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=model_timeframe,
|
||||
n_candles=1000
|
||||
)
|
||||
|
||||
if candles is None or len(candles) < self.window_size:
|
||||
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
|
||||
return None
|
||||
|
||||
# Prepare input data
|
||||
X, timestamp = self.data_interface.prepare_model_input(
|
||||
data=candles,
|
||||
window_size=self.window_size,
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
if X is None:
|
||||
logger.warning(f"Failed to prepare model input for {symbol}.")
|
||||
return None
|
||||
|
||||
# Get current price
|
||||
current_price = candles['close'].iloc[-1]
|
||||
|
||||
# Run model inference
|
||||
action_probs, price_pred = self.model.predict(X)
|
||||
|
||||
return action_probs, current_price
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
|
||||
"""Process a prediction and generate signals.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
prediction: Tuple of (action probabilities, current price)
|
||||
"""
|
||||
action_probs, current_price = prediction
|
||||
|
||||
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
|
||||
best_action = np.argmax(action_probs)
|
||||
best_prob = float(action_probs[best_action])
|
||||
|
||||
# Convert to action name
|
||||
action_names = ["SELL", "HOLD", "BUY"]
|
||||
action_name = action_names[best_action]
|
||||
|
||||
# Log the prediction
|
||||
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
|
||||
|
||||
# Add signal to chart if available
|
||||
if self.chart:
|
||||
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
|
||||
|
||||
# Process signal with trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.process_signal(
|
||||
symbol=symbol,
|
||||
action=action_name,
|
||||
confidence=best_prob,
|
||||
timestamp=int(time.time())
|
||||
)
|
||||
|
||||
def _on_trade_executed(self, trade_record: Dict[str, Any]):
|
||||
"""Callback for when a trade is executed.
|
||||
|
||||
Args:
|
||||
trade_record: Trade information
|
||||
"""
|
||||
if self.chart and trade_record:
|
||||
# Add trade to chart
|
||||
self.chart.add_trade(
|
||||
action=trade_record['action'],
|
||||
price=trade_record.get('price', 0),
|
||||
timestamp=trade_record['timestamp'],
|
||||
pnl=trade_record.get('pnl', 0)
|
||||
)
|
||||
|
||||
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
|
||||
|
||||
def get_trading_agent_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the trading agent.
|
||||
|
||||
Returns:
|
||||
dict: Trading agent information or None if no agent is available
|
||||
"""
|
||||
if self.trading_agent:
|
||||
return {
|
||||
'exchange_info': self.trading_agent.get_exchange_info(),
|
||||
'positions': self.trading_agent.get_current_positions(),
|
||||
'trades': len(self.trading_agent.get_trade_history())
|
||||
}
|
||||
return None
|
Reference in New Issue
Block a user