Files
gogo2/COBY/caching/redis_manager.py
Dobromir Popov ff75af566c caching
2025-08-04 17:55:00 +03:00

691 lines
22 KiB
Python

"""
Redis cache manager for high-performance data access.
"""
import asyncio
import redis.asyncio as redis
from typing import Any, Optional, List, Dict, Union
from datetime import datetime, timedelta
from ..config import config
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import StorageError
from ..utils.timing import get_current_timestamp
from .cache_keys import CacheKeys
from .data_serializer import DataSerializer
logger = get_logger(__name__)
class RedisManager:
"""
High-performance Redis cache manager for market data.
Provides:
- Connection pooling and management
- Data serialization and compression
- TTL management
- Batch operations
- Performance monitoring
"""
def __init__(self):
"""Initialize Redis manager"""
self.redis_pool: Optional[redis.ConnectionPool] = None
self.redis_client: Optional[redis.Redis] = None
self.serializer = DataSerializer(use_compression=True)
self.cache_keys = CacheKeys()
# Performance statistics
self.stats = {
'gets': 0,
'sets': 0,
'deletes': 0,
'hits': 0,
'misses': 0,
'errors': 0,
'total_data_size': 0,
'avg_response_time': 0.0
}
logger.info("Redis manager initialized")
async def initialize(self) -> None:
"""Initialize Redis connection pool"""
try:
# Create connection pool
self.redis_pool = redis.ConnectionPool(
host=config.redis.host,
port=config.redis.port,
password=config.redis.password,
db=config.redis.db,
max_connections=config.redis.max_connections,
socket_timeout=config.redis.socket_timeout,
socket_connect_timeout=config.redis.socket_connect_timeout,
decode_responses=False, # We handle bytes directly
retry_on_timeout=True,
health_check_interval=30
)
# Create Redis client
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
# Test connection
await self.redis_client.ping()
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
except Exception as e:
logger.error(f"Failed to initialize Redis connection: {e}")
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
async def close(self) -> None:
"""Close Redis connections"""
try:
if self.redis_client:
await self.redis_client.close()
if self.redis_pool:
await self.redis_pool.disconnect()
logger.info("Redis connections closed")
except Exception as e:
logger.warning(f"Error closing Redis connections: {e}")
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache with optional TTL.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds (None = use default)
Returns:
bool: True if successful, False otherwise
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Serialize value
serialized_value = self.serializer.serialize(value)
# Determine TTL
if ttl is None:
ttl = self.cache_keys.get_ttl(key)
# Set in Redis
result = await self.redis_client.setex(key, ttl, serialized_value)
# Update statistics
self.stats['sets'] += 1
self.stats['total_data_size'] += len(serialized_value)
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error setting cache key {key}: {e}")
return False
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache.
Args:
key: Cache key
Returns:
Any: Cached value or None if not found
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Get from Redis
serialized_value = await self.redis_client.get(key)
# Update statistics
self.stats['gets'] += 1
if serialized_value is None:
self.stats['misses'] += 1
logger.debug(f"Cache miss: {key}")
return None
# Deserialize value
value = self.serializer.deserialize(serialized_value)
# Update statistics
self.stats['hits'] += 1
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
return value
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error getting cache key {key}: {e}")
return None
async def delete(self, key: str) -> bool:
"""
Delete key from cache.
Args:
key: Cache key to delete
Returns:
bool: True if deleted, False otherwise
"""
try:
set_correlation_id()
result = await self.redis_client.delete(key)
self.stats['deletes'] += 1
logger.debug(f"Deleted cache key: {key}")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error deleting cache key {key}: {e}")
return False
async def exists(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key to check
Returns:
bool: True if exists, False otherwise
"""
try:
result = await self.redis_client.exists(key)
return bool(result)
except Exception as e:
logger.error(f"Error checking cache key existence {key}: {e}")
return False
async def expire(self, key: str, ttl: int) -> bool:
"""
Set expiration time for key.
Args:
key: Cache key
ttl: Time to live in seconds
Returns:
bool: True if successful, False otherwise
"""
try:
result = await self.redis_client.expire(key, ttl)
return bool(result)
except Exception as e:
logger.error(f"Error setting expiration for key {key}: {e}")
return False
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
"""
Get multiple values from cache.
Args:
keys: List of cache keys
Returns:
List[Optional[Any]]: List of values (None for missing keys)
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Get from Redis
serialized_values = await self.redis_client.mget(keys)
# Deserialize values
values = []
for serialized_value in serialized_values:
if serialized_value is None:
values.append(None)
self.stats['misses'] += 1
else:
try:
value = self.serializer.deserialize(serialized_value)
values.append(value)
self.stats['hits'] += 1
except Exception as e:
logger.warning(f"Error deserializing value: {e}")
values.append(None)
self.stats['errors'] += 1
# Update statistics
self.stats['gets'] += len(keys)
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
return values
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error in multi-get: {e}")
return [None] * len(keys)
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
"""
Set multiple key-value pairs.
Args:
key_value_pairs: Dictionary of key-value pairs
ttl: Time to live in seconds (None = use default per key)
Returns:
bool: True if successful, False otherwise
"""
try:
set_correlation_id()
# Serialize all values
serialized_pairs = {}
for key, value in key_value_pairs.items():
serialized_value = self.serializer.serialize(value)
serialized_pairs[key] = serialized_value
self.stats['total_data_size'] += len(serialized_value)
# Set in Redis
result = await self.redis_client.mset(serialized_pairs)
# Set TTL for each key if specified
if ttl is not None:
for key in key_value_pairs.keys():
await self.redis_client.expire(key, ttl)
else:
# Use individual TTLs
for key in key_value_pairs.keys():
key_ttl = self.cache_keys.get_ttl(key)
await self.redis_client.expire(key, key_ttl)
self.stats['sets'] += len(key_value_pairs)
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error in multi-set: {e}")
return False
async def keys(self, pattern: str) -> List[str]:
"""
Get keys matching pattern.
Args:
pattern: Redis pattern (e.g., "hm:*")
Returns:
List[str]: List of matching keys
"""
try:
keys = await self.redis_client.keys(pattern)
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
except Exception as e:
logger.error(f"Error getting keys with pattern {pattern}: {e}")
return []
async def flushdb(self) -> bool:
"""
Clear all keys in current database.
Returns:
bool: True if successful, False otherwise
"""
try:
result = await self.redis_client.flushdb()
logger.info("Redis database flushed")
return bool(result)
except Exception as e:
logger.error(f"Error flushing Redis database: {e}")
return False
async def info(self) -> Dict[str, Any]:
"""
Get Redis server information.
Returns:
Dict: Redis server info
"""
try:
info = await self.redis_client.info()
return info
except Exception as e:
logger.error(f"Error getting Redis info: {e}")
return {}
async def ping(self) -> bool:
"""
Ping Redis server.
Returns:
bool: True if server responds, False otherwise
"""
try:
result = await self.redis_client.ping()
return bool(result)
except Exception as e:
logger.error(f"Redis ping failed: {e}")
return False
async def set_heatmap(self, symbol: str, heatmap_data,
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
"""
Cache heatmap data with optimized serialization.
Args:
symbol: Trading symbol
heatmap_data: Heatmap data to cache
exchange: Exchange name (None for consolidated)
ttl: Time to live in seconds
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
# Use specialized heatmap serialization
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
# Determine TTL
if ttl is None:
ttl = self.cache_keys.HEATMAP_TTL
# Set in Redis
result = await self.redis_client.setex(key, ttl, serialized_value)
# Update statistics
self.stats['sets'] += 1
self.stats['total_data_size'] += len(serialized_value)
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error caching heatmap for {symbol}: {e}")
return False
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
"""
Get cached heatmap data with optimized deserialization.
Args:
symbol: Trading symbol
exchange: Exchange name (None for consolidated)
Returns:
HeatmapData: Cached heatmap or None if not found
"""
try:
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
# Get from Redis
serialized_value = await self.redis_client.get(key)
self.stats['gets'] += 1
if serialized_value is None:
self.stats['misses'] += 1
return None
# Use specialized heatmap deserialization
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
self.stats['hits'] += 1
logger.debug(f"Retrieved heatmap: {key}")
return heatmap_data
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
return None
async def cache_orderbook(self, orderbook) -> bool:
"""
Cache order book data.
Args:
orderbook: OrderBookSnapshot to cache
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
return await self.set(key, orderbook)
except Exception as e:
logger.error(f"Error caching order book: {e}")
return False
async def get_orderbook(self, symbol: str, exchange: str):
"""
Get cached order book data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
OrderBookSnapshot: Cached order book or None if not found
"""
try:
key = self.cache_keys.orderbook_key(symbol, exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving order book: {e}")
return None
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
"""
Cache metrics data.
Args:
metrics: Metrics data to cache
symbol: Trading symbol
exchange: Exchange name
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.metrics_key(symbol, exchange)
return await self.set(key, metrics)
except Exception as e:
logger.error(f"Error caching metrics: {e}")
return False
async def get_metrics(self, symbol: str, exchange: str):
"""
Get cached metrics data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
Metrics data or None if not found
"""
try:
key = self.cache_keys.metrics_key(symbol, exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving metrics: {e}")
return None
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
"""
Cache exchange status.
Args:
exchange: Exchange name
status_data: Status data to cache
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.status_key(exchange)
return await self.set(key, status_data)
except Exception as e:
logger.error(f"Error caching exchange status: {e}")
return False
async def get_exchange_status(self, exchange: str):
"""
Get cached exchange status.
Args:
exchange: Exchange name
Returns:
Status data or None if not found
"""
try:
key = self.cache_keys.status_key(exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving exchange status: {e}")
return None
async def cleanup_expired_keys(self) -> int:
"""
Clean up expired keys (Redis handles this automatically, but we can force it).
Returns:
int: Number of keys cleaned up
"""
try:
# Get all keys
all_keys = await self.keys("*")
# Check which ones are expired
expired_count = 0
for key in all_keys:
ttl = await self.redis_client.ttl(key)
if ttl == -2: # Key doesn't exist (expired)
expired_count += 1
logger.debug(f"Found {expired_count} expired keys")
return expired_count
except Exception as e:
logger.error(f"Error cleaning up expired keys: {e}")
return 0
def _update_avg_response_time(self, response_time: float) -> None:
"""Update average response time"""
total_operations = self.stats['gets'] + self.stats['sets']
if total_operations > 0:
self.stats['avg_response_time'] = (
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
total_operations
)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_operations = self.stats['gets'] + self.stats['sets']
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
return {
**self.stats,
'total_operations': total_operations,
'hit_rate_percentage': hit_rate,
'serializer_stats': self.serializer.get_stats()
}
def reset_stats(self) -> None:
"""Reset cache statistics"""
self.stats = {
'gets': 0,
'sets': 0,
'deletes': 0,
'hits': 0,
'misses': 0,
'errors': 0,
'total_data_size': 0,
'avg_response_time': 0.0
}
self.serializer.reset_stats()
logger.info("Redis manager statistics reset")
async def health_check(self) -> Dict[str, Any]:
"""
Perform comprehensive health check.
Returns:
Dict: Health check results
"""
health = {
'redis_ping': False,
'connection_pool_size': 0,
'memory_usage': 0,
'connected_clients': 0,
'total_keys': 0,
'hit_rate': 0.0,
'avg_response_time': self.stats['avg_response_time']
}
try:
# Test ping
health['redis_ping'] = await self.ping()
# Get Redis info
info = await self.info()
if info:
health['memory_usage'] = info.get('used_memory', 0)
health['connected_clients'] = info.get('connected_clients', 0)
# Get key count
all_keys = await self.keys("*")
health['total_keys'] = len(all_keys)
# Calculate hit rate
if self.stats['gets'] > 0:
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
# Connection pool info
if self.redis_pool:
health['connection_pool_size'] = self.redis_pool.max_connections
except Exception as e:
logger.error(f"Health check error: {e}")
return health
# Global Redis manager instance
redis_manager = RedisManager()