caching
This commit is contained in:
@ -46,6 +46,9 @@
|
||||
- Write unit tests for Binance connector functionality
|
||||
- _Requirements: 1.1, 1.2, 1.4, 6.2_
|
||||
|
||||
|
||||
|
||||
|
||||
- [ ] 5. Create data processing and normalization engine
|
||||
- Implement data processor for normalizing raw exchange data
|
||||
- Create validation logic for order book and trade data
|
||||
|
13
COBY/caching/__init__.py
Normal file
13
COBY/caching/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
Caching layer for the COBY system.
|
||||
"""
|
||||
|
||||
from .redis_manager import RedisManager
|
||||
from .cache_keys import CacheKeys
|
||||
from .data_serializer import DataSerializer
|
||||
|
||||
__all__ = [
|
||||
'RedisManager',
|
||||
'CacheKeys',
|
||||
'DataSerializer'
|
||||
]
|
278
COBY/caching/cache_keys.py
Normal file
278
COBY/caching/cache_keys.py
Normal file
@ -0,0 +1,278 @@
|
||||
"""
|
||||
Cache key management for Redis operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CacheKeys:
|
||||
"""
|
||||
Centralized cache key management for consistent Redis operations.
|
||||
|
||||
Provides standardized key patterns for different data types.
|
||||
"""
|
||||
|
||||
# Key prefixes
|
||||
ORDERBOOK_PREFIX = "ob"
|
||||
HEATMAP_PREFIX = "hm"
|
||||
TRADE_PREFIX = "tr"
|
||||
METRICS_PREFIX = "mt"
|
||||
STATUS_PREFIX = "st"
|
||||
STATS_PREFIX = "stats"
|
||||
|
||||
# TTL values (seconds)
|
||||
ORDERBOOK_TTL = 60 # 1 minute
|
||||
HEATMAP_TTL = 30 # 30 seconds
|
||||
TRADE_TTL = 300 # 5 minutes
|
||||
METRICS_TTL = 120 # 2 minutes
|
||||
STATUS_TTL = 60 # 1 minute
|
||||
STATS_TTL = 300 # 5 minutes
|
||||
|
||||
@classmethod
|
||||
def orderbook_key(cls, symbol: str, exchange: str) -> str:
|
||||
"""
|
||||
Generate cache key for order book data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.ORDERBOOK_PREFIX}:{exchange}:{symbol}"
|
||||
|
||||
@classmethod
|
||||
def heatmap_key(cls, symbol: str, bucket_size: float = 1.0,
|
||||
exchange: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate cache key for heatmap data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
bucket_size: Price bucket size
|
||||
exchange: Exchange name (None for consolidated)
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
if exchange:
|
||||
return f"{cls.HEATMAP_PREFIX}:{exchange}:{symbol}:{bucket_size}"
|
||||
else:
|
||||
return f"{cls.HEATMAP_PREFIX}:consolidated:{symbol}:{bucket_size}"
|
||||
|
||||
@classmethod
|
||||
def trade_key(cls, symbol: str, exchange: str, trade_id: str) -> str:
|
||||
"""
|
||||
Generate cache key for trade data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
trade_id: Trade identifier
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.TRADE_PREFIX}:{exchange}:{symbol}:{trade_id}"
|
||||
|
||||
@classmethod
|
||||
def metrics_key(cls, symbol: str, exchange: str) -> str:
|
||||
"""
|
||||
Generate cache key for metrics data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.METRICS_PREFIX}:{exchange}:{symbol}"
|
||||
|
||||
@classmethod
|
||||
def status_key(cls, exchange: str) -> str:
|
||||
"""
|
||||
Generate cache key for exchange status.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.STATUS_PREFIX}:{exchange}"
|
||||
|
||||
@classmethod
|
||||
def stats_key(cls, component: str) -> str:
|
||||
"""
|
||||
Generate cache key for component statistics.
|
||||
|
||||
Args:
|
||||
component: Component name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.STATS_PREFIX}:{component}"
|
||||
|
||||
@classmethod
|
||||
def latest_heatmaps_key(cls, symbol: str) -> str:
|
||||
"""
|
||||
Generate cache key for latest heatmaps list.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"{cls.HEATMAP_PREFIX}:latest:{symbol}"
|
||||
|
||||
@classmethod
|
||||
def symbol_list_key(cls, exchange: str) -> str:
|
||||
"""
|
||||
Generate cache key for symbol list.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"symbols:{exchange}"
|
||||
|
||||
@classmethod
|
||||
def price_bucket_key(cls, symbol: str, exchange: str) -> str:
|
||||
"""
|
||||
Generate cache key for price buckets.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"buckets:{exchange}:{symbol}"
|
||||
|
||||
@classmethod
|
||||
def arbitrage_key(cls, symbol: str) -> str:
|
||||
"""
|
||||
Generate cache key for arbitrage opportunities.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
str: Cache key
|
||||
"""
|
||||
return f"arbitrage:{symbol}"
|
||||
|
||||
@classmethod
|
||||
def get_ttl(cls, key: str) -> int:
|
||||
"""
|
||||
Get appropriate TTL for a cache key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
int: TTL in seconds
|
||||
"""
|
||||
if key.startswith(cls.ORDERBOOK_PREFIX):
|
||||
return cls.ORDERBOOK_TTL
|
||||
elif key.startswith(cls.HEATMAP_PREFIX):
|
||||
return cls.HEATMAP_TTL
|
||||
elif key.startswith(cls.TRADE_PREFIX):
|
||||
return cls.TRADE_TTL
|
||||
elif key.startswith(cls.METRICS_PREFIX):
|
||||
return cls.METRICS_TTL
|
||||
elif key.startswith(cls.STATUS_PREFIX):
|
||||
return cls.STATUS_TTL
|
||||
elif key.startswith(cls.STATS_PREFIX):
|
||||
return cls.STATS_TTL
|
||||
else:
|
||||
return 300 # Default 5 minutes
|
||||
|
||||
@classmethod
|
||||
def parse_key(cls, key: str) -> dict:
|
||||
"""
|
||||
Parse cache key to extract components.
|
||||
|
||||
Args:
|
||||
key: Cache key to parse
|
||||
|
||||
Returns:
|
||||
dict: Parsed key components
|
||||
"""
|
||||
parts = key.split(':')
|
||||
|
||||
if len(parts) < 2:
|
||||
return {'type': 'unknown', 'key': key}
|
||||
|
||||
key_type = parts[0]
|
||||
|
||||
if key_type == cls.ORDERBOOK_PREFIX and len(parts) >= 3:
|
||||
return {
|
||||
'type': 'orderbook',
|
||||
'exchange': parts[1],
|
||||
'symbol': parts[2]
|
||||
}
|
||||
elif key_type == cls.HEATMAP_PREFIX and len(parts) >= 4:
|
||||
return {
|
||||
'type': 'heatmap',
|
||||
'exchange': parts[1] if parts[1] != 'consolidated' else None,
|
||||
'symbol': parts[2],
|
||||
'bucket_size': float(parts[3]) if len(parts) > 3 else 1.0
|
||||
}
|
||||
elif key_type == cls.TRADE_PREFIX and len(parts) >= 4:
|
||||
return {
|
||||
'type': 'trade',
|
||||
'exchange': parts[1],
|
||||
'symbol': parts[2],
|
||||
'trade_id': parts[3]
|
||||
}
|
||||
elif key_type == cls.METRICS_PREFIX and len(parts) >= 3:
|
||||
return {
|
||||
'type': 'metrics',
|
||||
'exchange': parts[1],
|
||||
'symbol': parts[2]
|
||||
}
|
||||
elif key_type == cls.STATUS_PREFIX and len(parts) >= 2:
|
||||
return {
|
||||
'type': 'status',
|
||||
'exchange': parts[1]
|
||||
}
|
||||
elif key_type == cls.STATS_PREFIX and len(parts) >= 2:
|
||||
return {
|
||||
'type': 'stats',
|
||||
'component': parts[1]
|
||||
}
|
||||
else:
|
||||
return {'type': 'unknown', 'key': key}
|
||||
|
||||
@classmethod
|
||||
def get_pattern(cls, key_type: str) -> str:
|
||||
"""
|
||||
Get Redis pattern for key type.
|
||||
|
||||
Args:
|
||||
key_type: Type of key
|
||||
|
||||
Returns:
|
||||
str: Redis pattern
|
||||
"""
|
||||
patterns = {
|
||||
'orderbook': f"{cls.ORDERBOOK_PREFIX}:*",
|
||||
'heatmap': f"{cls.HEATMAP_PREFIX}:*",
|
||||
'trade': f"{cls.TRADE_PREFIX}:*",
|
||||
'metrics': f"{cls.METRICS_PREFIX}:*",
|
||||
'status': f"{cls.STATUS_PREFIX}:*",
|
||||
'stats': f"{cls.STATS_PREFIX}:*"
|
||||
}
|
||||
|
||||
return patterns.get(key_type, "*")
|
355
COBY/caching/data_serializer.py
Normal file
355
COBY/caching/data_serializer.py
Normal file
@ -0,0 +1,355 @@
|
||||
"""
|
||||
Data serialization for Redis caching.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import gzip
|
||||
from typing import Any, Union, Dict, List
|
||||
from datetime import datetime
|
||||
from ..models.core import (
|
||||
OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets,
|
||||
OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook
|
||||
)
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.exceptions import ProcessingError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataSerializer:
|
||||
"""
|
||||
Handles serialization and deserialization of data for Redis storage.
|
||||
|
||||
Supports multiple serialization formats:
|
||||
- JSON for simple data
|
||||
- Pickle for complex objects
|
||||
- Compressed formats for large data
|
||||
"""
|
||||
|
||||
def __init__(self, use_compression: bool = True):
|
||||
"""
|
||||
Initialize data serializer.
|
||||
|
||||
Args:
|
||||
use_compression: Whether to use gzip compression
|
||||
"""
|
||||
self.use_compression = use_compression
|
||||
self.serialization_stats = {
|
||||
'serialized': 0,
|
||||
'deserialized': 0,
|
||||
'compression_ratio': 0.0,
|
||||
'errors': 0
|
||||
}
|
||||
|
||||
logger.info(f"Data serializer initialized (compression: {use_compression})")
|
||||
|
||||
def serialize(self, data: Any, format_type: str = 'auto') -> bytes:
|
||||
"""
|
||||
Serialize data for Redis storage.
|
||||
|
||||
Args:
|
||||
data: Data to serialize
|
||||
format_type: Serialization format ('json', 'pickle', 'auto')
|
||||
|
||||
Returns:
|
||||
bytes: Serialized data
|
||||
"""
|
||||
try:
|
||||
# Determine format
|
||||
if format_type == 'auto':
|
||||
format_type = self._determine_format(data)
|
||||
|
||||
# Serialize based on format
|
||||
if format_type == 'json':
|
||||
serialized = self._serialize_json(data)
|
||||
elif format_type == 'pickle':
|
||||
serialized = self._serialize_pickle(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format_type}")
|
||||
|
||||
# Apply compression if enabled
|
||||
if self.use_compression:
|
||||
original_size = len(serialized)
|
||||
serialized = gzip.compress(serialized)
|
||||
compressed_size = len(serialized)
|
||||
|
||||
# Update compression ratio
|
||||
if original_size > 0:
|
||||
ratio = compressed_size / original_size
|
||||
self.serialization_stats['compression_ratio'] = (
|
||||
(self.serialization_stats['compression_ratio'] *
|
||||
self.serialization_stats['serialized'] + ratio) /
|
||||
(self.serialization_stats['serialized'] + 1)
|
||||
)
|
||||
|
||||
self.serialization_stats['serialized'] += 1
|
||||
return serialized
|
||||
|
||||
except Exception as e:
|
||||
self.serialization_stats['errors'] += 1
|
||||
logger.error(f"Serialization error: {e}")
|
||||
raise ProcessingError(f"Serialization failed: {e}", "SERIALIZE_ERROR")
|
||||
|
||||
def deserialize(self, data: bytes, format_type: str = 'auto') -> Any:
|
||||
"""
|
||||
Deserialize data from Redis storage.
|
||||
|
||||
Args:
|
||||
data: Serialized data
|
||||
format_type: Expected format ('json', 'pickle', 'auto')
|
||||
|
||||
Returns:
|
||||
Any: Deserialized data
|
||||
"""
|
||||
try:
|
||||
# Decompress if needed
|
||||
if self.use_compression:
|
||||
try:
|
||||
data = gzip.decompress(data)
|
||||
except gzip.BadGzipFile:
|
||||
# Data might not be compressed
|
||||
pass
|
||||
|
||||
# Determine format if auto
|
||||
if format_type == 'auto':
|
||||
format_type = self._detect_format(data)
|
||||
|
||||
# Deserialize based on format
|
||||
if format_type == 'json':
|
||||
result = self._deserialize_json(data)
|
||||
elif format_type == 'pickle':
|
||||
result = self._deserialize_pickle(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format_type}")
|
||||
|
||||
self.serialization_stats['deserialized'] += 1
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.serialization_stats['errors'] += 1
|
||||
logger.error(f"Deserialization error: {e}")
|
||||
raise ProcessingError(f"Deserialization failed: {e}", "DESERIALIZE_ERROR")
|
||||
|
||||
def _determine_format(self, data: Any) -> str:
|
||||
"""Determine best serialization format for data"""
|
||||
# Use JSON for simple data types
|
||||
if isinstance(data, (dict, list, str, int, float, bool)) or data is None:
|
||||
return 'json'
|
||||
|
||||
# Use pickle for complex objects
|
||||
return 'pickle'
|
||||
|
||||
def _detect_format(self, data: bytes) -> str:
|
||||
"""Detect serialization format from data"""
|
||||
try:
|
||||
# Try JSON first
|
||||
json.loads(data.decode('utf-8'))
|
||||
return 'json'
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# Assume pickle
|
||||
return 'pickle'
|
||||
|
||||
def _serialize_json(self, data: Any) -> bytes:
|
||||
"""Serialize data as JSON"""
|
||||
# Convert complex objects to dictionaries
|
||||
if hasattr(data, '__dict__'):
|
||||
data = self._object_to_dict(data)
|
||||
elif isinstance(data, list):
|
||||
data = [self._object_to_dict(item) if hasattr(item, '__dict__') else item
|
||||
for item in data]
|
||||
|
||||
json_str = json.dumps(data, default=self._json_serializer, ensure_ascii=False)
|
||||
return json_str.encode('utf-8')
|
||||
|
||||
def _deserialize_json(self, data: bytes) -> Any:
|
||||
"""Deserialize JSON data"""
|
||||
json_str = data.decode('utf-8')
|
||||
return json.loads(json_str, object_hook=self._json_deserializer)
|
||||
|
||||
def _serialize_pickle(self, data: Any) -> bytes:
|
||||
"""Serialize data as pickle"""
|
||||
return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def _deserialize_pickle(self, data: bytes) -> Any:
|
||||
"""Deserialize pickle data"""
|
||||
return pickle.loads(data)
|
||||
|
||||
def _object_to_dict(self, obj: Any) -> Dict:
|
||||
"""Convert object to dictionary for JSON serialization"""
|
||||
if isinstance(obj, (OrderBookSnapshot, TradeEvent, HeatmapData,
|
||||
PriceBuckets, OrderBookMetrics, ImbalanceMetrics,
|
||||
ConsolidatedOrderBook)):
|
||||
result = {
|
||||
'__type__': obj.__class__.__name__,
|
||||
'__data__': {}
|
||||
}
|
||||
|
||||
# Convert object attributes
|
||||
for key, value in obj.__dict__.items():
|
||||
if isinstance(value, datetime):
|
||||
result['__data__'][key] = {
|
||||
'__datetime__': value.isoformat()
|
||||
}
|
||||
elif isinstance(value, list):
|
||||
result['__data__'][key] = [
|
||||
self._object_to_dict(item) if hasattr(item, '__dict__') else item
|
||||
for item in value
|
||||
]
|
||||
elif hasattr(value, '__dict__'):
|
||||
result['__data__'][key] = self._object_to_dict(value)
|
||||
else:
|
||||
result['__data__'][key] = value
|
||||
|
||||
return result
|
||||
else:
|
||||
return obj.__dict__ if hasattr(obj, '__dict__') else obj
|
||||
|
||||
def _json_serializer(self, obj: Any) -> Any:
|
||||
"""Custom JSON serializer for special types"""
|
||||
if isinstance(obj, datetime):
|
||||
return {'__datetime__': obj.isoformat()}
|
||||
elif hasattr(obj, '__dict__'):
|
||||
return self._object_to_dict(obj)
|
||||
else:
|
||||
return str(obj)
|
||||
|
||||
def _json_deserializer(self, obj: Dict) -> Any:
|
||||
"""Custom JSON deserializer for special types"""
|
||||
if '__datetime__' in obj:
|
||||
return datetime.fromisoformat(obj['__datetime__'])
|
||||
elif '__type__' in obj and '__data__' in obj:
|
||||
return self._reconstruct_object(obj['__type__'], obj['__data__'])
|
||||
else:
|
||||
return obj
|
||||
|
||||
def _reconstruct_object(self, type_name: str, data: Dict) -> Any:
|
||||
"""Reconstruct object from serialized data"""
|
||||
# Import required classes
|
||||
from ..models.core import (
|
||||
OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets,
|
||||
OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook,
|
||||
PriceLevel, HeatmapPoint
|
||||
)
|
||||
|
||||
# Map type names to classes
|
||||
type_map = {
|
||||
'OrderBookSnapshot': OrderBookSnapshot,
|
||||
'TradeEvent': TradeEvent,
|
||||
'HeatmapData': HeatmapData,
|
||||
'PriceBuckets': PriceBuckets,
|
||||
'OrderBookMetrics': OrderBookMetrics,
|
||||
'ImbalanceMetrics': ImbalanceMetrics,
|
||||
'ConsolidatedOrderBook': ConsolidatedOrderBook,
|
||||
'PriceLevel': PriceLevel,
|
||||
'HeatmapPoint': HeatmapPoint
|
||||
}
|
||||
|
||||
if type_name in type_map:
|
||||
cls = type_map[type_name]
|
||||
|
||||
# Recursively deserialize nested objects
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict) and '__datetime__' in value:
|
||||
processed_data[key] = datetime.fromisoformat(value['__datetime__'])
|
||||
elif isinstance(value, dict) and '__type__' in value:
|
||||
processed_data[key] = self._reconstruct_object(
|
||||
value['__type__'], value['__data__']
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
processed_data[key] = [
|
||||
self._reconstruct_object(item['__type__'], item['__data__'])
|
||||
if isinstance(item, dict) and '__type__' in item
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
processed_data[key] = value
|
||||
|
||||
try:
|
||||
return cls(**processed_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct {type_name}: {e}")
|
||||
return processed_data
|
||||
else:
|
||||
logger.warning(f"Unknown type for reconstruction: {type_name}")
|
||||
return data
|
||||
|
||||
def serialize_heatmap(self, heatmap: HeatmapData) -> bytes:
|
||||
"""Specialized serialization for heatmap data"""
|
||||
try:
|
||||
# Create optimized representation
|
||||
heatmap_dict = {
|
||||
'symbol': heatmap.symbol,
|
||||
'timestamp': heatmap.timestamp.isoformat(),
|
||||
'bucket_size': heatmap.bucket_size,
|
||||
'points': [
|
||||
{
|
||||
'p': point.price, # price
|
||||
'v': point.volume, # volume
|
||||
'i': point.intensity, # intensity
|
||||
's': point.side # side
|
||||
}
|
||||
for point in heatmap.data
|
||||
]
|
||||
}
|
||||
|
||||
return self.serialize(heatmap_dict, 'json')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Heatmap serialization error: {e}")
|
||||
# Fallback to standard serialization
|
||||
return self.serialize(heatmap, 'pickle')
|
||||
|
||||
def deserialize_heatmap(self, data: bytes) -> HeatmapData:
|
||||
"""Specialized deserialization for heatmap data"""
|
||||
try:
|
||||
# Try optimized format first
|
||||
heatmap_dict = self.deserialize(data, 'json')
|
||||
|
||||
if isinstance(heatmap_dict, dict) and 'points' in heatmap_dict:
|
||||
from ..models.core import HeatmapData, HeatmapPoint
|
||||
|
||||
# Reconstruct heatmap points
|
||||
points = []
|
||||
for point_data in heatmap_dict['points']:
|
||||
point = HeatmapPoint(
|
||||
price=point_data['p'],
|
||||
volume=point_data['v'],
|
||||
intensity=point_data['i'],
|
||||
side=point_data['s']
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
# Create heatmap
|
||||
heatmap = HeatmapData(
|
||||
symbol=heatmap_dict['symbol'],
|
||||
timestamp=datetime.fromisoformat(heatmap_dict['timestamp']),
|
||||
bucket_size=heatmap_dict['bucket_size']
|
||||
)
|
||||
heatmap.data = points
|
||||
|
||||
return heatmap
|
||||
else:
|
||||
# Fallback to standard deserialization
|
||||
return self.deserialize(data, 'pickle')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Heatmap deserialization error: {e}")
|
||||
# Final fallback
|
||||
return self.deserialize(data, 'pickle')
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get serialization statistics"""
|
||||
return self.serialization_stats.copy()
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset serialization statistics"""
|
||||
self.serialization_stats = {
|
||||
'serialized': 0,
|
||||
'deserialized': 0,
|
||||
'compression_ratio': 0.0,
|
||||
'errors': 0
|
||||
}
|
||||
logger.info("Serialization statistics reset")
|
691
COBY/caching/redis_manager.py
Normal file
691
COBY/caching/redis_manager.py
Normal file
@ -0,0 +1,691 @@
|
||||
"""
|
||||
Redis cache manager for high-performance data access.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import redis.asyncio as redis
|
||||
from typing import Any, Optional, List, Dict, Union
|
||||
from datetime import datetime, timedelta
|
||||
from ..config import config
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import StorageError
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from .cache_keys import CacheKeys
|
||||
from .data_serializer import DataSerializer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisManager:
|
||||
"""
|
||||
High-performance Redis cache manager for market data.
|
||||
|
||||
Provides:
|
||||
- Connection pooling and management
|
||||
- Data serialization and compression
|
||||
- TTL management
|
||||
- Batch operations
|
||||
- Performance monitoring
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Redis manager"""
|
||||
self.redis_pool: Optional[redis.ConnectionPool] = None
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
self.serializer = DataSerializer(use_compression=True)
|
||||
self.cache_keys = CacheKeys()
|
||||
|
||||
# Performance statistics
|
||||
self.stats = {
|
||||
'gets': 0,
|
||||
'sets': 0,
|
||||
'deletes': 0,
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'errors': 0,
|
||||
'total_data_size': 0,
|
||||
'avg_response_time': 0.0
|
||||
}
|
||||
|
||||
logger.info("Redis manager initialized")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Redis connection pool"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.redis_pool = redis.ConnectionPool(
|
||||
host=config.redis.host,
|
||||
port=config.redis.port,
|
||||
password=config.redis.password,
|
||||
db=config.redis.db,
|
||||
max_connections=config.redis.max_connections,
|
||||
socket_timeout=config.redis.socket_timeout,
|
||||
socket_connect_timeout=config.redis.socket_connect_timeout,
|
||||
decode_responses=False, # We handle bytes directly
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
# Create Redis client
|
||||
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
|
||||
|
||||
# Test connection
|
||||
await self.redis_client.ping()
|
||||
|
||||
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis connection: {e}")
|
||||
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connections"""
|
||||
try:
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
if self.redis_pool:
|
||||
await self.redis_pool.disconnect()
|
||||
|
||||
logger.info("Redis connections closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Redis connections: {e}")
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in cache with optional TTL.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (None = use default)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Serialize value
|
||||
serialized_value = self.serializer.serialize(value)
|
||||
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
ttl = self.cache_keys.get_ttl(key)
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['sets'] += 1
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error setting cache key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Any: Cached value or None if not found
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Get from Redis
|
||||
serialized_value = await self.redis_client.get(key)
|
||||
|
||||
# Update statistics
|
||||
self.stats['gets'] += 1
|
||||
|
||||
if serialized_value is None:
|
||||
self.stats['misses'] += 1
|
||||
logger.debug(f"Cache miss: {key}")
|
||||
return None
|
||||
|
||||
# Deserialize value
|
||||
value = self.serializer.deserialize(serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['hits'] += 1
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error getting cache key {key}: {e}")
|
||||
return None
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
bool: True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
result = await self.redis_client.delete(key)
|
||||
|
||||
self.stats['deletes'] += 1
|
||||
|
||||
logger.debug(f"Deleted cache key: {key}")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error deleting cache key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
bool: True if exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.exists(key)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cache key existence {key}: {e}")
|
||||
return False
|
||||
|
||||
async def expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set expiration time for key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.expire(key, ttl)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting expiration for key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
|
||||
"""
|
||||
Get multiple values from cache.
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
List[Optional[Any]]: List of values (None for missing keys)
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Get from Redis
|
||||
serialized_values = await self.redis_client.mget(keys)
|
||||
|
||||
# Deserialize values
|
||||
values = []
|
||||
for serialized_value in serialized_values:
|
||||
if serialized_value is None:
|
||||
values.append(None)
|
||||
self.stats['misses'] += 1
|
||||
else:
|
||||
try:
|
||||
value = self.serializer.deserialize(serialized_value)
|
||||
values.append(value)
|
||||
self.stats['hits'] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deserializing value: {e}")
|
||||
values.append(None)
|
||||
self.stats['errors'] += 1
|
||||
|
||||
# Update statistics
|
||||
self.stats['gets'] += len(keys)
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
|
||||
return values
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error in multi-get: {e}")
|
||||
return [None] * len(keys)
|
||||
|
||||
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set multiple key-value pairs.
|
||||
|
||||
Args:
|
||||
key_value_pairs: Dictionary of key-value pairs
|
||||
ttl: Time to live in seconds (None = use default per key)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
# Serialize all values
|
||||
serialized_pairs = {}
|
||||
for key, value in key_value_pairs.items():
|
||||
serialized_value = self.serializer.serialize(value)
|
||||
serialized_pairs[key] = serialized_value
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.mset(serialized_pairs)
|
||||
|
||||
# Set TTL for each key if specified
|
||||
if ttl is not None:
|
||||
for key in key_value_pairs.keys():
|
||||
await self.redis_client.expire(key, ttl)
|
||||
else:
|
||||
# Use individual TTLs
|
||||
for key in key_value_pairs.keys():
|
||||
key_ttl = self.cache_keys.get_ttl(key)
|
||||
await self.redis_client.expire(key, key_ttl)
|
||||
|
||||
self.stats['sets'] += len(key_value_pairs)
|
||||
|
||||
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error in multi-set: {e}")
|
||||
return False
|
||||
|
||||
async def keys(self, pattern: str) -> List[str]:
|
||||
"""
|
||||
Get keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Redis pattern (e.g., "hm:*")
|
||||
|
||||
Returns:
|
||||
List[str]: List of matching keys
|
||||
"""
|
||||
try:
|
||||
keys = await self.redis_client.keys(pattern)
|
||||
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys with pattern {pattern}: {e}")
|
||||
return []
|
||||
|
||||
async def flushdb(self) -> bool:
|
||||
"""
|
||||
Clear all keys in current database.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.flushdb()
|
||||
logger.info("Redis database flushed")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error flushing Redis database: {e}")
|
||||
return False
|
||||
|
||||
async def info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Redis server information.
|
||||
|
||||
Returns:
|
||||
Dict: Redis server info
|
||||
"""
|
||||
try:
|
||||
info = await self.redis_client.info()
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis info: {e}")
|
||||
return {}
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""
|
||||
Ping Redis server.
|
||||
|
||||
Returns:
|
||||
bool: True if server responds, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.ping()
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis ping failed: {e}")
|
||||
return False
|
||||
|
||||
async def set_heatmap(self, symbol: str, heatmap_data,
|
||||
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Cache heatmap data with optimized serialization.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
heatmap_data: Heatmap data to cache
|
||||
exchange: Exchange name (None for consolidated)
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
||||
|
||||
# Use specialized heatmap serialization
|
||||
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
|
||||
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
ttl = self.cache_keys.HEATMAP_TTL
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['sets'] += 1
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error caching heatmap for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
|
||||
"""
|
||||
Get cached heatmap data with optimized deserialization.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name (None for consolidated)
|
||||
|
||||
Returns:
|
||||
HeatmapData: Cached heatmap or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
||||
|
||||
# Get from Redis
|
||||
serialized_value = await self.redis_client.get(key)
|
||||
|
||||
self.stats['gets'] += 1
|
||||
|
||||
if serialized_value is None:
|
||||
self.stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# Use specialized heatmap deserialization
|
||||
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
|
||||
|
||||
self.stats['hits'] += 1
|
||||
logger.debug(f"Retrieved heatmap: {key}")
|
||||
return heatmap_data
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def cache_orderbook(self, orderbook) -> bool:
|
||||
"""
|
||||
Cache order book data.
|
||||
|
||||
Args:
|
||||
orderbook: OrderBookSnapshot to cache
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
|
||||
return await self.set(key, orderbook)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching order book: {e}")
|
||||
return False
|
||||
|
||||
async def get_orderbook(self, symbol: str, exchange: str):
|
||||
"""
|
||||
Get cached order book data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
OrderBookSnapshot: Cached order book or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.orderbook_key(symbol, exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving order book: {e}")
|
||||
return None
|
||||
|
||||
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
|
||||
"""
|
||||
Cache metrics data.
|
||||
|
||||
Args:
|
||||
metrics: Metrics data to cache
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
||||
return await self.set(key, metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching metrics: {e}")
|
||||
return False
|
||||
|
||||
async def get_metrics(self, symbol: str, exchange: str):
|
||||
"""
|
||||
Get cached metrics data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Metrics data or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving metrics: {e}")
|
||||
return None
|
||||
|
||||
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
|
||||
"""
|
||||
Cache exchange status.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
status_data: Status data to cache
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.status_key(exchange)
|
||||
return await self.set(key, status_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching exchange status: {e}")
|
||||
return False
|
||||
|
||||
async def get_exchange_status(self, exchange: str):
|
||||
"""
|
||||
Get cached exchange status.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Status data or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.status_key(exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving exchange status: {e}")
|
||||
return None
|
||||
|
||||
async def cleanup_expired_keys(self) -> int:
|
||||
"""
|
||||
Clean up expired keys (Redis handles this automatically, but we can force it).
|
||||
|
||||
Returns:
|
||||
int: Number of keys cleaned up
|
||||
"""
|
||||
try:
|
||||
# Get all keys
|
||||
all_keys = await self.keys("*")
|
||||
|
||||
# Check which ones are expired
|
||||
expired_count = 0
|
||||
for key in all_keys:
|
||||
ttl = await self.redis_client.ttl(key)
|
||||
if ttl == -2: # Key doesn't exist (expired)
|
||||
expired_count += 1
|
||||
|
||||
logger.debug(f"Found {expired_count} expired keys")
|
||||
return expired_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up expired keys: {e}")
|
||||
return 0
|
||||
|
||||
def _update_avg_response_time(self, response_time: float) -> None:
|
||||
"""Update average response time"""
|
||||
total_operations = self.stats['gets'] + self.stats['sets']
|
||||
if total_operations > 0:
|
||||
self.stats['avg_response_time'] = (
|
||||
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
|
||||
total_operations
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_operations = self.stats['gets'] + self.stats['sets']
|
||||
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
'total_operations': total_operations,
|
||||
'hit_rate_percentage': hit_rate,
|
||||
'serializer_stats': self.serializer.get_stats()
|
||||
}
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics"""
|
||||
self.stats = {
|
||||
'gets': 0,
|
||||
'sets': 0,
|
||||
'deletes': 0,
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'errors': 0,
|
||||
'total_data_size': 0,
|
||||
'avg_response_time': 0.0
|
||||
}
|
||||
self.serializer.reset_stats()
|
||||
logger.info("Redis manager statistics reset")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform comprehensive health check.
|
||||
|
||||
Returns:
|
||||
Dict: Health check results
|
||||
"""
|
||||
health = {
|
||||
'redis_ping': False,
|
||||
'connection_pool_size': 0,
|
||||
'memory_usage': 0,
|
||||
'connected_clients': 0,
|
||||
'total_keys': 0,
|
||||
'hit_rate': 0.0,
|
||||
'avg_response_time': self.stats['avg_response_time']
|
||||
}
|
||||
|
||||
try:
|
||||
# Test ping
|
||||
health['redis_ping'] = await self.ping()
|
||||
|
||||
# Get Redis info
|
||||
info = await self.info()
|
||||
if info:
|
||||
health['memory_usage'] = info.get('used_memory', 0)
|
||||
health['connected_clients'] = info.get('connected_clients', 0)
|
||||
|
||||
# Get key count
|
||||
all_keys = await self.keys("*")
|
||||
health['total_keys'] = len(all_keys)
|
||||
|
||||
# Calculate hit rate
|
||||
if self.stats['gets'] > 0:
|
||||
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
|
||||
|
||||
# Connection pool info
|
||||
if self.redis_pool:
|
||||
health['connection_pool_size'] = self.redis_pool.max_connections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
return health
|
||||
|
||||
|
||||
# Global Redis manager instance
|
||||
redis_manager = RedisManager()
|
347
COBY/tests/test_redis_manager.py
Normal file
347
COBY/tests/test_redis_manager.py
Normal file
@ -0,0 +1,347 @@
|
||||
"""
|
||||
Tests for Redis caching system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from ..caching.redis_manager import RedisManager
|
||||
from ..caching.cache_keys import CacheKeys
|
||||
from ..caching.data_serializer import DataSerializer
|
||||
from ..models.core import OrderBookSnapshot, HeatmapData, PriceLevel, HeatmapPoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def redis_manager():
|
||||
"""Create and initialize Redis manager for testing"""
|
||||
manager = RedisManager()
|
||||
await manager.initialize()
|
||||
yield manager
|
||||
await manager.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_keys():
|
||||
"""Create cache keys helper"""
|
||||
return CacheKeys()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_serializer():
|
||||
"""Create data serializer"""
|
||||
return DataSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_orderbook():
|
||||
"""Create sample order book for testing"""
|
||||
return OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange="binance",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[
|
||||
PriceLevel(price=50000.0, size=1.5),
|
||||
PriceLevel(price=49999.0, size=2.0)
|
||||
],
|
||||
asks=[
|
||||
PriceLevel(price=50001.0, size=1.0),
|
||||
PriceLevel(price=50002.0, size=1.5)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_heatmap():
|
||||
"""Create sample heatmap for testing"""
|
||||
heatmap = HeatmapData(
|
||||
symbol="BTCUSDT",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bucket_size=1.0
|
||||
)
|
||||
|
||||
# Add some sample points
|
||||
heatmap.data = [
|
||||
HeatmapPoint(price=50000.0, volume=1.5, intensity=0.8, side='bid'),
|
||||
HeatmapPoint(price=50001.0, volume=1.0, intensity=0.6, side='ask'),
|
||||
HeatmapPoint(price=49999.0, volume=2.0, intensity=1.0, side='bid'),
|
||||
HeatmapPoint(price=50002.0, volume=1.5, intensity=0.7, side='ask')
|
||||
]
|
||||
|
||||
return heatmap
|
||||
|
||||
|
||||
class TestCacheKeys:
|
||||
"""Test cases for CacheKeys"""
|
||||
|
||||
def test_orderbook_key_generation(self, cache_keys):
|
||||
"""Test order book key generation"""
|
||||
key = cache_keys.orderbook_key("BTCUSDT", "binance")
|
||||
assert key == "ob:binance:BTCUSDT"
|
||||
|
||||
def test_heatmap_key_generation(self, cache_keys):
|
||||
"""Test heatmap key generation"""
|
||||
# Exchange-specific heatmap
|
||||
key1 = cache_keys.heatmap_key("BTCUSDT", 1.0, "binance")
|
||||
assert key1 == "hm:binance:BTCUSDT:1.0"
|
||||
|
||||
# Consolidated heatmap
|
||||
key2 = cache_keys.heatmap_key("BTCUSDT", 1.0)
|
||||
assert key2 == "hm:consolidated:BTCUSDT:1.0"
|
||||
|
||||
def test_ttl_determination(self, cache_keys):
|
||||
"""Test TTL determination for different key types"""
|
||||
ob_key = cache_keys.orderbook_key("BTCUSDT", "binance")
|
||||
hm_key = cache_keys.heatmap_key("BTCUSDT", 1.0)
|
||||
|
||||
assert cache_keys.get_ttl(ob_key) == cache_keys.ORDERBOOK_TTL
|
||||
assert cache_keys.get_ttl(hm_key) == cache_keys.HEATMAP_TTL
|
||||
|
||||
def test_key_parsing(self, cache_keys):
|
||||
"""Test cache key parsing"""
|
||||
ob_key = cache_keys.orderbook_key("BTCUSDT", "binance")
|
||||
parsed = cache_keys.parse_key(ob_key)
|
||||
|
||||
assert parsed['type'] == 'orderbook'
|
||||
assert parsed['exchange'] == 'binance'
|
||||
assert parsed['symbol'] == 'BTCUSDT'
|
||||
|
||||
|
||||
class TestDataSerializer:
|
||||
"""Test cases for DataSerializer"""
|
||||
|
||||
def test_simple_data_serialization(self, data_serializer):
|
||||
"""Test serialization of simple data types"""
|
||||
test_data = {
|
||||
'string': 'test',
|
||||
'number': 42,
|
||||
'float': 3.14,
|
||||
'boolean': True,
|
||||
'list': [1, 2, 3],
|
||||
'nested': {'key': 'value'}
|
||||
}
|
||||
|
||||
# Serialize and deserialize
|
||||
serialized = data_serializer.serialize(test_data)
|
||||
deserialized = data_serializer.deserialize(serialized)
|
||||
|
||||
assert deserialized == test_data
|
||||
|
||||
def test_orderbook_serialization(self, data_serializer, sample_orderbook):
|
||||
"""Test order book serialization"""
|
||||
# Serialize and deserialize
|
||||
serialized = data_serializer.serialize(sample_orderbook)
|
||||
deserialized = data_serializer.deserialize(serialized)
|
||||
|
||||
assert isinstance(deserialized, OrderBookSnapshot)
|
||||
assert deserialized.symbol == sample_orderbook.symbol
|
||||
assert deserialized.exchange == sample_orderbook.exchange
|
||||
assert len(deserialized.bids) == len(sample_orderbook.bids)
|
||||
assert len(deserialized.asks) == len(sample_orderbook.asks)
|
||||
|
||||
def test_heatmap_serialization(self, data_serializer, sample_heatmap):
|
||||
"""Test heatmap serialization"""
|
||||
# Test specialized heatmap serialization
|
||||
serialized = data_serializer.serialize_heatmap(sample_heatmap)
|
||||
deserialized = data_serializer.deserialize_heatmap(serialized)
|
||||
|
||||
assert isinstance(deserialized, HeatmapData)
|
||||
assert deserialized.symbol == sample_heatmap.symbol
|
||||
assert deserialized.bucket_size == sample_heatmap.bucket_size
|
||||
assert len(deserialized.data) == len(sample_heatmap.data)
|
||||
|
||||
# Check first point
|
||||
original_point = sample_heatmap.data[0]
|
||||
deserialized_point = deserialized.data[0]
|
||||
assert deserialized_point.price == original_point.price
|
||||
assert deserialized_point.volume == original_point.volume
|
||||
assert deserialized_point.side == original_point.side
|
||||
|
||||
|
||||
class TestRedisManager:
|
||||
"""Test cases for RedisManager"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_set_get(self, redis_manager):
|
||||
"""Test basic set and get operations"""
|
||||
# Set a simple value
|
||||
key = "test:basic"
|
||||
value = {"test": "data", "number": 42}
|
||||
|
||||
success = await redis_manager.set(key, value, ttl=60)
|
||||
assert success is True
|
||||
|
||||
# Get the value back
|
||||
retrieved = await redis_manager.get(key)
|
||||
assert retrieved == value
|
||||
|
||||
# Clean up
|
||||
await redis_manager.delete(key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orderbook_caching(self, redis_manager, sample_orderbook):
|
||||
"""Test order book caching"""
|
||||
# Cache order book
|
||||
success = await redis_manager.cache_orderbook(sample_orderbook)
|
||||
assert success is True
|
||||
|
||||
# Retrieve order book
|
||||
retrieved = await redis_manager.get_orderbook(
|
||||
sample_orderbook.symbol,
|
||||
sample_orderbook.exchange
|
||||
)
|
||||
|
||||
assert retrieved is not None
|
||||
assert isinstance(retrieved, OrderBookSnapshot)
|
||||
assert retrieved.symbol == sample_orderbook.symbol
|
||||
assert retrieved.exchange == sample_orderbook.exchange
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heatmap_caching(self, redis_manager, sample_heatmap):
|
||||
"""Test heatmap caching"""
|
||||
# Cache heatmap
|
||||
success = await redis_manager.set_heatmap(
|
||||
sample_heatmap.symbol,
|
||||
sample_heatmap,
|
||||
exchange="binance"
|
||||
)
|
||||
assert success is True
|
||||
|
||||
# Retrieve heatmap
|
||||
retrieved = await redis_manager.get_heatmap(
|
||||
sample_heatmap.symbol,
|
||||
exchange="binance"
|
||||
)
|
||||
|
||||
assert retrieved is not None
|
||||
assert isinstance(retrieved, HeatmapData)
|
||||
assert retrieved.symbol == sample_heatmap.symbol
|
||||
assert len(retrieved.data) == len(sample_heatmap.data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_operations(self, redis_manager):
|
||||
"""Test multi-get and multi-set operations"""
|
||||
# Prepare test data
|
||||
test_data = {
|
||||
"test:multi1": {"value": 1},
|
||||
"test:multi2": {"value": 2},
|
||||
"test:multi3": {"value": 3}
|
||||
}
|
||||
|
||||
# Multi-set
|
||||
success = await redis_manager.mset(test_data, ttl=60)
|
||||
assert success is True
|
||||
|
||||
# Multi-get
|
||||
keys = list(test_data.keys())
|
||||
values = await redis_manager.mget(keys)
|
||||
|
||||
assert len(values) == 3
|
||||
assert all(v is not None for v in values)
|
||||
|
||||
# Verify values
|
||||
for i, key in enumerate(keys):
|
||||
assert values[i] == test_data[key]
|
||||
|
||||
# Clean up
|
||||
for key in keys:
|
||||
await redis_manager.delete(key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_expiration(self, redis_manager):
|
||||
"""Test key expiration"""
|
||||
key = "test:expiration"
|
||||
value = {"expires": "soon"}
|
||||
|
||||
# Set with short TTL
|
||||
success = await redis_manager.set(key, value, ttl=1)
|
||||
assert success is True
|
||||
|
||||
# Should exist immediately
|
||||
exists = await redis_manager.exists(key)
|
||||
assert exists is True
|
||||
|
||||
# Wait for expiration
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Should not exist after expiration
|
||||
exists = await redis_manager.exists(key)
|
||||
assert exists is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self, redis_manager):
|
||||
"""Test cache miss behavior"""
|
||||
# Try to get non-existent key
|
||||
value = await redis_manager.get("test:nonexistent")
|
||||
assert value is None
|
||||
|
||||
# Check statistics
|
||||
stats = redis_manager.get_stats()
|
||||
assert stats['misses'] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, redis_manager):
|
||||
"""Test Redis health check"""
|
||||
health = await redis_manager.health_check()
|
||||
|
||||
assert isinstance(health, dict)
|
||||
assert 'redis_ping' in health
|
||||
assert 'total_keys' in health
|
||||
assert 'hit_rate' in health
|
||||
|
||||
# Should be able to ping
|
||||
assert health['redis_ping'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_statistics_tracking(self, redis_manager):
|
||||
"""Test statistics tracking"""
|
||||
# Reset stats
|
||||
redis_manager.reset_stats()
|
||||
|
||||
# Perform some operations
|
||||
await redis_manager.set("test:stats1", {"data": 1})
|
||||
await redis_manager.set("test:stats2", {"data": 2})
|
||||
await redis_manager.get("test:stats1")
|
||||
await redis_manager.get("test:nonexistent")
|
||||
|
||||
# Check statistics
|
||||
stats = redis_manager.get_stats()
|
||||
|
||||
assert stats['sets'] >= 2
|
||||
assert stats['gets'] >= 2
|
||||
assert stats['hits'] >= 1
|
||||
assert stats['misses'] >= 1
|
||||
assert stats['total_operations'] >= 4
|
||||
|
||||
# Clean up
|
||||
await redis_manager.delete("test:stats1")
|
||||
await redis_manager.delete("test:stats2")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run simple tests
|
||||
async def simple_test():
|
||||
manager = RedisManager()
|
||||
await manager.initialize()
|
||||
|
||||
# Test basic operations
|
||||
success = await manager.set("test", {"simple": "test"}, ttl=60)
|
||||
print(f"Set operation: {'SUCCESS' if success else 'FAILED'}")
|
||||
|
||||
value = await manager.get("test")
|
||||
print(f"Get operation: {'SUCCESS' if value else 'FAILED'}")
|
||||
|
||||
# Test ping
|
||||
ping_result = await manager.ping()
|
||||
print(f"Ping test: {'SUCCESS' if ping_result else 'FAILED'}")
|
||||
|
||||
# Get statistics
|
||||
stats = manager.get_stats()
|
||||
print(f"Statistics: {stats}")
|
||||
|
||||
# Clean up
|
||||
await manager.delete("test")
|
||||
await manager.close()
|
||||
|
||||
print("Simple Redis test completed")
|
||||
|
||||
asyncio.run(simple_test())
|
Reference in New Issue
Block a user