""" Redis cache manager for high-performance data access. """ import asyncio import redis.asyncio as redis from typing import Any, Optional, List, Dict, Union from datetime import datetime, timedelta from ..config import config from ..utils.logging import get_logger, set_correlation_id from ..utils.exceptions import StorageError from ..utils.timing import get_current_timestamp from .cache_keys import CacheKeys from .data_serializer import DataSerializer logger = get_logger(__name__) class RedisManager: """ High-performance Redis cache manager for market data. Provides: - Connection pooling and management - Data serialization and compression - TTL management - Batch operations - Performance monitoring """ def __init__(self): """Initialize Redis manager""" self.redis_pool: Optional[redis.ConnectionPool] = None self.redis_client: Optional[redis.Redis] = None self.serializer = DataSerializer(use_compression=True) self.cache_keys = CacheKeys() # Performance statistics self.stats = { 'gets': 0, 'sets': 0, 'deletes': 0, 'hits': 0, 'misses': 0, 'errors': 0, 'total_data_size': 0, 'avg_response_time': 0.0 } logger.info("Redis manager initialized") async def initialize(self) -> None: """Initialize Redis connection pool""" try: # Create connection pool self.redis_pool = redis.ConnectionPool( host=config.redis.host, port=config.redis.port, password=config.redis.password, db=config.redis.db, max_connections=config.redis.max_connections, socket_timeout=config.redis.socket_timeout, socket_connect_timeout=config.redis.socket_connect_timeout, decode_responses=False, # We handle bytes directly retry_on_timeout=True, health_check_interval=30 ) # Create Redis client self.redis_client = redis.Redis(connection_pool=self.redis_pool) # Test connection await self.redis_client.ping() logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}") except Exception as e: logger.error(f"Failed to initialize Redis connection: {e}") raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR") async def close(self) -> None: """Close Redis connections""" try: if self.redis_client: await self.redis_client.close() if self.redis_pool: await self.redis_pool.disconnect() logger.info("Redis connections closed") except Exception as e: logger.warning(f"Error closing Redis connections: {e}") async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: """ Set value in cache with optional TTL. Args: key: Cache key value: Value to cache ttl: Time to live in seconds (None = use default) Returns: bool: True if successful, False otherwise """ try: set_correlation_id() start_time = asyncio.get_event_loop().time() # Serialize value serialized_value = self.serializer.serialize(value) # Determine TTL if ttl is None: ttl = self.cache_keys.get_ttl(key) # Set in Redis result = await self.redis_client.setex(key, ttl, serialized_value) # Update statistics self.stats['sets'] += 1 self.stats['total_data_size'] += len(serialized_value) # Update response time response_time = asyncio.get_event_loop().time() - start_time self._update_avg_response_time(response_time) logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)") return bool(result) except Exception as e: self.stats['errors'] += 1 logger.error(f"Error setting cache key {key}: {e}") return False async def get(self, key: str) -> Optional[Any]: """ Get value from cache. Args: key: Cache key Returns: Any: Cached value or None if not found """ try: set_correlation_id() start_time = asyncio.get_event_loop().time() # Get from Redis serialized_value = await self.redis_client.get(key) # Update statistics self.stats['gets'] += 1 if serialized_value is None: self.stats['misses'] += 1 logger.debug(f"Cache miss: {key}") return None # Deserialize value value = self.serializer.deserialize(serialized_value) # Update statistics self.stats['hits'] += 1 # Update response time response_time = asyncio.get_event_loop().time() - start_time self._update_avg_response_time(response_time) logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)") return value except Exception as e: self.stats['errors'] += 1 logger.error(f"Error getting cache key {key}: {e}") return None async def delete(self, key: str) -> bool: """ Delete key from cache. Args: key: Cache key to delete Returns: bool: True if deleted, False otherwise """ try: set_correlation_id() result = await self.redis_client.delete(key) self.stats['deletes'] += 1 logger.debug(f"Deleted cache key: {key}") return bool(result) except Exception as e: self.stats['errors'] += 1 logger.error(f"Error deleting cache key {key}: {e}") return False async def exists(self, key: str) -> bool: """ Check if key exists in cache. Args: key: Cache key to check Returns: bool: True if exists, False otherwise """ try: result = await self.redis_client.exists(key) return bool(result) except Exception as e: logger.error(f"Error checking cache key existence {key}: {e}") return False async def expire(self, key: str, ttl: int) -> bool: """ Set expiration time for key. Args: key: Cache key ttl: Time to live in seconds Returns: bool: True if successful, False otherwise """ try: result = await self.redis_client.expire(key, ttl) return bool(result) except Exception as e: logger.error(f"Error setting expiration for key {key}: {e}") return False async def mget(self, keys: List[str]) -> List[Optional[Any]]: """ Get multiple values from cache. Args: keys: List of cache keys Returns: List[Optional[Any]]: List of values (None for missing keys) """ try: set_correlation_id() start_time = asyncio.get_event_loop().time() # Get from Redis serialized_values = await self.redis_client.mget(keys) # Deserialize values values = [] for serialized_value in serialized_values: if serialized_value is None: values.append(None) self.stats['misses'] += 1 else: try: value = self.serializer.deserialize(serialized_value) values.append(value) self.stats['hits'] += 1 except Exception as e: logger.warning(f"Error deserializing value: {e}") values.append(None) self.stats['errors'] += 1 # Update statistics self.stats['gets'] += len(keys) # Update response time response_time = asyncio.get_event_loop().time() - start_time self._update_avg_response_time(response_time) logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits") return values except Exception as e: self.stats['errors'] += 1 logger.error(f"Error in multi-get: {e}") return [None] * len(keys) async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool: """ Set multiple key-value pairs. Args: key_value_pairs: Dictionary of key-value pairs ttl: Time to live in seconds (None = use default per key) Returns: bool: True if successful, False otherwise """ try: set_correlation_id() # Serialize all values serialized_pairs = {} for key, value in key_value_pairs.items(): serialized_value = self.serializer.serialize(value) serialized_pairs[key] = serialized_value self.stats['total_data_size'] += len(serialized_value) # Set in Redis result = await self.redis_client.mset(serialized_pairs) # Set TTL for each key if specified if ttl is not None: for key in key_value_pairs.keys(): await self.redis_client.expire(key, ttl) else: # Use individual TTLs for key in key_value_pairs.keys(): key_ttl = self.cache_keys.get_ttl(key) await self.redis_client.expire(key, key_ttl) self.stats['sets'] += len(key_value_pairs) logger.debug(f"Multi-set: {len(key_value_pairs)} keys") return bool(result) except Exception as e: self.stats['errors'] += 1 logger.error(f"Error in multi-set: {e}") return False async def keys(self, pattern: str) -> List[str]: """ Get keys matching pattern. Args: pattern: Redis pattern (e.g., "hm:*") Returns: List[str]: List of matching keys """ try: keys = await self.redis_client.keys(pattern) return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys] except Exception as e: logger.error(f"Error getting keys with pattern {pattern}: {e}") return [] async def flushdb(self) -> bool: """ Clear all keys in current database. Returns: bool: True if successful, False otherwise """ try: result = await self.redis_client.flushdb() logger.info("Redis database flushed") return bool(result) except Exception as e: logger.error(f"Error flushing Redis database: {e}") return False async def info(self) -> Dict[str, Any]: """ Get Redis server information. Returns: Dict: Redis server info """ try: info = await self.redis_client.info() return info except Exception as e: logger.error(f"Error getting Redis info: {e}") return {} async def ping(self) -> bool: """ Ping Redis server. Returns: bool: True if server responds, False otherwise """ try: result = await self.redis_client.ping() return bool(result) except Exception as e: logger.error(f"Redis ping failed: {e}") return False async def set_heatmap(self, symbol: str, heatmap_data, exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool: """ Cache heatmap data with optimized serialization. Args: symbol: Trading symbol heatmap_data: Heatmap data to cache exchange: Exchange name (None for consolidated) ttl: Time to live in seconds Returns: bool: True if successful, False otherwise """ try: key = self.cache_keys.heatmap_key(symbol, 1.0, exchange) # Use specialized heatmap serialization serialized_value = self.serializer.serialize_heatmap(heatmap_data) # Determine TTL if ttl is None: ttl = self.cache_keys.HEATMAP_TTL # Set in Redis result = await self.redis_client.setex(key, ttl, serialized_value) # Update statistics self.stats['sets'] += 1 self.stats['total_data_size'] += len(serialized_value) logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)") return bool(result) except Exception as e: self.stats['errors'] += 1 logger.error(f"Error caching heatmap for {symbol}: {e}") return False async def get_heatmap(self, symbol: str, exchange: Optional[str] = None): """ Get cached heatmap data with optimized deserialization. Args: symbol: Trading symbol exchange: Exchange name (None for consolidated) Returns: HeatmapData: Cached heatmap or None if not found """ try: key = self.cache_keys.heatmap_key(symbol, 1.0, exchange) # Get from Redis serialized_value = await self.redis_client.get(key) self.stats['gets'] += 1 if serialized_value is None: self.stats['misses'] += 1 return None # Use specialized heatmap deserialization heatmap_data = self.serializer.deserialize_heatmap(serialized_value) self.stats['hits'] += 1 logger.debug(f"Retrieved heatmap: {key}") return heatmap_data except Exception as e: self.stats['errors'] += 1 logger.error(f"Error retrieving heatmap for {symbol}: {e}") return None async def cache_orderbook(self, orderbook) -> bool: """ Cache order book data. Args: orderbook: OrderBookSnapshot to cache Returns: bool: True if successful, False otherwise """ try: key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange) return await self.set(key, orderbook) except Exception as e: logger.error(f"Error caching order book: {e}") return False async def get_orderbook(self, symbol: str, exchange: str): """ Get cached order book data. Args: symbol: Trading symbol exchange: Exchange name Returns: OrderBookSnapshot: Cached order book or None if not found """ try: key = self.cache_keys.orderbook_key(symbol, exchange) return await self.get(key) except Exception as e: logger.error(f"Error retrieving order book: {e}") return None async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool: """ Cache metrics data. Args: metrics: Metrics data to cache symbol: Trading symbol exchange: Exchange name Returns: bool: True if successful, False otherwise """ try: key = self.cache_keys.metrics_key(symbol, exchange) return await self.set(key, metrics) except Exception as e: logger.error(f"Error caching metrics: {e}") return False async def get_metrics(self, symbol: str, exchange: str): """ Get cached metrics data. Args: symbol: Trading symbol exchange: Exchange name Returns: Metrics data or None if not found """ try: key = self.cache_keys.metrics_key(symbol, exchange) return await self.get(key) except Exception as e: logger.error(f"Error retrieving metrics: {e}") return None async def cache_exchange_status(self, exchange: str, status_data) -> bool: """ Cache exchange status. Args: exchange: Exchange name status_data: Status data to cache Returns: bool: True if successful, False otherwise """ try: key = self.cache_keys.status_key(exchange) return await self.set(key, status_data) except Exception as e: logger.error(f"Error caching exchange status: {e}") return False async def get_exchange_status(self, exchange: str): """ Get cached exchange status. Args: exchange: Exchange name Returns: Status data or None if not found """ try: key = self.cache_keys.status_key(exchange) return await self.get(key) except Exception as e: logger.error(f"Error retrieving exchange status: {e}") return None async def cleanup_expired_keys(self) -> int: """ Clean up expired keys (Redis handles this automatically, but we can force it). Returns: int: Number of keys cleaned up """ try: # Get all keys all_keys = await self.keys("*") # Check which ones are expired expired_count = 0 for key in all_keys: ttl = await self.redis_client.ttl(key) if ttl == -2: # Key doesn't exist (expired) expired_count += 1 logger.debug(f"Found {expired_count} expired keys") return expired_count except Exception as e: logger.error(f"Error cleaning up expired keys: {e}") return 0 def _update_avg_response_time(self, response_time: float) -> None: """Update average response time""" total_operations = self.stats['gets'] + self.stats['sets'] if total_operations > 0: self.stats['avg_response_time'] = ( (self.stats['avg_response_time'] * (total_operations - 1) + response_time) / total_operations ) def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" total_operations = self.stats['gets'] + self.stats['sets'] hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100 return { **self.stats, 'total_operations': total_operations, 'hit_rate_percentage': hit_rate, 'serializer_stats': self.serializer.get_stats() } def reset_stats(self) -> None: """Reset cache statistics""" self.stats = { 'gets': 0, 'sets': 0, 'deletes': 0, 'hits': 0, 'misses': 0, 'errors': 0, 'total_data_size': 0, 'avg_response_time': 0.0 } self.serializer.reset_stats() logger.info("Redis manager statistics reset") async def health_check(self) -> Dict[str, Any]: """ Perform comprehensive health check. Returns: Dict: Health check results """ health = { 'redis_ping': False, 'connection_pool_size': 0, 'memory_usage': 0, 'connected_clients': 0, 'total_keys': 0, 'hit_rate': 0.0, 'avg_response_time': self.stats['avg_response_time'] } try: # Test ping health['redis_ping'] = await self.ping() # Get Redis info info = await self.info() if info: health['memory_usage'] = info.get('used_memory', 0) health['connected_clients'] = info.get('connected_clients', 0) # Get key count all_keys = await self.keys("*") health['total_keys'] = len(all_keys) # Calculate hit rate if self.stats['gets'] > 0: health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100 # Connection pool info if self.redis_pool: health['connection_pool_size'] = self.redis_pool.max_connections except Exception as e: logger.error(f"Health check error: {e}") return health # Global Redis manager instance redis_manager = RedisManager()