From ff75af566c9d674f27c689a4bb4035fa127f9959 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 4 Aug 2025 17:55:00 +0300 Subject: [PATCH] caching --- .../multi-exchange-data-aggregation/tasks.md | 3 + COBY/caching/__init__.py | 13 + COBY/caching/cache_keys.py | 278 +++++++ COBY/caching/data_serializer.py | 355 +++++++++ COBY/caching/redis_manager.py | 691 ++++++++++++++++++ COBY/tests/test_redis_manager.py | 347 +++++++++ 6 files changed, 1687 insertions(+) create mode 100644 COBY/caching/__init__.py create mode 100644 COBY/caching/cache_keys.py create mode 100644 COBY/caching/data_serializer.py create mode 100644 COBY/caching/redis_manager.py create mode 100644 COBY/tests/test_redis_manager.py diff --git a/.kiro/specs/multi-exchange-data-aggregation/tasks.md b/.kiro/specs/multi-exchange-data-aggregation/tasks.md index 385a3cd..2d511e5 100644 --- a/.kiro/specs/multi-exchange-data-aggregation/tasks.md +++ b/.kiro/specs/multi-exchange-data-aggregation/tasks.md @@ -46,6 +46,9 @@ - Write unit tests for Binance connector functionality - _Requirements: 1.1, 1.2, 1.4, 6.2_ + + + - [ ] 5. Create data processing and normalization engine - Implement data processor for normalizing raw exchange data - Create validation logic for order book and trade data diff --git a/COBY/caching/__init__.py b/COBY/caching/__init__.py new file mode 100644 index 0000000..9a44cfa --- /dev/null +++ b/COBY/caching/__init__.py @@ -0,0 +1,13 @@ +""" +Caching layer for the COBY system. +""" + +from .redis_manager import RedisManager +from .cache_keys import CacheKeys +from .data_serializer import DataSerializer + +__all__ = [ + 'RedisManager', + 'CacheKeys', + 'DataSerializer' +] \ No newline at end of file diff --git a/COBY/caching/cache_keys.py b/COBY/caching/cache_keys.py new file mode 100644 index 0000000..4697742 --- /dev/null +++ b/COBY/caching/cache_keys.py @@ -0,0 +1,278 @@ +""" +Cache key management for Redis operations. +""" + +from typing import Optional +from ..utils.logging import get_logger + +logger = get_logger(__name__) + + +class CacheKeys: + """ + Centralized cache key management for consistent Redis operations. + + Provides standardized key patterns for different data types. + """ + + # Key prefixes + ORDERBOOK_PREFIX = "ob" + HEATMAP_PREFIX = "hm" + TRADE_PREFIX = "tr" + METRICS_PREFIX = "mt" + STATUS_PREFIX = "st" + STATS_PREFIX = "stats" + + # TTL values (seconds) + ORDERBOOK_TTL = 60 # 1 minute + HEATMAP_TTL = 30 # 30 seconds + TRADE_TTL = 300 # 5 minutes + METRICS_TTL = 120 # 2 minutes + STATUS_TTL = 60 # 1 minute + STATS_TTL = 300 # 5 minutes + + @classmethod + def orderbook_key(cls, symbol: str, exchange: str) -> str: + """ + Generate cache key for order book data. + + Args: + symbol: Trading symbol + exchange: Exchange name + + Returns: + str: Cache key + """ + return f"{cls.ORDERBOOK_PREFIX}:{exchange}:{symbol}" + + @classmethod + def heatmap_key(cls, symbol: str, bucket_size: float = 1.0, + exchange: Optional[str] = None) -> str: + """ + Generate cache key for heatmap data. + + Args: + symbol: Trading symbol + bucket_size: Price bucket size + exchange: Exchange name (None for consolidated) + + Returns: + str: Cache key + """ + if exchange: + return f"{cls.HEATMAP_PREFIX}:{exchange}:{symbol}:{bucket_size}" + else: + return f"{cls.HEATMAP_PREFIX}:consolidated:{symbol}:{bucket_size}" + + @classmethod + def trade_key(cls, symbol: str, exchange: str, trade_id: str) -> str: + """ + Generate cache key for trade data. + + Args: + symbol: Trading symbol + exchange: Exchange name + trade_id: Trade identifier + + Returns: + str: Cache key + """ + return f"{cls.TRADE_PREFIX}:{exchange}:{symbol}:{trade_id}" + + @classmethod + def metrics_key(cls, symbol: str, exchange: str) -> str: + """ + Generate cache key for metrics data. + + Args: + symbol: Trading symbol + exchange: Exchange name + + Returns: + str: Cache key + """ + return f"{cls.METRICS_PREFIX}:{exchange}:{symbol}" + + @classmethod + def status_key(cls, exchange: str) -> str: + """ + Generate cache key for exchange status. + + Args: + exchange: Exchange name + + Returns: + str: Cache key + """ + return f"{cls.STATUS_PREFIX}:{exchange}" + + @classmethod + def stats_key(cls, component: str) -> str: + """ + Generate cache key for component statistics. + + Args: + component: Component name + + Returns: + str: Cache key + """ + return f"{cls.STATS_PREFIX}:{component}" + + @classmethod + def latest_heatmaps_key(cls, symbol: str) -> str: + """ + Generate cache key for latest heatmaps list. + + Args: + symbol: Trading symbol + + Returns: + str: Cache key + """ + return f"{cls.HEATMAP_PREFIX}:latest:{symbol}" + + @classmethod + def symbol_list_key(cls, exchange: str) -> str: + """ + Generate cache key for symbol list. + + Args: + exchange: Exchange name + + Returns: + str: Cache key + """ + return f"symbols:{exchange}" + + @classmethod + def price_bucket_key(cls, symbol: str, exchange: str) -> str: + """ + Generate cache key for price buckets. + + Args: + symbol: Trading symbol + exchange: Exchange name + + Returns: + str: Cache key + """ + return f"buckets:{exchange}:{symbol}" + + @classmethod + def arbitrage_key(cls, symbol: str) -> str: + """ + Generate cache key for arbitrage opportunities. + + Args: + symbol: Trading symbol + + Returns: + str: Cache key + """ + return f"arbitrage:{symbol}" + + @classmethod + def get_ttl(cls, key: str) -> int: + """ + Get appropriate TTL for a cache key. + + Args: + key: Cache key + + Returns: + int: TTL in seconds + """ + if key.startswith(cls.ORDERBOOK_PREFIX): + return cls.ORDERBOOK_TTL + elif key.startswith(cls.HEATMAP_PREFIX): + return cls.HEATMAP_TTL + elif key.startswith(cls.TRADE_PREFIX): + return cls.TRADE_TTL + elif key.startswith(cls.METRICS_PREFIX): + return cls.METRICS_TTL + elif key.startswith(cls.STATUS_PREFIX): + return cls.STATUS_TTL + elif key.startswith(cls.STATS_PREFIX): + return cls.STATS_TTL + else: + return 300 # Default 5 minutes + + @classmethod + def parse_key(cls, key: str) -> dict: + """ + Parse cache key to extract components. + + Args: + key: Cache key to parse + + Returns: + dict: Parsed key components + """ + parts = key.split(':') + + if len(parts) < 2: + return {'type': 'unknown', 'key': key} + + key_type = parts[0] + + if key_type == cls.ORDERBOOK_PREFIX and len(parts) >= 3: + return { + 'type': 'orderbook', + 'exchange': parts[1], + 'symbol': parts[2] + } + elif key_type == cls.HEATMAP_PREFIX and len(parts) >= 4: + return { + 'type': 'heatmap', + 'exchange': parts[1] if parts[1] != 'consolidated' else None, + 'symbol': parts[2], + 'bucket_size': float(parts[3]) if len(parts) > 3 else 1.0 + } + elif key_type == cls.TRADE_PREFIX and len(parts) >= 4: + return { + 'type': 'trade', + 'exchange': parts[1], + 'symbol': parts[2], + 'trade_id': parts[3] + } + elif key_type == cls.METRICS_PREFIX and len(parts) >= 3: + return { + 'type': 'metrics', + 'exchange': parts[1], + 'symbol': parts[2] + } + elif key_type == cls.STATUS_PREFIX and len(parts) >= 2: + return { + 'type': 'status', + 'exchange': parts[1] + } + elif key_type == cls.STATS_PREFIX and len(parts) >= 2: + return { + 'type': 'stats', + 'component': parts[1] + } + else: + return {'type': 'unknown', 'key': key} + + @classmethod + def get_pattern(cls, key_type: str) -> str: + """ + Get Redis pattern for key type. + + Args: + key_type: Type of key + + Returns: + str: Redis pattern + """ + patterns = { + 'orderbook': f"{cls.ORDERBOOK_PREFIX}:*", + 'heatmap': f"{cls.HEATMAP_PREFIX}:*", + 'trade': f"{cls.TRADE_PREFIX}:*", + 'metrics': f"{cls.METRICS_PREFIX}:*", + 'status': f"{cls.STATUS_PREFIX}:*", + 'stats': f"{cls.STATS_PREFIX}:*" + } + + return patterns.get(key_type, "*") \ No newline at end of file diff --git a/COBY/caching/data_serializer.py b/COBY/caching/data_serializer.py new file mode 100644 index 0000000..cf22336 --- /dev/null +++ b/COBY/caching/data_serializer.py @@ -0,0 +1,355 @@ +""" +Data serialization for Redis caching. +""" + +import json +import pickle +import gzip +from typing import Any, Union, Dict, List +from datetime import datetime +from ..models.core import ( + OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets, + OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook +) +from ..utils.logging import get_logger +from ..utils.exceptions import ProcessingError + +logger = get_logger(__name__) + + +class DataSerializer: + """ + Handles serialization and deserialization of data for Redis storage. + + Supports multiple serialization formats: + - JSON for simple data + - Pickle for complex objects + - Compressed formats for large data + """ + + def __init__(self, use_compression: bool = True): + """ + Initialize data serializer. + + Args: + use_compression: Whether to use gzip compression + """ + self.use_compression = use_compression + self.serialization_stats = { + 'serialized': 0, + 'deserialized': 0, + 'compression_ratio': 0.0, + 'errors': 0 + } + + logger.info(f"Data serializer initialized (compression: {use_compression})") + + def serialize(self, data: Any, format_type: str = 'auto') -> bytes: + """ + Serialize data for Redis storage. + + Args: + data: Data to serialize + format_type: Serialization format ('json', 'pickle', 'auto') + + Returns: + bytes: Serialized data + """ + try: + # Determine format + if format_type == 'auto': + format_type = self._determine_format(data) + + # Serialize based on format + if format_type == 'json': + serialized = self._serialize_json(data) + elif format_type == 'pickle': + serialized = self._serialize_pickle(data) + else: + raise ValueError(f"Unsupported format: {format_type}") + + # Apply compression if enabled + if self.use_compression: + original_size = len(serialized) + serialized = gzip.compress(serialized) + compressed_size = len(serialized) + + # Update compression ratio + if original_size > 0: + ratio = compressed_size / original_size + self.serialization_stats['compression_ratio'] = ( + (self.serialization_stats['compression_ratio'] * + self.serialization_stats['serialized'] + ratio) / + (self.serialization_stats['serialized'] + 1) + ) + + self.serialization_stats['serialized'] += 1 + return serialized + + except Exception as e: + self.serialization_stats['errors'] += 1 + logger.error(f"Serialization error: {e}") + raise ProcessingError(f"Serialization failed: {e}", "SERIALIZE_ERROR") + + def deserialize(self, data: bytes, format_type: str = 'auto') -> Any: + """ + Deserialize data from Redis storage. + + Args: + data: Serialized data + format_type: Expected format ('json', 'pickle', 'auto') + + Returns: + Any: Deserialized data + """ + try: + # Decompress if needed + if self.use_compression: + try: + data = gzip.decompress(data) + except gzip.BadGzipFile: + # Data might not be compressed + pass + + # Determine format if auto + if format_type == 'auto': + format_type = self._detect_format(data) + + # Deserialize based on format + if format_type == 'json': + result = self._deserialize_json(data) + elif format_type == 'pickle': + result = self._deserialize_pickle(data) + else: + raise ValueError(f"Unsupported format: {format_type}") + + self.serialization_stats['deserialized'] += 1 + return result + + except Exception as e: + self.serialization_stats['errors'] += 1 + logger.error(f"Deserialization error: {e}") + raise ProcessingError(f"Deserialization failed: {e}", "DESERIALIZE_ERROR") + + def _determine_format(self, data: Any) -> str: + """Determine best serialization format for data""" + # Use JSON for simple data types + if isinstance(data, (dict, list, str, int, float, bool)) or data is None: + return 'json' + + # Use pickle for complex objects + return 'pickle' + + def _detect_format(self, data: bytes) -> str: + """Detect serialization format from data""" + try: + # Try JSON first + json.loads(data.decode('utf-8')) + return 'json' + except (json.JSONDecodeError, UnicodeDecodeError): + # Assume pickle + return 'pickle' + + def _serialize_json(self, data: Any) -> bytes: + """Serialize data as JSON""" + # Convert complex objects to dictionaries + if hasattr(data, '__dict__'): + data = self._object_to_dict(data) + elif isinstance(data, list): + data = [self._object_to_dict(item) if hasattr(item, '__dict__') else item + for item in data] + + json_str = json.dumps(data, default=self._json_serializer, ensure_ascii=False) + return json_str.encode('utf-8') + + def _deserialize_json(self, data: bytes) -> Any: + """Deserialize JSON data""" + json_str = data.decode('utf-8') + return json.loads(json_str, object_hook=self._json_deserializer) + + def _serialize_pickle(self, data: Any) -> bytes: + """Serialize data as pickle""" + return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + + def _deserialize_pickle(self, data: bytes) -> Any: + """Deserialize pickle data""" + return pickle.loads(data) + + def _object_to_dict(self, obj: Any) -> Dict: + """Convert object to dictionary for JSON serialization""" + if isinstance(obj, (OrderBookSnapshot, TradeEvent, HeatmapData, + PriceBuckets, OrderBookMetrics, ImbalanceMetrics, + ConsolidatedOrderBook)): + result = { + '__type__': obj.__class__.__name__, + '__data__': {} + } + + # Convert object attributes + for key, value in obj.__dict__.items(): + if isinstance(value, datetime): + result['__data__'][key] = { + '__datetime__': value.isoformat() + } + elif isinstance(value, list): + result['__data__'][key] = [ + self._object_to_dict(item) if hasattr(item, '__dict__') else item + for item in value + ] + elif hasattr(value, '__dict__'): + result['__data__'][key] = self._object_to_dict(value) + else: + result['__data__'][key] = value + + return result + else: + return obj.__dict__ if hasattr(obj, '__dict__') else obj + + def _json_serializer(self, obj: Any) -> Any: + """Custom JSON serializer for special types""" + if isinstance(obj, datetime): + return {'__datetime__': obj.isoformat()} + elif hasattr(obj, '__dict__'): + return self._object_to_dict(obj) + else: + return str(obj) + + def _json_deserializer(self, obj: Dict) -> Any: + """Custom JSON deserializer for special types""" + if '__datetime__' in obj: + return datetime.fromisoformat(obj['__datetime__']) + elif '__type__' in obj and '__data__' in obj: + return self._reconstruct_object(obj['__type__'], obj['__data__']) + else: + return obj + + def _reconstruct_object(self, type_name: str, data: Dict) -> Any: + """Reconstruct object from serialized data""" + # Import required classes + from ..models.core import ( + OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets, + OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook, + PriceLevel, HeatmapPoint + ) + + # Map type names to classes + type_map = { + 'OrderBookSnapshot': OrderBookSnapshot, + 'TradeEvent': TradeEvent, + 'HeatmapData': HeatmapData, + 'PriceBuckets': PriceBuckets, + 'OrderBookMetrics': OrderBookMetrics, + 'ImbalanceMetrics': ImbalanceMetrics, + 'ConsolidatedOrderBook': ConsolidatedOrderBook, + 'PriceLevel': PriceLevel, + 'HeatmapPoint': HeatmapPoint + } + + if type_name in type_map: + cls = type_map[type_name] + + # Recursively deserialize nested objects + processed_data = {} + for key, value in data.items(): + if isinstance(value, dict) and '__datetime__' in value: + processed_data[key] = datetime.fromisoformat(value['__datetime__']) + elif isinstance(value, dict) and '__type__' in value: + processed_data[key] = self._reconstruct_object( + value['__type__'], value['__data__'] + ) + elif isinstance(value, list): + processed_data[key] = [ + self._reconstruct_object(item['__type__'], item['__data__']) + if isinstance(item, dict) and '__type__' in item + else item + for item in value + ] + else: + processed_data[key] = value + + try: + return cls(**processed_data) + except Exception as e: + logger.warning(f"Failed to reconstruct {type_name}: {e}") + return processed_data + else: + logger.warning(f"Unknown type for reconstruction: {type_name}") + return data + + def serialize_heatmap(self, heatmap: HeatmapData) -> bytes: + """Specialized serialization for heatmap data""" + try: + # Create optimized representation + heatmap_dict = { + 'symbol': heatmap.symbol, + 'timestamp': heatmap.timestamp.isoformat(), + 'bucket_size': heatmap.bucket_size, + 'points': [ + { + 'p': point.price, # price + 'v': point.volume, # volume + 'i': point.intensity, # intensity + 's': point.side # side + } + for point in heatmap.data + ] + } + + return self.serialize(heatmap_dict, 'json') + + except Exception as e: + logger.error(f"Heatmap serialization error: {e}") + # Fallback to standard serialization + return self.serialize(heatmap, 'pickle') + + def deserialize_heatmap(self, data: bytes) -> HeatmapData: + """Specialized deserialization for heatmap data""" + try: + # Try optimized format first + heatmap_dict = self.deserialize(data, 'json') + + if isinstance(heatmap_dict, dict) and 'points' in heatmap_dict: + from ..models.core import HeatmapData, HeatmapPoint + + # Reconstruct heatmap points + points = [] + for point_data in heatmap_dict['points']: + point = HeatmapPoint( + price=point_data['p'], + volume=point_data['v'], + intensity=point_data['i'], + side=point_data['s'] + ) + points.append(point) + + # Create heatmap + heatmap = HeatmapData( + symbol=heatmap_dict['symbol'], + timestamp=datetime.fromisoformat(heatmap_dict['timestamp']), + bucket_size=heatmap_dict['bucket_size'] + ) + heatmap.data = points + + return heatmap + else: + # Fallback to standard deserialization + return self.deserialize(data, 'pickle') + + except Exception as e: + logger.error(f"Heatmap deserialization error: {e}") + # Final fallback + return self.deserialize(data, 'pickle') + + def get_stats(self) -> Dict[str, Any]: + """Get serialization statistics""" + return self.serialization_stats.copy() + + def reset_stats(self) -> None: + """Reset serialization statistics""" + self.serialization_stats = { + 'serialized': 0, + 'deserialized': 0, + 'compression_ratio': 0.0, + 'errors': 0 + } + logger.info("Serialization statistics reset") \ No newline at end of file diff --git a/COBY/caching/redis_manager.py b/COBY/caching/redis_manager.py new file mode 100644 index 0000000..0f3206d --- /dev/null +++ b/COBY/caching/redis_manager.py @@ -0,0 +1,691 @@ +""" +Redis cache manager for high-performance data access. +""" + +import asyncio +import redis.asyncio as redis +from typing import Any, Optional, List, Dict, Union +from datetime import datetime, timedelta +from ..config import config +from ..utils.logging import get_logger, set_correlation_id +from ..utils.exceptions import StorageError +from ..utils.timing import get_current_timestamp +from .cache_keys import CacheKeys +from .data_serializer import DataSerializer + +logger = get_logger(__name__) + + +class RedisManager: + """ + High-performance Redis cache manager for market data. + + Provides: + - Connection pooling and management + - Data serialization and compression + - TTL management + - Batch operations + - Performance monitoring + """ + + def __init__(self): + """Initialize Redis manager""" + self.redis_pool: Optional[redis.ConnectionPool] = None + self.redis_client: Optional[redis.Redis] = None + self.serializer = DataSerializer(use_compression=True) + self.cache_keys = CacheKeys() + + # Performance statistics + self.stats = { + 'gets': 0, + 'sets': 0, + 'deletes': 0, + 'hits': 0, + 'misses': 0, + 'errors': 0, + 'total_data_size': 0, + 'avg_response_time': 0.0 + } + + logger.info("Redis manager initialized") + + async def initialize(self) -> None: + """Initialize Redis connection pool""" + try: + # Create connection pool + self.redis_pool = redis.ConnectionPool( + host=config.redis.host, + port=config.redis.port, + password=config.redis.password, + db=config.redis.db, + max_connections=config.redis.max_connections, + socket_timeout=config.redis.socket_timeout, + socket_connect_timeout=config.redis.socket_connect_timeout, + decode_responses=False, # We handle bytes directly + retry_on_timeout=True, + health_check_interval=30 + ) + + # Create Redis client + self.redis_client = redis.Redis(connection_pool=self.redis_pool) + + # Test connection + await self.redis_client.ping() + + logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}") + + except Exception as e: + logger.error(f"Failed to initialize Redis connection: {e}") + raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR") + + async def close(self) -> None: + """Close Redis connections""" + try: + if self.redis_client: + await self.redis_client.close() + if self.redis_pool: + await self.redis_pool.disconnect() + + logger.info("Redis connections closed") + + except Exception as e: + logger.warning(f"Error closing Redis connections: {e}") + + async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: + """ + Set value in cache with optional TTL. + + Args: + key: Cache key + value: Value to cache + ttl: Time to live in seconds (None = use default) + + Returns: + bool: True if successful, False otherwise + """ + try: + set_correlation_id() + start_time = asyncio.get_event_loop().time() + + # Serialize value + serialized_value = self.serializer.serialize(value) + + # Determine TTL + if ttl is None: + ttl = self.cache_keys.get_ttl(key) + + # Set in Redis + result = await self.redis_client.setex(key, ttl, serialized_value) + + # Update statistics + self.stats['sets'] += 1 + self.stats['total_data_size'] += len(serialized_value) + + # Update response time + response_time = asyncio.get_event_loop().time() - start_time + self._update_avg_response_time(response_time) + + logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)") + return bool(result) + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error setting cache key {key}: {e}") + return False + + async def get(self, key: str) -> Optional[Any]: + """ + Get value from cache. + + Args: + key: Cache key + + Returns: + Any: Cached value or None if not found + """ + try: + set_correlation_id() + start_time = asyncio.get_event_loop().time() + + # Get from Redis + serialized_value = await self.redis_client.get(key) + + # Update statistics + self.stats['gets'] += 1 + + if serialized_value is None: + self.stats['misses'] += 1 + logger.debug(f"Cache miss: {key}") + return None + + # Deserialize value + value = self.serializer.deserialize(serialized_value) + + # Update statistics + self.stats['hits'] += 1 + + # Update response time + response_time = asyncio.get_event_loop().time() - start_time + self._update_avg_response_time(response_time) + + logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)") + return value + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error getting cache key {key}: {e}") + return None + + async def delete(self, key: str) -> bool: + """ + Delete key from cache. + + Args: + key: Cache key to delete + + Returns: + bool: True if deleted, False otherwise + """ + try: + set_correlation_id() + + result = await self.redis_client.delete(key) + + self.stats['deletes'] += 1 + + logger.debug(f"Deleted cache key: {key}") + return bool(result) + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error deleting cache key {key}: {e}") + return False + + async def exists(self, key: str) -> bool: + """ + Check if key exists in cache. + + Args: + key: Cache key to check + + Returns: + bool: True if exists, False otherwise + """ + try: + result = await self.redis_client.exists(key) + return bool(result) + + except Exception as e: + logger.error(f"Error checking cache key existence {key}: {e}") + return False + + async def expire(self, key: str, ttl: int) -> bool: + """ + Set expiration time for key. + + Args: + key: Cache key + ttl: Time to live in seconds + + Returns: + bool: True if successful, False otherwise + """ + try: + result = await self.redis_client.expire(key, ttl) + return bool(result) + + except Exception as e: + logger.error(f"Error setting expiration for key {key}: {e}") + return False + + async def mget(self, keys: List[str]) -> List[Optional[Any]]: + """ + Get multiple values from cache. + + Args: + keys: List of cache keys + + Returns: + List[Optional[Any]]: List of values (None for missing keys) + """ + try: + set_correlation_id() + start_time = asyncio.get_event_loop().time() + + # Get from Redis + serialized_values = await self.redis_client.mget(keys) + + # Deserialize values + values = [] + for serialized_value in serialized_values: + if serialized_value is None: + values.append(None) + self.stats['misses'] += 1 + else: + try: + value = self.serializer.deserialize(serialized_value) + values.append(value) + self.stats['hits'] += 1 + except Exception as e: + logger.warning(f"Error deserializing value: {e}") + values.append(None) + self.stats['errors'] += 1 + + # Update statistics + self.stats['gets'] += len(keys) + + # Update response time + response_time = asyncio.get_event_loop().time() - start_time + self._update_avg_response_time(response_time) + + logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits") + return values + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error in multi-get: {e}") + return [None] * len(keys) + + async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool: + """ + Set multiple key-value pairs. + + Args: + key_value_pairs: Dictionary of key-value pairs + ttl: Time to live in seconds (None = use default per key) + + Returns: + bool: True if successful, False otherwise + """ + try: + set_correlation_id() + + # Serialize all values + serialized_pairs = {} + for key, value in key_value_pairs.items(): + serialized_value = self.serializer.serialize(value) + serialized_pairs[key] = serialized_value + self.stats['total_data_size'] += len(serialized_value) + + # Set in Redis + result = await self.redis_client.mset(serialized_pairs) + + # Set TTL for each key if specified + if ttl is not None: + for key in key_value_pairs.keys(): + await self.redis_client.expire(key, ttl) + else: + # Use individual TTLs + for key in key_value_pairs.keys(): + key_ttl = self.cache_keys.get_ttl(key) + await self.redis_client.expire(key, key_ttl) + + self.stats['sets'] += len(key_value_pairs) + + logger.debug(f"Multi-set: {len(key_value_pairs)} keys") + return bool(result) + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error in multi-set: {e}") + return False + + async def keys(self, pattern: str) -> List[str]: + """ + Get keys matching pattern. + + Args: + pattern: Redis pattern (e.g., "hm:*") + + Returns: + List[str]: List of matching keys + """ + try: + keys = await self.redis_client.keys(pattern) + return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys] + + except Exception as e: + logger.error(f"Error getting keys with pattern {pattern}: {e}") + return [] + + async def flushdb(self) -> bool: + """ + Clear all keys in current database. + + Returns: + bool: True if successful, False otherwise + """ + try: + result = await self.redis_client.flushdb() + logger.info("Redis database flushed") + return bool(result) + + except Exception as e: + logger.error(f"Error flushing Redis database: {e}") + return False + + async def info(self) -> Dict[str, Any]: + """ + Get Redis server information. + + Returns: + Dict: Redis server info + """ + try: + info = await self.redis_client.info() + return info + + except Exception as e: + logger.error(f"Error getting Redis info: {e}") + return {} + + async def ping(self) -> bool: + """ + Ping Redis server. + + Returns: + bool: True if server responds, False otherwise + """ + try: + result = await self.redis_client.ping() + return bool(result) + + except Exception as e: + logger.error(f"Redis ping failed: {e}") + return False + + async def set_heatmap(self, symbol: str, heatmap_data, + exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool: + """ + Cache heatmap data with optimized serialization. + + Args: + symbol: Trading symbol + heatmap_data: Heatmap data to cache + exchange: Exchange name (None for consolidated) + ttl: Time to live in seconds + + Returns: + bool: True if successful, False otherwise + """ + try: + key = self.cache_keys.heatmap_key(symbol, 1.0, exchange) + + # Use specialized heatmap serialization + serialized_value = self.serializer.serialize_heatmap(heatmap_data) + + # Determine TTL + if ttl is None: + ttl = self.cache_keys.HEATMAP_TTL + + # Set in Redis + result = await self.redis_client.setex(key, ttl, serialized_value) + + # Update statistics + self.stats['sets'] += 1 + self.stats['total_data_size'] += len(serialized_value) + + logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)") + return bool(result) + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error caching heatmap for {symbol}: {e}") + return False + + async def get_heatmap(self, symbol: str, exchange: Optional[str] = None): + """ + Get cached heatmap data with optimized deserialization. + + Args: + symbol: Trading symbol + exchange: Exchange name (None for consolidated) + + Returns: + HeatmapData: Cached heatmap or None if not found + """ + try: + key = self.cache_keys.heatmap_key(symbol, 1.0, exchange) + + # Get from Redis + serialized_value = await self.redis_client.get(key) + + self.stats['gets'] += 1 + + if serialized_value is None: + self.stats['misses'] += 1 + return None + + # Use specialized heatmap deserialization + heatmap_data = self.serializer.deserialize_heatmap(serialized_value) + + self.stats['hits'] += 1 + logger.debug(f"Retrieved heatmap: {key}") + return heatmap_data + + except Exception as e: + self.stats['errors'] += 1 + logger.error(f"Error retrieving heatmap for {symbol}: {e}") + return None + + async def cache_orderbook(self, orderbook) -> bool: + """ + Cache order book data. + + Args: + orderbook: OrderBookSnapshot to cache + + Returns: + bool: True if successful, False otherwise + """ + try: + key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange) + return await self.set(key, orderbook) + + except Exception as e: + logger.error(f"Error caching order book: {e}") + return False + + async def get_orderbook(self, symbol: str, exchange: str): + """ + Get cached order book data. + + Args: + symbol: Trading symbol + exchange: Exchange name + + Returns: + OrderBookSnapshot: Cached order book or None if not found + """ + try: + key = self.cache_keys.orderbook_key(symbol, exchange) + return await self.get(key) + + except Exception as e: + logger.error(f"Error retrieving order book: {e}") + return None + + async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool: + """ + Cache metrics data. + + Args: + metrics: Metrics data to cache + symbol: Trading symbol + exchange: Exchange name + + Returns: + bool: True if successful, False otherwise + """ + try: + key = self.cache_keys.metrics_key(symbol, exchange) + return await self.set(key, metrics) + + except Exception as e: + logger.error(f"Error caching metrics: {e}") + return False + + async def get_metrics(self, symbol: str, exchange: str): + """ + Get cached metrics data. + + Args: + symbol: Trading symbol + exchange: Exchange name + + Returns: + Metrics data or None if not found + """ + try: + key = self.cache_keys.metrics_key(symbol, exchange) + return await self.get(key) + + except Exception as e: + logger.error(f"Error retrieving metrics: {e}") + return None + + async def cache_exchange_status(self, exchange: str, status_data) -> bool: + """ + Cache exchange status. + + Args: + exchange: Exchange name + status_data: Status data to cache + + Returns: + bool: True if successful, False otherwise + """ + try: + key = self.cache_keys.status_key(exchange) + return await self.set(key, status_data) + + except Exception as e: + logger.error(f"Error caching exchange status: {e}") + return False + + async def get_exchange_status(self, exchange: str): + """ + Get cached exchange status. + + Args: + exchange: Exchange name + + Returns: + Status data or None if not found + """ + try: + key = self.cache_keys.status_key(exchange) + return await self.get(key) + + except Exception as e: + logger.error(f"Error retrieving exchange status: {e}") + return None + + async def cleanup_expired_keys(self) -> int: + """ + Clean up expired keys (Redis handles this automatically, but we can force it). + + Returns: + int: Number of keys cleaned up + """ + try: + # Get all keys + all_keys = await self.keys("*") + + # Check which ones are expired + expired_count = 0 + for key in all_keys: + ttl = await self.redis_client.ttl(key) + if ttl == -2: # Key doesn't exist (expired) + expired_count += 1 + + logger.debug(f"Found {expired_count} expired keys") + return expired_count + + except Exception as e: + logger.error(f"Error cleaning up expired keys: {e}") + return 0 + + def _update_avg_response_time(self, response_time: float) -> None: + """Update average response time""" + total_operations = self.stats['gets'] + self.stats['sets'] + if total_operations > 0: + self.stats['avg_response_time'] = ( + (self.stats['avg_response_time'] * (total_operations - 1) + response_time) / + total_operations + ) + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics""" + total_operations = self.stats['gets'] + self.stats['sets'] + hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100 + + return { + **self.stats, + 'total_operations': total_operations, + 'hit_rate_percentage': hit_rate, + 'serializer_stats': self.serializer.get_stats() + } + + def reset_stats(self) -> None: + """Reset cache statistics""" + self.stats = { + 'gets': 0, + 'sets': 0, + 'deletes': 0, + 'hits': 0, + 'misses': 0, + 'errors': 0, + 'total_data_size': 0, + 'avg_response_time': 0.0 + } + self.serializer.reset_stats() + logger.info("Redis manager statistics reset") + + async def health_check(self) -> Dict[str, Any]: + """ + Perform comprehensive health check. + + Returns: + Dict: Health check results + """ + health = { + 'redis_ping': False, + 'connection_pool_size': 0, + 'memory_usage': 0, + 'connected_clients': 0, + 'total_keys': 0, + 'hit_rate': 0.0, + 'avg_response_time': self.stats['avg_response_time'] + } + + try: + # Test ping + health['redis_ping'] = await self.ping() + + # Get Redis info + info = await self.info() + if info: + health['memory_usage'] = info.get('used_memory', 0) + health['connected_clients'] = info.get('connected_clients', 0) + + # Get key count + all_keys = await self.keys("*") + health['total_keys'] = len(all_keys) + + # Calculate hit rate + if self.stats['gets'] > 0: + health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100 + + # Connection pool info + if self.redis_pool: + health['connection_pool_size'] = self.redis_pool.max_connections + + except Exception as e: + logger.error(f"Health check error: {e}") + + return health + + +# Global Redis manager instance +redis_manager = RedisManager() \ No newline at end of file diff --git a/COBY/tests/test_redis_manager.py b/COBY/tests/test_redis_manager.py new file mode 100644 index 0000000..22821be --- /dev/null +++ b/COBY/tests/test_redis_manager.py @@ -0,0 +1,347 @@ +""" +Tests for Redis caching system. +""" + +import pytest +import asyncio +from datetime import datetime, timezone +from ..caching.redis_manager import RedisManager +from ..caching.cache_keys import CacheKeys +from ..caching.data_serializer import DataSerializer +from ..models.core import OrderBookSnapshot, HeatmapData, PriceLevel, HeatmapPoint + + +@pytest.fixture +async def redis_manager(): + """Create and initialize Redis manager for testing""" + manager = RedisManager() + await manager.initialize() + yield manager + await manager.close() + + +@pytest.fixture +def cache_keys(): + """Create cache keys helper""" + return CacheKeys() + + +@pytest.fixture +def data_serializer(): + """Create data serializer""" + return DataSerializer() + + +@pytest.fixture +def sample_orderbook(): + """Create sample order book for testing""" + return OrderBookSnapshot( + symbol="BTCUSDT", + exchange="binance", + timestamp=datetime.now(timezone.utc), + bids=[ + PriceLevel(price=50000.0, size=1.5), + PriceLevel(price=49999.0, size=2.0) + ], + asks=[ + PriceLevel(price=50001.0, size=1.0), + PriceLevel(price=50002.0, size=1.5) + ] + ) + + +@pytest.fixture +def sample_heatmap(): + """Create sample heatmap for testing""" + heatmap = HeatmapData( + symbol="BTCUSDT", + timestamp=datetime.now(timezone.utc), + bucket_size=1.0 + ) + + # Add some sample points + heatmap.data = [ + HeatmapPoint(price=50000.0, volume=1.5, intensity=0.8, side='bid'), + HeatmapPoint(price=50001.0, volume=1.0, intensity=0.6, side='ask'), + HeatmapPoint(price=49999.0, volume=2.0, intensity=1.0, side='bid'), + HeatmapPoint(price=50002.0, volume=1.5, intensity=0.7, side='ask') + ] + + return heatmap + + +class TestCacheKeys: + """Test cases for CacheKeys""" + + def test_orderbook_key_generation(self, cache_keys): + """Test order book key generation""" + key = cache_keys.orderbook_key("BTCUSDT", "binance") + assert key == "ob:binance:BTCUSDT" + + def test_heatmap_key_generation(self, cache_keys): + """Test heatmap key generation""" + # Exchange-specific heatmap + key1 = cache_keys.heatmap_key("BTCUSDT", 1.0, "binance") + assert key1 == "hm:binance:BTCUSDT:1.0" + + # Consolidated heatmap + key2 = cache_keys.heatmap_key("BTCUSDT", 1.0) + assert key2 == "hm:consolidated:BTCUSDT:1.0" + + def test_ttl_determination(self, cache_keys): + """Test TTL determination for different key types""" + ob_key = cache_keys.orderbook_key("BTCUSDT", "binance") + hm_key = cache_keys.heatmap_key("BTCUSDT", 1.0) + + assert cache_keys.get_ttl(ob_key) == cache_keys.ORDERBOOK_TTL + assert cache_keys.get_ttl(hm_key) == cache_keys.HEATMAP_TTL + + def test_key_parsing(self, cache_keys): + """Test cache key parsing""" + ob_key = cache_keys.orderbook_key("BTCUSDT", "binance") + parsed = cache_keys.parse_key(ob_key) + + assert parsed['type'] == 'orderbook' + assert parsed['exchange'] == 'binance' + assert parsed['symbol'] == 'BTCUSDT' + + +class TestDataSerializer: + """Test cases for DataSerializer""" + + def test_simple_data_serialization(self, data_serializer): + """Test serialization of simple data types""" + test_data = { + 'string': 'test', + 'number': 42, + 'float': 3.14, + 'boolean': True, + 'list': [1, 2, 3], + 'nested': {'key': 'value'} + } + + # Serialize and deserialize + serialized = data_serializer.serialize(test_data) + deserialized = data_serializer.deserialize(serialized) + + assert deserialized == test_data + + def test_orderbook_serialization(self, data_serializer, sample_orderbook): + """Test order book serialization""" + # Serialize and deserialize + serialized = data_serializer.serialize(sample_orderbook) + deserialized = data_serializer.deserialize(serialized) + + assert isinstance(deserialized, OrderBookSnapshot) + assert deserialized.symbol == sample_orderbook.symbol + assert deserialized.exchange == sample_orderbook.exchange + assert len(deserialized.bids) == len(sample_orderbook.bids) + assert len(deserialized.asks) == len(sample_orderbook.asks) + + def test_heatmap_serialization(self, data_serializer, sample_heatmap): + """Test heatmap serialization""" + # Test specialized heatmap serialization + serialized = data_serializer.serialize_heatmap(sample_heatmap) + deserialized = data_serializer.deserialize_heatmap(serialized) + + assert isinstance(deserialized, HeatmapData) + assert deserialized.symbol == sample_heatmap.symbol + assert deserialized.bucket_size == sample_heatmap.bucket_size + assert len(deserialized.data) == len(sample_heatmap.data) + + # Check first point + original_point = sample_heatmap.data[0] + deserialized_point = deserialized.data[0] + assert deserialized_point.price == original_point.price + assert deserialized_point.volume == original_point.volume + assert deserialized_point.side == original_point.side + + +class TestRedisManager: + """Test cases for RedisManager""" + + @pytest.mark.asyncio + async def test_basic_set_get(self, redis_manager): + """Test basic set and get operations""" + # Set a simple value + key = "test:basic" + value = {"test": "data", "number": 42} + + success = await redis_manager.set(key, value, ttl=60) + assert success is True + + # Get the value back + retrieved = await redis_manager.get(key) + assert retrieved == value + + # Clean up + await redis_manager.delete(key) + + @pytest.mark.asyncio + async def test_orderbook_caching(self, redis_manager, sample_orderbook): + """Test order book caching""" + # Cache order book + success = await redis_manager.cache_orderbook(sample_orderbook) + assert success is True + + # Retrieve order book + retrieved = await redis_manager.get_orderbook( + sample_orderbook.symbol, + sample_orderbook.exchange + ) + + assert retrieved is not None + assert isinstance(retrieved, OrderBookSnapshot) + assert retrieved.symbol == sample_orderbook.symbol + assert retrieved.exchange == sample_orderbook.exchange + + @pytest.mark.asyncio + async def test_heatmap_caching(self, redis_manager, sample_heatmap): + """Test heatmap caching""" + # Cache heatmap + success = await redis_manager.set_heatmap( + sample_heatmap.symbol, + sample_heatmap, + exchange="binance" + ) + assert success is True + + # Retrieve heatmap + retrieved = await redis_manager.get_heatmap( + sample_heatmap.symbol, + exchange="binance" + ) + + assert retrieved is not None + assert isinstance(retrieved, HeatmapData) + assert retrieved.symbol == sample_heatmap.symbol + assert len(retrieved.data) == len(sample_heatmap.data) + + @pytest.mark.asyncio + async def test_multi_operations(self, redis_manager): + """Test multi-get and multi-set operations""" + # Prepare test data + test_data = { + "test:multi1": {"value": 1}, + "test:multi2": {"value": 2}, + "test:multi3": {"value": 3} + } + + # Multi-set + success = await redis_manager.mset(test_data, ttl=60) + assert success is True + + # Multi-get + keys = list(test_data.keys()) + values = await redis_manager.mget(keys) + + assert len(values) == 3 + assert all(v is not None for v in values) + + # Verify values + for i, key in enumerate(keys): + assert values[i] == test_data[key] + + # Clean up + for key in keys: + await redis_manager.delete(key) + + @pytest.mark.asyncio + async def test_key_expiration(self, redis_manager): + """Test key expiration""" + key = "test:expiration" + value = {"expires": "soon"} + + # Set with short TTL + success = await redis_manager.set(key, value, ttl=1) + assert success is True + + # Should exist immediately + exists = await redis_manager.exists(key) + assert exists is True + + # Wait for expiration + await asyncio.sleep(2) + + # Should not exist after expiration + exists = await redis_manager.exists(key) + assert exists is False + + @pytest.mark.asyncio + async def test_cache_miss(self, redis_manager): + """Test cache miss behavior""" + # Try to get non-existent key + value = await redis_manager.get("test:nonexistent") + assert value is None + + # Check statistics + stats = redis_manager.get_stats() + assert stats['misses'] > 0 + + @pytest.mark.asyncio + async def test_health_check(self, redis_manager): + """Test Redis health check""" + health = await redis_manager.health_check() + + assert isinstance(health, dict) + assert 'redis_ping' in health + assert 'total_keys' in health + assert 'hit_rate' in health + + # Should be able to ping + assert health['redis_ping'] is True + + @pytest.mark.asyncio + async def test_statistics_tracking(self, redis_manager): + """Test statistics tracking""" + # Reset stats + redis_manager.reset_stats() + + # Perform some operations + await redis_manager.set("test:stats1", {"data": 1}) + await redis_manager.set("test:stats2", {"data": 2}) + await redis_manager.get("test:stats1") + await redis_manager.get("test:nonexistent") + + # Check statistics + stats = redis_manager.get_stats() + + assert stats['sets'] >= 2 + assert stats['gets'] >= 2 + assert stats['hits'] >= 1 + assert stats['misses'] >= 1 + assert stats['total_operations'] >= 4 + + # Clean up + await redis_manager.delete("test:stats1") + await redis_manager.delete("test:stats2") + + +if __name__ == "__main__": + # Run simple tests + async def simple_test(): + manager = RedisManager() + await manager.initialize() + + # Test basic operations + success = await manager.set("test", {"simple": "test"}, ttl=60) + print(f"Set operation: {'SUCCESS' if success else 'FAILED'}") + + value = await manager.get("test") + print(f"Get operation: {'SUCCESS' if value else 'FAILED'}") + + # Test ping + ping_result = await manager.ping() + print(f"Ping test: {'SUCCESS' if ping_result else 'FAILED'}") + + # Get statistics + stats = manager.get_stats() + print(f"Statistics: {stats}") + + # Clean up + await manager.delete("test") + await manager.close() + + print("Simple Redis test completed") + + asyncio.run(simple_test()) \ No newline at end of file