From a46b2c74f8f2fd7d4e9e7272f7ecc5bb46852223 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 31 Mar 2025 14:22:33 +0300 Subject: [PATCH] showing trades on realtime chart - chart broken --- NN/exchanges/__init__.py | 5 + NN/exchanges/binance_interface.py | 276 +++++++ NN/exchanges/exchange_interface.py | 191 +++++ NN/exchanges/mexc_interface.py | 258 +++++++ NN/main.py | 244 ++++++ NN/models/cnn_model_pytorch.py | 172 +++-- NN/neural_network_orchestrator.py | 287 ++++++++ NN/realtime_main.py | 109 ++- NN/trading_agent.py | 310 ++++++++ NN/train_rl.py | 18 + _notes.md | 2 + launch_training.py | 124 ++++ main.py | 1107 ++++++++++++++++++++++++++++ trading_main.py | 155 ++++ 14 files changed, 3182 insertions(+), 76 deletions(-) create mode 100644 NN/exchanges/__init__.py create mode 100644 NN/exchanges/binance_interface.py create mode 100644 NN/exchanges/exchange_interface.py create mode 100644 NN/exchanges/mexc_interface.py create mode 100644 NN/main.py create mode 100644 NN/neural_network_orchestrator.py create mode 100644 NN/trading_agent.py create mode 100644 launch_training.py create mode 100644 trading_main.py diff --git a/NN/exchanges/__init__.py b/NN/exchanges/__init__.py new file mode 100644 index 0000000..56fd7de --- /dev/null +++ b/NN/exchanges/__init__.py @@ -0,0 +1,5 @@ +from .exchange_interface import ExchangeInterface +from .mexc_interface import MEXCInterface +from .binance_interface import BinanceInterface + +__all__ = ['ExchangeInterface', 'MEXCInterface', 'BinanceInterface'] \ No newline at end of file diff --git a/NN/exchanges/binance_interface.py b/NN/exchanges/binance_interface.py new file mode 100644 index 0000000..561613e --- /dev/null +++ b/NN/exchanges/binance_interface.py @@ -0,0 +1,276 @@ +import logging +import time +from typing import Dict, Any, List, Optional +import requests +import hmac +import hashlib +from urllib.parse import urlencode + +from .exchange_interface import ExchangeInterface + +logger = logging.getLogger(__name__) + +class BinanceInterface(ExchangeInterface): + """Binance Exchange API Interface""" + + def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True): + """Initialize Binance exchange interface. + + Args: + api_key: Binance API key + api_secret: Binance API secret + test_mode: If True, use testnet environment + """ + super().__init__(api_key, api_secret, test_mode) + + # Use testnet URLs if in test mode + if test_mode: + self.base_url = "https://testnet.binance.vision" + else: + self.base_url = "https://api.binance.com" + + self.api_version = "v3" + + def connect(self) -> bool: + """Connect to Binance API. This is a no-op for REST API.""" + if not self.api_key or not self.api_secret: + logger.warning("Binance API credentials not provided. Running in read-only mode.") + return False + + try: + # Test connection by pinging server and checking account info + ping_result = self._send_public_request('GET', 'ping') + + if self.api_key and self.api_secret: + # Check account connectivity + self.get_account_info() + + logger.info(f"Successfully connected to Binance API ({'testnet' if self.test_mode else 'live'})") + return True + except Exception as e: + logger.error(f"Failed to connect to Binance API: {str(e)}") + return False + + def _generate_signature(self, params: Dict[str, Any]) -> str: + """Generate signature for authenticated requests.""" + query_string = urlencode(params) + signature = hmac.new( + self.api_secret.encode('utf-8'), + query_string.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + return signature + + def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]: + """Send public request to Binance API.""" + url = f"{self.base_url}/api/{self.api_version}/{endpoint}" + + try: + if method.upper() == 'GET': + response = requests.get(url, params=params) + else: + response = requests.post(url, json=params) + + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error in public request to {endpoint}: {str(e)}") + raise + + def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]: + """Send private/authenticated request to Binance API.""" + if not self.api_key or not self.api_secret: + raise ValueError("API key and secret are required for private requests") + + if params is None: + params = {} + + # Add timestamp + params['timestamp'] = int(time.time() * 1000) + + # Generate signature + signature = self._generate_signature(params) + params['signature'] = signature + + # Set headers + headers = { + 'X-MBX-APIKEY': self.api_key + } + + url = f"{self.base_url}/api/{self.api_version}/{endpoint}" + + try: + if method.upper() == 'GET': + response = requests.get(url, params=params, headers=headers) + elif method.upper() == 'POST': + response = requests.post(url, data=params, headers=headers) + elif method.upper() == 'DELETE': + response = requests.delete(url, params=params, headers=headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + # Log detailed error if available + if response.status_code != 200: + logger.error(f"Binance API error: {response.text}") + + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error in private request to {endpoint}: {str(e)}") + raise + + def get_account_info(self) -> Dict[str, Any]: + """Get account information.""" + return self._send_private_request('GET', 'account') + + def get_balance(self, asset: str) -> float: + """Get balance of a specific asset. + + Args: + asset: Asset symbol (e.g., 'BTC', 'USDT') + + Returns: + float: Available balance of the asset + """ + try: + account_info = self._send_private_request('GET', 'account') + balances = account_info.get('balances', []) + + for balance in balances: + if balance['asset'] == asset: + return float(balance['free']) + + # Asset not found + return 0.0 + except Exception as e: + logger.error(f"Error getting balance for {asset}: {str(e)}") + return 0.0 + + def get_ticker(self, symbol: str) -> Dict[str, Any]: + """Get current ticker data for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + + Returns: + dict: Ticker data including price information + """ + binance_symbol = symbol.replace('/', '') + try: + ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': binance_symbol}) + + # Convert to a standardized format + result = { + 'symbol': symbol, + 'bid': float(ticker['bidPrice']), + 'ask': float(ticker['askPrice']), + 'last': float(ticker['lastPrice']), + 'volume': float(ticker['volume']), + 'timestamp': int(ticker['closeTime']) + } + return result + except Exception as e: + logger.error(f"Error getting ticker for {symbol}: {str(e)}") + raise + + def place_order(self, symbol: str, side: str, order_type: str, + quantity: float, price: float = None) -> Dict[str, Any]: + """Place an order on the exchange. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + side: Order side ('buy' or 'sell') + order_type: Order type ('market', 'limit', etc.) + quantity: Order quantity + price: Order price (for limit orders) + + Returns: + dict: Order information including order ID + """ + binance_symbol = symbol.replace('/', '') + params = { + 'symbol': binance_symbol, + 'side': side.upper(), + 'type': order_type.upper(), + 'quantity': quantity, + } + + if order_type.lower() == 'limit' and price is not None: + params['price'] = price + params['timeInForce'] = 'GTC' # Good Till Cancelled + + # Use test order endpoint in test mode + endpoint = 'order/test' if self.test_mode else 'order' + + try: + order_result = self._send_private_request('POST', endpoint, params) + return order_result + except Exception as e: + logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}") + raise + + def cancel_order(self, symbol: str, order_id: str) -> bool: + """Cancel an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order to cancel + + Returns: + bool: True if cancellation successful, False otherwise + """ + binance_symbol = symbol.replace('/', '') + params = { + 'symbol': binance_symbol, + 'orderId': order_id + } + + try: + cancel_result = self._send_private_request('DELETE', 'order', params) + return True + except Exception as e: + logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}") + return False + + def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]: + """Get status of an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order + + Returns: + dict: Order status information + """ + binance_symbol = symbol.replace('/', '') + params = { + 'symbol': binance_symbol, + 'orderId': order_id + } + + try: + order_info = self._send_private_request('GET', 'order', params) + return order_info + except Exception as e: + logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}") + raise + + def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]: + """Get all open orders, optionally filtered by symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols + + Returns: + list: List of open orders + """ + params = {} + if symbol: + params['symbol'] = symbol.replace('/', '') + + try: + open_orders = self._send_private_request('GET', 'openOrders', params) + return open_orders + except Exception as e: + logger.error(f"Error getting open orders: {str(e)}") + return [] \ No newline at end of file diff --git a/NN/exchanges/exchange_interface.py b/NN/exchanges/exchange_interface.py new file mode 100644 index 0000000..4773d1c --- /dev/null +++ b/NN/exchanges/exchange_interface.py @@ -0,0 +1,191 @@ +import abc +import logging +from typing import Dict, Any, List, Tuple, Optional + +logger = logging.getLogger(__name__) + +class ExchangeInterface(abc.ABC): + """Base class for all exchange interfaces. + + This abstract class defines the required methods that all exchange + implementations must provide to ensure compatibility with the trading system. + """ + + def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True): + """Initialize the exchange interface. + + Args: + api_key: API key for the exchange + api_secret: API secret for the exchange + test_mode: If True, use test/sandbox environment + """ + self.api_key = api_key + self.api_secret = api_secret + self.test_mode = test_mode + self.client = None + self.last_price_cache = {} + + @abc.abstractmethod + def connect(self) -> bool: + """Connect to the exchange API. + + Returns: + bool: True if connection successful, False otherwise + """ + pass + + @abc.abstractmethod + def get_balance(self, asset: str) -> float: + """Get balance of a specific asset. + + Args: + asset: Asset symbol (e.g., 'BTC', 'USDT') + + Returns: + float: Available balance of the asset + """ + pass + + @abc.abstractmethod + def get_ticker(self, symbol: str) -> Dict[str, Any]: + """Get current ticker data for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + + Returns: + dict: Ticker data including price information + """ + pass + + @abc.abstractmethod + def place_order(self, symbol: str, side: str, order_type: str, + quantity: float, price: float = None) -> Dict[str, Any]: + """Place an order on the exchange. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + side: Order side ('buy' or 'sell') + order_type: Order type ('market', 'limit', etc.) + quantity: Order quantity + price: Order price (for limit orders) + + Returns: + dict: Order information including order ID + """ + pass + + @abc.abstractmethod + def cancel_order(self, symbol: str, order_id: str) -> bool: + """Cancel an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order to cancel + + Returns: + bool: True if cancellation successful, False otherwise + """ + pass + + @abc.abstractmethod + def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]: + """Get status of an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order + + Returns: + dict: Order status information + """ + pass + + @abc.abstractmethod + def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]: + """Get all open orders, optionally filtered by symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols + + Returns: + list: List of open orders + """ + pass + + def get_last_price(self, symbol: str) -> float: + """Get last known price for a symbol, may use cached value. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + + Returns: + float: Last price + """ + try: + ticker = self.get_ticker(symbol) + price = float(ticker['last']) + self.last_price_cache[symbol] = price + return price + except Exception as e: + logger.error(f"Error getting price for {symbol}: {str(e)}") + # Return cached price if available + return self.last_price_cache.get(symbol, 0.0) + + def execute_trade(self, symbol: str, action: str, quantity: float = None, + percent_of_balance: float = None) -> Optional[Dict[str, Any]]: + """Execute a trade based on a signal. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + action: Trade action ('BUY', 'SELL') + quantity: Specific quantity to trade + percent_of_balance: Alternative to quantity - percentage of available balance to use + + Returns: + dict: Order information or None if order failed + """ + if action not in ['BUY', 'SELL']: + logger.error(f"Invalid action: {action}. Must be 'BUY' or 'SELL'") + return None + + side = action.lower() + + try: + # Determine base and quote assets from symbol (e.g., BTC/USDT -> BTC, USDT) + base_asset, quote_asset = symbol.split('/') + + # Calculate quantity if percent_of_balance is provided + if quantity is None and percent_of_balance is not None: + if percent_of_balance <= 0 or percent_of_balance > 1: + logger.error(f"Invalid percent_of_balance: {percent_of_balance}. Must be between 0 and 1") + return None + + if side == 'buy': + # For buy, use quote asset (e.g., USDT) + balance = self.get_balance(quote_asset) + price = self.get_last_price(symbol) + quantity = (balance * percent_of_balance) / price + else: + # For sell, use base asset (e.g., BTC) + balance = self.get_balance(base_asset) + quantity = balance * percent_of_balance + + if not quantity or quantity <= 0: + logger.error(f"Invalid quantity: {quantity}") + return None + + # Place market order + order = self.place_order( + symbol=symbol, + side=side, + order_type='market', + quantity=quantity + ) + + logger.info(f"Executed {side.upper()} order for {quantity} {base_asset} at market price") + return order + + except Exception as e: + logger.error(f"Error executing {action} trade for {symbol}: {str(e)}") + return None \ No newline at end of file diff --git a/NN/exchanges/mexc_interface.py b/NN/exchanges/mexc_interface.py new file mode 100644 index 0000000..a9b7b68 --- /dev/null +++ b/NN/exchanges/mexc_interface.py @@ -0,0 +1,258 @@ +import logging +import time +from typing import Dict, Any, List, Optional +import requests +import hmac +import hashlib +from urllib.parse import urlencode + +from .exchange_interface import ExchangeInterface + +logger = logging.getLogger(__name__) + +class MEXCInterface(ExchangeInterface): + """MEXC Exchange API Interface""" + + def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True): + """Initialize MEXC exchange interface. + + Args: + api_key: MEXC API key + api_secret: MEXC API secret + test_mode: If True, use test/sandbox environment (Note: MEXC doesn't have a true sandbox) + """ + super().__init__(api_key, api_secret, test_mode) + self.base_url = "https://api.mexc.com" + self.api_version = "v3" + + def connect(self) -> bool: + """Connect to MEXC API. This is a no-op for REST API.""" + if not self.api_key or not self.api_secret: + logger.warning("MEXC API credentials not provided. Running in read-only mode.") + return False + + try: + # Test connection by getting account info + self.get_account_info() + logger.info("Successfully connected to MEXC API") + return True + except Exception as e: + logger.error(f"Failed to connect to MEXC API: {str(e)}") + return False + + def _generate_signature(self, params: Dict[str, Any]) -> str: + """Generate signature for authenticated requests.""" + query_string = urlencode(params) + signature = hmac.new( + self.api_secret.encode('utf-8'), + query_string.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + return signature + + def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]: + """Send public request to MEXC API.""" + url = f"{self.base_url}/{self.api_version}/{endpoint}" + + try: + if method.upper() == 'GET': + response = requests.get(url, params=params) + else: + response = requests.post(url, json=params) + + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error in public request to {endpoint}: {str(e)}") + raise + + def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]: + """Send private/authenticated request to MEXC API.""" + if not self.api_key or not self.api_secret: + raise ValueError("API key and secret are required for private requests") + + if params is None: + params = {} + + # Add timestamp + params['timestamp'] = int(time.time() * 1000) + + # Generate signature + signature = self._generate_signature(params) + params['signature'] = signature + + # Set headers + headers = { + 'X-MEXC-APIKEY': self.api_key + } + + url = f"{self.base_url}/{self.api_version}/{endpoint}" + + try: + if method.upper() == 'GET': + response = requests.get(url, params=params, headers=headers) + elif method.upper() == 'POST': + response = requests.post(url, json=params, headers=headers) + elif method.upper() == 'DELETE': + response = requests.delete(url, params=params, headers=headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error in private request to {endpoint}: {str(e)}") + raise + + def get_account_info(self) -> Dict[str, Any]: + """Get account information.""" + return self._send_private_request('GET', 'account') + + def get_balance(self, asset: str) -> float: + """Get balance of a specific asset. + + Args: + asset: Asset symbol (e.g., 'BTC', 'USDT') + + Returns: + float: Available balance of the asset + """ + try: + account_info = self._send_private_request('GET', 'account') + balances = account_info.get('balances', []) + + for balance in balances: + if balance['asset'] == asset: + return float(balance['free']) + + # Asset not found + return 0.0 + except Exception as e: + logger.error(f"Error getting balance for {asset}: {str(e)}") + return 0.0 + + def get_ticker(self, symbol: str) -> Dict[str, Any]: + """Get current ticker data for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + + Returns: + dict: Ticker data including price information + """ + mexc_symbol = symbol.replace('/', '') + try: + ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol}) + + # Convert to a standardized format + result = { + 'symbol': symbol, + 'bid': float(ticker['bidPrice']), + 'ask': float(ticker['askPrice']), + 'last': float(ticker['lastPrice']), + 'volume': float(ticker['volume']), + 'timestamp': int(ticker['closeTime']) + } + return result + except Exception as e: + logger.error(f"Error getting ticker for {symbol}: {str(e)}") + raise + + def place_order(self, symbol: str, side: str, order_type: str, + quantity: float, price: float = None) -> Dict[str, Any]: + """Place an order on the exchange. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + side: Order side ('buy' or 'sell') + order_type: Order type ('market', 'limit', etc.) + quantity: Order quantity + price: Order price (for limit orders) + + Returns: + dict: Order information including order ID + """ + mexc_symbol = symbol.replace('/', '') + params = { + 'symbol': mexc_symbol, + 'side': side.upper(), + 'type': order_type.upper(), + 'quantity': quantity, + } + + if order_type.lower() == 'limit' and price is not None: + params['price'] = price + params['timeInForce'] = 'GTC' # Good Till Cancelled + + try: + order_result = self._send_private_request('POST', 'order', params) + return order_result + except Exception as e: + logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}") + raise + + def cancel_order(self, symbol: str, order_id: str) -> bool: + """Cancel an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order to cancel + + Returns: + bool: True if cancellation successful, False otherwise + """ + mexc_symbol = symbol.replace('/', '') + params = { + 'symbol': mexc_symbol, + 'orderId': order_id + } + + try: + cancel_result = self._send_private_request('DELETE', 'order', params) + return True + except Exception as e: + logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}") + return False + + def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]: + """Get status of an existing order. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + order_id: ID of the order + + Returns: + dict: Order status information + """ + mexc_symbol = symbol.replace('/', '') + params = { + 'symbol': mexc_symbol, + 'orderId': order_id + } + + try: + order_info = self._send_private_request('GET', 'order', params) + return order_info + except Exception as e: + logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}") + raise + + def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]: + """Get all open orders, optionally filtered by symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols + + Returns: + list: List of open orders + """ + params = {} + if symbol: + params['symbol'] = symbol.replace('/', '') + + try: + open_orders = self._send_private_request('GET', 'openOrders', params) + return open_orders + except Exception as e: + logger.error(f"Error getting open orders: {str(e)}") + return [] \ No newline at end of file diff --git a/NN/main.py b/NN/main.py new file mode 100644 index 0000000..11d9241 --- /dev/null +++ b/NN/main.py @@ -0,0 +1,244 @@ +""" +Neural Network Trading System Main Module (Compatibility Layer) + +This module serves as a compatibility layer for the realtime.py module. +It re-exports the functionality from realtime_main.py that is needed by realtime.py. +""" + +import os +import sys +import logging +from datetime import datetime +import numpy as np + +# Configure logging +logger = logging.getLogger('NN') +logger.setLevel(logging.INFO) + +# Re-export everything from realtime_main.py +from .realtime_main import ( + parse_arguments, + realtime, + train, + predict +) + +# Create a class that realtime.py expects +class NeuralNetworkOrchestrator: + """ + Orchestrates the neural network operations. + """ + + def __init__(self, config): + """ + Initialize the orchestrator with configuration. + + Args: + config (dict): Configuration parameters + """ + self.config = config + self.symbol = config.get('symbol', 'BTC/USDT') + self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h']) + self.window_size = config.get('window_size', 20) + self.n_features = config.get('n_features', 5) + self.output_size = config.get('output_size', 3) + self.model_dir = config.get('model_dir', 'NN/models/saved') + self.data_dir = config.get('data_dir', 'NN/data') + self.model = None + self.data_interface = None + + # Initialize with default values in case imports fail + self.model_initialized = False + self.data_initialized = False + + # Import necessary modules dynamically + try: + from .utils.data_interface import DataInterface + + # Initialize data interface + self.data_interface = DataInterface( + symbol=self.symbol, + timeframes=self.timeframes + ) + self.data_initialized = True + logger.info(f"Data interface initialized for {self.symbol}") + + try: + from .models.cnn_model_pytorch import CNNModelPyTorch as Model + + # Initialize model + feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5 + try: + # First try with expected parameters + self.model = Model( + window_size=self.window_size, + num_features=feature_count, + output_size=self.output_size, + timeframes=self.timeframes + ) + except TypeError as e: + logger.warning(f"TypeError in model initialization with num_features: {str(e)}") + # Try alternate parameter naming + try: + self.model = Model( + input_shape=(self.window_size, feature_count), + output_size=self.output_size + ) + logger.info("Model initialized with alternate parameters") + except Exception as ex: + logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}") + self.model = DummyModel() + + # Try to load the best model + self._load_model() + self.model_initialized = True + logger.info("Model initialized successfully") + except Exception as e: + logger.error(f"Error initializing model: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + self.model = DummyModel() + + logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}") + except Exception as e: + logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + self.model = DummyModel() + + def _load_model(self): + """Load the best trained model from available files""" + try: + model_paths = [ + os.path.join(self.model_dir, "dqn_agent_best_policy.pt"), + os.path.join(self.model_dir, "cnn_model_best.pt"), + os.path.join("models/saved", "dqn_agent_best_policy.pt"), + os.path.join("models/saved", "cnn_model_best.pt") + ] + + for model_path in model_paths: + if os.path.exists(model_path): + try: + self.model.load(model_path) + logger.info(f"Loaded model from {model_path}") + return True + except Exception as e: + logger.warning(f"Failed to load model from {model_path}: {str(e)}") + continue + + logger.warning("No trained model found, using dummy model") + self.model = DummyModel() + return False + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + self.model = DummyModel() + return False + + def run_inference_pipeline(self, model_type='cnn', timeframe='1h'): + """ + Run the inference pipeline using the trained model. + + Args: + model_type (str): Type of model to use (cnn, transformer, etc.) + timeframe (str): Timeframe to use for inference + + Returns: + dict: Inference result + """ + try: + # Check if we have a model + if not hasattr(self, 'model') or self.model is None: + logger.warning("No model available, initializing dummy model") + self.model = DummyModel() + + # Check if we have a data interface + if not hasattr(self, 'data_interface') or self.data_interface is None: + logger.warning("No data interface available") + # Return a dummy prediction + return self._get_dummy_prediction() + + # Prepare input data for the selected timeframe + X, timestamp = self.data_interface.prepare_realtime_input( + timeframe=timeframe, + n_candles=self.window_size + 10, # Extra candles for safety + window_size=self.window_size + ) + + if X is None: + logger.warning(f"No data available for {self.symbol}") + return self._get_dummy_prediction() + + # Get model predictions + action_probs, price_pred = self.model.predict(X) + + # Convert predictions to action + action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD + action_names = ['SELL', 'HOLD', 'BUY'] + action = action_names[action_idx] + + # Format timestamp + if not isinstance(timestamp, str): + try: + if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object + timestamp = timestamp.isoformat() + else: # If it's a numeric timestamp + timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat() + except (TypeError, ValueError): + timestamp = datetime.now().isoformat() + + # Return result + result = { + 'timestamp': timestamp, + 'action': action, + 'action_index': int(action_idx), + 'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33, + 'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33}, + 'price_prediction': float(price_pred) if price_pred is not None else None + } + + logger.info(f"Inference result: {result}") + return result + + except Exception as e: + logger.error(f"Error in inference pipeline: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return self._get_dummy_prediction() + + def _get_dummy_prediction(self): + """Return a dummy prediction when model or data is unavailable""" + action_names = ['SELL', 'HOLD', 'BUY'] + action_idx = 1 # Default to HOLD + timestamp = datetime.now().isoformat() + + return { + 'timestamp': timestamp, + 'action': 'HOLD', + 'action_index': action_idx, + 'probability': 0.8, + 'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1}, + 'price_prediction': None, + 'is_dummy': True + } + + +class DummyModel: + """Dummy model that returns random predictions""" + + def __init__(self): + logger.info("Initializing dummy model") + + def predict(self, X): + """Return random predictions""" + # Generate random probabilities for SELL, HOLD, BUY + action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD + + # Generate a random price prediction (None for now) + price_pred = None + + return action_probs, price_pred + + def load(self, model_path): + """Dummy load method""" + logger.info(f"Dummy model pretending to load from {model_path}") + return True \ No newline at end of file diff --git a/NN/models/cnn_model_pytorch.py b/NN/models/cnn_model_pytorch.py index 13bcb21..1c709dc 100644 --- a/NN/models/cnn_model_pytorch.py +++ b/NN/models/cnn_model_pytorch.py @@ -72,6 +72,9 @@ class CNNPyTorch(nn.Module): """ super(CNNPyTorch, self).__init__() + # Set device + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + window_size, num_features = input_shape self.window_size = window_size @@ -225,74 +228,93 @@ class CNNModelPyTorch: predictions with the CNN model, optimized for short-term trading opportunities. """ - def __init__(self, window_size, num_features, output_size=3, timeframes=None): + def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3): """ Initialize the CNN model. Args: - window_size (int): Size of the input window - num_features (int): Number of features in the input data - output_size (int): Size of the output (default: 3 for BUY/HOLD/SELL) - timeframes (list): List of timeframes used (for logging) + window_size (int): Size of the sliding window + timeframes (list): List of timeframes used + output_size (int): Number of output classes (3 for BUY/HOLD/SELL) + num_pairs (int): Number of trading pairs to analyze in parallel (default 3) """ - # Action tracking - self.action_counts = { - 'BUY': 0, - 'SELL': 0, - 'HOLD': 0 - } self.window_size = window_size - self.num_features = num_features + self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"] self.output_size = output_size - self.timeframes = timeframes or [] + self.num_pairs = num_pairs - # Determine device (GPU or CPU) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {self.device}") + # Calculate total features (5 OHLCV features per timeframe per pair) + self.total_features = len(self.timeframes) * 5 * self.num_pairs - # Initialize model - self.model = None - self.build_model() + # Build the model + logger.info(f"Building PyTorch CNN model with window_size={window_size}, " + f"num_features={self.total_features}, output_size={output_size}, " + f"num_pairs={num_pairs}") - # Initialize training history - self.history = { - 'loss': [], - 'val_loss': [], - 'accuracy': [], - 'val_accuracy': [] - } - - # Sensitivity parameters for high-leverage trading - self.confidence_threshold = 0.65 # Minimum confidence for trading actions - self.max_consecutive_same_action = 3 # Limit consecutive identical actions - self.last_actions = [] # Track recent actions - - def build_model(self): - """Build the CNN model architecture""" - logger.info(f"Building PyTorch CNN model with window_size={self.window_size}, " - f"num_features={self.num_features}, output_size={self.output_size}") - - # Ensure window size is not less than the actual input - input_window_size = max(self.window_size, 20) # Use at least 20 as minimum window size - - self.model = CNNPyTorch( - input_shape=(input_window_size, self.num_features), - output_size=self.output_size + # Calculate channel sizes that are divisible by num_pairs + base_channels = 96 # 96 is divisible by 3 + self.model = nn.Sequential( + # First convolutional layer - process each pair's features + nn.Sequential( + nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs), + nn.ReLU(), + nn.BatchNorm1d(base_channels), + nn.Dropout(0.2) + ), + + # Second convolutional layer - start mixing pair information + nn.Sequential( + nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1), + nn.ReLU(), + nn.BatchNorm1d(base_channels*2), + nn.Dropout(0.2) + ), + + # Third convolutional layer - deeper feature extraction + nn.Sequential( + nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1), + nn.ReLU(), + nn.BatchNorm1d(base_channels*4), + nn.Dropout(0.2) + ), + + # Global average pooling + nn.AdaptiveAvgPool1d(1), + + # Flatten + nn.Flatten(), + + # Dense layers for action prediction with cross-pair attention + nn.Sequential( + nn.Linear(base_channels*4, base_channels*2), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(base_channels*2, base_channels), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(base_channels, output_size * num_pairs) # Output for each pair + ) ).to(self.device) - # Initialize optimizer with higher learning rate for faster adaptation - self.optimizer = optim.Adam(self.model.parameters(), lr=0.002) - - # Learning rate scheduler with faster decay + # Initialize optimizer and loss function + self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, mode='max', factor=0.6, patience=6, verbose=True + self.optimizer, mode='max', factor=0.5, patience=5, verbose=True ) + self.criterion = nn.CrossEntropyLoss() - # Initialize loss function with higher weights for BUY/SELL - class_weights = torch.tensor([7.0, 1.0, 7.0]).to(self.device) # Even higher weights for BUY/SELL - self.criterion = nn.CrossEntropyLoss(weight=class_weights) + # Initialize metrics tracking + self.train_losses = [] + self.val_losses = [] + self.train_accuracies = [] + self.val_accuracies = [] logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters") + + # Sensitivity parameters for high-leverage trading + self.confidence_threshold = 0.65 + self.max_consecutive_same_action = 3 + self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None): """ @@ -644,12 +666,12 @@ class CNNModelPyTorch: action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities # Implement signal filtering based on previous actions to avoid oscillation - if len(self.last_actions) >= self.max_consecutive_same_action: + if len(self.last_actions[0]) >= self.max_consecutive_same_action: # Check for too many consecutive identical actions - if all(a == 0 for a in self.last_actions[-self.max_consecutive_same_action:]): + if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]): # Too many consecutive SELL - reduce sell probability action_probs_np[:, 0] *= 0.7 - elif all(a == 2 for a in self.last_actions[-self.max_consecutive_same_action:]): + elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]): # Too many consecutive BUY - reduce buy probability action_probs_np[:, 2] *= 0.7 @@ -666,9 +688,9 @@ class CNNModelPyTorch: # Store the predicted action for the most recent input if action_probs_np.shape[0] > 0: latest_action = np.argmax(action_probs_np[-1]) - self.last_actions.append(int(latest_action)) + self.last_actions[0].append(int(latest_action)) # Keep only the most recent actions - self.last_actions = self.last_actions[-10:] # Store last 10 actions + self.last_actions[0] = self.last_actions[0][-10:] # Store last 10 actions # Update action counts for stats actions = np.argmax(action_probs_np, axis=1) @@ -676,11 +698,11 @@ class CNNModelPyTorch: action_dict = dict(zip(unique, counts)) if 0 in action_dict: - self.action_counts['SELL'] += action_dict[0] + self.action_counts['SELL'][0] += action_dict[0] if 1 in action_dict: - self.action_counts['HOLD'] += action_dict[1] + self.action_counts['HOLD'][0] += action_dict[1] if 2 in action_dict: - self.action_counts['BUY'] += action_dict[2] + self.action_counts['BUY'][0] += action_dict[2] # Get the current close prices from the input current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0]) @@ -838,20 +860,25 @@ class CNNModelPyTorch: f"val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}") # Update history - self.history['loss'].append(epoch_loss) - self.history['accuracy'].append(epoch_acc) - self.history['val_loss'].append(val_loss) - self.history['val_accuracy'].append(val_acc) + self.train_losses.append(epoch_loss) + self.train_accuracies.append(epoch_acc) + self.val_losses.append(val_loss) + self.val_accuracies.append(val_acc) else: logger.info(f"Epoch {epoch+1}/{epochs} - " f"loss: {epoch_loss:.4f} - acc: {epoch_acc:.4f}") # Update history without validation - self.history['loss'].append(epoch_loss) - self.history['accuracy'].append(epoch_acc) + self.train_losses.append(epoch_loss) + self.train_accuracies.append(epoch_acc) logger.info("Training completed") - return self.history + return { + 'loss': self.train_losses, + 'accuracy': self.train_accuracies, + 'val_loss': self.val_losses, + 'val_accuracy': self.val_accuracies + } def evaluate_metrics(self, X_test, y_test): """ @@ -892,9 +919,14 @@ class CNNModelPyTorch: model_state = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), - 'history': self.history, + 'history': { + 'loss': self.train_losses, + 'accuracy': self.train_accuracies, + 'val_loss': self.val_losses, + 'val_accuracy': self.val_accuracies + }, 'window_size': self.window_size, - 'num_features': self.num_features, + 'num_features': self.total_features, 'output_size': self.output_size, 'timeframes': self.timeframes, # Save trading configuration @@ -930,12 +962,12 @@ class CNNModelPyTorch: # Update model parameters self.window_size = model_state['window_size'] - self.num_features = model_state['num_features'] + self.total_features = model_state['num_features'] self.output_size = model_state['output_size'] self.timeframes = model_state.get('timeframes', ["1m"]) # Load model state dict - self.load_state_dict(model_state['model_state_dict']) + self.model.load_state_dict(model_state['model_state_dict']) # Load optimizer state if available if 'optimizer_state_dict' in model_state: diff --git a/NN/neural_network_orchestrator.py b/NN/neural_network_orchestrator.py new file mode 100644 index 0000000..f04dba9 --- /dev/null +++ b/NN/neural_network_orchestrator.py @@ -0,0 +1,287 @@ +import logging +import threading +import time +from typing import Dict, Any, List, Optional, Callable, Tuple +import os +import numpy as np +import pandas as pd + +from .trading_agent import TradingAgent + +logger = logging.getLogger(__name__) + +class NeuralNetworkOrchestrator: + """Orchestrator for neural network models and trading operations. + + This class coordinates between neural network models and trading agents, + ensuring that signals from the models are properly processed and trades + are executed according to the strategy. + """ + + def __init__(self, model, data_interface, chart=None, + symbols: List[str] = None, + timeframes: List[str] = None, + window_size: int = 20, + num_features: int = 5, + output_size: int = 3, + models_dir: str = "NN/models/saved", + data_dir: str = "NN/data", + exchange_config: Dict[str, Any] = None): + """Initialize the neural network orchestrator. + + Args: + model: Neural network model instance + data_interface: Data interface for retrieving market data + chart: Real-time chart for visualization (optional) + symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT']) + timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h']) + window_size: Window size for model input + num_features: Number of features per datapoint + output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL) + models_dir: Directory for saved models + data_dir: Directory for data storage + exchange_config: Configuration for trading agent (exchange, API keys, etc.) + """ + self.model = model + self.data_interface = data_interface + self.chart = chart + + self.symbols = symbols or ["BTC/USDT"] + self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"] + self.window_size = window_size + self.num_features = num_features + self.output_size = output_size + self.models_dir = models_dir + self.data_dir = data_dir + + # Initialize trading agent if configuration provided + self.trading_agent = None + if exchange_config: + self.init_trading_agent(exchange_config) + + # Initialize inference state + self.is_running = False + self.inference_thread = None + self.stop_event = threading.Event() + self.last_inference_time = 0 + self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60")) + + logger.info(f"Initializing NeuralNetworkOrchestrator with:") + logger.info(f"- Symbol: {self.symbols[0]}") + logger.info(f"- Timeframes: {', '.join(self.timeframes)}") + logger.info(f"- Window size: {window_size}") + logger.info(f"- Num features: {num_features}") + logger.info(f"- Output size: {output_size}") + logger.info(f"- Models dir: {models_dir}") + logger.info(f"- Data dir: {data_dir}") + logger.info(f"- Inference interval: {self.inference_interval} seconds") + + def init_trading_agent(self, config: Dict[str, Any]): + """Initialize the trading agent with the given configuration. + + Args: + config: Configuration for the trading agent + """ + exchange_name = config.get("exchange", "binance") + api_key = config.get("api_key") + api_secret = config.get("api_secret") + test_mode = config.get("test_mode", True) + trade_symbols = config.get("trade_symbols", self.symbols) + position_size = config.get("position_size", 0.1) + max_trades_per_day = config.get("max_trades_per_day", 5) + trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60) + + self.trading_agent = TradingAgent( + exchange_name=exchange_name, + api_key=api_key, + api_secret=api_secret, + test_mode=test_mode, + trade_symbols=trade_symbols, + position_size=position_size, + max_trades_per_day=max_trades_per_day, + trade_cooldown_minutes=trade_cooldown_minutes + ) + + logger.info(f"Trading agent initialized for {exchange_name} exchange.") + + def start_inference(self): + """Start the inference thread.""" + if self.is_running: + logger.warning("Neural network inference is already running.") + return + + self.is_running = True + self.stop_event.clear() + + # Start inference thread + self.inference_thread = threading.Thread(target=self._inference_loop) + self.inference_thread.daemon = True + self.inference_thread.start() + + logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.") + + # Start trading agent if available + if self.trading_agent: + self.trading_agent.start(signal_callback=self._on_trade_executed) + + def stop_inference(self): + """Stop the inference thread.""" + if not self.is_running: + logger.warning("Neural network inference is not running.") + return + + logger.info("Stopping neural network inference...") + self.is_running = False + self.stop_event.set() + + if self.inference_thread and self.inference_thread.is_alive(): + self.inference_thread.join(timeout=10) + + logger.info("Neural network inference stopped.") + + # Stop trading agent if available + if self.trading_agent: + self.trading_agent.stop() + + def _inference_loop(self): + """Main inference loop that processes data and generates signals.""" + logger.info("Inference loop started.") + + try: + while self.is_running and not self.stop_event.is_set(): + current_time = time.time() + + # Check if we should run inference + if current_time - self.last_inference_time >= self.inference_interval: + try: + # Run inference for all symbols + for symbol in self.symbols: + prediction = self._run_inference(symbol) + if prediction: + self._process_prediction(symbol, prediction) + + self.last_inference_time = current_time + except Exception as e: + logger.error(f"Error during inference: {str(e)}") + + # Sleep for a short time to prevent CPU hogging + time.sleep(1) + except Exception as e: + logger.error(f"Error in inference loop: {str(e)}") + finally: + logger.info("Inference loop stopped.") + + def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]: + """Run inference for a specific symbol. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + + Returns: + tuple: (action probabilities, current price) or None if inference failed + """ + try: + # Get the model timeframe from environment + model_timeframe = os.environ.get("NN_TIMEFRAME", "1h") + if model_timeframe not in self.timeframes: + logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.") + model_timeframe = self.timeframes[0] + + # Load candles for the model timeframe + logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe") + candles = self.data_interface.get_historical_data( + symbol=symbol, + timeframe=model_timeframe, + n_candles=1000 + ) + + if candles is None or len(candles) < self.window_size: + logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.") + return None + + # Prepare input data + X, timestamp = self.data_interface.prepare_model_input( + data=candles, + window_size=self.window_size, + symbol=symbol + ) + + if X is None: + logger.warning(f"Failed to prepare model input for {symbol}.") + return None + + # Get current price + current_price = candles['close'].iloc[-1] + + # Run model inference + action_probs, price_pred = self.model.predict(X) + + return action_probs, current_price + + except Exception as e: + logger.error(f"Error running inference for {symbol}: {str(e)}") + return None + + def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]): + """Process a prediction and generate signals. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + prediction: Tuple of (action probabilities, current price) + """ + action_probs, current_price = prediction + + # Get the best action (0=SELL, 1=HOLD, 2=BUY) + best_action = np.argmax(action_probs) + best_prob = float(action_probs[best_action]) + + # Convert to action name + action_names = ["SELL", "HOLD", "BUY"] + action_name = action_names[best_action] + + # Log the prediction + logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}") + + # Add signal to chart if available + if self.chart: + self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time())) + + # Process signal with trading agent if available + if self.trading_agent: + self.trading_agent.process_signal( + symbol=symbol, + action=action_name, + confidence=best_prob, + timestamp=int(time.time()) + ) + + def _on_trade_executed(self, trade_record: Dict[str, Any]): + """Callback for when a trade is executed. + + Args: + trade_record: Trade information + """ + if self.chart and trade_record: + # Add trade to chart + self.chart.add_trade( + action=trade_record['action'], + price=trade_record.get('price', 0), + timestamp=trade_record['timestamp'], + pnl=trade_record.get('pnl', 0) + ) + + logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}") + + def get_trading_agent_info(self) -> Dict[str, Any]: + """Get information about the trading agent. + + Returns: + dict: Trading agent information or None if no agent is available + """ + if self.trading_agent: + return { + 'exchange_info': self.trading_agent.get_exchange_info(), + 'positions': self.trading_agent.get_current_positions(), + 'trades': len(self.trading_agent.get_trade_history()) + } + return None \ No newline at end of file diff --git a/NN/realtime_main.py b/NN/realtime_main.py index d19b5f4..f1fa0cd 100644 --- a/NN/realtime_main.py +++ b/NN/realtime_main.py @@ -13,6 +13,7 @@ import argparse from datetime import datetime from torch.utils.tensorboard import SummaryWriter import numpy as np +import time # Configure logging logger = logging.getLogger('NN') @@ -100,7 +101,7 @@ def main(): # Verify data interface by fetching initial data logger.info("Verifying data interface...") X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True) - if X_sample is None or y_sample is None: + if X_sample is None or y_sample is not None: logger.error("Failed to prepare initial training data") return @@ -369,12 +370,12 @@ def predict(data_interface, model, args): except Exception as e: logger.error(f"Error in prediction mode: {str(e)}") -def realtime(data_interface, model, args): - """Run the model in real-time mode""" - logger.info("Starting real-time mode...") +def realtime(data_interface, model, args, chart=None, symbol=None): + """Run real-time inference with the trained model""" + logger.info(f"Starting real-time inference mode for {symbol}...") try: - from NN.utils.realtime_analyzer import RealtimeAnalyzer + from NN.utils.realtime_analyzer import RealtimeAnalyzer # Load the latest model model_dir = os.path.join('models') @@ -403,8 +404,104 @@ def realtime(data_interface, model, args): logger.info("Starting real-time analysis...") realtime_analyzer.start() + + + # Initialize variables for tracking performance + total_pnl = 0.0 + trades = [] + current_position = 0.0 + last_action = None + last_price = None + + # Get the pair index for this symbol + pair_index = args.symbols.index(symbol) + + # Only execute trades if this is the main pair (BTC/USDT) + is_main_pair = symbol == "BTC/USDT" + + while True: + # Get current market data for all pairs + all_pairs_data = [] + for s in args.symbols: + X, timestamp = data_interface.prepare_realtime_input( + timeframe=args.timeframes[0], # Use shortest timeframe + n_candles=args.window_size + 10, # Extra candles for safety + window_size=args.window_size + ) + if X is not None: + all_pairs_data.append(X) + else: + logger.warning(f"No data available for {s}") + time.sleep(1) + continue + + if not all_pairs_data: + logger.warning("No data available for any pair") + time.sleep(1) + continue + + # Stack data from all pairs for model input + X_combined = np.concatenate(all_pairs_data, axis=2) + + # Get model predictions + action_probs, price_pred = model.predict(X_combined) + + # Get predictions for this specific pair + action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY + + # Get current price for the main pair + current_price = data_interface.get_historical_data( + timeframe=args.timeframes[0], + n_candles=1 + )['close'].iloc[-1] + + # Calculate PnL if we have a position (only for main pair) + pnl = 0.0 + if is_main_pair and last_action is not None and last_price is not None: + if last_action == 2: # BUY + pnl = (current_price - last_price) / last_price + elif last_action == 0: # SELL + pnl = (last_price - current_price) / last_price + + # Update total PnL (only for main pair) + if is_main_pair and pnl != 0: + total_pnl += pnl + + # Log the prediction + action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY" + log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, " + if is_main_pair: + log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}" + else: + log_msg += f"Price: {current_price:.2f} (Context Only)" + logger.info(log_msg) + + # Update the chart if provided (only for main pair) + if chart is not None and is_main_pair and action != 1: # Skip HOLD actions + chart.add_trade( + action=action_name, + price=current_price, + timestamp=timestamp, + pnl=pnl + ) + + # Update tracking variables (only for main pair) + if is_main_pair and action != 1: # If not HOLD + last_action = action + last_price = current_price + + # Sleep for a short time + time.sleep(1) + + except KeyboardInterrupt: + if is_main_pair: + logger.info(f"Real-time inference stopped by user for {symbol}") + logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}") + else: + logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)") except Exception as e: - logger.error(f"Error in real-time mode: {str(e)}") + logger.error(f"Error in real-time inference for {symbol}: {str(e)}") + raise if __name__ == "__main__": main() diff --git a/NN/trading_agent.py b/NN/trading_agent.py new file mode 100644 index 0000000..a89f994 --- /dev/null +++ b/NN/trading_agent.py @@ -0,0 +1,310 @@ +import logging +import time +import threading +from typing import Dict, Any, List, Optional, Callable, Tuple, Union + +from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface + +logger = logging.getLogger(__name__) + +class TradingAgent: + """Trading agent that executes trades based on neural network signals. + + This agent interfaces with different exchanges and executes trades + based on the signals received from the neural network. + """ + + def __init__(self, + exchange_name: str = 'binance', + api_key: str = None, + api_secret: str = None, + test_mode: bool = True, + trade_symbols: List[str] = None, + position_size: float = 0.1, + max_trades_per_day: int = 5, + trade_cooldown_minutes: int = 60): + """Initialize the trading agent. + + Args: + exchange_name: Name of the exchange to use ('binance', 'mexc') + api_key: API key for the exchange + api_secret: API secret for the exchange + test_mode: If True, use test/sandbox environment + trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT']) + position_size: Size of each position as a fraction of total available balance (0.0-1.0) + max_trades_per_day: Maximum number of trades to execute per day + trade_cooldown_minutes: Minimum time between trades in minutes + """ + self.exchange_name = exchange_name.lower() + self.api_key = api_key + self.api_secret = api_secret + self.test_mode = test_mode + self.trade_symbols = trade_symbols or ['BTC/USDT'] + self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0 + self.max_trades_per_day = max(1, max_trades_per_day) + self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60) + + # Initialize exchange interface + self.exchange = self._create_exchange() + + # Trading state + self.active = False + self.current_positions = {} # Symbol -> quantity + self.trades_today = {} # Symbol -> count + self.last_trade_time = {} # Symbol -> timestamp + self.trade_history = [] # List of trade records + + # Threading + self.trading_thread = None + self.stop_event = threading.Event() + + # Signal callback + self.signal_callback = None + + # Connect to exchange + if not self.exchange.connect(): + logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.") + else: + logger.info(f"Successfully connected to {self.exchange_name} exchange.") + self._load_current_positions() + + def _create_exchange(self) -> ExchangeInterface: + """Create an exchange interface based on the exchange name.""" + if self.exchange_name == 'mexc': + return MEXCInterface( + api_key=self.api_key, + api_secret=self.api_secret, + test_mode=self.test_mode + ) + elif self.exchange_name == 'binance': + return BinanceInterface( + api_key=self.api_key, + api_secret=self.api_secret, + test_mode=self.test_mode + ) + else: + raise ValueError(f"Unsupported exchange: {self.exchange_name}") + + def _load_current_positions(self): + """Load current positions from the exchange.""" + for symbol in self.trade_symbols: + try: + base_asset, quote_asset = symbol.split('/') + balance = self.exchange.get_balance(base_asset) + + if balance > 0: + self.current_positions[symbol] = balance + logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}") + except Exception as e: + logger.error(f"Error loading position for {symbol}: {str(e)}") + + def start(self, signal_callback: Callable = None): + """Start the trading agent. + + Args: + signal_callback: Optional callback function to receive trade signals + """ + if self.active: + logger.warning("Trading agent is already running.") + return + + self.active = True + self.signal_callback = signal_callback + self.stop_event.clear() + + logger.info(f"Starting trading agent for {self.exchange_name} exchange.") + logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}") + logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance") + logger.info(f"Max trades per day: {self.max_trades_per_day}") + logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes") + + # Reset trading state + self.trades_today = {symbol: 0 for symbol in self.trade_symbols} + self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols} + + # Start trading thread + self.trading_thread = threading.Thread(target=self._trading_loop) + self.trading_thread.daemon = True + self.trading_thread.start() + + def stop(self): + """Stop the trading agent.""" + if not self.active: + logger.warning("Trading agent is not running.") + return + + logger.info("Stopping trading agent...") + self.active = False + self.stop_event.set() + + if self.trading_thread and self.trading_thread.is_alive(): + self.trading_thread.join(timeout=10) + + logger.info("Trading agent stopped.") + + def _trading_loop(self): + """Main trading loop that monitors positions and executes trades.""" + logger.info("Trading loop started.") + + try: + while self.active and not self.stop_event.is_set(): + # Check positions and update state + for symbol in self.trade_symbols: + try: + base_asset, _ = symbol.split('/') + current_balance = self.exchange.get_balance(base_asset) + + # Update position if it has changed + if symbol in self.current_positions: + prev_balance = self.current_positions[symbol] + if abs(current_balance - prev_balance) > 0.001 * prev_balance: + logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}") + + self.current_positions[symbol] = current_balance + except Exception as e: + logger.error(f"Error checking position for {symbol}: {str(e)}") + + # Sleep for a while + time.sleep(10) + except Exception as e: + logger.error(f"Error in trading loop: {str(e)}") + finally: + logger.info("Trading loop stopped.") + + def reset_daily_limits(self): + """Reset daily trading limits. Call this at the start of each trading day.""" + self.trades_today = {symbol: 0 for symbol in self.trade_symbols} + logger.info("Daily trading limits reset.") + + def process_signal(self, symbol: str, action: str, + confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]: + """Process a trading signal and execute a trade if conditions are met. + + Args: + symbol: Trading symbol (e.g., 'BTC/USDT') + action: Trade action ('BUY', 'SELL', 'HOLD') + confidence: Confidence level of the signal (0.0-1.0) + timestamp: Timestamp of the signal (unix time) + + Returns: + dict: Trade information if a trade was executed, None otherwise + """ + if not self.active: + logger.warning("Trading agent is not active. Signal ignored.") + return None + + if symbol not in self.trade_symbols: + logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.") + return None + + if action not in ['BUY', 'SELL', 'HOLD']: + logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.") + return None + + # Log the signal + confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else "" + logger.info(f"Received {action} signal for {symbol}{confidence_str}") + + # Ignore HOLD signals for trading + if action == 'HOLD': + return None + + # Check if we can trade based on limits + current_time = time.time() + + # Check max trades per day + if self.trades_today.get(symbol, 0) >= self.max_trades_per_day: + logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.") + return None + + # Check trade cooldown + last_trade_time = self.last_trade_time.get(symbol, 0) + if current_time - last_trade_time < self.trade_cooldown_seconds: + cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time) + logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.") + return None + + # Check if the action makes sense based on current position + base_asset, _ = symbol.split('/') + current_position = self.current_positions.get(symbol, 0) + + if action == 'BUY' and current_position > 0: + logger.warning(f"Already have a position in {symbol}. BUY signal ignored.") + return None + + if action == 'SELL' and current_position <= 0: + logger.warning(f"No position in {symbol} to sell. SELL signal ignored.") + return None + + # Execute the trade + try: + trade_result = self.exchange.execute_trade( + symbol=symbol, + action=action, + percent_of_balance=self.position_size + ) + + if trade_result: + # Update trading state + self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1 + self.last_trade_time[symbol] = current_time + + # Create trade record + trade_record = { + 'symbol': symbol, + 'action': action, + 'timestamp': timestamp or int(current_time), + 'confidence': confidence, + 'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None, + 'status': 'executed' + } + + # Add to trade history + self.trade_history.append(trade_record) + + # Call signal callback if provided + if self.signal_callback: + self.signal_callback(trade_record) + + logger.info(f"Successfully executed {action} trade for {symbol}") + return trade_record + else: + logger.error(f"Failed to execute {action} trade for {symbol}") + return None + + except Exception as e: + logger.error(f"Error executing trade for {symbol}: {str(e)}") + return None + + def get_current_positions(self) -> Dict[str, float]: + """Get current positions. + + Returns: + dict: Symbol -> position size + """ + return self.current_positions.copy() + + def get_trade_history(self) -> List[Dict[str, Any]]: + """Get trade history. + + Returns: + list: List of trade records + """ + return self.trade_history.copy() + + def get_exchange_info(self) -> Dict[str, Any]: + """Get exchange information. + + Returns: + dict: Exchange information + """ + return { + 'name': self.exchange_name, + 'test_mode': self.test_mode, + 'active': self.active, + 'trade_symbols': self.trade_symbols, + 'position_size': self.position_size, + 'max_trades_per_day': self.max_trades_per_day, + 'trade_cooldown_seconds': self.trade_cooldown_seconds, + 'trades_today': self.trades_today.copy() + } \ No newline at end of file diff --git a/NN/train_rl.py b/NN/train_rl.py index c052e3e..1c1cf2b 100644 --- a/NN/train_rl.py +++ b/NN/train_rl.py @@ -247,9 +247,27 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo ) # Get training data for each timeframe with more data + logger.info("Loading training data...") features_1m = data_interface.get_training_data("1m", n_candles=5000) + if features_1m is not None: + logger.info(f"Loaded {len(features_1m)} 1m candles") + else: + logger.error("Failed to load 1m data") + return None + features_5m = data_interface.get_training_data("5m", n_candles=2500) + if features_5m is not None: + logger.info(f"Loaded {len(features_5m)} 5m candles") + else: + logger.error("Failed to load 5m data") + return None + features_15m = data_interface.get_training_data("15m", n_candles=2500) + if features_15m is not None: + logger.info(f"Loaded {len(features_15m)} 15m candles") + else: + logger.error("Failed to load 15m data") + return None if features_1m is None or features_5m is None or features_15m is None: logger.error("Failed to load training data") diff --git a/_notes.md b/_notes.md index 9e21d8d..6ea50d7 100644 --- a/_notes.md +++ b/_notes.md @@ -1,5 +1,7 @@ https://github.com/mexcdevelop/mexc-api-sdk/blob/main/README.md#test-new-order +https://mexcdevelop.github.io/apidocs/spot_v3_en/#test-new-order + python mexc_tick_visualizer.py --symbol BTC/USDT --interval 1.0 --candle 60 python main.py --mode live --symbol ETH/USDT --timeframe 1m --use-websocket diff --git a/launch_training.py b/launch_training.py new file mode 100644 index 0000000..961ec1d --- /dev/null +++ b/launch_training.py @@ -0,0 +1,124 @@ +import os +import sys +import subprocess +import time +import logging +from datetime import datetime +import webbrowser +from threading import Thread + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('training_launch.log'), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +def start_tensorboard(port=6007): + """Start TensorBoard on a specified port""" + try: + cmd = f"tensorboard --logdir=runs --port={port}" + process = subprocess.Popen(cmd, shell=True) + logger.info(f"Started TensorBoard on port {port}") + return process + except Exception as e: + logger.error(f"Failed to start TensorBoard: {str(e)}") + return None + +def start_web_chart(): + """Start the web chart server""" + try: + cmd = "python main.py --symbols BTC/USDT ETH/USDT SOL/USDT --timeframes 1m 5m 15m --mode realtime" + process = subprocess.Popen(cmd, shell=True) + logger.info("Started web chart server") + return process + except Exception as e: + logger.error(f"Failed to start web chart server: {str(e)}") + return None + +def start_training(): + """Start the RL training process""" + try: + cmd = "python NN/train_rl.py" + process = subprocess.Popen(cmd, shell=True) + logger.info("Started RL training process") + return process + except Exception as e: + logger.error(f"Failed to start training process: {str(e)}") + return None + +def open_web_interfaces(): + """Open web browsers for TensorBoard and chart after a delay""" + time.sleep(5) # Wait for servers to start + try: + webbrowser.open('http://localhost:6007') # TensorBoard + webbrowser.open('http://localhost:8050') # Web chart + except Exception as e: + logger.error(f"Failed to open web interfaces: {str(e)}") + +def monitor_processes(processes): + """Monitor running processes and log any unexpected terminations""" + while True: + for name, process in processes.items(): + if process and process.poll() is not None: + logger.error(f"{name} process terminated unexpectedly") + return False + time.sleep(1) + +def main(): + """Main function to orchestrate the training environment""" + logger.info("Starting training environment setup...") + + # Start TensorBoard + tensorboard_process = start_tensorboard(port=6007) + if not tensorboard_process: + logger.error("Failed to start TensorBoard") + return + + # Start web chart + web_chart_process = start_web_chart() + if not web_chart_process: + tensorboard_process.terminate() + logger.error("Failed to start web chart") + return + + # Start training + training_process = start_training() + if not training_process: + tensorboard_process.terminate() + web_chart_process.terminate() + logger.error("Failed to start training") + return + + # Open web interfaces in a separate thread + Thread(target=open_web_interfaces).start() + + # Monitor processes + processes = { + 'tensorboard': tensorboard_process, + 'web_chart': web_chart_process, + 'training': training_process + } + + try: + if not monitor_processes(processes): + raise Exception("One or more processes terminated unexpectedly") + except KeyboardInterrupt: + logger.info("Received shutdown signal") + except Exception as e: + logger.error(f"Error in monitoring: {str(e)}") + finally: + # Cleanup + logger.info("Shutting down training environment...") + for name, process in processes.items(): + if process: + process.terminate() + logger.info(f"Terminated {name} process") + logger.info("Training environment shutdown complete") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/main.py b/main.py index d173aa1..143cd48 100644 --- a/main.py +++ b/main.py @@ -302,6 +302,1113 @@ class PricePredictionModel(nn.Module): return total_loss / epochs +import os +import time +import logging +import sys +import argparse +import json + +# Add the NN directory to the Python path +sys.path.append(os.path.abspath("NN")) + +from NN.main import load_model +from NN.neural_network_orchestrator import NeuralNetworkOrchestrator +from NN.realtime_data_interface import RealtimeDataInterface + +# Initialize logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("trading_bot.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +def main(): + """Main function for the trading bot.""" + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration") + parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"], + help='Trading symbols to monitor') + parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"], + help='Timeframes to monitor') + parser.add_argument('--window-size', type=int, default=20, + help='Window size for model input') + parser.add_argument('--output-size', type=int, default=3, + help='Output size of the model (3 for BUY/HOLD/SELL)') + parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"], + help='Type of neural network model') + parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"], + help='Trading mode') + parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"], + help='Exchange to use for trading') + parser.add_argument('--api-key', type=str, default=None, + help='API key for the exchange') + parser.add_argument('--api-secret', type=str, default=None, + help='API secret for the exchange') + parser.add_argument('--test-mode', action='store_true', + help='Use test/sandbox exchange environment') + parser.add_argument('--position-size', type=float, default=0.1, + help='Position size as a fraction of total balance (0.0-1.0)') + parser.add_argument('--max-trades-per-day', type=int, default=5, + help='Maximum number of trades per day') + parser.add_argument('--trade-cooldown', type=int, default=60, + help='Trade cooldown period in minutes') + parser.add_argument('--config-file', type=str, default=None, + help='Path to configuration file') + + args = parser.parse_args() + + # Load configuration from file if provided + if args.config_file and os.path.exists(args.config_file): + with open(args.config_file, 'r') as f: + config = json.load(f) + # Override config with command-line args + for key, value in vars(args).items(): + if key != 'config_file' and value is not None: + config[key] = value + else: + # Use command-line args as config + config = vars(args) + + # Initialize real-time charts and data interfaces + try: + from realtime import RealTimeChart + + # Create a real-time chart for each symbol + charts = {} + for symbol in config['symbols']: + charts[symbol] = RealTimeChart(symbol=symbol) + + main_chart = charts[config['symbols'][0]] + + # Create a data interface for retrieving market data + data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart) + + # Load trained model + model_type = os.environ.get("NN_MODEL_TYPE", config['model_type']) + model = load_model( + model_type=model_type, + input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV) + output_size=config['output_size'] + ) + + # Configure trading agent + exchange_config = { + "exchange": config['exchange'], + "api_key": config['api_key'], + "api_secret": config['api_secret'], + "test_mode": config['test_mode'], + "trade_symbols": config['symbols'], + "position_size": config['position_size'], + "max_trades_per_day": config['max_trades_per_day'], + "trade_cooldown_minutes": config['trade_cooldown'] + } + + # Initialize neural network orchestrator + orchestrator = NeuralNetworkOrchestrator( + model=model, + data_interface=data_interface, + chart=main_chart, + symbols=config['symbols'], + timeframes=config['timeframes'], + window_size=config['window_size'], + num_features=5, # OHLCV + output_size=config['output_size'], + exchange_config=exchange_config + ) + + # Start data collection + logger.info("Starting data collection threads...") + for symbol in config['symbols']: + charts[symbol].start() + + # Start neural network inference + if os.environ.get("ENABLE_NN_MODELS", "0") == "1": + logger.info("Starting neural network inference...") + orchestrator.start_inference() + else: + logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.") + + # Start web servers for chart display + logger.info("Starting web servers for chart display...") + main_chart.start_server() + + logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.") + + # Keep the main thread alive + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received. Shutting down...") + # Stop all threads + for symbol in config['symbols']: + charts[symbol].stop() + orchestrator.stop_inference() + logger.info("Trading bot stopped.") + + except Exception as e: + logger.error(f"Error in main function: {str(e)}", exc_info=True) + sys.exit(1) + +if __name__ == "__main__": + main() + + def get_state(self): + """Create state representation for the agent with enhanced features""" + # Ensure we have enough data + if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0: + # Return zeros if not enough data + return np.zeros(STATE_SIZE) + + # Create a normalized state vector with recent price action and indicators + state_components = [] + + # Safely get the latest price + try: + latest_price = self.features['price'][-1] + except IndexError: + # If we can't get the latest price, return zeros + return np.zeros(STATE_SIZE) + + # Safely get price features + try: + # Price features (normalize recent prices by the latest price) + price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0 + state_components.append(price_features) + except (IndexError, ZeroDivisionError): + # If we can't get price features, use zeros + state_components.append(np.zeros(10)) + + # Safely get volume features + try: + # Volume features (normalize by max volume) + max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1 + vol_features = np.array(self.features['volume'][-5:]) / max_vol + state_components.append(vol_features) + except (IndexError, ZeroDivisionError): + # If we can't get volume features, use zeros + state_components.append(np.zeros(5)) + + # Technical indicators + rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1 + state_components.append(rsi) + + # MACD (normalize) + macd_vals = np.array(self.features['macd'][-3:]) + macd_signal = np.array(self.features['macd_signal'][-3:]) + macd_hist = np.array(self.features['macd_hist'][-3:]) + macd_scale = max(abs(np.max(macd_vals)), abs(np.min(macd_vals)), 1e-5) + macd_norm = macd_vals / macd_scale + macd_signal_norm = macd_signal / macd_scale + macd_hist_norm = macd_hist / macd_scale + + state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm]) + + # Bollinger position (where is price relative to bands) + bb_upper = np.array(self.features['bollinger_upper'][-3:]) + bb_lower = np.array(self.features['bollinger_lower'][-3:]) + bb_mid = np.array(self.features['bollinger_mid'][-3:]) + price = np.array(self.features['price'][-3:]) + + # Calculate position of price within Bollinger Bands (0 to 1) + bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] + state_components.append(np.array(bb_pos)) + + # Stochastic oscillator + state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) + state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) + + # Add predicted prices (if available) + if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: + # Normalize predictions relative to current price + pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0 + state_components.append(pred_norm) + else: + # Add zeros if no predictions + state_components.append(np.zeros(3)) + + # Add extrema signals (if available) + if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: + # Get recent signals + idx = len(self.optimal_signals) - 5 + if idx < 0: + idx = 0 + recent_signals = self.optimal_signals[idx:idx+5] + # Pad if needed + if len(recent_signals) < 5: + recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') + state_components.append(recent_signals) + else: + # Add zeros if no signals + state_components.append(np.zeros(5)) + + # Position info + position_info = np.zeros(5) + if self.position == 'long': + position_info[0] = 1.0 # Position is long + position_info[1] = (latest_price - self.entry_price) / self.entry_price # Unrealized PnL % + position_info[2] = (self.stop_loss - self.entry_price) / self.entry_price # Stop loss % + position_info[3] = (self.take_profit - self.entry_price) / self.entry_price # Take profit % + position_info[4] = self.position_size / self.balance # Position size relative to balance + elif self.position == 'short': + position_info[0] = -1.0 # Position is short + position_info[1] = (self.entry_price - latest_price) / self.entry_price # Unrealized PnL % + position_info[2] = (self.entry_price - self.stop_loss) / self.entry_price # Stop loss % + position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit % + position_info[4] = self.position_size / self.balance # Position size relative to balance + + state_components.append(position_info) + + # NEW FEATURES START HERE + + # 1. Price momentum features (rate of change over different periods) + if len(self.features['price']) >= 20: + roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 + roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 + roc_20 = (latest_price / self.features['price'][-20] - 1.0) if self.features['price'][-20] != 0 else 0 + momentum_features = np.array([roc_5, roc_10, roc_20]) + state_components.append(momentum_features) + else: + state_components.append(np.zeros(3)) + + # 2. Volatility features + if len(self.features['price']) >= 20: + # Calculate price returns + returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] + # Calculate volatility (standard deviation of returns) + volatility = np.std(returns) + # Calculate normalized high-low range + high_low_range = np.mean([ + (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] + for i in range(max(0, len(self.data)-5), len(self.data)) + ]) if len(self.data) > 0 else 0 + # ATR normalized by price + atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 + + volatility_features = np.array([volatility, high_low_range, atr_norm]) + state_components.append(volatility_features) + else: + state_components.append(np.zeros(3)) + + # 3. Market regime features + if len(self.features['price']) >= 50: + # Trend strength (ADX-like measure) + ema9 = self.features['ema_9'][-1] if len(self.features['ema_9']) > 0 else latest_price + ema21 = self.features['ema_21'][-1] if len(self.features['ema_21']) > 0 else latest_price + trend_strength = abs(ema9 - ema21) / ema21 + + # Detect if in range or trending + is_range_bound = 1.0 if self.is_uncertain_market() else 0.0 + is_trending = 1.0 if (self.is_uptrend() or self.is_downtrend()) else 0.0 + + # Detect if near support/resistance + near_support = 1.0 if self.is_near_support() else 0.0 + near_resistance = 1.0 if self.is_near_resistance() else 0.0 + + market_regime = np.array([trend_strength, is_range_bound, is_trending, near_support, near_resistance]) + state_components.append(market_regime) + else: + state_components.append(np.zeros(5)) + + # 4. Trade history features + if len(self.trades) > 0: + # Recent win/loss ratio + recent_trades = self.trades[-min(10, len(self.trades)):] + win_ratio = sum(1 for t in recent_trades if t.get('pnl_dollar', 0) > 0) / len(recent_trades) + + # Average profit/loss + avg_profit = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) > 0]) if any(t.get('pnl_dollar', 0) > 0 for t in recent_trades) else 0 + avg_loss = np.mean([t.get('pnl_dollar', 0) for t in recent_trades if t.get('pnl_dollar', 0) <= 0]) if any(t.get('pnl_dollar', 0) <= 0 for t in recent_trades) else 0 + + # Normalize by balance + avg_profit_norm = avg_profit / self.balance if self.balance > 0 else 0 + avg_loss_norm = avg_loss / self.balance if self.balance > 0 else 0 + + # Last trade result + last_trade_pnl = self.trades[-1].get('pnl_dollar', 0) / self.balance if self.balance > 0 else 0 + + trade_history = np.array([win_ratio, avg_profit_norm, avg_loss_norm, last_trade_pnl]) + state_components.append(trade_history) + else: + state_components.append(np.zeros(4)) + + # Combine all features + state = np.concatenate([comp.flatten() for comp in state_components]) + + # Replace any NaN or infinite values + state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) + + # Ensure the state has the correct size + if len(state) != STATE_SIZE: + logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}") + # Pad or truncate to match expected size + if len(state) < STATE_SIZE: + state = np.pad(state, (0, STATE_SIZE - len(state))) + else: + state = state[:STATE_SIZE] + + return state + + def get_expanded_state_size(self): + """Calculate the size of the expanded state representation""" + # Create a dummy state to get its size + state = self.get_state() + return len(state) + + async def expand_model_with_new_features(agent, env): + """Expand the model to handle new features without retraining from scratch""" + # Get the new state size + new_state_size = env.get_expanded_state_size() + + # Only expand if the new state size is larger + if new_state_size > agent.state_size: + logger.info(f"Expanding model to handle {new_state_size} features (was {agent.state_size})") + + # Expand the model + success = agent.expand_model( + new_state_size=new_state_size, + new_hidden_size=512, # Increase hidden size for more capacity + new_lstm_layers=3, # More layers for deeper patterns + new_attention_heads=8 # More attention heads for complex relationships + ) + + if success: + logger.info(f"Model successfully expanded to handle {new_state_size} features") + return True + else: + logger.error("Failed to expand model") + return False + else: + logger.info(f"No need to expand model, current size ({agent.state_size}) is sufficient") + return True + + + def calculate_reward(self, action): + """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" + reward = 0 + + # Base reward for actions + if action == 0: # HOLD + reward = -0.05 # Increased penalty for doing nothing to encourage more trading + + elif action == 1: # BUY/LONG + if self.position == 'flat': + # Opening a long position + self.position = 'long' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + # Use the adjusted risk parameters + self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) + self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) + + # Check if this is an optimal buy point (bottom) + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: + reward += 3.0 # Increased bonus for buying at a bottom + + # Check for volume spike (indicating potential big movement) + if len(self.features['volume']) > 5: + avg_volume = np.mean(self.features['volume'][-5:-1]) + current_volume = self.features['volume'][-1] + if current_volume > avg_volume * 1.5: + reward += 2.0 # Bonus for entering during high volume + + # Check for price action signals + if self.features['rsi'][-1] < 30: # Oversold condition + reward += 1.5 # Bonus for buying at oversold levels + + # Check if we're buying in a clear uptrend (good) + if self.is_uptrend(): + reward += 1.0 # Bonus for buying in uptrend + elif self.is_downtrend(): + reward -= 0.25 # Reduced penalty for buying in downtrend + else: + reward += 0.2 # Small reward for opening a position + + logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif self.position == 'short': + # Close short and open long + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Record trade + trade_duration = len(self.features['price']) - self.entry_index + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': trade_duration, + 'market_direction': self.get_market_direction() + }) + + # Reward based on PnL with stronger penalties for losses + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + # Stronger penalty for losses, scaled by the size of the loss + loss_penalty = 1.0 + abs(pnl_dollar) / 5 + reward -= loss_penalty + self.loss_count += 1 + + # Extra penalty for closing a losing trade too quickly + if trade_duration < 5: + reward -= 0.5 # Penalty for very short losing trades + + logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Now open long + self.position = 'long' + self.entry_price = self.current_price + self.entry_index = len(self.features['price']) - 1 + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) + self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) + + # Check if this is an optimal buy point + if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: + reward += 2.0 # Bonus for buying at a bottom + + logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif action == 2: # SELL/SHORT + if self.position == 'flat': + # Opening a short position + self.position = 'short' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + # Use the adjusted risk parameters + self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) + self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) + + # Check if this is an optimal sell point (top) + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: + reward += 3.0 # Increased bonus for selling at a top + + # Check for volume spike + if len(self.features['volume']) > 5: + avg_volume = np.mean(self.features['volume'][-5:-1]) + current_volume = self.features['volume'][-1] + if current_volume > avg_volume * 1.5: + reward += 2.0 # Bonus for entering during high volume + + # Check for price action signals + if self.features['rsi'][-1] > 70: # Overbought condition + reward += 1.5 # Bonus for selling at overbought levels + + # Check if we're selling in a clear downtrend (good) + if self.is_downtrend(): + reward += 1.0 # Bonus for selling in downtrend + elif self.is_uptrend(): + reward -= 0.25 # Reduced penalty for selling in uptrend + else: + reward += 0.2 # Small reward for opening a position + + logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif self.position == 'long': + # Close long and open short + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Now open short + self.position = 'short' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 + self.stop_loss_pct/100) + self.take_profit = self.entry_price * (1 - self.take_profit_pct/100) + + # Check if this is an optimal sell point + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: + reward += 2.0 # Bonus for selling at a top + + logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif action == 3: # CLOSE + if self.position == 'long': + # Close long position + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + elif self.position == 'short': + # Close short position + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) + + # Record trade + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + # Add prediction accuracy component to reward + if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: + # Compare the first prediction with actual price + if len(self.data) > 1: + actual_price = self.data[-1]['close'] + predicted_price = self.predicted_prices[0] + prediction_error = abs(predicted_price - actual_price) / actual_price + + # Reward accurate predictions, penalize bad ones + if prediction_error < 0.005: # Less than 0.5% error + reward += 0.5 + elif prediction_error > 0.02: # More than 2% error + reward -= 0.5 + + return reward + + def is_downtrend(self): + """Check if the market is in a downtrend""" + if len(self.features['price']) < 20: + return False + + # Use EMA to determine trend + short_ema = self.features['ema_9'][-1] + long_ema = self.features['ema_21'][-1] + + # Downtrend if short EMA is below long EMA + return short_ema < long_ema + + def is_uptrend(self): + """Check if the market is in an uptrend""" + if len(self.features['price']) < 20: + return False + + # Use EMA to determine trend + short_ema = self.features['ema_9'][-1] + long_ema = self.features['ema_21'][-1] + + # Uptrend if short EMA is above long EMA + return short_ema > long_ema + + def get_market_direction(self): + """Get the current market direction""" + if self.is_uptrend(): + return "uptrend" + elif self.is_downtrend(): + return "downtrend" + else: + return "sideways" + + def analyze_trades(self): + """Analyze completed trades to identify patterns""" + if not self.trades: + return {} + + analysis = { + 'total_trades': len(self.trades), + 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), + 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), + 'avg_win': 0, + 'avg_loss': 0, + 'avg_duration': 0, + 'uptrend_win_rate': 0, + 'downtrend_win_rate': 0, + 'sideways_win_rate': 0 + } + + # Calculate averages + wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] + losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] + durations = [t.get('duration', 0) for t in self.trades] + + analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 + analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 + analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 + + # Calculate win rates by market direction + for direction in ['uptrend', 'downtrend', 'sideways']: + direction_trades = [t for t in self.trades if t.get('market_direction') == direction] + if direction_trades: + wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) + analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 + + return analysis + + def initialize_price_predictor(self, device="cpu"): + """Initialize the price prediction model""" + self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) + self.price_predictor.to(device) + self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) + self.predicted_prices = np.array([]) + + def train_price_predictor(self): + """Train the price prediction model on recent data""" + if len(self.features['price']) < 35: + return 0.0 + + # Get price history + price_history = self.features['price'] + + # Train the model + loss = self.price_predictor.train_on_new_data( + price_history, + self.price_predictor_optimizer, + epochs=5 + ) + + return loss + + def update_price_predictions(self): + """Update price predictions""" + if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: + self.predicted_prices = np.array([]) + return + + # Get price history + price_history = self.features['price'] + + try: + # Get predictions + self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) + except Exception as e: + logger.warning(f"Error updating predictions: {e}") + self.predicted_prices = np.array([]) + + def identify_optimal_trades(self): + """Identify optimal entry and exit points based on local extrema""" + if len(self.features['price']) < 20: + return + + # Find local bottoms and tops + bottoms, tops = find_local_extrema(self.features['price'], window=5) + + # Store optimal trade points + self.optimal_bottoms = bottoms # Buy points + self.optimal_tops = tops # Sell points + + # Create optimal trade signals + self.optimal_signals = np.zeros(len(self.features['price'])) + for i in bottoms: + if 0 <= i < len(self.optimal_signals): # Ensure index is valid + self.optimal_signals[i] = 1 # Buy signal + for i in tops: + if 0 <= i < len(self.optimal_signals): # Ensure index is valid + self.optimal_signals[i] = -1 # Sell signal + + logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") + +import os +import time +import json +import numpy as np +import pandas as pd +from datetime import datetime +import random +import logging +import asyncio +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from collections import deque, namedtuple +from dotenv import load_dotenv +import ccxt +import websockets +from torch.utils.tensorboard import SummaryWriter +import torch.cuda.amp as amp # Add this import at the top +from sklearn.preprocessing import MinMaxScaler +import copy +import argparse +import traceback +import io +import matplotlib.dates as mdates +from matplotlib.figure import Figure +from PIL import Image +import matplotlib.pyplot as mpf +import matplotlib.gridspec as gridspec +import datetime +from realtime import BinanceWebSocket, BinanceHistoricalData +from datetime import datetime as dt +# Add Dash-related imports +import dash +from dash import html, dcc, callback_context +from dash.dependencies import Input, Output, State +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from threading import Thread +import socket + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.FileHandler("trading_bot.log"), logging.StreamHandler()] +) +logger = logging.getLogger("trading_bot") + +# Look for WebSocket specific logger +websocket_logger = logging.getLogger('websocket') # or similar name +websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO + +# Add this somewhere after the logger is defined +class WebSocketFilter(logging.Filter): + def filter(self, record): + # Filter out DEBUG messages from WebSocket-related modules + if record.levelno == logging.DEBUG and ('websocket' in record.name or + 'protocol' in record.name or + 'realtime' in record.name): + return False + return True + +logger.addFilter(WebSocketFilter()) + +# Load environment variables +load_dotenv() +MEXC_API_KEY = os.getenv('MEXC_API_KEY') +MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY') + +# Constants +INITIAL_BALANCE = 100 # USD +MAX_LEVERAGE = 100 +STOP_LOSS_PERCENT = 0.5 # Very tight stop loss (0.5%) due to high leverage +TAKE_PROFIT_PERCENT = 1.5 # Take profit at 1.5% +MEMORY_SIZE = 100000 +BATCH_SIZE = 64 +GAMMA = 0.99 # Discount factor +EPSILON_START = 1.0 +EPSILON_END = 0.05 +EPSILON_DECAY = 10000 +STATE_SIZE = 64 # Size of our state representation +LEARNING_RATE = 1e-4 +TARGET_UPDATE = 10 # Update target network every 10 episodes + +# Experience replay tuple +Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done']) + +# Add this function near the top of the file, after the imports but before any classes +def find_local_extrema(prices, window=5): + """Find local minima (bottoms) and maxima (tops) in price data""" + bottoms = [] + tops = [] + + if len(prices) < window * 2 + 1: + return bottoms, tops + + for i in range(window, len(prices) - window): + # Check if this is a local minimum (bottom) + if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ + all(prices[i] <= prices[i+j] for j in range(1, window+1)): + bottoms.append(i) + + # Check if this is a local maximum (top) + if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ + all(prices[i] >= prices[i+j] for j in range(1, window+1)): + tops.append(i) + + return bottoms, tops + +class ReplayMemory: + def __init__(self, capacity): + self.memory = deque(maxlen=capacity) + + def push(self, state, action, reward, next_state, done): + self.memory.append(Experience(state, action, reward, next_state, done)) + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + +class DQN(nn.Module): + def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): + super(DQN, self).__init__() + + self.state_size = state_size + self.hidden_size = hidden_size + self.lstm_layers = lstm_layers + + # Initial feature extraction + self.fc1 = nn.Linear(state_size, hidden_size) + # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes + self.ln1 = nn.LayerNorm(hidden_size) + self.dropout1 = nn.Dropout(0.2) + + # LSTM layer for sequential data + self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) + + # Attention mechanism + self.attention = nn.MultiheadAttention(hidden_size, attention_heads) + + # Output layers with increased capacity + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm + self.dropout2 = nn.Dropout(0.2) + self.fc3 = nn.Linear(hidden_size, hidden_size // 2) + + # Dueling DQN architecture + self.value_stream = nn.Linear(hidden_size // 2, 1) + self.advantage_stream = nn.Linear(hidden_size // 2, action_size) + + # Transformer encoder for more complex pattern recognition + encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) + + def forward(self, x): + batch_size = x.size(0) if x.dim() > 1 else 1 + + # Ensure input has correct shape + if x.dim() == 1: + x = x.unsqueeze(0) # Add batch dimension + + # Check if state size matches expected input size + if x.size(1) != self.state_size: + # Handle mismatched input by either truncating or padding + if x.size(1) > self.state_size: + x = x[:, :self.state_size] # Truncate + else: + # Pad with zeros + padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) + x = torch.cat([x, padding], dim=1) + + # Initial feature extraction + x = self.fc1(x) + x = F.relu(self.ln1(x)) # LayerNorm works with any batch size + x = self.dropout1(x) + + # Reshape for LSTM + x_lstm = x.unsqueeze(1) if x.dim() == 2 else x + + # Process through LSTM + lstm_out, _ = self.lstm(x_lstm) + lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] + + # Process through transformer for more complex patterns + transformer_input = x.unsqueeze(1) if x.dim() == 2 else x + transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) + transformer_out = transformer_out.transpose(0, 1).mean(dim=1) + + # Combine LSTM and transformer outputs + x = lstm_out + transformer_out + + # Final layers + x = self.fc2(x) + x = F.relu(self.ln2(x)) # LayerNorm works with any batch size + x = self.dropout2(x) + x = F.relu(self.fc3(x)) + + # Dueling architecture + value = self.value_stream(x) + advantages = self.advantage_stream(x) + qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) + + return qvals + +class PricePredictionModel(nn.Module): + def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): + super(PricePredictionModel, self).__init__() + self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2) + self.fc = nn.Linear(hidden_size, output_size) + self.scaler = MinMaxScaler(feature_range=(0, 1)) + self.is_fitted = False + + def forward(self, x): + # x shape: [batch_size, seq_len, 1] + lstm_out, _ = self.lstm(x) + # Use the last time step output + predictions = self.fc(lstm_out[:, -1, :]) + return predictions + + def preprocess(self, data): + # Reshape data for scaler + data_reshaped = np.array(data).reshape(-1, 1) + + # Fit scaler if not already fitted + if not self.is_fitted: + self.scaler.fit(data_reshaped) + self.is_fitted = True + + # Transform data + scaled_data = self.scaler.transform(data_reshaped) + return scaled_data + + def postprocess(self, scaled_predictions): + # Inverse transform to get actual price values + return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten() + + def predict_next_candles(self, price_history, num_candles=5): + if len(price_history) < 30: # Need enough history + return np.zeros(num_candles) + + # Preprocess data + scaled_data = self.preprocess(price_history) + + # Create sequence + sequence = scaled_data[-30:].reshape(1, 30, 1) + sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device) + + # Get predictions + with torch.no_grad(): + scaled_predictions = self(sequence_tensor).cpu().numpy()[0] + + # Postprocess predictions + predictions = self.postprocess(scaled_predictions) + return predictions + + def train_on_new_data(self, price_history, optimizer, epochs=10): + if len(price_history) < 35: # Need enough history for training + return 0.0 + + # Preprocess data + scaled_data = self.preprocess(price_history) + + # Create sequences and targets + sequences = [] + targets = [] + + for i in range(len(scaled_data) - 35): + # Sequence: 30 time steps + seq = scaled_data[i:i+30] + # Target: next 5 time steps + target = scaled_data[i+30:i+35].flatten() + + sequences.append(seq) + targets.append(target) + + if not sequences: # If no sequences were created + return 0.0 + + # Convert to tensors + sequences_tensor = torch.FloatTensor(np.array(sequences).reshape(-1, 30, 1)).to(next(self.parameters()).device) + targets_tensor = torch.FloatTensor(np.array(targets)).to(next(self.parameters()).device) + + # Training loop + total_loss = 0 + for _ in range(epochs): + # Forward pass + predictions = self(sequences_tensor) + + # Calculate loss + loss = F.mse_loss(predictions, targets_tensor) + + # Backward pass and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + return total_loss / epochs + class TradingEnvironment: """Trading environment for reinforcement learning with enhanced features""" def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): diff --git a/trading_main.py b/trading_main.py new file mode 100644 index 0000000..f8a44ca --- /dev/null +++ b/trading_main.py @@ -0,0 +1,155 @@ +import os +import time +import logging +import sys +import argparse +import json + +# Add the NN directory to the Python path +sys.path.append(os.path.abspath("NN")) + +from NN.main import load_model +from NN.neural_network_orchestrator import NeuralNetworkOrchestrator +from NN.realtime_data_interface import RealtimeDataInterface + +# Initialize logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("trading_bot.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +def main(): + """Main function for the trading bot.""" + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration") + parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"], + help='Trading symbols to monitor') + parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"], + help='Timeframes to monitor') + parser.add_argument('--window-size', type=int, default=20, + help='Window size for model input') + parser.add_argument('--output-size', type=int, default=3, + help='Output size of the model (3 for BUY/HOLD/SELL)') + parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"], + help='Type of neural network model') + parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"], + help='Trading mode') + parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"], + help='Exchange to use for trading') + parser.add_argument('--api-key', type=str, default=None, + help='API key for the exchange') + parser.add_argument('--api-secret', type=str, default=None, + help='API secret for the exchange') + parser.add_argument('--test-mode', action='store_true', + help='Use test/sandbox exchange environment') + parser.add_argument('--position-size', type=float, default=0.1, + help='Position size as a fraction of total balance (0.0-1.0)') + parser.add_argument('--max-trades-per-day', type=int, default=5, + help='Maximum number of trades per day') + parser.add_argument('--trade-cooldown', type=int, default=60, + help='Trade cooldown period in minutes') + parser.add_argument('--config-file', type=str, default=None, + help='Path to configuration file') + + args = parser.parse_args() + + # Load configuration from file if provided + if args.config_file and os.path.exists(args.config_file): + with open(args.config_file, 'r') as f: + config = json.load(f) + # Override config with command-line args + for key, value in vars(args).items(): + if key != 'config_file' and value is not None: + config[key] = value + else: + # Use command-line args as config + config = vars(args) + + # Initialize real-time charts and data interfaces + try: + from realtime import RealTimeChart + + # Create a real-time chart for each symbol + charts = {} + for symbol in config['symbols']: + charts[symbol] = RealTimeChart(symbol=symbol) + + main_chart = charts[config['symbols'][0]] + + # Create a data interface for retrieving market data + data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart) + + # Load trained model + model_type = os.environ.get("NN_MODEL_TYPE", config['model_type']) + model = load_model( + model_type=model_type, + input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV) + output_size=config['output_size'] + ) + + # Configure trading agent + exchange_config = { + "exchange": config['exchange'], + "api_key": config['api_key'], + "api_secret": config['api_secret'], + "test_mode": config['test_mode'], + "trade_symbols": config['symbols'], + "position_size": config['position_size'], + "max_trades_per_day": config['max_trades_per_day'], + "trade_cooldown_minutes": config['trade_cooldown'] + } + + # Initialize neural network orchestrator + orchestrator = NeuralNetworkOrchestrator( + model=model, + data_interface=data_interface, + chart=main_chart, + symbols=config['symbols'], + timeframes=config['timeframes'], + window_size=config['window_size'], + num_features=5, # OHLCV + output_size=config['output_size'], + exchange_config=exchange_config + ) + + # Start data collection + logger.info("Starting data collection threads...") + for symbol in config['symbols']: + charts[symbol].start() + + # Start neural network inference + if os.environ.get("ENABLE_NN_MODELS", "0") == "1": + logger.info("Starting neural network inference...") + orchestrator.start_inference() + else: + logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.") + + # Start web servers for chart display + logger.info("Starting web servers for chart display...") + main_chart.start_server() + + logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.") + + # Keep the main thread alive + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received. Shutting down...") + # Stop all threads + for symbol in config['symbols']: + charts[symbol].stop() + orchestrator.stop_inference() + logger.info("Trading bot stopped.") + + except Exception as e: + logger.error(f"Error in main function: {str(e)}", exc_info=True) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file