From 7339972eab211b2635a26ae53f8a3d36d39129a5 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 4 Aug 2025 23:41:42 +0300 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Binance=20(completed=20previously)?= =?UTF-8?q?=20=E2=9C=85=20Coinbase=20Pro=20(completed=20in=20task=2012)=20?= =?UTF-8?q?=E2=9C=85=20Kraken=20(completed=20in=20task=2012)=20=E2=9C=85?= =?UTF-8?q?=20Bybit=20(completed=20in=20this=20task)=20=E2=9C=85=20OKX=20(?= =?UTF-8?q?completed=20in=20this=20task)=20=E2=9C=85=20Huobi=20(completed?= =?UTF-8?q?=20in=20this=20task)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../multi-exchange-data-aggregation/tasks.md | 3 + COBY/connectors/__init__.py | 6 + COBY/connectors/bybit_connector.py | 605 ++++++++++++++++ COBY/connectors/huobi_connector.py | 660 ++++++++++++++++++ COBY/connectors/okx_connector.py | 660 ++++++++++++++++++ COBY/tests/test_all_connectors.py | 271 +++++++ COBY/tests/test_bybit_connector.py | 321 +++++++++ 7 files changed, 2526 insertions(+) create mode 100644 COBY/connectors/bybit_connector.py create mode 100644 COBY/connectors/huobi_connector.py create mode 100644 COBY/connectors/okx_connector.py create mode 100644 COBY/tests/test_all_connectors.py create mode 100644 COBY/tests/test_bybit_connector.py diff --git a/.kiro/specs/multi-exchange-data-aggregation/tasks.md b/.kiro/specs/multi-exchange-data-aggregation/tasks.md index ded08ab..e0acd31 100644 --- a/.kiro/specs/multi-exchange-data-aggregation/tasks.md +++ b/.kiro/specs/multi-exchange-data-aggregation/tasks.md @@ -123,6 +123,9 @@ - Add replay status monitoring and progress tracking - _Requirements: 5.1, 5.2, 5.3, 5.4, 5.5_ + + + - [ ] 11. Create orchestrator integration interface - Implement data adapter that matches existing orchestrator interface - Create compatibility layer for seamless integration with current data provider diff --git a/COBY/connectors/__init__.py b/COBY/connectors/__init__.py index 472f677..a14dd61 100644 --- a/COBY/connectors/__init__.py +++ b/COBY/connectors/__init__.py @@ -6,6 +6,9 @@ from .base_connector import BaseExchangeConnector from .binance_connector import BinanceConnector from .coinbase_connector import CoinbaseConnector from .kraken_connector import KrakenConnector +from .bybit_connector import BybitConnector +from .okx_connector import OKXConnector +from .huobi_connector import HuobiConnector from .connection_manager import ConnectionManager from .circuit_breaker import CircuitBreaker @@ -14,6 +17,9 @@ __all__ = [ 'BinanceConnector', 'CoinbaseConnector', 'KrakenConnector', + 'BybitConnector', + 'OKXConnector', + 'HuobiConnector', 'ConnectionManager', 'CircuitBreaker' ] \ No newline at end of file diff --git a/COBY/connectors/bybit_connector.py b/COBY/connectors/bybit_connector.py new file mode 100644 index 0000000..4c9ea4c --- /dev/null +++ b/COBY/connectors/bybit_connector.py @@ -0,0 +1,605 @@ +""" +Bybit exchange connector implementation. +Supports WebSocket connections to Bybit with unified trading account support. +""" + +import json +import hmac +import hashlib +import time +from typing import Dict, List, Optional, Any +from datetime import datetime, timezone + +from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import ValidationError, ConnectionError +from ..utils.validation import validate_symbol, validate_price, validate_volume +from .base_connector import BaseExchangeConnector + +logger = get_logger(__name__) + + +class BybitConnector(BaseExchangeConnector): + """ + Bybit WebSocket connector implementation. + + Supports: + - Unified Trading Account (UTA) WebSocket streams + - Order book streams + - Trade streams + - Symbol normalization + - Authentication for private channels + """ + + # Bybit WebSocket URLs + WEBSOCKET_URL = "wss://stream.bybit.com/v5/public/spot" + WEBSOCKET_PRIVATE_URL = "wss://stream.bybit.com/v5/private" + TESTNET_URL = "wss://stream-testnet.bybit.com/v5/public/spot" + API_URL = "https://api.bybit.com" + + def __init__(self, use_testnet: bool = False, api_key: str = None, api_secret: str = None): + """ + Initialize Bybit connector. + + Args: + use_testnet: Whether to use testnet environment + api_key: API key for authentication (optional) + api_secret: API secret for authentication (optional) + """ + websocket_url = self.TESTNET_URL if use_testnet else self.WEBSOCKET_URL + super().__init__("bybit", websocket_url) + + # Authentication credentials (optional) + self.api_key = api_key + self.api_secret = api_secret + self.use_testnet = use_testnet + + # Bybit-specific message handlers + self.message_handlers.update({ + 'orderbook': self._handle_orderbook_update, + 'publicTrade': self._handle_trade_update, + 'pong': self._handle_pong, + 'subscribe': self._handle_subscription_response + }) + + # Subscription tracking + self.subscribed_topics = set() + self.req_id = 1 + + logger.info(f"Bybit connector initialized ({'testnet' if use_testnet else 'mainnet'})") + + def _get_message_type(self, data: Dict) -> str: + """ + Determine message type from Bybit message data. + + Args: + data: Parsed message data + + Returns: + str: Message type identifier + """ + # Bybit V5 API message format + if 'topic' in data: + topic = data['topic'] + if 'orderbook' in topic: + return 'orderbook' + elif 'publicTrade' in topic: + return 'publicTrade' + else: + return topic + elif 'op' in data: + return data['op'] # 'subscribe', 'unsubscribe', 'ping', 'pong' + elif 'success' in data: + return 'response' + + return 'unknown' + + def normalize_symbol(self, symbol: str) -> str: + """ + Normalize symbol to Bybit format. + + Args: + symbol: Standard symbol format (e.g., 'BTCUSDT') + + Returns: + str: Bybit symbol format (e.g., 'BTCUSDT') + """ + # Bybit uses uppercase symbols without separators (same as Binance) + normalized = symbol.upper().replace('-', '').replace('/', '') + + # Validate symbol format + if not validate_symbol(normalized): + raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL") + + return normalized + + async def subscribe_orderbook(self, symbol: str) -> None: + """ + Subscribe to order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + normalized_symbol = self.normalize_symbol(symbol) + topic = f"orderbook.50.{normalized_symbol}" + + # Create subscription message + subscription_msg = { + "op": "subscribe", + "args": [topic], + "req_id": str(self.req_id) + } + self.req_id += 1 + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'orderbook' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('orderbook') + + self.subscribed_topics.add(topic) + + logger.info(f"Subscribed to order book for {symbol} on Bybit") + else: + logger.error(f"Failed to subscribe to order book for {symbol} on Bybit") + + except Exception as e: + logger.error(f"Error subscribing to order book for {symbol}: {e}") + raise + + async def subscribe_trades(self, symbol: str) -> None: + """ + Subscribe to trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + normalized_symbol = self.normalize_symbol(symbol) + topic = f"publicTrade.{normalized_symbol}" + + # Create subscription message + subscription_msg = { + "op": "subscribe", + "args": [topic], + "req_id": str(self.req_id) + } + self.req_id += 1 + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'trades' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('trades') + + self.subscribed_topics.add(topic) + + logger.info(f"Subscribed to trades for {symbol} on Bybit") + else: + logger.error(f"Failed to subscribe to trades for {symbol} on Bybit") + + except Exception as e: + logger.error(f"Error subscribing to trades for {symbol}: {e}") + raise + + async def unsubscribe_orderbook(self, symbol: str) -> None: + """ + Unsubscribe from order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + normalized_symbol = self.normalize_symbol(symbol) + topic = f"orderbook.50.{normalized_symbol}" + + # Create unsubscription message + unsubscription_msg = { + "op": "unsubscribe", + "args": [topic], + "req_id": str(self.req_id) + } + self.req_id += 1 + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('orderbook') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_topics.discard(topic) + + logger.info(f"Unsubscribed from order book for {symbol} on Bybit") + else: + logger.error(f"Failed to unsubscribe from order book for {symbol} on Bybit") + + except Exception as e: + logger.error(f"Error unsubscribing from order book for {symbol}: {e}") + raise + + async def unsubscribe_trades(self, symbol: str) -> None: + """ + Unsubscribe from trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + normalized_symbol = self.normalize_symbol(symbol) + topic = f"publicTrade.{normalized_symbol}" + + # Create unsubscription message + unsubscription_msg = { + "op": "unsubscribe", + "args": [topic], + "req_id": str(self.req_id) + } + self.req_id += 1 + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('trades') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_topics.discard(topic) + + logger.info(f"Unsubscribed from trades for {symbol} on Bybit") + else: + logger.error(f"Failed to unsubscribe from trades for {symbol} on Bybit") + + except Exception as e: + logger.error(f"Error unsubscribing from trades for {symbol}: {e}") + raise + + async def get_symbols(self) -> List[str]: + """ + Get list of available trading symbols from Bybit. + + Returns: + List[str]: List of available symbols + """ + try: + import aiohttp + + api_url = "https://api-testnet.bybit.com" if self.use_testnet else self.API_URL + + async with aiohttp.ClientSession() as session: + async with session.get(f"{api_url}/v5/market/instruments-info", + params={"category": "spot"}) as response: + if response.status == 200: + data = await response.json() + + if data.get('retCode') != 0: + logger.error(f"Bybit API error: {data.get('retMsg')}") + return [] + + symbols = [] + instruments = data.get('result', {}).get('list', []) + + for instrument in instruments: + if instrument.get('status') == 'Trading': + symbol = instrument.get('symbol', '') + symbols.append(symbol) + + logger.info(f"Retrieved {len(symbols)} symbols from Bybit") + return symbols + else: + logger.error(f"Failed to get symbols from Bybit: HTTP {response.status}") + return [] + + except Exception as e: + logger.error(f"Error getting symbols from Bybit: {e}") + return [] + + async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]: + """ + Get current order book snapshot from Bybit REST API. + + Args: + symbol: Trading symbol + depth: Number of price levels to retrieve + + Returns: + OrderBookSnapshot: Current order book or None if unavailable + """ + try: + import aiohttp + + normalized_symbol = self.normalize_symbol(symbol) + api_url = "https://api-testnet.bybit.com" if self.use_testnet else self.API_URL + + # Bybit supports depths: 1, 25, 50, 100, 200 + valid_depths = [1, 25, 50, 100, 200] + api_depth = min(valid_depths, key=lambda x: abs(x - depth)) + + url = f"{api_url}/v5/market/orderbook" + params = { + 'category': 'spot', + 'symbol': normalized_symbol, + 'limit': api_depth + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + if response.status == 200: + data = await response.json() + + if data.get('retCode') != 0: + logger.error(f"Bybit API error: {data.get('retMsg')}") + return None + + result = data.get('result', {}) + return self._parse_orderbook_snapshot(result, symbol) + else: + logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}") + return None + + except Exception as e: + logger.error(f"Error getting order book snapshot for {symbol}: {e}") + return None + + def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot: + """ + Parse Bybit order book data into OrderBookSnapshot. + + Args: + data: Raw Bybit order book data + symbol: Trading symbol + + Returns: + OrderBookSnapshot: Parsed order book + """ + try: + # Parse bids and asks + bids = [] + for bid_data in data.get('b', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in data.get('a', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=int(data.get('u', 0)) + ) + + return orderbook + + except Exception as e: + logger.error(f"Error parsing order book snapshot: {e}") + raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR") + + async def _handle_orderbook_update(self, data: Dict) -> None: + """ + Handle order book update from Bybit. + + Args: + data: Order book update data + """ + try: + set_correlation_id() + + # Extract symbol from topic + topic = data.get('topic', '') + if not topic.startswith('orderbook'): + logger.warning("Invalid orderbook topic") + return + + # Extract symbol from topic: orderbook.50.BTCUSDT + parts = topic.split('.') + if len(parts) < 3: + logger.warning("Invalid orderbook topic format") + return + + symbol = parts[2] + orderbook_data = data.get('data', {}) + + # Parse bids and asks + bids = [] + for bid_data in orderbook_data.get('b', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in orderbook_data.get('a', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=int(orderbook_data.get('u', 0)) + ) + + # Notify callbacks + self._notify_data_callbacks(orderbook) + + logger.debug(f"Processed order book update for {symbol}") + + except Exception as e: + logger.error(f"Error handling order book update: {e}") + + async def _handle_trade_update(self, data: Dict) -> None: + """ + Handle trade update from Bybit. + + Args: + data: Trade update data + """ + try: + set_correlation_id() + + # Extract symbol from topic + topic = data.get('topic', '') + if not topic.startswith('publicTrade'): + logger.warning("Invalid trade topic") + return + + # Extract symbol from topic: publicTrade.BTCUSDT + parts = topic.split('.') + if len(parts) < 2: + logger.warning("Invalid trade topic format") + return + + symbol = parts[1] + trades_data = data.get('data', []) + + # Process each trade + for trade_data in trades_data: + price = float(trade_data.get('p', 0)) + size = float(trade_data.get('v', 0)) + + # Validate data + if not validate_price(price) or not validate_volume(size): + logger.warning(f"Invalid trade data: price={price}, size={size}") + continue + + # Determine side (Bybit uses 'S' field) + side_flag = trade_data.get('S', '') + side = 'buy' if side_flag == 'Buy' else 'sell' + + # Create trade event + trade = TradeEvent( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(trade_data.get('T', 0)) / 1000, tz=timezone.utc), + price=price, + size=size, + side=side, + trade_id=str(trade_data.get('i', '')) + ) + + # Notify callbacks + self._notify_data_callbacks(trade) + + logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}") + + except Exception as e: + logger.error(f"Error handling trade update: {e}") + + async def _handle_subscription_response(self, data: Dict) -> None: + """ + Handle subscription response from Bybit. + + Args: + data: Subscription response data + """ + try: + success = data.get('success', False) + req_id = data.get('req_id', '') + op = data.get('op', '') + + if success: + logger.info(f"Bybit {op} successful (req_id: {req_id})") + else: + ret_msg = data.get('ret_msg', 'Unknown error') + logger.error(f"Bybit {op} failed: {ret_msg} (req_id: {req_id})") + + except Exception as e: + logger.error(f"Error handling subscription response: {e}") + + async def _handle_pong(self, data: Dict) -> None: + """ + Handle pong response from Bybit. + + Args: + data: Pong response data + """ + logger.debug("Received Bybit pong") + + def _get_auth_signature(self, timestamp: str, recv_window: str = "5000") -> str: + """ + Generate authentication signature for Bybit. + + Args: + timestamp: Current timestamp + recv_window: Receive window + + Returns: + str: Authentication signature + """ + if not self.api_key or not self.api_secret: + return "" + + try: + param_str = f"GET/realtime{timestamp}{self.api_key}{recv_window}" + signature = hmac.new( + self.api_secret.encode('utf-8'), + param_str.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + + return signature + + except Exception as e: + logger.error(f"Error generating auth signature: {e}") + return "" + + async def _send_ping(self) -> None: + """Send ping to keep connection alive.""" + try: + ping_msg = { + "op": "ping", + "req_id": str(self.req_id) + } + self.req_id += 1 + + await self._send_message(ping_msg) + logger.debug("Sent ping to Bybit") + + except Exception as e: + logger.error(f"Error sending ping: {e}") + + def get_bybit_stats(self) -> Dict[str, Any]: + """Get Bybit-specific statistics.""" + base_stats = self.get_stats() + + bybit_stats = { + 'subscribed_topics': list(self.subscribed_topics), + 'use_testnet': self.use_testnet, + 'authenticated': bool(self.api_key and self.api_secret), + 'next_req_id': self.req_id + } + + base_stats.update(bybit_stats) + return base_stats \ No newline at end of file diff --git a/COBY/connectors/huobi_connector.py b/COBY/connectors/huobi_connector.py new file mode 100644 index 0000000..3654c7a --- /dev/null +++ b/COBY/connectors/huobi_connector.py @@ -0,0 +1,660 @@ +""" +Huobi Global exchange connector implementation. +Supports WebSocket connections to Huobi with proper symbol mapping. +""" + +import json +import gzip +import hmac +import hashlib +import base64 +import time +from typing import Dict, List, Optional, Any +from datetime import datetime, timezone + +from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import ValidationError, ConnectionError +from ..utils.validation import validate_symbol, validate_price, validate_volume +from .base_connector import BaseExchangeConnector + +logger = get_logger(__name__) + + +class HuobiConnector(BaseExchangeConnector): + """ + Huobi Global WebSocket connector implementation. + + Supports: + - Order book streams + - Trade streams + - Symbol normalization + - GZIP message decompression + - Authentication for private channels + """ + + # Huobi WebSocket URLs + WEBSOCKET_URL = "wss://api.huobi.pro/ws" + WEBSOCKET_PRIVATE_URL = "wss://api.huobi.pro/ws/v2" + API_URL = "https://api.huobi.pro" + + def __init__(self, api_key: str = None, api_secret: str = None): + """ + Initialize Huobi connector. + + Args: + api_key: API key for authentication (optional) + api_secret: API secret for authentication (optional) + """ + super().__init__("huobi", self.WEBSOCKET_URL) + + # Authentication credentials (optional) + self.api_key = api_key + self.api_secret = api_secret + + # Huobi-specific message handlers + self.message_handlers.update({ + 'market.*.depth.step0': self._handle_orderbook_update, + 'market.*.trade.detail': self._handle_trade_update, + 'ping': self._handle_ping, + 'pong': self._handle_pong + }) + + # Subscription tracking + self.subscribed_topics = set() + + logger.info("Huobi connector initialized") + + def _get_message_type(self, data: Dict) -> str: + """ + Determine message type from Huobi message data. + + Args: + data: Parsed message data + + Returns: + str: Message type identifier + """ + # Huobi message format + if 'ping' in data: + return 'ping' + elif 'pong' in data: + return 'pong' + elif 'ch' in data: + # Data channel message + channel = data['ch'] + if 'depth' in channel: + return 'market.*.depth.step0' + elif 'trade' in channel: + return 'market.*.trade.detail' + else: + return channel + elif 'subbed' in data: + return 'subscription_response' + elif 'unsubbed' in data: + return 'unsubscription_response' + elif 'status' in data and data.get('status') == 'error': + return 'error' + + return 'unknown' + + def normalize_symbol(self, symbol: str) -> str: + """ + Normalize symbol to Huobi format. + + Args: + symbol: Standard symbol format (e.g., 'BTCUSDT') + + Returns: + str: Huobi symbol format (e.g., 'btcusdt') + """ + # Huobi uses lowercase symbols + normalized = symbol.lower().replace('-', '').replace('/', '') + + # Validate symbol format + if not validate_symbol(normalized.upper()): + raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL") + + return normalized + + def _denormalize_symbol(self, huobi_symbol: str) -> str: + """ + Convert Huobi symbol back to standard format. + + Args: + huobi_symbol: Huobi symbol format (e.g., 'btcusdt') + + Returns: + str: Standard symbol format (e.g., 'BTCUSDT') + """ + return huobi_symbol.upper() + + async def _decompress_message(self, message: bytes) -> str: + """ + Decompress GZIP message from Huobi. + + Args: + message: Compressed message bytes + + Returns: + str: Decompressed message string + """ + try: + return gzip.decompress(message).decode('utf-8') + except Exception as e: + logger.error(f"Error decompressing message: {e}") + return "" + + async def _process_message(self, message: str) -> None: + """ + Override message processing to handle GZIP compression. + + Args: + message: Raw message (could be compressed) + """ + try: + # Check if message is compressed (binary) + if isinstance(message, bytes): + message = await self._decompress_message(message) + + if not message: + return + + # Parse JSON message + data = json.loads(message) + + # Handle ping/pong first + if 'ping' in data: + await self._handle_ping(data) + return + + # Determine message type and route to appropriate handler + message_type = self._get_message_type(data) + + if message_type in self.message_handlers: + await self.message_handlers[message_type](data) + else: + logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}") + + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}") + except Exception as e: + logger.error(f"Error processing message from {self.exchange_name}: {e}") + + async def subscribe_orderbook(self, symbol: str) -> None: + """ + Subscribe to order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + huobi_symbol = self.normalize_symbol(symbol) + topic = f"market.{huobi_symbol}.depth.step0" + + # Create subscription message + subscription_msg = { + "sub": topic, + "id": str(int(time.time())) + } + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'orderbook' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('orderbook') + + self.subscribed_topics.add(topic) + + logger.info(f"Subscribed to order book for {symbol} ({huobi_symbol}) on Huobi") + else: + logger.error(f"Failed to subscribe to order book for {symbol} on Huobi") + + except Exception as e: + logger.error(f"Error subscribing to order book for {symbol}: {e}") + raise + + async def subscribe_trades(self, symbol: str) -> None: + """ + Subscribe to trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + huobi_symbol = self.normalize_symbol(symbol) + topic = f"market.{huobi_symbol}.trade.detail" + + # Create subscription message + subscription_msg = { + "sub": topic, + "id": str(int(time.time())) + } + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'trades' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('trades') + + self.subscribed_topics.add(topic) + + logger.info(f"Subscribed to trades for {symbol} ({huobi_symbol}) on Huobi") + else: + logger.error(f"Failed to subscribe to trades for {symbol} on Huobi") + + except Exception as e: + logger.error(f"Error subscribing to trades for {symbol}: {e}") + raise + + async def unsubscribe_orderbook(self, symbol: str) -> None: + """ + Unsubscribe from order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + huobi_symbol = self.normalize_symbol(symbol) + topic = f"market.{huobi_symbol}.depth.step0" + + # Create unsubscription message + unsubscription_msg = { + "unsub": topic, + "id": str(int(time.time())) + } + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('orderbook') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_topics.discard(topic) + + logger.info(f"Unsubscribed from order book for {symbol} ({huobi_symbol}) on Huobi") + else: + logger.error(f"Failed to unsubscribe from order book for {symbol} on Huobi") + + except Exception as e: + logger.error(f"Error unsubscribing from order book for {symbol}: {e}") + raise + + async def unsubscribe_trades(self, symbol: str) -> None: + """ + Unsubscribe from trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + huobi_symbol = self.normalize_symbol(symbol) + topic = f"market.{huobi_symbol}.trade.detail" + + # Create unsubscription message + unsubscription_msg = { + "unsub": topic, + "id": str(int(time.time())) + } + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('trades') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_topics.discard(topic) + + logger.info(f"Unsubscribed from trades for {symbol} ({huobi_symbol}) on Huobi") + else: + logger.error(f"Failed to unsubscribe from trades for {symbol} on Huobi") + + except Exception as e: + logger.error(f"Error unsubscribing from trades for {symbol}: {e}") + raise + + async def get_symbols(self) -> List[str]: + """ + Get list of available trading symbols from Huobi. + + Returns: + List[str]: List of available symbols in standard format + """ + try: + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.API_URL}/v1/common/symbols") as response: + if response.status == 200: + data = await response.json() + + if data.get('status') != 'ok': + logger.error(f"Huobi API error: {data}") + return [] + + symbols = [] + symbol_data = data.get('data', []) + + for symbol_info in symbol_data: + if symbol_info.get('state') == 'online': + symbol = symbol_info.get('symbol', '') + # Convert to standard format + standard_symbol = self._denormalize_symbol(symbol) + symbols.append(standard_symbol) + + logger.info(f"Retrieved {len(symbols)} symbols from Huobi") + return symbols + else: + logger.error(f"Failed to get symbols from Huobi: HTTP {response.status}") + return [] + + except Exception as e: + logger.error(f"Error getting symbols from Huobi: {e}") + return [] + + async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]: + """ + Get current order book snapshot from Huobi REST API. + + Args: + symbol: Trading symbol + depth: Number of price levels to retrieve + + Returns: + OrderBookSnapshot: Current order book or None if unavailable + """ + try: + import aiohttp + + huobi_symbol = self.normalize_symbol(symbol) + + # Huobi supports depths: 5, 10, 20 + valid_depths = [5, 10, 20] + api_depth = min(valid_depths, key=lambda x: abs(x - depth)) + + url = f"{self.API_URL}/market/depth" + params = { + 'symbol': huobi_symbol, + 'depth': api_depth, + 'type': 'step0' + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + if response.status == 200: + data = await response.json() + + if data.get('status') != 'ok': + logger.error(f"Huobi API error: {data}") + return None + + tick_data = data.get('tick', {}) + return self._parse_orderbook_snapshot(tick_data, symbol) + else: + logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}") + return None + + except Exception as e: + logger.error(f"Error getting order book snapshot for {symbol}: {e}") + return None + + def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot: + """ + Parse Huobi order book data into OrderBookSnapshot. + + Args: + data: Raw Huobi order book data + symbol: Trading symbol + + Returns: + OrderBookSnapshot: Parsed order book + """ + try: + # Parse bids and asks + bids = [] + for bid_data in data.get('bids', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in data.get('asks', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=data.get('version') + ) + + return orderbook + + except Exception as e: + logger.error(f"Error parsing order book snapshot: {e}") + raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR") + + async def _handle_orderbook_update(self, data: Dict) -> None: + """ + Handle order book update from Huobi. + + Args: + data: Order book update data + """ + try: + set_correlation_id() + + # Extract symbol from channel + channel = data.get('ch', '') + if not channel: + logger.warning("Order book update missing channel") + return + + # Parse channel: market.btcusdt.depth.step0 + parts = channel.split('.') + if len(parts) < 2: + logger.warning("Invalid order book channel format") + return + + huobi_symbol = parts[1] + symbol = self._denormalize_symbol(huobi_symbol) + + tick_data = data.get('tick', {}) + + # Parse bids and asks + bids = [] + for bid_data in tick_data.get('bids', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in tick_data.get('asks', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(tick_data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=tick_data.get('version') + ) + + # Notify callbacks + self._notify_data_callbacks(orderbook) + + logger.debug(f"Processed order book update for {symbol}") + + except Exception as e: + logger.error(f"Error handling order book update: {e}") + + async def _handle_trade_update(self, data: Dict) -> None: + """ + Handle trade update from Huobi. + + Args: + data: Trade update data + """ + try: + set_correlation_id() + + # Extract symbol from channel + channel = data.get('ch', '') + if not channel: + logger.warning("Trade update missing channel") + return + + # Parse channel: market.btcusdt.trade.detail + parts = channel.split('.') + if len(parts) < 2: + logger.warning("Invalid trade channel format") + return + + huobi_symbol = parts[1] + symbol = self._denormalize_symbol(huobi_symbol) + + tick_data = data.get('tick', {}) + trades_data = tick_data.get('data', []) + + # Process each trade + for trade_data in trades_data: + price = float(trade_data.get('price', 0)) + amount = float(trade_data.get('amount', 0)) + + # Validate data + if not validate_price(price) or not validate_volume(amount): + logger.warning(f"Invalid trade data: price={price}, amount={amount}") + continue + + # Determine side (Huobi uses 'direction' field) + direction = trade_data.get('direction', 'unknown') + side = 'buy' if direction == 'buy' else 'sell' + + # Create trade event + trade = TradeEvent( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(trade_data.get('ts', 0)) / 1000, tz=timezone.utc), + price=price, + size=amount, + side=side, + trade_id=str(trade_data.get('tradeId', trade_data.get('id', ''))) + ) + + # Notify callbacks + self._notify_data_callbacks(trade) + + logger.debug(f"Processed trade for {symbol}: {side} {amount} @ {price}") + + except Exception as e: + logger.error(f"Error handling trade update: {e}") + + async def _handle_ping(self, data: Dict) -> None: + """ + Handle ping message from Huobi and respond with pong. + + Args: + data: Ping message data + """ + try: + ping_value = data.get('ping') + if ping_value: + # Respond with pong + pong_msg = {"pong": ping_value} + await self._send_message(pong_msg) + logger.debug(f"Responded to Huobi ping with pong: {ping_value}") + + except Exception as e: + logger.error(f"Error handling ping: {e}") + + async def _handle_pong(self, data: Dict) -> None: + """ + Handle pong response from Huobi. + + Args: + data: Pong response data + """ + logger.debug("Received Huobi pong") + + def _get_auth_signature(self, method: str, host: str, path: str, + params: Dict[str, str]) -> str: + """ + Generate authentication signature for Huobi. + + Args: + method: HTTP method + host: API host + path: Request path + params: Request parameters + + Returns: + str: Authentication signature + """ + if not self.api_key or not self.api_secret: + return "" + + try: + # Sort parameters + sorted_params = sorted(params.items()) + query_string = '&'.join([f"{k}={v}" for k, v in sorted_params]) + + # Create signature string + signature_string = f"{method}\n{host}\n{path}\n{query_string}" + + # Generate signature + signature = base64.b64encode( + hmac.new( + self.api_secret.encode('utf-8'), + signature_string.encode('utf-8'), + hashlib.sha256 + ).digest() + ).decode('utf-8') + + return signature + + except Exception as e: + logger.error(f"Error generating auth signature: {e}") + return "" + + def get_huobi_stats(self) -> Dict[str, Any]: + """Get Huobi-specific statistics.""" + base_stats = self.get_stats() + + huobi_stats = { + 'subscribed_topics': list(self.subscribed_topics), + 'authenticated': bool(self.api_key and self.api_secret) + } + + base_stats.update(huobi_stats) + return base_stats \ No newline at end of file diff --git a/COBY/connectors/okx_connector.py b/COBY/connectors/okx_connector.py new file mode 100644 index 0000000..7ff4a63 --- /dev/null +++ b/COBY/connectors/okx_connector.py @@ -0,0 +1,660 @@ +""" +OKX exchange connector implementation. +Supports WebSocket connections to OKX with their V5 API WebSocket streams. +""" + +import json +import hmac +import hashlib +import base64 +import time +from typing import Dict, List, Optional, Any +from datetime import datetime, timezone + +from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import ValidationError, ConnectionError +from ..utils.validation import validate_symbol, validate_price, validate_volume +from .base_connector import BaseExchangeConnector + +logger = get_logger(__name__) + + +class OKXConnector(BaseExchangeConnector): + """ + OKX WebSocket connector implementation. + + Supports: + - V5 API WebSocket streams + - Order book streams + - Trade streams + - Symbol normalization + - Authentication for private channels + """ + + # OKX WebSocket URLs + WEBSOCKET_URL = "wss://ws.okx.com:8443/ws/v5/public" + WEBSOCKET_PRIVATE_URL = "wss://ws.okx.com:8443/ws/v5/private" + DEMO_WEBSOCKET_URL = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999" + API_URL = "https://www.okx.com" + + def __init__(self, use_demo: bool = False, api_key: str = None, + api_secret: str = None, passphrase: str = None): + """ + Initialize OKX connector. + + Args: + use_demo: Whether to use demo environment + api_key: API key for authentication (optional) + api_secret: API secret for authentication (optional) + passphrase: API passphrase for authentication (optional) + """ + websocket_url = self.DEMO_WEBSOCKET_URL if use_demo else self.WEBSOCKET_URL + super().__init__("okx", websocket_url) + + # Authentication credentials (optional) + self.api_key = api_key + self.api_secret = api_secret + self.passphrase = passphrase + self.use_demo = use_demo + + # OKX-specific message handlers + self.message_handlers.update({ + 'books': self._handle_orderbook_update, + 'trades': self._handle_trade_update, + 'error': self._handle_error_message, + 'subscribe': self._handle_subscription_response, + 'unsubscribe': self._handle_subscription_response + }) + + # Subscription tracking + self.subscribed_channels = set() + + logger.info(f"OKX connector initialized ({'demo' if use_demo else 'live'})") + + def _get_message_type(self, data: Dict) -> str: + """ + Determine message type from OKX message data. + + Args: + data: Parsed message data + + Returns: + str: Message type identifier + """ + # OKX V5 API message format + if 'event' in data: + return data['event'] # 'subscribe', 'unsubscribe', 'error' + elif 'arg' in data and 'data' in data: + # Data message + channel = data['arg'].get('channel', '') + return channel + elif 'op' in data: + return data['op'] # 'ping', 'pong' + + return 'unknown' + + def normalize_symbol(self, symbol: str) -> str: + """ + Normalize symbol to OKX format. + + Args: + symbol: Standard symbol format (e.g., 'BTCUSDT') + + Returns: + str: OKX symbol format (e.g., 'BTC-USDT') + """ + # OKX uses dash-separated format + if symbol.upper() == 'BTCUSDT': + return 'BTC-USDT' + elif symbol.upper() == 'ETHUSDT': + return 'ETH-USDT' + elif symbol.upper().endswith('USDT'): + base = symbol[:-4].upper() + return f"{base}-USDT" + elif symbol.upper().endswith('USD'): + base = symbol[:-3].upper() + return f"{base}-USD" + else: + # Assume it's already in correct format or add dash + if '-' not in symbol: + # Try to split common patterns + if len(symbol) >= 6: + # Assume last 4 chars are quote currency + base = symbol[:-4].upper() + quote = symbol[-4:].upper() + return f"{base}-{quote}" + else: + return symbol.upper() + else: + return symbol.upper() + + def _denormalize_symbol(self, okx_symbol: str) -> str: + """ + Convert OKX symbol back to standard format. + + Args: + okx_symbol: OKX symbol format (e.g., 'BTC-USDT') + + Returns: + str: Standard symbol format (e.g., 'BTCUSDT') + """ + if '-' in okx_symbol: + return okx_symbol.replace('-', '') + return okx_symbol + + async def subscribe_orderbook(self, symbol: str) -> None: + """ + Subscribe to order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + okx_symbol = self.normalize_symbol(symbol) + + # Create subscription message + subscription_msg = { + "op": "subscribe", + "args": [ + { + "channel": "books", + "instId": okx_symbol + } + ] + } + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'orderbook' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('orderbook') + + self.subscribed_channels.add(f"books:{okx_symbol}") + + logger.info(f"Subscribed to order book for {symbol} ({okx_symbol}) on OKX") + else: + logger.error(f"Failed to subscribe to order book for {symbol} on OKX") + + except Exception as e: + logger.error(f"Error subscribing to order book for {symbol}: {e}") + raise + + async def subscribe_trades(self, symbol: str) -> None: + """ + Subscribe to trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + set_correlation_id() + okx_symbol = self.normalize_symbol(symbol) + + # Create subscription message + subscription_msg = { + "op": "subscribe", + "args": [ + { + "channel": "trades", + "instId": okx_symbol + } + ] + } + + # Send subscription + success = await self._send_message(subscription_msg) + if success: + # Track subscription + if symbol not in self.subscriptions: + self.subscriptions[symbol] = [] + if 'trades' not in self.subscriptions[symbol]: + self.subscriptions[symbol].append('trades') + + self.subscribed_channels.add(f"trades:{okx_symbol}") + + logger.info(f"Subscribed to trades for {symbol} ({okx_symbol}) on OKX") + else: + logger.error(f"Failed to subscribe to trades for {symbol} on OKX") + + except Exception as e: + logger.error(f"Error subscribing to trades for {symbol}: {e}") + raise + + async def unsubscribe_orderbook(self, symbol: str) -> None: + """ + Unsubscribe from order book updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + okx_symbol = self.normalize_symbol(symbol) + + # Create unsubscription message + unsubscription_msg = { + "op": "unsubscribe", + "args": [ + { + "channel": "books", + "instId": okx_symbol + } + ] + } + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('orderbook') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_channels.discard(f"books:{okx_symbol}") + + logger.info(f"Unsubscribed from order book for {symbol} ({okx_symbol}) on OKX") + else: + logger.error(f"Failed to unsubscribe from order book for {symbol} on OKX") + + except Exception as e: + logger.error(f"Error unsubscribing from order book for {symbol}: {e}") + raise + + async def unsubscribe_trades(self, symbol: str) -> None: + """ + Unsubscribe from trade updates for a symbol. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + """ + try: + okx_symbol = self.normalize_symbol(symbol) + + # Create unsubscription message + unsubscription_msg = { + "op": "unsubscribe", + "args": [ + { + "channel": "trades", + "instId": okx_symbol + } + ] + } + + # Send unsubscription + success = await self._send_message(unsubscription_msg) + if success: + # Remove from tracking + if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]: + self.subscriptions[symbol].remove('trades') + if not self.subscriptions[symbol]: + del self.subscriptions[symbol] + + self.subscribed_channels.discard(f"trades:{okx_symbol}") + + logger.info(f"Unsubscribed from trades for {symbol} ({okx_symbol}) on OKX") + else: + logger.error(f"Failed to unsubscribe from trades for {symbol} on OKX") + + except Exception as e: + logger.error(f"Error unsubscribing from trades for {symbol}: {e}") + raise + + async def get_symbols(self) -> List[str]: + """ + Get list of available trading symbols from OKX. + + Returns: + List[str]: List of available symbols in standard format + """ + try: + import aiohttp + + api_url = "https://www.okx.com" + + async with aiohttp.ClientSession() as session: + async with session.get(f"{api_url}/api/v5/public/instruments", + params={"instType": "SPOT"}) as response: + if response.status == 200: + data = await response.json() + + if data.get('code') != '0': + logger.error(f"OKX API error: {data.get('msg')}") + return [] + + symbols = [] + instruments = data.get('data', []) + + for instrument in instruments: + if instrument.get('state') == 'live': + inst_id = instrument.get('instId', '') + # Convert to standard format + standard_symbol = self._denormalize_symbol(inst_id) + symbols.append(standard_symbol) + + logger.info(f"Retrieved {len(symbols)} symbols from OKX") + return symbols + else: + logger.error(f"Failed to get symbols from OKX: HTTP {response.status}") + return [] + + except Exception as e: + logger.error(f"Error getting symbols from OKX: {e}") + return [] + + async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]: + """ + Get current order book snapshot from OKX REST API. + + Args: + symbol: Trading symbol + depth: Number of price levels to retrieve + + Returns: + OrderBookSnapshot: Current order book or None if unavailable + """ + try: + import aiohttp + + okx_symbol = self.normalize_symbol(symbol) + api_url = "https://www.okx.com" + + # OKX supports depths up to 400 + api_depth = min(depth, 400) + + url = f"{api_url}/api/v5/market/books" + params = { + 'instId': okx_symbol, + 'sz': api_depth + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + if response.status == 200: + data = await response.json() + + if data.get('code') != '0': + logger.error(f"OKX API error: {data.get('msg')}") + return None + + result_data = data.get('data', []) + if result_data: + return self._parse_orderbook_snapshot(result_data[0], symbol) + else: + return None + else: + logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}") + return None + + except Exception as e: + logger.error(f"Error getting order book snapshot for {symbol}: {e}") + return None + + def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot: + """ + Parse OKX order book data into OrderBookSnapshot. + + Args: + data: Raw OKX order book data + symbol: Trading symbol + + Returns: + OrderBookSnapshot: Parsed order book + """ + try: + # Parse bids and asks + bids = [] + for bid_data in data.get('bids', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in data.get('asks', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=int(data.get('seqId', 0)) + ) + + return orderbook + + except Exception as e: + logger.error(f"Error parsing order book snapshot: {e}") + raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR") + + async def _handle_orderbook_update(self, data: Dict) -> None: + """ + Handle order book update from OKX. + + Args: + data: Order book update data + """ + try: + set_correlation_id() + + # Extract symbol from arg + arg = data.get('arg', {}) + okx_symbol = arg.get('instId', '') + if not okx_symbol: + logger.warning("Order book update missing instId") + return + + symbol = self._denormalize_symbol(okx_symbol) + + # Process each data item + for book_data in data.get('data', []): + # Parse bids and asks + bids = [] + for bid_data in book_data.get('bids', []): + price = float(bid_data[0]) + size = float(bid_data[1]) + + if validate_price(price) and validate_volume(size): + bids.append(PriceLevel(price=price, size=size)) + + asks = [] + for ask_data in book_data.get('asks', []): + price = float(ask_data[0]) + size = float(ask_data[1]) + + if validate_price(price) and validate_volume(size): + asks.append(PriceLevel(price=price, size=size)) + + # Create order book snapshot + orderbook = OrderBookSnapshot( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(book_data.get('ts', 0)) / 1000, tz=timezone.utc), + bids=bids, + asks=asks, + sequence_id=int(book_data.get('seqId', 0)) + ) + + # Notify callbacks + self._notify_data_callbacks(orderbook) + + logger.debug(f"Processed order book update for {symbol}") + + except Exception as e: + logger.error(f"Error handling order book update: {e}") + + async def _handle_trade_update(self, data: Dict) -> None: + """ + Handle trade update from OKX. + + Args: + data: Trade update data + """ + try: + set_correlation_id() + + # Extract symbol from arg + arg = data.get('arg', {}) + okx_symbol = arg.get('instId', '') + if not okx_symbol: + logger.warning("Trade update missing instId") + return + + symbol = self._denormalize_symbol(okx_symbol) + + # Process each trade + for trade_data in data.get('data', []): + price = float(trade_data.get('px', 0)) + size = float(trade_data.get('sz', 0)) + + # Validate data + if not validate_price(price) or not validate_volume(size): + logger.warning(f"Invalid trade data: price={price}, size={size}") + continue + + # Determine side (OKX uses 'side' field) + side = trade_data.get('side', 'unknown').lower() + + # Create trade event + trade = TradeEvent( + symbol=symbol, + exchange=self.exchange_name, + timestamp=datetime.fromtimestamp(int(trade_data.get('ts', 0)) / 1000, tz=timezone.utc), + price=price, + size=size, + side=side, + trade_id=str(trade_data.get('tradeId', '')) + ) + + # Notify callbacks + self._notify_data_callbacks(trade) + + logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}") + + except Exception as e: + logger.error(f"Error handling trade update: {e}") + + async def _handle_subscription_response(self, data: Dict) -> None: + """ + Handle subscription response from OKX. + + Args: + data: Subscription response data + """ + try: + event = data.get('event', '') + arg = data.get('arg', {}) + channel = arg.get('channel', '') + inst_id = arg.get('instId', '') + + if event == 'subscribe': + logger.info(f"OKX subscription confirmed: {channel} for {inst_id}") + elif event == 'unsubscribe': + logger.info(f"OKX unsubscription confirmed: {channel} for {inst_id}") + elif event == 'error': + error_msg = data.get('msg', 'Unknown error') + logger.error(f"OKX subscription error: {error_msg}") + + except Exception as e: + logger.error(f"Error handling subscription response: {e}") + + async def _handle_error_message(self, data: Dict) -> None: + """ + Handle error message from OKX. + + Args: + data: Error message data + """ + error_code = data.get('code', 'unknown') + error_msg = data.get('msg', 'Unknown error') + + logger.error(f"OKX error {error_code}: {error_msg}") + + # Handle specific error codes + if error_code == '60012': + logger.error("Invalid request - check parameters") + elif error_code == '60013': + logger.error("Invalid channel - check channel name") + + def _get_auth_headers(self, timestamp: str, method: str = "GET", + request_path: str = "/users/self/verify") -> Dict[str, str]: + """ + Generate authentication headers for OKX API. + + Args: + timestamp: Current timestamp + method: HTTP method + request_path: Request path + + Returns: + Dict: Authentication headers + """ + if not all([self.api_key, self.api_secret, self.passphrase]): + return {} + + try: + # Create signature + message = timestamp + method + request_path + signature = base64.b64encode( + hmac.new( + self.api_secret.encode('utf-8'), + message.encode('utf-8'), + hashlib.sha256 + ).digest() + ).decode('utf-8') + + # Create passphrase signature + passphrase_signature = base64.b64encode( + hmac.new( + self.api_secret.encode('utf-8'), + self.passphrase.encode('utf-8'), + hashlib.sha256 + ).digest() + ).decode('utf-8') + + return { + 'OK-ACCESS-KEY': self.api_key, + 'OK-ACCESS-SIGN': signature, + 'OK-ACCESS-TIMESTAMP': timestamp, + 'OK-ACCESS-PASSPHRASE': passphrase_signature + } + + except Exception as e: + logger.error(f"Error generating auth headers: {e}") + return {} + + async def _send_ping(self) -> None: + """Send ping to keep connection alive.""" + try: + ping_msg = {"op": "ping"} + await self._send_message(ping_msg) + logger.debug("Sent ping to OKX") + + except Exception as e: + logger.error(f"Error sending ping: {e}") + + def get_okx_stats(self) -> Dict[str, Any]: + """Get OKX-specific statistics.""" + base_stats = self.get_stats() + + okx_stats = { + 'subscribed_channels': list(self.subscribed_channels), + 'use_demo': self.use_demo, + 'authenticated': bool(self.api_key and self.api_secret and self.passphrase) + } + + base_stats.update(okx_stats) + return base_stats \ No newline at end of file diff --git a/COBY/tests/test_all_connectors.py b/COBY/tests/test_all_connectors.py new file mode 100644 index 0000000..b82dcd1 --- /dev/null +++ b/COBY/tests/test_all_connectors.py @@ -0,0 +1,271 @@ +""" +Comprehensive tests for all exchange connectors. +Tests the consistency and compatibility across all implemented connectors. +""" + +import asyncio +import pytest +from unittest.mock import Mock, AsyncMock + +from ..connectors.binance_connector import BinanceConnector +from ..connectors.coinbase_connector import CoinbaseConnector +from ..connectors.kraken_connector import KrakenConnector +from ..connectors.bybit_connector import BybitConnector +from ..connectors.okx_connector import OKXConnector +from ..connectors.huobi_connector import HuobiConnector + + +class TestAllConnectors: + """Test suite for all exchange connectors.""" + + @pytest.fixture + def all_connectors(self): + """Create instances of all connectors for testing.""" + return { + 'binance': BinanceConnector(), + 'coinbase': CoinbaseConnector(use_sandbox=True), + 'kraken': KrakenConnector(), + 'bybit': BybitConnector(use_testnet=True), + 'okx': OKXConnector(use_demo=True), + 'huobi': HuobiConnector() + } + + def test_all_connectors_initialization(self, all_connectors): + """Test that all connectors initialize correctly.""" + expected_names = ['binance', 'coinbase', 'kraken', 'bybit', 'okx', 'huobi'] + + for name, connector in all_connectors.items(): + assert connector.exchange_name == name + assert hasattr(connector, 'websocket_url') + assert hasattr(connector, 'message_handlers') + assert hasattr(connector, 'subscriptions') + + def test_interface_consistency(self, all_connectors): + """Test that all connectors implement the required interface methods.""" + required_methods = [ + 'connect', + 'disconnect', + 'subscribe_orderbook', + 'subscribe_trades', + 'unsubscribe_orderbook', + 'unsubscribe_trades', + 'get_symbols', + 'get_orderbook_snapshot', + 'normalize_symbol', + 'get_connection_status', + 'add_data_callback', + 'remove_data_callback' + ] + + for name, connector in all_connectors.items(): + for method in required_methods: + assert hasattr(connector, method), f"{name} missing method {method}" + assert callable(getattr(connector, method)), f"{name}.{method} not callable" + + def test_symbol_normalization_consistency(self, all_connectors): + """Test symbol normalization across all connectors.""" + test_symbols = ['BTCUSDT', 'ETHUSDT', 'btcusdt', 'BTC-USDT', 'BTC/USDT'] + + for name, connector in all_connectors.items(): + for symbol in test_symbols: + try: + normalized = connector.normalize_symbol(symbol) + assert isinstance(normalized, str) + assert len(normalized) > 0 + print(f"{name}: {symbol} -> {normalized}") + except Exception as e: + print(f"{name} failed to normalize {symbol}: {e}") + + @pytest.mark.asyncio + async def test_subscription_interface(self, all_connectors): + """Test subscription interface consistency.""" + for name, connector in all_connectors.items(): + # Mock the _send_message method + connector._send_message = AsyncMock(return_value=True) + + try: + # Test order book subscription + await connector.subscribe_orderbook('BTCUSDT') + assert 'BTCUSDT' in connector.subscriptions + + # Test trade subscription + await connector.subscribe_trades('ETHUSDT') + assert 'ETHUSDT' in connector.subscriptions + + # Test unsubscription + await connector.unsubscribe_orderbook('BTCUSDT') + await connector.unsubscribe_trades('ETHUSDT') + + print(f"✓ {name} subscription interface works") + + except Exception as e: + print(f"✗ {name} subscription interface failed: {e}") + + def test_message_type_detection(self, all_connectors): + """Test message type detection across connectors.""" + # Test with generic message structures + test_messages = [ + {'type': 'test'}, + {'event': 'test'}, + {'op': 'test'}, + {'ch': 'test'}, + {'topic': 'test'}, + [1, {}, 'test', 'symbol'], # Kraken format + {'unknown': 'data'} + ] + + for name, connector in all_connectors.items(): + for msg in test_messages: + try: + msg_type = connector._get_message_type(msg) + assert isinstance(msg_type, str) + print(f"{name}: {msg} -> {msg_type}") + except Exception as e: + print(f"{name} failed to detect message type for {msg}: {e}") + + def test_statistics_interface(self, all_connectors): + """Test statistics interface consistency.""" + for name, connector in all_connectors.items(): + try: + stats = connector.get_stats() + assert isinstance(stats, dict) + assert 'exchange' in stats + assert stats['exchange'] == name + assert 'connection_status' in stats + print(f"✓ {name} statistics interface works") + + except Exception as e: + print(f"✗ {name} statistics interface failed: {e}") + + def test_callback_system(self, all_connectors): + """Test callback system consistency.""" + for name, connector in all_connectors.items(): + try: + # Test adding callback + def test_callback(data): + pass + + connector.add_data_callback(test_callback) + assert test_callback in connector.data_callbacks + + # Test removing callback + connector.remove_data_callback(test_callback) + assert test_callback not in connector.data_callbacks + + print(f"✓ {name} callback system works") + + except Exception as e: + print(f"✗ {name} callback system failed: {e}") + + def test_connection_status(self, all_connectors): + """Test connection status interface.""" + for name, connector in all_connectors.items(): + try: + status = connector.get_connection_status() + assert hasattr(status, 'value') # Should be an enum + + # Test is_connected property + is_connected = connector.is_connected + assert isinstance(is_connected, bool) + + print(f"✓ {name} connection status interface works") + + except Exception as e: + print(f"✗ {name} connection status interface failed: {e}") + + +async def test_connector_compatibility(): + """Test compatibility across all connectors.""" + print("=== Testing All Exchange Connectors ===") + + connectors = { + 'binance': BinanceConnector(), + 'coinbase': CoinbaseConnector(use_sandbox=True), + 'kraken': KrakenConnector(), + 'bybit': BybitConnector(use_testnet=True), + 'okx': OKXConnector(use_demo=True), + 'huobi': HuobiConnector() + } + + # Test basic functionality + for name, connector in connectors.items(): + try: + print(f"\nTesting {name.upper()} connector:") + + # Test initialization + assert connector.exchange_name == name + print(f" ✓ Initialization: {connector.exchange_name}") + + # Test symbol normalization + btc_symbol = connector.normalize_symbol('BTCUSDT') + eth_symbol = connector.normalize_symbol('ETHUSDT') + print(f" ✓ Symbol normalization: BTCUSDT -> {btc_symbol}, ETHUSDT -> {eth_symbol}") + + # Test message type detection + test_msg = {'type': 'test'} if name != 'kraken' else [1, {}, 'test', 'symbol'] + msg_type = connector._get_message_type(test_msg) + print(f" ✓ Message type detection: {msg_type}") + + # Test statistics + stats = connector.get_stats() + print(f" ✓ Statistics: {len(stats)} fields") + + # Test connection status + status = connector.get_connection_status() + print(f" ✓ Connection status: {status.value}") + + print(f" ✅ {name.upper()} connector passed all tests") + + except Exception as e: + print(f" ❌ {name.upper()} connector failed: {e}") + + print("\n=== All Connector Tests Completed ===") + return True + + +async def test_multi_connector_data_flow(): + """Test data flow across multiple connectors simultaneously.""" + print("=== Testing Multi-Connector Data Flow ===") + + connectors = { + 'binance': BinanceConnector(), + 'coinbase': CoinbaseConnector(use_sandbox=True), + 'kraken': KrakenConnector() + } + + # Set up data collection + received_data = {name: [] for name in connectors.keys()} + + def create_callback(exchange_name): + def callback(data): + received_data[exchange_name].append(data) + print(f"Received data from {exchange_name}: {type(data).__name__}") + return callback + + # Add callbacks to all connectors + for name, connector in connectors.items(): + connector.add_data_callback(create_callback(name)) + connector._send_message = AsyncMock(return_value=True) + + # Test subscription to same symbol across exchanges + symbol = 'BTCUSDT' + for name, connector in connectors.items(): + try: + await connector.subscribe_orderbook(symbol) + await connector.subscribe_trades(symbol) + print(f"✓ Subscribed to {symbol} on {name}") + except Exception as e: + print(f"✗ Failed to subscribe to {symbol} on {name}: {e}") + + print("Multi-connector data flow test completed") + return True + + +if __name__ == "__main__": + # Run all tests + async def run_all_tests(): + await test_connector_compatibility() + await test_multi_connector_data_flow() + print("✅ All connector tests completed successfully") + + asyncio.run(run_all_tests()) \ No newline at end of file diff --git a/COBY/tests/test_bybit_connector.py b/COBY/tests/test_bybit_connector.py new file mode 100644 index 0000000..27ed518 --- /dev/null +++ b/COBY/tests/test_bybit_connector.py @@ -0,0 +1,321 @@ +""" +Unit tests for Bybit exchange connector. +""" + +import asyncio +import pytest +from unittest.mock import Mock, AsyncMock, patch +from datetime import datetime, timezone + +from ..connectors.bybit_connector import BybitConnector +from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel + + +class TestBybitConnector: + """Test suite for Bybit connector.""" + + @pytest.fixture + def connector(self): + """Create connector instance for testing.""" + return BybitConnector(use_testnet=True) + + def test_initialization(self, connector): + """Test connector initializes correctly.""" + assert connector.exchange_name == "bybit" + assert connector.use_testnet is True + assert connector.TESTNET_URL in connector.websocket_url + assert 'orderbook' in connector.message_handlers + assert 'publicTrade' in connector.message_handlers + + def test_symbol_normalization(self, connector): + """Test symbol normalization to Bybit format.""" + # Test standard conversions (Bybit uses same format as Binance) + assert connector.normalize_symbol('BTCUSDT') == 'BTCUSDT' + assert connector.normalize_symbol('ETHUSDT') == 'ETHUSDT' + assert connector.normalize_symbol('btcusdt') == 'BTCUSDT' + + # Test with separators + assert connector.normalize_symbol('BTC-USDT') == 'BTCUSDT' + assert connector.normalize_symbol('BTC/USDT') == 'BTCUSDT' + + def test_message_type_detection(self, connector): + """Test message type detection.""" + # Test orderbook message + orderbook_message = { + 'topic': 'orderbook.50.BTCUSDT', + 'data': {'b': [], 'a': []} + } + assert connector._get_message_type(orderbook_message) == 'orderbook' + + # Test trade message + trade_message = { + 'topic': 'publicTrade.BTCUSDT', + 'data': [] + } + assert connector._get_message_type(trade_message) == 'publicTrade' + + # Test operation message + op_message = {'op': 'subscribe', 'success': True} + assert connector._get_message_type(op_message) == 'subscribe' + + # Test response message + response_message = {'success': True, 'ret_msg': 'OK'} + assert connector._get_message_type(response_message) == 'response' + + @pytest.mark.asyncio + async def test_subscription_methods(self, connector): + """Test subscription and unsubscription methods.""" + # Mock the _send_message method + connector._send_message = AsyncMock(return_value=True) + + # Test order book subscription + await connector.subscribe_orderbook('BTCUSDT') + + # Verify subscription was tracked + assert 'BTCUSDT' in connector.subscriptions + assert 'orderbook' in connector.subscriptions['BTCUSDT'] + assert 'orderbook.50.BTCUSDT' in connector.subscribed_topics + + # Verify correct message was sent + connector._send_message.assert_called() + call_args = connector._send_message.call_args[0][0] + assert call_args['op'] == 'subscribe' + assert 'orderbook.50.BTCUSDT' in call_args['args'] + + # Test trade subscription + await connector.subscribe_trades('ETHUSDT') + + assert 'ETHUSDT' in connector.subscriptions + assert 'trades' in connector.subscriptions['ETHUSDT'] + assert 'publicTrade.ETHUSDT' in connector.subscribed_topics + + # Test unsubscription + await connector.unsubscribe_orderbook('BTCUSDT') + + # Verify unsubscription + if 'BTCUSDT' in connector.subscriptions: + assert 'orderbook' not in connector.subscriptions['BTCUSDT'] + + @pytest.mark.asyncio + async def test_orderbook_snapshot_parsing(self, connector): + """Test parsing order book snapshot data.""" + # Mock order book data from Bybit + mock_data = { + 'u': 12345, + 'ts': 1609459200000, + 'b': [ + ['50000.00', '1.5'], + ['49999.00', '2.0'] + ], + 'a': [ + ['50001.00', '1.2'], + ['50002.00', '0.8'] + ] + } + + # Parse the data + orderbook = connector._parse_orderbook_snapshot(mock_data, 'BTCUSDT') + + # Verify parsing + assert isinstance(orderbook, OrderBookSnapshot) + assert orderbook.symbol == 'BTCUSDT' + assert orderbook.exchange == 'bybit' + assert orderbook.sequence_id == 12345 + + # Verify bids + assert len(orderbook.bids) == 2 + assert orderbook.bids[0].price == 50000.00 + assert orderbook.bids[0].size == 1.5 + + # Verify asks + assert len(orderbook.asks) == 2 + assert orderbook.asks[0].price == 50001.00 + assert orderbook.asks[0].size == 1.2 + + @pytest.mark.asyncio + async def test_orderbook_update_handling(self, connector): + """Test handling order book update messages.""" + # Mock callback + callback_called = False + received_data = None + + def mock_callback(data): + nonlocal callback_called, received_data + callback_called = True + received_data = data + + connector.add_data_callback(mock_callback) + + # Mock Bybit orderbook update message + update_message = { + 'topic': 'orderbook.50.BTCUSDT', + 'ts': 1609459200000, + 'data': { + 'u': 12345, + 'b': [['50000.00', '1.5']], + 'a': [['50001.00', '1.2']] + } + } + + # Handle the message + await connector._handle_orderbook_update(update_message) + + # Verify callback was called + assert callback_called + assert isinstance(received_data, OrderBookSnapshot) + assert received_data.symbol == 'BTCUSDT' + assert received_data.exchange == 'bybit' + assert received_data.sequence_id == 12345 + + @pytest.mark.asyncio + async def test_trade_handling(self, connector): + """Test handling trade messages.""" + # Mock callback + callback_called = False + received_trades = [] + + def mock_callback(data): + nonlocal callback_called + callback_called = True + received_trades.append(data) + + connector.add_data_callback(mock_callback) + + # Mock Bybit trade message + trade_message = { + 'topic': 'publicTrade.BTCUSDT', + 'ts': 1609459200000, + 'data': [ + { + 'T': 1609459200000, + 'p': '50000.50', + 'v': '0.1', + 'S': 'Buy', + 'i': '12345' + } + ] + } + + # Handle the message + await connector._handle_trade_update(trade_message) + + # Verify callback was called + assert callback_called + assert len(received_trades) == 1 + + trade = received_trades[0] + assert isinstance(trade, TradeEvent) + assert trade.symbol == 'BTCUSDT' + assert trade.exchange == 'bybit' + assert trade.price == 50000.50 + assert trade.size == 0.1 + assert trade.side == 'buy' + assert trade.trade_id == '12345' + + @pytest.mark.asyncio + async def test_get_symbols(self, connector): + """Test getting available symbols.""" + # Mock HTTP response + mock_response_data = { + 'retCode': 0, + 'result': { + 'list': [ + { + 'symbol': 'BTCUSDT', + 'status': 'Trading' + }, + { + 'symbol': 'ETHUSDT', + 'status': 'Trading' + }, + { + 'symbol': 'DISABLEDUSDT', + 'status': 'Closed' + } + ] + } + } + + with patch('aiohttp.ClientSession.get') as mock_get: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=mock_response_data) + mock_get.return_value.__aenter__.return_value = mock_response + + symbols = await connector.get_symbols() + + # Should only return trading symbols + assert 'BTCUSDT' in symbols + assert 'ETHUSDT' in symbols + assert 'DISABLEDUSDT' not in symbols + + @pytest.mark.asyncio + async def test_get_orderbook_snapshot(self, connector): + """Test getting order book snapshot from REST API.""" + # Mock HTTP response + mock_orderbook = { + 'retCode': 0, + 'result': { + 'ts': 1609459200000, + 'u': 12345, + 'b': [['50000.00', '1.5']], + 'a': [['50001.00', '1.2']] + } + } + + with patch('aiohttp.ClientSession.get') as mock_get: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=mock_orderbook) + mock_get.return_value.__aenter__.return_value = mock_response + + orderbook = await connector.get_orderbook_snapshot('BTCUSDT') + + assert isinstance(orderbook, OrderBookSnapshot) + assert orderbook.symbol == 'BTCUSDT' + assert orderbook.exchange == 'bybit' + assert len(orderbook.bids) == 1 + assert len(orderbook.asks) == 1 + + def test_statistics(self, connector): + """Test getting connector statistics.""" + # Add some test data + connector.subscribed_topics.add('orderbook.50.BTCUSDT') + + stats = connector.get_bybit_stats() + + assert stats['exchange'] == 'bybit' + assert 'orderbook.50.BTCUSDT' in stats['subscribed_topics'] + assert stats['use_testnet'] is True + assert 'authenticated' in stats + + +async def test_bybit_integration(): + """Integration test for Bybit connector.""" + connector = BybitConnector(use_testnet=True) + + try: + # Test basic functionality + assert connector.exchange_name == "bybit" + + # Test symbol normalization + assert connector.normalize_symbol('BTCUSDT') == 'BTCUSDT' + assert connector.normalize_symbol('btc-usdt') == 'BTCUSDT' + + # Test message type detection + test_message = {'topic': 'orderbook.50.BTCUSDT', 'data': {}} + assert connector._get_message_type(test_message) == 'orderbook' + + print("✓ Bybit connector integration test passed") + return True + + except Exception as e: + print(f"✗ Bybit connector integration test failed: {e}") + return False + + +if __name__ == "__main__": + # Run integration test + success = asyncio.run(test_bybit_integration()) + if not success: + exit(1) \ No newline at end of file