691 lines
22 KiB
Python
691 lines
22 KiB
Python
"""
|
|
Redis cache manager for high-performance data access.
|
|
"""
|
|
|
|
import asyncio
|
|
import redis.asyncio as redis
|
|
from typing import Any, Optional, List, Dict, Union
|
|
from datetime import datetime, timedelta
|
|
from ..config import config
|
|
from ..utils.logging import get_logger, set_correlation_id
|
|
from ..utils.exceptions import StorageError
|
|
from ..utils.timing import get_current_timestamp
|
|
from .cache_keys import CacheKeys
|
|
from .data_serializer import DataSerializer
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RedisManager:
|
|
"""
|
|
High-performance Redis cache manager for market data.
|
|
|
|
Provides:
|
|
- Connection pooling and management
|
|
- Data serialization and compression
|
|
- TTL management
|
|
- Batch operations
|
|
- Performance monitoring
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize Redis manager"""
|
|
self.redis_pool: Optional[redis.ConnectionPool] = None
|
|
self.redis_client: Optional[redis.Redis] = None
|
|
self.serializer = DataSerializer(use_compression=True)
|
|
self.cache_keys = CacheKeys()
|
|
|
|
# Performance statistics
|
|
self.stats = {
|
|
'gets': 0,
|
|
'sets': 0,
|
|
'deletes': 0,
|
|
'hits': 0,
|
|
'misses': 0,
|
|
'errors': 0,
|
|
'total_data_size': 0,
|
|
'avg_response_time': 0.0
|
|
}
|
|
|
|
logger.info("Redis manager initialized")
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize Redis connection pool"""
|
|
try:
|
|
# Create connection pool
|
|
self.redis_pool = redis.ConnectionPool(
|
|
host=config.redis.host,
|
|
port=config.redis.port,
|
|
password=config.redis.password,
|
|
db=config.redis.db,
|
|
max_connections=config.redis.max_connections,
|
|
socket_timeout=config.redis.socket_timeout,
|
|
socket_connect_timeout=config.redis.socket_connect_timeout,
|
|
decode_responses=False, # We handle bytes directly
|
|
retry_on_timeout=True,
|
|
health_check_interval=30
|
|
)
|
|
|
|
# Create Redis client
|
|
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
|
|
|
|
# Test connection
|
|
await self.redis_client.ping()
|
|
|
|
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Redis connection: {e}")
|
|
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
|
|
|
|
async def close(self) -> None:
|
|
"""Close Redis connections"""
|
|
try:
|
|
if self.redis_client:
|
|
await self.redis_client.close()
|
|
if self.redis_pool:
|
|
await self.redis_pool.disconnect()
|
|
|
|
logger.info("Redis connections closed")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error closing Redis connections: {e}")
|
|
|
|
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
|
"""
|
|
Set value in cache with optional TTL.
|
|
|
|
Args:
|
|
key: Cache key
|
|
value: Value to cache
|
|
ttl: Time to live in seconds (None = use default)
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Serialize value
|
|
serialized_value = self.serializer.serialize(value)
|
|
|
|
# Determine TTL
|
|
if ttl is None:
|
|
ttl = self.cache_keys.get_ttl(key)
|
|
|
|
# Set in Redis
|
|
result = await self.redis_client.setex(key, ttl, serialized_value)
|
|
|
|
# Update statistics
|
|
self.stats['sets'] += 1
|
|
self.stats['total_data_size'] += len(serialized_value)
|
|
|
|
# Update response time
|
|
response_time = asyncio.get_event_loop().time() - start_time
|
|
self._update_avg_response_time(response_time)
|
|
|
|
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error setting cache key {key}: {e}")
|
|
return False
|
|
|
|
async def get(self, key: str) -> Optional[Any]:
|
|
"""
|
|
Get value from cache.
|
|
|
|
Args:
|
|
key: Cache key
|
|
|
|
Returns:
|
|
Any: Cached value or None if not found
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Get from Redis
|
|
serialized_value = await self.redis_client.get(key)
|
|
|
|
# Update statistics
|
|
self.stats['gets'] += 1
|
|
|
|
if serialized_value is None:
|
|
self.stats['misses'] += 1
|
|
logger.debug(f"Cache miss: {key}")
|
|
return None
|
|
|
|
# Deserialize value
|
|
value = self.serializer.deserialize(serialized_value)
|
|
|
|
# Update statistics
|
|
self.stats['hits'] += 1
|
|
|
|
# Update response time
|
|
response_time = asyncio.get_event_loop().time() - start_time
|
|
self._update_avg_response_time(response_time)
|
|
|
|
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
|
|
return value
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error getting cache key {key}: {e}")
|
|
return None
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
"""
|
|
Delete key from cache.
|
|
|
|
Args:
|
|
key: Cache key to delete
|
|
|
|
Returns:
|
|
bool: True if deleted, False otherwise
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
result = await self.redis_client.delete(key)
|
|
|
|
self.stats['deletes'] += 1
|
|
|
|
logger.debug(f"Deleted cache key: {key}")
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error deleting cache key {key}: {e}")
|
|
return False
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""
|
|
Check if key exists in cache.
|
|
|
|
Args:
|
|
key: Cache key to check
|
|
|
|
Returns:
|
|
bool: True if exists, False otherwise
|
|
"""
|
|
try:
|
|
result = await self.redis_client.exists(key)
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking cache key existence {key}: {e}")
|
|
return False
|
|
|
|
async def expire(self, key: str, ttl: int) -> bool:
|
|
"""
|
|
Set expiration time for key.
|
|
|
|
Args:
|
|
key: Cache key
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
result = await self.redis_client.expire(key, ttl)
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error setting expiration for key {key}: {e}")
|
|
return False
|
|
|
|
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
|
|
"""
|
|
Get multiple values from cache.
|
|
|
|
Args:
|
|
keys: List of cache keys
|
|
|
|
Returns:
|
|
List[Optional[Any]]: List of values (None for missing keys)
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Get from Redis
|
|
serialized_values = await self.redis_client.mget(keys)
|
|
|
|
# Deserialize values
|
|
values = []
|
|
for serialized_value in serialized_values:
|
|
if serialized_value is None:
|
|
values.append(None)
|
|
self.stats['misses'] += 1
|
|
else:
|
|
try:
|
|
value = self.serializer.deserialize(serialized_value)
|
|
values.append(value)
|
|
self.stats['hits'] += 1
|
|
except Exception as e:
|
|
logger.warning(f"Error deserializing value: {e}")
|
|
values.append(None)
|
|
self.stats['errors'] += 1
|
|
|
|
# Update statistics
|
|
self.stats['gets'] += len(keys)
|
|
|
|
# Update response time
|
|
response_time = asyncio.get_event_loop().time() - start_time
|
|
self._update_avg_response_time(response_time)
|
|
|
|
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
|
|
return values
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error in multi-get: {e}")
|
|
return [None] * len(keys)
|
|
|
|
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
|
"""
|
|
Set multiple key-value pairs.
|
|
|
|
Args:
|
|
key_value_pairs: Dictionary of key-value pairs
|
|
ttl: Time to live in seconds (None = use default per key)
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
# Serialize all values
|
|
serialized_pairs = {}
|
|
for key, value in key_value_pairs.items():
|
|
serialized_value = self.serializer.serialize(value)
|
|
serialized_pairs[key] = serialized_value
|
|
self.stats['total_data_size'] += len(serialized_value)
|
|
|
|
# Set in Redis
|
|
result = await self.redis_client.mset(serialized_pairs)
|
|
|
|
# Set TTL for each key if specified
|
|
if ttl is not None:
|
|
for key in key_value_pairs.keys():
|
|
await self.redis_client.expire(key, ttl)
|
|
else:
|
|
# Use individual TTLs
|
|
for key in key_value_pairs.keys():
|
|
key_ttl = self.cache_keys.get_ttl(key)
|
|
await self.redis_client.expire(key, key_ttl)
|
|
|
|
self.stats['sets'] += len(key_value_pairs)
|
|
|
|
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error in multi-set: {e}")
|
|
return False
|
|
|
|
async def keys(self, pattern: str) -> List[str]:
|
|
"""
|
|
Get keys matching pattern.
|
|
|
|
Args:
|
|
pattern: Redis pattern (e.g., "hm:*")
|
|
|
|
Returns:
|
|
List[str]: List of matching keys
|
|
"""
|
|
try:
|
|
keys = await self.redis_client.keys(pattern)
|
|
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting keys with pattern {pattern}: {e}")
|
|
return []
|
|
|
|
async def flushdb(self) -> bool:
|
|
"""
|
|
Clear all keys in current database.
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
result = await self.redis_client.flushdb()
|
|
logger.info("Redis database flushed")
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error flushing Redis database: {e}")
|
|
return False
|
|
|
|
async def info(self) -> Dict[str, Any]:
|
|
"""
|
|
Get Redis server information.
|
|
|
|
Returns:
|
|
Dict: Redis server info
|
|
"""
|
|
try:
|
|
info = await self.redis_client.info()
|
|
return info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting Redis info: {e}")
|
|
return {}
|
|
|
|
async def ping(self) -> bool:
|
|
"""
|
|
Ping Redis server.
|
|
|
|
Returns:
|
|
bool: True if server responds, False otherwise
|
|
"""
|
|
try:
|
|
result = await self.redis_client.ping()
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Redis ping failed: {e}")
|
|
return False
|
|
|
|
async def set_heatmap(self, symbol: str, heatmap_data,
|
|
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
|
|
"""
|
|
Cache heatmap data with optimized serialization.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
heatmap_data: Heatmap data to cache
|
|
exchange: Exchange name (None for consolidated)
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
|
|
|
# Use specialized heatmap serialization
|
|
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
|
|
|
|
# Determine TTL
|
|
if ttl is None:
|
|
ttl = self.cache_keys.HEATMAP_TTL
|
|
|
|
# Set in Redis
|
|
result = await self.redis_client.setex(key, ttl, serialized_value)
|
|
|
|
# Update statistics
|
|
self.stats['sets'] += 1
|
|
self.stats['total_data_size'] += len(serialized_value)
|
|
|
|
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error caching heatmap for {symbol}: {e}")
|
|
return False
|
|
|
|
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
|
|
"""
|
|
Get cached heatmap data with optimized deserialization.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
exchange: Exchange name (None for consolidated)
|
|
|
|
Returns:
|
|
HeatmapData: Cached heatmap or None if not found
|
|
"""
|
|
try:
|
|
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
|
|
|
# Get from Redis
|
|
serialized_value = await self.redis_client.get(key)
|
|
|
|
self.stats['gets'] += 1
|
|
|
|
if serialized_value is None:
|
|
self.stats['misses'] += 1
|
|
return None
|
|
|
|
# Use specialized heatmap deserialization
|
|
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
|
|
|
|
self.stats['hits'] += 1
|
|
logger.debug(f"Retrieved heatmap: {key}")
|
|
return heatmap_data
|
|
|
|
except Exception as e:
|
|
self.stats['errors'] += 1
|
|
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
|
|
return None
|
|
|
|
async def cache_orderbook(self, orderbook) -> bool:
|
|
"""
|
|
Cache order book data.
|
|
|
|
Args:
|
|
orderbook: OrderBookSnapshot to cache
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
|
|
return await self.set(key, orderbook)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error caching order book: {e}")
|
|
return False
|
|
|
|
async def get_orderbook(self, symbol: str, exchange: str):
|
|
"""
|
|
Get cached order book data.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
exchange: Exchange name
|
|
|
|
Returns:
|
|
OrderBookSnapshot: Cached order book or None if not found
|
|
"""
|
|
try:
|
|
key = self.cache_keys.orderbook_key(symbol, exchange)
|
|
return await self.get(key)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving order book: {e}")
|
|
return None
|
|
|
|
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
|
|
"""
|
|
Cache metrics data.
|
|
|
|
Args:
|
|
metrics: Metrics data to cache
|
|
symbol: Trading symbol
|
|
exchange: Exchange name
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
key = self.cache_keys.metrics_key(symbol, exchange)
|
|
return await self.set(key, metrics)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error caching metrics: {e}")
|
|
return False
|
|
|
|
async def get_metrics(self, symbol: str, exchange: str):
|
|
"""
|
|
Get cached metrics data.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
exchange: Exchange name
|
|
|
|
Returns:
|
|
Metrics data or None if not found
|
|
"""
|
|
try:
|
|
key = self.cache_keys.metrics_key(symbol, exchange)
|
|
return await self.get(key)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving metrics: {e}")
|
|
return None
|
|
|
|
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
|
|
"""
|
|
Cache exchange status.
|
|
|
|
Args:
|
|
exchange: Exchange name
|
|
status_data: Status data to cache
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
key = self.cache_keys.status_key(exchange)
|
|
return await self.set(key, status_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error caching exchange status: {e}")
|
|
return False
|
|
|
|
async def get_exchange_status(self, exchange: str):
|
|
"""
|
|
Get cached exchange status.
|
|
|
|
Args:
|
|
exchange: Exchange name
|
|
|
|
Returns:
|
|
Status data or None if not found
|
|
"""
|
|
try:
|
|
key = self.cache_keys.status_key(exchange)
|
|
return await self.get(key)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving exchange status: {e}")
|
|
return None
|
|
|
|
async def cleanup_expired_keys(self) -> int:
|
|
"""
|
|
Clean up expired keys (Redis handles this automatically, but we can force it).
|
|
|
|
Returns:
|
|
int: Number of keys cleaned up
|
|
"""
|
|
try:
|
|
# Get all keys
|
|
all_keys = await self.keys("*")
|
|
|
|
# Check which ones are expired
|
|
expired_count = 0
|
|
for key in all_keys:
|
|
ttl = await self.redis_client.ttl(key)
|
|
if ttl == -2: # Key doesn't exist (expired)
|
|
expired_count += 1
|
|
|
|
logger.debug(f"Found {expired_count} expired keys")
|
|
return expired_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up expired keys: {e}")
|
|
return 0
|
|
|
|
def _update_avg_response_time(self, response_time: float) -> None:
|
|
"""Update average response time"""
|
|
total_operations = self.stats['gets'] + self.stats['sets']
|
|
if total_operations > 0:
|
|
self.stats['avg_response_time'] = (
|
|
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
|
|
total_operations
|
|
)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get cache statistics"""
|
|
total_operations = self.stats['gets'] + self.stats['sets']
|
|
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
|
|
|
|
return {
|
|
**self.stats,
|
|
'total_operations': total_operations,
|
|
'hit_rate_percentage': hit_rate,
|
|
'serializer_stats': self.serializer.get_stats()
|
|
}
|
|
|
|
def reset_stats(self) -> None:
|
|
"""Reset cache statistics"""
|
|
self.stats = {
|
|
'gets': 0,
|
|
'sets': 0,
|
|
'deletes': 0,
|
|
'hits': 0,
|
|
'misses': 0,
|
|
'errors': 0,
|
|
'total_data_size': 0,
|
|
'avg_response_time': 0.0
|
|
}
|
|
self.serializer.reset_stats()
|
|
logger.info("Redis manager statistics reset")
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""
|
|
Perform comprehensive health check.
|
|
|
|
Returns:
|
|
Dict: Health check results
|
|
"""
|
|
health = {
|
|
'redis_ping': False,
|
|
'connection_pool_size': 0,
|
|
'memory_usage': 0,
|
|
'connected_clients': 0,
|
|
'total_keys': 0,
|
|
'hit_rate': 0.0,
|
|
'avg_response_time': self.stats['avg_response_time']
|
|
}
|
|
|
|
try:
|
|
# Test ping
|
|
health['redis_ping'] = await self.ping()
|
|
|
|
# Get Redis info
|
|
info = await self.info()
|
|
if info:
|
|
health['memory_usage'] = info.get('used_memory', 0)
|
|
health['connected_clients'] = info.get('connected_clients', 0)
|
|
|
|
# Get key count
|
|
all_keys = await self.keys("*")
|
|
health['total_keys'] = len(all_keys)
|
|
|
|
# Calculate hit rate
|
|
if self.stats['gets'] > 0:
|
|
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
|
|
|
|
# Connection pool info
|
|
if self.redis_pool:
|
|
health['connection_pool_size'] = self.redis_pool.max_connections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Health check error: {e}")
|
|
|
|
return health
|
|
|
|
|
|
# Global Redis manager instance
|
|
redis_manager = RedisManager() |