From 02804ee64f14dd0782bd098ca3039abd6835aa2e Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 14 Jul 2025 22:57:02 +0300 Subject: [PATCH] bybit REST api --- NN/exchanges/bybit_interface.py | 197 +++++++++++++++---- NN/exchanges/bybit_rest_client.py | 314 ++++++++++++++++++++++++++++++ NN/models/dqn_agent.py | 2 +- config.yaml | 4 +- 4 files changed, 476 insertions(+), 41 deletions(-) create mode 100644 NN/exchanges/bybit_rest_client.py diff --git a/NN/exchanges/bybit_interface.py b/NN/exchanges/bybit_interface.py index 5f330a8..a106268 100644 --- a/NN/exchanges/bybit_interface.py +++ b/NN/exchanges/bybit_interface.py @@ -1,3 +1,9 @@ +""" +Bybit Interface + +""" + + import logging import time from typing import Dict, Any, List, Optional, Tuple @@ -12,6 +18,7 @@ except ImportError: logging.warning("pybit not installed. Run: pip install pybit") from .exchange_interface import ExchangeInterface +from .bybit_rest_client import BybitRestClient logger = logging.getLogger(__name__) @@ -35,8 +42,15 @@ class BybitInterface(ExchangeInterface): # Bybit-specific settings self.session = None + self.rest_client = None # Raw REST client fallback self.category = "linear" # Default to USDT perpetuals self.supported_symbols = set() + self.use_fallback = False # Track if we should use REST client + + # Caching to reduce API calls and avoid rate limiting + self._open_orders_cache = {} + self._open_orders_cache_time = 0 + self._cache_timeout = 5 # 5 seconds cache timeout # Load credentials from environment if not provided if not api_key: @@ -61,22 +75,43 @@ class BybitInterface(ExchangeInterface): logger.error("API key and secret required for Bybit connection") return False - # Create HTTP session + # Initialize pybit session self.session = HTTP( testnet=self.test_mode, api_key=self.api_key, api_secret=self.api_secret, ) - # Test connection by getting account info - account_info = self.session.get_wallet_balance(accountType="UNIFIED") - if account_info.get('retCode') == 0: - logger.info(f"Successfully connected to Bybit (testnet: {self.test_mode})") - self._load_instruments() - return True - else: - logger.error(f"Failed to connect to Bybit: {account_info}") - return False + # Initialize raw REST client as fallback + self.rest_client = BybitRestClient( + api_key=self.api_key, + api_secret=self.api_secret, + testnet=self.test_mode + ) + + # Test pybit connection first + try: + account_info = self.session.get_wallet_balance(accountType="UNIFIED") + if account_info.get('retCode') == 0: + logger.info(f"Successfully connected to Bybit via pybit (testnet: {self.test_mode})") + self.use_fallback = False + self._load_instruments() + return True + else: + logger.warning(f"pybit connection failed: {account_info}") + raise Exception("pybit connection failed") + except Exception as e: + logger.warning(f"pybit failed, trying REST client fallback: {e}") + + # Test REST client fallback + if self.rest_client.test_connectivity() and self.rest_client.test_authentication(): + logger.info(f"Successfully connected to Bybit via REST client fallback (testnet: {self.test_mode})") + self.use_fallback = True + self._load_instruments() + return True + else: + logger.error("Both pybit and REST client failed") + return False except Exception as e: logger.error(f"Error connecting to Bybit: {e}") @@ -164,6 +199,43 @@ class BybitInterface(ExchangeInterface): logger.error(f"Error getting account summary: {e}") return {} + def get_all_balances(self) -> Dict[str, Dict[str, float]]: + """Get all account balances in the format expected by trading executor. + + Returns: + Dictionary with asset balances in format: {asset: {'free': float, 'locked': float}} + """ + try: + account_info = self.session.get_wallet_balance(accountType="UNIFIED") + if account_info.get('retCode') == 0: + balances = {} + accounts = account_info.get('result', {}).get('list', []) + + for account in accounts: + coins = account.get('coin', []) + for coin in coins: + asset = coin.get('coin', '') + if asset: + # Convert Bybit balance format to MEXC-compatible format + available = float(coin.get('availableToWithdraw', 0)) + locked = float(coin.get('locked', 0)) + + balances[asset] = { + 'free': available, + 'locked': locked, + 'total': available + locked + } + + logger.debug(f"Retrieved balances for {len(balances)} assets") + return balances + else: + logger.error(f"Failed to get all balances: {account_info}") + return {} + + except Exception as e: + logger.error(f"Error getting all balances: {e}") + return {} + def get_ticker(self, symbol: str) -> Dict[str, Any]: """Get ticker information for a symbol. @@ -269,6 +341,29 @@ class BybitInterface(ExchangeInterface): logger.error(f"Error placing order: {e}") return {'error': str(e)} + def _process_pybit_orders(self, orders_list: List[Dict]) -> List[Dict[str, Any]]: + """Process orders from pybit response format.""" + open_orders = [] + for order in orders_list: + order_info = { + 'order_id': order.get('orderId'), + 'symbol': order.get('symbol'), + 'side': order.get('side', '').lower(), + 'type': order.get('orderType', '').lower(), + 'quantity': float(order.get('qty', 0)), + 'filled_quantity': float(order.get('cumExecQty', 0)), + 'price': float(order.get('price', 0)), + 'status': self._map_order_status(order.get('orderStatus', '')), + 'timestamp': int(order.get('createdTime', 0)) + } + open_orders.append(order_info) + return open_orders + + def _process_rest_orders(self, orders_list: List[Dict]) -> List[Dict[str, Any]]: + """Process orders from REST client response format.""" + # REST client returns same format as pybit, so we can reuse the method + return self._process_pybit_orders(orders_list) + def cancel_order(self, symbol: str, order_id: str) -> bool: """Cancel an order. @@ -380,7 +475,7 @@ class BybitInterface(ExchangeInterface): return {} def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]: - """Get open orders. + """Get open orders with caching and fallback to REST client. Args: symbol: Trading symbol (optional, gets all if None) @@ -389,38 +484,64 @@ class BybitInterface(ExchangeInterface): List of open order dictionaries """ try: - params = { - 'category': self.category, - 'openOnly': True - } + import time + current_time = time.time() + cache_key = symbol or 'all' - if symbol: - params['symbol'] = self._format_symbol(symbol) + # Check if we have fresh cached data + if (cache_key in self._open_orders_cache and + current_time - self._open_orders_cache_time < self._cache_timeout): + logger.debug(f"Returning cached open orders for {cache_key}") + return self._open_orders_cache[cache_key] - response = self.session.get_open_orders(**params) - - if response.get('retCode') == 0: - orders = response.get('result', {}).get('list', []) - - open_orders = [] - for order in orders: - order_info = { - 'order_id': order.get('orderId'), - 'symbol': order.get('symbol'), - 'side': order.get('side', '').lower(), - 'type': order.get('orderType', '').lower(), - 'quantity': float(order.get('qty', 0)), - 'filled_quantity': float(order.get('cumExecQty', 0)), - 'price': float(order.get('price', 0)), - 'status': self._map_order_status(order.get('orderStatus', '')), - 'timestamp': int(order.get('createdTime', 0)) + # Try pybit first if not using fallback + if not self.use_fallback and self.session: + try: + params = { + 'category': self.category, + 'openOnly': True } - open_orders.append(order_info) + + if symbol: + params['symbol'] = self._format_symbol(symbol) + + response = self.session.get_open_orders(**params) + + # Process pybit response + if response.get('retCode') == 0: + orders = self._process_pybit_orders(response.get('result', {}).get('list', [])) + # Cache the result + self._open_orders_cache[cache_key] = orders + self._open_orders_cache_time = current_time + logger.debug(f"Found {len(orders)} open orders via pybit, cached for {self._cache_timeout}s") + return orders + else: + logger.warning(f"pybit get_open_orders failed: {response}") + raise Exception("pybit failed") + + except Exception as e: + error_str = str(e) + if "10016" in error_str or "System error" in error_str: + logger.warning(f"pybit rate limited (Error 10016), switching to REST fallback: {e}") + self.use_fallback = True + else: + logger.warning(f"pybit get_open_orders error, trying REST fallback: {e}") + + # Use REST client (either as primary or fallback) + if self.rest_client: + formatted_symbol = self._format_symbol(symbol) if symbol else None + response = self.rest_client.get_open_orders(self.category, formatted_symbol) - logger.debug(f"Found {len(open_orders)} open orders") - return open_orders + orders = self._process_rest_orders(response.get('result', {}).get('list', [])) + + # Cache the result + self._open_orders_cache[cache_key] = orders + self._open_orders_cache_time = current_time + + logger.debug(f"Found {len(orders)} open orders via REST client, cached for {self._cache_timeout}s") + return orders else: - logger.error(f"Failed to get open orders: {response}") + logger.error("No available API client (pybit or REST)") return [] except Exception as e: diff --git a/NN/exchanges/bybit_rest_client.py b/NN/exchanges/bybit_rest_client.py new file mode 100644 index 0000000..bd0315e --- /dev/null +++ b/NN/exchanges/bybit_rest_client.py @@ -0,0 +1,314 @@ +""" +Bybit Raw REST API Client +Implementation using direct HTTP calls with proper authentication +Based on Bybit API v5 documentation and official examples and https://github.com/bybit-exchange/api-connectors/blob/master/encryption_example/Encryption.py +""" + +import hmac +import hashlib +import time +import json +import logging +import requests +from typing import Dict, Any, Optional +from urllib.parse import urlencode + +logger = logging.getLogger(__name__) + + +class BybitRestClient: + """Raw REST API client for Bybit with proper authentication and rate limiting.""" + + def __init__(self, api_key: str, api_secret: str, testnet: bool = False): + """Initialize Bybit REST client. + + Args: + api_key: Bybit API key + api_secret: Bybit API secret + testnet: If True, use testnet endpoints + """ + self.api_key = api_key + self.api_secret = api_secret + self.testnet = testnet + + # API endpoints + if testnet: + self.base_url = "https://api-testnet.bybit.com" + else: + self.base_url = "https://api.bybit.com" + + # Rate limiting + self.last_request_time = 0 + self.min_request_interval = 0.1 # 100ms between requests + + # Request session for connection pooling + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': 'gogo2-trading-bot/1.0', + 'Content-Type': 'application/json' + }) + + logger.info(f"Initialized Bybit REST client (testnet: {testnet})") + + def _generate_signature(self, timestamp: str, params: str) -> str: + """Generate HMAC-SHA256 signature for Bybit API. + + Args: + timestamp: Request timestamp + params: Query parameters or request body + + Returns: + HMAC-SHA256 signature + """ + # Bybit signature format: timestamp + api_key + recv_window + params + recv_window = "5000" # 5 seconds + param_str = f"{timestamp}{self.api_key}{recv_window}{params}" + + signature = hmac.new( + self.api_secret.encode('utf-8'), + param_str.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + + return signature + + def _get_headers(self, timestamp: str, signature: str) -> Dict[str, str]: + """Get request headers with authentication. + + Args: + timestamp: Request timestamp + signature: HMAC signature + + Returns: + Headers dictionary + """ + return { + 'X-BAPI-API-KEY': self.api_key, + 'X-BAPI-SIGN': signature, + 'X-BAPI-TIMESTAMP': timestamp, + 'X-BAPI-RECV-WINDOW': '5000', + 'Content-Type': 'application/json' + } + + def _rate_limit(self): + """Apply rate limiting between requests.""" + current_time = time.time() + time_since_last = current_time - self.last_request_time + + if time_since_last < self.min_request_interval: + sleep_time = self.min_request_interval - time_since_last + time.sleep(sleep_time) + + self.last_request_time = time.time() + + def _make_request(self, method: str, endpoint: str, params: Dict = None, signed: bool = False) -> Dict[str, Any]: + """Make HTTP request to Bybit API. + + Args: + method: HTTP method (GET, POST, etc.) + endpoint: API endpoint path + params: Request parameters + signed: Whether request requires authentication + + Returns: + API response as dictionary + """ + self._rate_limit() + + url = f"{self.base_url}{endpoint}" + timestamp = str(int(time.time() * 1000)) + + if params is None: + params = {} + + headers = {'Content-Type': 'application/json'} + + if signed: + if method == 'GET': + # For GET requests, params go in query string + query_string = urlencode(sorted(params.items())) + signature = self._generate_signature(timestamp, query_string) + headers.update(self._get_headers(timestamp, signature)) + + response = self.session.get(url, params=params, headers=headers) + else: + # For POST/PUT/DELETE, params go in body + body = json.dumps(params) if params else "" + signature = self._generate_signature(timestamp, body) + headers.update(self._get_headers(timestamp, signature)) + + response = self.session.request(method, url, data=body, headers=headers) + else: + # Public endpoint + if method == 'GET': + response = self.session.get(url, params=params, headers=headers) + else: + body = json.dumps(params) if params else "" + response = self.session.request(method, url, data=body, headers=headers) + + # Log request details for debugging + logger.debug(f"{method} {url} - Status: {response.status_code}") + + try: + result = response.json() + except json.JSONDecodeError: + logger.error(f"Failed to decode JSON response: {response.text}") + raise Exception(f"Invalid JSON response: {response.text}") + + # Check for API errors + if response.status_code != 200: + error_msg = result.get('retMsg', f'HTTP {response.status_code}') + logger.error(f"API Error: {error_msg}") + raise Exception(f"Bybit API Error: {error_msg}") + + if result.get('retCode') != 0: + error_msg = result.get('retMsg', 'Unknown error') + error_code = result.get('retCode', 'Unknown') + logger.error(f"Bybit Error {error_code}: {error_msg}") + raise Exception(f"Bybit Error {error_code}: {error_msg}") + + return result + + def get_server_time(self) -> Dict[str, Any]: + """Get server time (public endpoint).""" + return self._make_request('GET', '/v5/market/time') + + def get_account_info(self) -> Dict[str, Any]: + """Get account information (private endpoint).""" + return self._make_request('GET', '/v5/account/wallet-balance', + {'accountType': 'UNIFIED'}, signed=True) + + def get_ticker(self, symbol: str, category: str = "linear") -> Dict[str, Any]: + """Get ticker information. + + Args: + symbol: Trading symbol (e.g., BTCUSDT) + category: Product category (linear, inverse, spot, option) + """ + params = {'category': category, 'symbol': symbol} + return self._make_request('GET', '/v5/market/tickers', params) + + def get_orderbook(self, symbol: str, category: str = "linear", limit: int = 25) -> Dict[str, Any]: + """Get orderbook data. + + Args: + symbol: Trading symbol + category: Product category + limit: Number of price levels (max 200) + """ + params = {'category': category, 'symbol': symbol, 'limit': min(limit, 200)} + return self._make_request('GET', '/v5/market/orderbook', params) + + def get_positions(self, category: str = "linear", symbol: str = None) -> Dict[str, Any]: + """Get position information. + + Args: + category: Product category + symbol: Trading symbol (optional) + """ + params = {'category': category} + if symbol: + params['symbol'] = symbol + return self._make_request('GET', '/v5/position/list', params, signed=True) + + def get_open_orders(self, category: str = "linear", symbol: str = None) -> Dict[str, Any]: + """Get open orders with caching. + + Args: + category: Product category + symbol: Trading symbol (optional) + """ + params = {'category': category, 'openOnly': True} + if symbol: + params['symbol'] = symbol + return self._make_request('GET', '/v5/order/realtime', params, signed=True) + + def place_order(self, category: str, symbol: str, side: str, order_type: str, + qty: str, price: str = None, **kwargs) -> Dict[str, Any]: + """Place an order. + + Args: + category: Product category (linear, inverse, spot, option) + symbol: Trading symbol + side: Buy or Sell + order_type: Market, Limit, etc. + qty: Order quantity as string + price: Order price as string (for limit orders) + **kwargs: Additional order parameters + """ + params = { + 'category': category, + 'symbol': symbol, + 'side': side, + 'orderType': order_type, + 'qty': qty + } + + if price: + params['price'] = price + + # Add additional parameters + params.update(kwargs) + + return self._make_request('POST', '/v5/order/create', params, signed=True) + + def cancel_order(self, category: str, symbol: str, order_id: str = None, + order_link_id: str = None) -> Dict[str, Any]: + """Cancel an order. + + Args: + category: Product category + symbol: Trading symbol + order_id: Order ID + order_link_id: Order link ID (alternative to order_id) + """ + params = {'category': category, 'symbol': symbol} + + if order_id: + params['orderId'] = order_id + elif order_link_id: + params['orderLinkId'] = order_link_id + else: + raise ValueError("Either order_id or order_link_id must be provided") + + return self._make_request('POST', '/v5/order/cancel', params, signed=True) + + def get_instruments_info(self, category: str = "linear", symbol: str = None) -> Dict[str, Any]: + """Get instruments information. + + Args: + category: Product category + symbol: Trading symbol (optional) + """ + params = {'category': category} + if symbol: + params['symbol'] = symbol + return self._make_request('GET', '/v5/market/instruments-info', params) + + def test_connectivity(self) -> bool: + """Test API connectivity. + + Returns: + True if connected successfully + """ + try: + result = self.get_server_time() + logger.info("✅ Bybit REST API connectivity test successful") + return True + except Exception as e: + logger.error(f"❌ Bybit REST API connectivity test failed: {e}") + return False + + def test_authentication(self) -> bool: + """Test API authentication. + + Returns: + True if authentication successful + """ + try: + result = self.get_account_info() + logger.info("✅ Bybit REST API authentication test successful") + return True + except Exception as e: + logger.error(f"❌ Bybit REST API authentication test failed: {e}") + return False \ No newline at end of file diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 3e3a09e..4afd369 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -221,7 +221,7 @@ class DQNAgent: # Check if mixed precision training should be used if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: self.use_mixed_precision = True - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = torch.amp.GradScaler('cuda') logger.info("Mixed precision training enabled") else: self.use_mixed_precision = False diff --git a/config.yaml b/config.yaml index a6420d8..9df1102 100644 --- a/config.yaml +++ b/config.yaml @@ -41,8 +41,8 @@ exchanges: # Bybit Configuration bybit: enabled: true - test_mode: true # Use testnet for testing - trading_mode: "testnet" # simulation, testnet, live + test_mode: false # Use mainnet (your credentials are for live trading) + trading_mode: "live" # simulation, testnet, live supported_symbols: ["BTCUSDT", "ETHUSDT"] # Bybit perpetual format base_position_percent: 5.0 max_position_percent: 20.0