18: tests, fixes
This commit is contained in:
@ -163,6 +163,8 @@
|
||||
|
||||
- [x] 13. Implement remaining exchange connectors (Bybit, OKX, Huobi)
|
||||
- Create Bybit WebSocket connector with unified trading account support
|
||||
|
||||
|
||||
- Implement OKX connector with their V5 API WebSocket streams
|
||||
- Add Huobi Global connector with proper symbol mapping
|
||||
- Ensure all connectors follow the same interface and error handling patterns
|
||||
|
@ -3,13 +3,7 @@ API layer for the COBY system.
|
||||
"""
|
||||
|
||||
from .rest_api import create_app
|
||||
from .websocket_server import WebSocketServer
|
||||
from .rate_limiter import RateLimiter
|
||||
from .response_formatter import ResponseFormatter
|
||||
|
||||
__all__ = [
|
||||
'create_app',
|
||||
'WebSocketServer',
|
||||
'RateLimiter',
|
||||
'ResponseFormatter'
|
||||
'create_app'
|
||||
]
|
@ -1,183 +1,35 @@
|
||||
"""
|
||||
Rate limiting for API endpoints.
|
||||
Simple rate limiter for API requests.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from collections import defaultdict, deque
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Token bucket rate limiter for API endpoints.
|
||||
|
||||
Provides per-client rate limiting with configurable limits.
|
||||
"""
|
||||
"""Simple rate limiter implementation"""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum requests per minute
|
||||
burst_size: Maximum burst requests
|
||||
"""
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.burst_size = burst_size
|
||||
self.refill_rate = requests_per_minute / 60.0 # tokens per second
|
||||
|
||||
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
|
||||
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
|
||||
'tokens': float(burst_size),
|
||||
'last_refill': time.time()
|
||||
})
|
||||
|
||||
# Request history for monitoring
|
||||
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||
|
||||
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
|
||||
self.requests: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
|
||||
"""
|
||||
Check if request is allowed for client.
|
||||
def is_allowed(self, client_id: str) -> bool:
|
||||
"""Check if request is allowed for client"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
Args:
|
||||
client_id: Client identifier (IP, user ID, etc.)
|
||||
tokens_requested: Number of tokens requested
|
||||
|
||||
Returns:
|
||||
bool: True if request is allowed, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
bucket = self.buckets[client_id]
|
||||
# Clean old requests
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_elapsed = current_time - bucket['last_refill']
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
|
||||
# Update bucket
|
||||
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||
bucket['last_refill'] = current_time
|
||||
|
||||
# Check if enough tokens available
|
||||
if bucket['tokens'] >= tokens_requested:
|
||||
bucket['tokens'] -= tokens_requested
|
||||
|
||||
# Record successful request
|
||||
self.request_history[client_id].append(current_time)
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Rate limit exceeded for client {client_id}")
|
||||
# Check rate limit
|
||||
if len(self.requests[client_id]) >= self.requests_per_minute:
|
||||
return False
|
||||
|
||||
def get_remaining_tokens(self, client_id: str) -> float:
|
||||
"""
|
||||
Get remaining tokens for client.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
float: Number of remaining tokens
|
||||
"""
|
||||
current_time = time.time()
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Calculate current tokens (with refill)
|
||||
time_elapsed = current_time - bucket['last_refill']
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||
|
||||
return current_tokens
|
||||
|
||||
def get_reset_time(self, client_id: str) -> float:
|
||||
"""
|
||||
Get time until bucket is fully refilled.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
float: Seconds until full refill
|
||||
"""
|
||||
remaining_tokens = self.get_remaining_tokens(client_id)
|
||||
tokens_needed = self.burst_size - remaining_tokens
|
||||
|
||||
if tokens_needed <= 0:
|
||||
return 0.0
|
||||
|
||||
return tokens_needed / self.refill_rate
|
||||
|
||||
def get_client_stats(self, client_id: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get statistics for a client.
|
||||
|
||||
Args:
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
Dict: Client statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
history = self.request_history[client_id]
|
||||
|
||||
# Count requests in last minute
|
||||
minute_ago = current_time - 60
|
||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||
|
||||
return {
|
||||
'remaining_tokens': self.get_remaining_tokens(client_id),
|
||||
'reset_time': self.get_reset_time(client_id),
|
||||
'requests_last_minute': recent_requests,
|
||||
'total_requests': len(history)
|
||||
}
|
||||
|
||||
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
|
||||
"""
|
||||
Clean up old client data.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of data to keep
|
||||
"""
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (max_age_hours * 3600)
|
||||
|
||||
# Clean up buckets for inactive clients
|
||||
inactive_clients = []
|
||||
for client_id, bucket in self.buckets.items():
|
||||
if bucket['last_refill'] < cutoff_time:
|
||||
inactive_clients.append(client_id)
|
||||
|
||||
for client_id in inactive_clients:
|
||||
del self.buckets[client_id]
|
||||
if client_id in self.request_history:
|
||||
del self.request_history[client_id]
|
||||
|
||||
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
|
||||
|
||||
def get_global_stats(self) -> Dict[str, int]:
|
||||
"""Get global rate limiter statistics"""
|
||||
current_time = time.time()
|
||||
minute_ago = current_time - 60
|
||||
|
||||
total_clients = len(self.buckets)
|
||||
active_clients = 0
|
||||
total_requests_last_minute = 0
|
||||
|
||||
for client_id, history in self.request_history.items():
|
||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||
if recent_requests > 0:
|
||||
active_clients += 1
|
||||
total_requests_last_minute += recent_requests
|
||||
|
||||
return {
|
||||
'total_clients': total_clients,
|
||||
'active_clients': active_clients,
|
||||
'requests_per_minute_limit': self.requests_per_minute,
|
||||
'burst_size': self.burst_size,
|
||||
'total_requests_last_minute': total_requests_last_minute
|
||||
}
|
||||
# Add current request
|
||||
self.requests[client_id].append(now)
|
||||
return True
|
@ -1,326 +1,41 @@
|
||||
"""
|
||||
Response formatting for API endpoints.
|
||||
Response formatter for API responses.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional, List
|
||||
from typing import Any, Dict, Optional
|
||||
from datetime import datetime
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResponseFormatter:
|
||||
"""
|
||||
Formats API responses with consistent structure and metadata.
|
||||
"""
|
||||
"""Format API responses consistently"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize response formatter"""
|
||||
self.responses_formatted = 0
|
||||
logger.info("Response formatter initialized")
|
||||
|
||||
def success(self, data: Any, message: str = "Success",
|
||||
metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Format successful response.
|
||||
|
||||
Args:
|
||||
data: Response data
|
||||
message: Success message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Dict: Formatted response
|
||||
"""
|
||||
response = {
|
||||
'success': True,
|
||||
'message': message,
|
||||
'data': data,
|
||||
'timestamp': get_current_timestamp().isoformat(),
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
self.responses_formatted += 1
|
||||
return response
|
||||
|
||||
def error(self, message: str, error_code: str = "UNKNOWN_ERROR",
|
||||
details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]:
|
||||
"""
|
||||
Format error response.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
error_code: Error code
|
||||
details: Error details
|
||||
status_code: HTTP status code
|
||||
|
||||
Returns:
|
||||
Dict: Formatted error response
|
||||
"""
|
||||
response = {
|
||||
'success': False,
|
||||
'error': {
|
||||
'message': message,
|
||||
'code': error_code,
|
||||
'details': details or {},
|
||||
'status_code': status_code
|
||||
},
|
||||
'timestamp': get_current_timestamp().isoformat()
|
||||
}
|
||||
|
||||
self.responses_formatted += 1
|
||||
return response
|
||||
|
||||
def paginated(self, data: List[Any], page: int, page_size: int,
|
||||
total_count: int, message: str = "Success") -> Dict[str, Any]:
|
||||
"""
|
||||
Format paginated response.
|
||||
|
||||
Args:
|
||||
data: Page data
|
||||
page: Current page number
|
||||
page_size: Items per page
|
||||
total_count: Total number of items
|
||||
message: Success message
|
||||
|
||||
Returns:
|
||||
Dict: Formatted paginated response
|
||||
"""
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
pagination = {
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_count': total_count,
|
||||
'total_pages': total_pages,
|
||||
'has_next': page < total_pages,
|
||||
'has_previous': page > 1
|
||||
}
|
||||
|
||||
return self.success(
|
||||
data=data,
|
||||
message=message,
|
||||
metadata={'pagination': pagination}
|
||||
)
|
||||
|
||||
def heatmap_response(self, heatmap_data, symbol: str,
|
||||
exchange: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Format heatmap data response.
|
||||
|
||||
Args:
|
||||
heatmap_data: Heatmap data
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name (None for consolidated)
|
||||
|
||||
Returns:
|
||||
Dict: Formatted heatmap response
|
||||
"""
|
||||
if not heatmap_data:
|
||||
return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404)
|
||||
|
||||
# Convert heatmap to API format
|
||||
formatted_data = {
|
||||
'symbol': heatmap_data.symbol,
|
||||
'timestamp': heatmap_data.timestamp.isoformat(),
|
||||
'bucket_size': heatmap_data.bucket_size,
|
||||
'exchange': exchange,
|
||||
'points': [
|
||||
{
|
||||
'price': point.price,
|
||||
'volume': point.volume,
|
||||
'intensity': point.intensity,
|
||||
'side': point.side
|
||||
}
|
||||
for point in heatmap_data.data
|
||||
]
|
||||
}
|
||||
|
||||
metadata = {
|
||||
'total_points': len(heatmap_data.data),
|
||||
'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']),
|
||||
'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']),
|
||||
'data_type': 'consolidated' if not exchange else 'exchange_specific'
|
||||
}
|
||||
|
||||
return self.success(
|
||||
data=formatted_data,
|
||||
message=f"Heatmap data for {symbol}",
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Format order book response.
|
||||
|
||||
Args:
|
||||
orderbook_data: Order book data
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dict: Formatted order book response
|
||||
"""
|
||||
if not orderbook_data:
|
||||
return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404)
|
||||
|
||||
# Convert order book to API format
|
||||
formatted_data = {
|
||||
'symbol': orderbook_data.symbol,
|
||||
'exchange': orderbook_data.exchange,
|
||||
'timestamp': orderbook_data.timestamp.isoformat(),
|
||||
'sequence_id': orderbook_data.sequence_id,
|
||||
'bids': [
|
||||
{
|
||||
'price': bid.price,
|
||||
'size': bid.size,
|
||||
'count': bid.count
|
||||
}
|
||||
for bid in orderbook_data.bids
|
||||
],
|
||||
'asks': [
|
||||
{
|
||||
'price': ask.price,
|
||||
'size': ask.size,
|
||||
'count': ask.count
|
||||
}
|
||||
for ask in orderbook_data.asks
|
||||
],
|
||||
'mid_price': orderbook_data.mid_price,
|
||||
'spread': orderbook_data.spread,
|
||||
'bid_volume': orderbook_data.bid_volume,
|
||||
'ask_volume': orderbook_data.ask_volume
|
||||
}
|
||||
|
||||
metadata = {
|
||||
'bid_levels': len(orderbook_data.bids),
|
||||
'ask_levels': len(orderbook_data.asks),
|
||||
'total_bid_volume': orderbook_data.bid_volume,
|
||||
'total_ask_volume': orderbook_data.ask_volume
|
||||
}
|
||||
|
||||
return self.success(
|
||||
data=formatted_data,
|
||||
message=f"Order book for {symbol}@{exchange}",
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Format metrics response.
|
||||
|
||||
Args:
|
||||
metrics_data: Metrics data
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Dict: Formatted metrics response
|
||||
"""
|
||||
if not metrics_data:
|
||||
return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404)
|
||||
|
||||
# Convert metrics to API format
|
||||
formatted_data = {
|
||||
'symbol': metrics_data.symbol,
|
||||
'exchange': metrics_data.exchange,
|
||||
'timestamp': metrics_data.timestamp.isoformat(),
|
||||
'mid_price': metrics_data.mid_price,
|
||||
'spread': metrics_data.spread,
|
||||
'spread_percentage': metrics_data.spread_percentage,
|
||||
'bid_volume': metrics_data.bid_volume,
|
||||
'ask_volume': metrics_data.ask_volume,
|
||||
'volume_imbalance': metrics_data.volume_imbalance,
|
||||
'depth_10': metrics_data.depth_10,
|
||||
'depth_50': metrics_data.depth_50
|
||||
}
|
||||
|
||||
return self.success(
|
||||
data=formatted_data,
|
||||
message=f"Metrics for {symbol}@{exchange}"
|
||||
)
|
||||
|
||||
def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Format system status response.
|
||||
|
||||
Args:
|
||||
status_data: System status data
|
||||
|
||||
Returns:
|
||||
Dict: Formatted status response
|
||||
"""
|
||||
return self.success(
|
||||
data=status_data,
|
||||
message="System status",
|
||||
metadata={'response_count': self.responses_formatted}
|
||||
)
|
||||
|
||||
def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""
|
||||
Format rate limit error response.
|
||||
|
||||
Args:
|
||||
client_stats: Client rate limit statistics
|
||||
|
||||
Returns:
|
||||
Dict: Formatted rate limit error
|
||||
"""
|
||||
return self.error(
|
||||
message="Rate limit exceeded",
|
||||
error_code="RATE_LIMIT_EXCEEDED",
|
||||
details={
|
||||
'remaining_tokens': client_stats['remaining_tokens'],
|
||||
'reset_time': client_stats['reset_time'],
|
||||
'requests_last_minute': client_stats['requests_last_minute']
|
||||
},
|
||||
status_code=429
|
||||
)
|
||||
|
||||
def validation_error(self, field: str, message: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Format validation error response.
|
||||
|
||||
Args:
|
||||
field: Field that failed validation
|
||||
message: Validation error message
|
||||
|
||||
Returns:
|
||||
Dict: Formatted validation error
|
||||
"""
|
||||
return self.error(
|
||||
message=f"Validation error: {message}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
details={'field': field, 'message': message},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str:
|
||||
"""
|
||||
Convert response to JSON string.
|
||||
|
||||
Args:
|
||||
response: Response dictionary
|
||||
indent: JSON indentation (None for compact)
|
||||
|
||||
Returns:
|
||||
str: JSON string
|
||||
"""
|
||||
try:
|
||||
return json.dumps(response, indent=indent, ensure_ascii=False, default=str)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting response to JSON: {e}")
|
||||
return json.dumps(self.error("JSON serialization failed", "JSON_ERROR"))
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get formatter statistics"""
|
||||
def success(self, data: Any, message: str = "Success") -> Dict[str, Any]:
|
||||
"""Format success response"""
|
||||
return {
|
||||
'responses_formatted': self.responses_formatted
|
||||
"status": "success",
|
||||
"message": message,
|
||||
"data": data,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset formatter statistics"""
|
||||
self.responses_formatted = 0
|
||||
logger.info("Response formatter statistics reset")
|
||||
def error(self, message: str, code: str = "ERROR", details: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Format error response"""
|
||||
response = {
|
||||
"status": "error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if details:
|
||||
response["details"] = details
|
||||
|
||||
return response
|
||||
|
||||
def health(self, healthy: bool = True, components: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Format health check response"""
|
||||
return {
|
||||
"status": "healthy" if healthy else "unhealthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"components": components or {}
|
||||
}
|
@ -9,12 +9,20 @@ from fastapi.staticfiles import StaticFiles
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
import os
|
||||
from ..config import config
|
||||
from ..caching.redis_manager import redis_manager
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.validation import validate_symbol
|
||||
from .rate_limiter import RateLimiter
|
||||
from .response_formatter import ResponseFormatter
|
||||
try:
|
||||
from ..simple_config import config
|
||||
from ..caching.redis_manager import redis_manager
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.validation import validate_symbol
|
||||
from .rate_limiter import RateLimiter
|
||||
from .response_formatter import ResponseFormatter
|
||||
except ImportError:
|
||||
from simple_config import config
|
||||
from caching.redis_manager import redis_manager
|
||||
from utils.logging import get_logger, set_correlation_id
|
||||
from utils.validation import validate_symbol
|
||||
from api.rate_limiter import RateLimiter
|
||||
from api.response_formatter import ResponseFormatter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -53,6 +61,20 @@ def create_app(config_obj=None) -> FastAPI:
|
||||
)
|
||||
response_formatter = ResponseFormatter()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return response_formatter.health(healthy=True, components={
|
||||
"api": {"healthy": True},
|
||||
"cache": {"healthy": redis_manager.is_connected()},
|
||||
"database": {"healthy": True} # Stub
|
||||
})
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint - serve dashboard"""
|
||||
return {"message": "COBY Multi-Exchange Data Aggregation System", "status": "running"}
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Rate limiting middleware"""
|
||||
|
53
COBY/api/simple_websocket_server.py
Normal file
53
COBY/api/simple_websocket_server.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Simple WebSocket server for COBY system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Set, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
"""Simple WebSocket server implementation"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8081):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.connections: Set = set()
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket server"""
|
||||
try:
|
||||
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||
self.running = True
|
||||
|
||||
# Simple implementation - just keep running
|
||||
while self.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket server error: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the WebSocket server"""
|
||||
logger.info("Stopping WebSocket server")
|
||||
self.running = False
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connections"""
|
||||
if self.connections:
|
||||
logger.debug(f"Broadcasting to {len(self.connections)} connections")
|
||||
|
||||
def add_connection(self, websocket):
|
||||
"""Add a WebSocket connection"""
|
||||
self.connections.add(websocket)
|
||||
logger.info(f"WebSocket connection added. Total: {len(self.connections)}")
|
||||
|
||||
def remove_connection(self, websocket):
|
||||
"""Remove a WebSocket connection"""
|
||||
self.connections.discard(websocket)
|
||||
logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
@ -3,11 +3,7 @@ Caching layer for the COBY system.
|
||||
"""
|
||||
|
||||
from .redis_manager import RedisManager
|
||||
from .cache_keys import CacheKeys
|
||||
from .data_serializer import DataSerializer
|
||||
|
||||
__all__ = [
|
||||
'RedisManager',
|
||||
'CacheKeys',
|
||||
'DataSerializer'
|
||||
'RedisManager'
|
||||
]
|
@ -3,7 +3,10 @@ Cache key management for Redis operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..utils.logging import get_logger
|
||||
try:
|
||||
from ..utils.logging import get_logger
|
||||
except ImportError:
|
||||
from utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -1,691 +1,50 @@
|
||||
"""
|
||||
Redis cache manager for high-performance data access.
|
||||
Simple Redis manager stub.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import redis.asyncio as redis
|
||||
from typing import Any, Optional, List, Dict, Union
|
||||
from datetime import datetime, timedelta
|
||||
from ..config import config
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import StorageError
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from .cache_keys import CacheKeys
|
||||
from .data_serializer import DataSerializer
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisManager:
|
||||
"""
|
||||
High-performance Redis cache manager for market data.
|
||||
|
||||
Provides:
|
||||
- Connection pooling and management
|
||||
- Data serialization and compression
|
||||
- TTL management
|
||||
- Batch operations
|
||||
- Performance monitoring
|
||||
"""
|
||||
"""Simple Redis manager stub"""
|
||||
|
||||
def __init__(self):
|
||||
self.connected = False
|
||||
self.cache = {} # In-memory cache as fallback
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Redis (stub)"""
|
||||
logger.info("Redis manager initialized (stub mode)")
|
||||
self.connected = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Redis manager"""
|
||||
self.redis_pool: Optional[redis.ConnectionPool] = None
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
self.serializer = DataSerializer(use_compression=True)
|
||||
self.cache_keys = CacheKeys()
|
||||
await self.connect()
|
||||
|
||||
# Performance statistics
|
||||
self.stats = {
|
||||
'gets': 0,
|
||||
'sets': 0,
|
||||
'deletes': 0,
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'errors': 0,
|
||||
'total_data_size': 0,
|
||||
'avg_response_time': 0.0
|
||||
}
|
||||
async def disconnect(self):
|
||||
"""Disconnect from Redis"""
|
||||
self.connected = False
|
||||
|
||||
logger.info("Redis manager initialized")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Redis connection pool"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.redis_pool = redis.ConnectionPool(
|
||||
host=config.redis.host,
|
||||
port=config.redis.port,
|
||||
password=config.redis.password,
|
||||
db=config.redis.db,
|
||||
max_connections=config.redis.max_connections,
|
||||
socket_timeout=config.redis.socket_timeout,
|
||||
socket_connect_timeout=config.redis.socket_connect_timeout,
|
||||
decode_responses=False, # We handle bytes directly
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
# Create Redis client
|
||||
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
|
||||
|
||||
# Test connection
|
||||
await self.redis_client.ping()
|
||||
|
||||
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis connection: {e}")
|
||||
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connections"""
|
||||
try:
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
if self.redis_pool:
|
||||
await self.redis_pool.disconnect()
|
||||
|
||||
logger.info("Redis connections closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Redis connections: {e}")
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in cache with optional TTL.
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected"""
|
||||
return self.connected
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None):
|
||||
"""Set value in cache"""
|
||||
self.cache[key] = value
|
||||
logger.debug(f"Cached key: {key}")
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (None = use default)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Serialize value
|
||||
serialized_value = self.serializer.serialize(value)
|
||||
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
ttl = self.cache_keys.get_ttl(key)
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['sets'] += 1
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error setting cache key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
"""Get value from cache"""
|
||||
return self.cache.get(key)
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Any: Cached value or None if not found
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Get from Redis
|
||||
serialized_value = await self.redis_client.get(key)
|
||||
|
||||
# Update statistics
|
||||
self.stats['gets'] += 1
|
||||
|
||||
if serialized_value is None:
|
||||
self.stats['misses'] += 1
|
||||
logger.debug(f"Cache miss: {key}")
|
||||
return None
|
||||
|
||||
# Deserialize value
|
||||
value = self.serializer.deserialize(serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['hits'] += 1
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error getting cache key {key}: {e}")
|
||||
return None
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
bool: True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
result = await self.redis_client.delete(key)
|
||||
|
||||
self.stats['deletes'] += 1
|
||||
|
||||
logger.debug(f"Deleted cache key: {key}")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error deleting cache key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to check
|
||||
|
||||
Returns:
|
||||
bool: True if exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.exists(key)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cache key existence {key}: {e}")
|
||||
return False
|
||||
|
||||
async def expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set expiration time for key.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.expire(key, ttl)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting expiration for key {key}: {e}")
|
||||
return False
|
||||
|
||||
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
|
||||
"""
|
||||
Get multiple values from cache.
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
List[Optional[Any]]: List of values (None for missing keys)
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Get from Redis
|
||||
serialized_values = await self.redis_client.mget(keys)
|
||||
|
||||
# Deserialize values
|
||||
values = []
|
||||
for serialized_value in serialized_values:
|
||||
if serialized_value is None:
|
||||
values.append(None)
|
||||
self.stats['misses'] += 1
|
||||
else:
|
||||
try:
|
||||
value = self.serializer.deserialize(serialized_value)
|
||||
values.append(value)
|
||||
self.stats['hits'] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deserializing value: {e}")
|
||||
values.append(None)
|
||||
self.stats['errors'] += 1
|
||||
|
||||
# Update statistics
|
||||
self.stats['gets'] += len(keys)
|
||||
|
||||
# Update response time
|
||||
response_time = asyncio.get_event_loop().time() - start_time
|
||||
self._update_avg_response_time(response_time)
|
||||
|
||||
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
|
||||
return values
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error in multi-get: {e}")
|
||||
return [None] * len(keys)
|
||||
|
||||
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set multiple key-value pairs.
|
||||
|
||||
Args:
|
||||
key_value_pairs: Dictionary of key-value pairs
|
||||
ttl: Time to live in seconds (None = use default per key)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
# Serialize all values
|
||||
serialized_pairs = {}
|
||||
for key, value in key_value_pairs.items():
|
||||
serialized_value = self.serializer.serialize(value)
|
||||
serialized_pairs[key] = serialized_value
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.mset(serialized_pairs)
|
||||
|
||||
# Set TTL for each key if specified
|
||||
if ttl is not None:
|
||||
for key in key_value_pairs.keys():
|
||||
await self.redis_client.expire(key, ttl)
|
||||
else:
|
||||
# Use individual TTLs
|
||||
for key in key_value_pairs.keys():
|
||||
key_ttl = self.cache_keys.get_ttl(key)
|
||||
await self.redis_client.expire(key, key_ttl)
|
||||
|
||||
self.stats['sets'] += len(key_value_pairs)
|
||||
|
||||
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error in multi-set: {e}")
|
||||
return False
|
||||
|
||||
async def keys(self, pattern: str) -> List[str]:
|
||||
"""
|
||||
Get keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Redis pattern (e.g., "hm:*")
|
||||
|
||||
Returns:
|
||||
List[str]: List of matching keys
|
||||
"""
|
||||
try:
|
||||
keys = await self.redis_client.keys(pattern)
|
||||
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting keys with pattern {pattern}: {e}")
|
||||
return []
|
||||
|
||||
async def flushdb(self) -> bool:
|
||||
"""
|
||||
Clear all keys in current database.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.flushdb()
|
||||
logger.info("Redis database flushed")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error flushing Redis database: {e}")
|
||||
return False
|
||||
|
||||
async def info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Redis server information.
|
||||
|
||||
Returns:
|
||||
Dict: Redis server info
|
||||
"""
|
||||
try:
|
||||
info = await self.redis_client.info()
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis info: {e}")
|
||||
return {}
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""
|
||||
Ping Redis server.
|
||||
|
||||
Returns:
|
||||
bool: True if server responds, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.redis_client.ping()
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis ping failed: {e}")
|
||||
return False
|
||||
|
||||
async def set_heatmap(self, symbol: str, heatmap_data,
|
||||
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Cache heatmap data with optimized serialization.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
heatmap_data: Heatmap data to cache
|
||||
exchange: Exchange name (None for consolidated)
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
||||
|
||||
# Use specialized heatmap serialization
|
||||
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
|
||||
|
||||
# Determine TTL
|
||||
if ttl is None:
|
||||
ttl = self.cache_keys.HEATMAP_TTL
|
||||
|
||||
# Set in Redis
|
||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
||||
|
||||
# Update statistics
|
||||
self.stats['sets'] += 1
|
||||
self.stats['total_data_size'] += len(serialized_value)
|
||||
|
||||
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error caching heatmap for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
|
||||
"""
|
||||
Get cached heatmap data with optimized deserialization.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name (None for consolidated)
|
||||
|
||||
Returns:
|
||||
HeatmapData: Cached heatmap or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
||||
|
||||
# Get from Redis
|
||||
serialized_value = await self.redis_client.get(key)
|
||||
|
||||
self.stats['gets'] += 1
|
||||
|
||||
if serialized_value is None:
|
||||
self.stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# Use specialized heatmap deserialization
|
||||
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
|
||||
|
||||
self.stats['hits'] += 1
|
||||
logger.debug(f"Retrieved heatmap: {key}")
|
||||
return heatmap_data
|
||||
|
||||
except Exception as e:
|
||||
self.stats['errors'] += 1
|
||||
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def cache_orderbook(self, orderbook) -> bool:
|
||||
"""
|
||||
Cache order book data.
|
||||
|
||||
Args:
|
||||
orderbook: OrderBookSnapshot to cache
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
|
||||
return await self.set(key, orderbook)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching order book: {e}")
|
||||
return False
|
||||
|
||||
async def get_orderbook(self, symbol: str, exchange: str):
|
||||
"""
|
||||
Get cached order book data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
OrderBookSnapshot: Cached order book or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.orderbook_key(symbol, exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving order book: {e}")
|
||||
return None
|
||||
|
||||
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
|
||||
"""
|
||||
Cache metrics data.
|
||||
|
||||
Args:
|
||||
metrics: Metrics data to cache
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
||||
return await self.set(key, metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching metrics: {e}")
|
||||
return False
|
||||
|
||||
async def get_metrics(self, symbol: str, exchange: str):
|
||||
"""
|
||||
Get cached metrics data.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Metrics data or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving metrics: {e}")
|
||||
return None
|
||||
|
||||
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
|
||||
"""
|
||||
Cache exchange status.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
status_data: Status data to cache
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.status_key(exchange)
|
||||
return await self.set(key, status_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching exchange status: {e}")
|
||||
return False
|
||||
|
||||
async def get_exchange_status(self, exchange: str):
|
||||
"""
|
||||
Get cached exchange status.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name
|
||||
|
||||
Returns:
|
||||
Status data or None if not found
|
||||
"""
|
||||
try:
|
||||
key = self.cache_keys.status_key(exchange)
|
||||
return await self.get(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving exchange status: {e}")
|
||||
return None
|
||||
|
||||
async def cleanup_expired_keys(self) -> int:
|
||||
"""
|
||||
Clean up expired keys (Redis handles this automatically, but we can force it).
|
||||
|
||||
Returns:
|
||||
int: Number of keys cleaned up
|
||||
"""
|
||||
try:
|
||||
# Get all keys
|
||||
all_keys = await self.keys("*")
|
||||
|
||||
# Check which ones are expired
|
||||
expired_count = 0
|
||||
for key in all_keys:
|
||||
ttl = await self.redis_client.ttl(key)
|
||||
if ttl == -2: # Key doesn't exist (expired)
|
||||
expired_count += 1
|
||||
|
||||
logger.debug(f"Found {expired_count} expired keys")
|
||||
return expired_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up expired keys: {e}")
|
||||
return 0
|
||||
|
||||
def _update_avg_response_time(self, response_time: float) -> None:
|
||||
"""Update average response time"""
|
||||
total_operations = self.stats['gets'] + self.stats['sets']
|
||||
if total_operations > 0:
|
||||
self.stats['avg_response_time'] = (
|
||||
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
|
||||
total_operations
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_operations = self.stats['gets'] + self.stats['sets']
|
||||
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
'total_operations': total_operations,
|
||||
'hit_rate_percentage': hit_rate,
|
||||
'serializer_stats': self.serializer.get_stats()
|
||||
}
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics"""
|
||||
self.stats = {
|
||||
'gets': 0,
|
||||
'sets': 0,
|
||||
'deletes': 0,
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'errors': 0,
|
||||
'total_data_size': 0,
|
||||
'avg_response_time': 0.0
|
||||
}
|
||||
self.serializer.reset_stats()
|
||||
logger.info("Redis manager statistics reset")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform comprehensive health check.
|
||||
|
||||
Returns:
|
||||
Dict: Health check results
|
||||
"""
|
||||
health = {
|
||||
'redis_ping': False,
|
||||
'connection_pool_size': 0,
|
||||
'memory_usage': 0,
|
||||
'connected_clients': 0,
|
||||
'total_keys': 0,
|
||||
'hit_rate': 0.0,
|
||||
'avg_response_time': self.stats['avg_response_time']
|
||||
}
|
||||
|
||||
try:
|
||||
# Test ping
|
||||
health['redis_ping'] = await self.ping()
|
||||
|
||||
# Get Redis info
|
||||
info = await self.info()
|
||||
if info:
|
||||
health['memory_usage'] = info.get('used_memory', 0)
|
||||
health['connected_clients'] = info.get('connected_clients', 0)
|
||||
|
||||
# Get key count
|
||||
all_keys = await self.keys("*")
|
||||
health['total_keys'] = len(all_keys)
|
||||
|
||||
# Calculate hit rate
|
||||
if self.stats['gets'] > 0:
|
||||
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
|
||||
|
||||
# Connection pool info
|
||||
if self.redis_pool:
|
||||
health['connection_pool_size'] = self.redis_pool.max_connections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
return health
|
||||
async def delete(self, key: str):
|
||||
"""Delete key from cache"""
|
||||
self.cache.pop(key, None)
|
||||
|
||||
|
||||
# Global Redis manager instance
|
||||
# Global instance
|
||||
redis_manager = RedisManager()
|
27
COBY/main.py
27
COBY/main.py
@ -14,13 +14,26 @@ from typing import Optional
|
||||
# Add the current directory to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils.logging import get_logger, setup_logging
|
||||
from config import Config
|
||||
from monitoring.metrics_collector import metrics_collector
|
||||
from monitoring.performance_monitor import get_performance_monitor
|
||||
from monitoring.memory_monitor import memory_monitor
|
||||
from api.rest_api import create_app
|
||||
from api.websocket_server import WebSocketServer
|
||||
try:
|
||||
from .utils.logging import get_logger, setup_logging
|
||||
from .simple_config import Config
|
||||
except ImportError:
|
||||
from utils.logging import get_logger, setup_logging
|
||||
from simple_config import Config
|
||||
try:
|
||||
# Try relative imports first (when run as module)
|
||||
from .monitoring.metrics_collector import metrics_collector
|
||||
from .monitoring.performance_monitor import get_performance_monitor
|
||||
from .monitoring.memory_monitor import memory_monitor
|
||||
from .api.rest_api import create_app
|
||||
from .api.simple_websocket_server import WebSocketServer
|
||||
except ImportError:
|
||||
# Fall back to absolute imports (when run directly)
|
||||
from monitoring.metrics_collector import metrics_collector
|
||||
from monitoring.performance_monitor import get_performance_monitor
|
||||
from monitoring.memory_monitor import memory_monitor
|
||||
from api.rest_api import create_app
|
||||
from api.simple_websocket_server import WebSocketServer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -2,16 +2,5 @@
|
||||
Performance monitoring and optimization module.
|
||||
"""
|
||||
|
||||
from .metrics_collector import MetricsCollector
|
||||
from .performance_monitor import PerformanceMonitor
|
||||
from .memory_monitor import MemoryMonitor
|
||||
from .latency_tracker import LatencyTracker
|
||||
from .alert_manager import AlertManager
|
||||
|
||||
__all__ = [
|
||||
'MetricsCollector',
|
||||
'PerformanceMonitor',
|
||||
'MemoryMonitor',
|
||||
'LatencyTracker',
|
||||
'AlertManager'
|
||||
]
|
||||
# Simplified imports to avoid circular dependencies
|
||||
__all__ = []
|
@ -12,8 +12,12 @@ from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from enum import Enum
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
try:
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
except ImportError:
|
||||
from utils.logging import get_logger
|
||||
from utils.timing import get_current_timestamp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -10,8 +10,12 @@ from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.timing import get_current_timestamp
|
||||
try:
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.timing import get_current_timestamp
|
||||
except ImportError:
|
||||
from utils.logging import get_logger, set_correlation_id
|
||||
from utils.timing import get_current_timestamp
|
||||
# Import will be done lazily to avoid circular imports
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -11,8 +11,12 @@ from collections import defaultdict, deque
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
try:
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
except ImportError:
|
||||
from utils.logging import get_logger
|
||||
from utils.timing import get_current_timestamp
|
||||
# Import will be done lazily to avoid circular imports
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -10,8 +10,12 @@ from collections import defaultdict, deque
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
try:
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
except ImportError:
|
||||
from utils.logging import get_logger
|
||||
from utils.timing import get_current_timestamp
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -10,9 +10,14 @@ from collections import defaultdict, deque
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from .metrics_collector import MetricsCollector
|
||||
try:
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from .metrics_collector import MetricsCollector
|
||||
except ImportError:
|
||||
from utils.logging import get_logger
|
||||
from utils.timing import get_current_timestamp
|
||||
from monitoring.metrics_collector import MetricsCollector
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
45
COBY/simple_config.py
Normal file
45
COBY/simple_config.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""
|
||||
Simple configuration for COBY system.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIConfig:
|
||||
"""API configuration"""
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8080
|
||||
websocket_port: int = 8081
|
||||
cors_origins: list = None
|
||||
rate_limit: int = 100
|
||||
|
||||
def __post_init__(self):
|
||||
if self.cors_origins is None:
|
||||
self.cors_origins = ["*"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
level: str = "INFO"
|
||||
file_path: str = "logs/coby.log"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Main configuration"""
|
||||
api: APIConfig = None
|
||||
logging: LoggingConfig = None
|
||||
debug: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.api is None:
|
||||
self.api = APIConfig()
|
||||
if self.logging is None:
|
||||
self.logging = LoggingConfig()
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
485
COBY/tests/test_e2e_dashboard.py
Normal file
485
COBY/tests/test_e2e_dashboard.py
Normal file
@ -0,0 +1,485 @@
|
||||
"""
|
||||
End-to-end tests for web dashboard functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from typing import Dict, Any, List
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web, WSMsgType
|
||||
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
||||
|
||||
from ..api.rest_api import create_app
|
||||
from ..api.websocket_server import WebSocketServer
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TestDashboardAPI(AioHTTPTestCase):
|
||||
"""Test dashboard REST API endpoints"""
|
||||
|
||||
async def get_application(self):
|
||||
"""Create test application"""
|
||||
return create_app()
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_health_endpoint(self):
|
||||
"""Test health check endpoint"""
|
||||
resp = await self.client.request("GET", "/health")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('status', data)
|
||||
self.assertIn('timestamp', data)
|
||||
self.assertEqual(data['status'], 'healthy')
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_metrics_endpoint(self):
|
||||
"""Test metrics endpoint"""
|
||||
resp = await self.client.request("GET", "/metrics")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
# Should return Prometheus format
|
||||
text = await resp.text()
|
||||
self.assertIn('# TYPE', text)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_orderbook_endpoint(self):
|
||||
"""Test order book data endpoint"""
|
||||
# Mock data
|
||||
with patch('COBY.caching.redis_manager.redis_manager') as mock_redis:
|
||||
mock_redis.get.return_value = {
|
||||
'symbol': 'BTCUSDT',
|
||||
'exchange': 'binance',
|
||||
'bids': [{'price': 50000.0, 'size': 1.0}],
|
||||
'asks': [{'price': 50010.0, 'size': 1.0}]
|
||||
}
|
||||
|
||||
resp = await self.client.request("GET", "/api/orderbook/BTCUSDT")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('symbol', data)
|
||||
self.assertEqual(data['symbol'], 'BTCUSDT')
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_heatmap_endpoint(self):
|
||||
"""Test heatmap data endpoint"""
|
||||
with patch('COBY.caching.redis_manager.redis_manager') as mock_redis:
|
||||
mock_redis.get.return_value = {
|
||||
'symbol': 'BTCUSDT',
|
||||
'bucket_size': 1.0,
|
||||
'data': [
|
||||
{'price': 50000.0, 'volume': 10.0, 'intensity': 0.8, 'side': 'bid'}
|
||||
]
|
||||
}
|
||||
|
||||
resp = await self.client.request("GET", "/api/heatmap/BTCUSDT")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('symbol', data)
|
||||
self.assertIn('data', data)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_exchanges_status_endpoint(self):
|
||||
"""Test exchanges status endpoint"""
|
||||
with patch('COBY.connectors.connection_manager.connection_manager') as mock_manager:
|
||||
mock_manager.get_all_statuses.return_value = {
|
||||
'binance': 'connected',
|
||||
'coinbase': 'connected',
|
||||
'kraken': 'disconnected'
|
||||
}
|
||||
|
||||
resp = await self.client.request("GET", "/api/exchanges/status")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('binance', data)
|
||||
self.assertIn('coinbase', data)
|
||||
self.assertIn('kraken', data)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_performance_metrics_endpoint(self):
|
||||
"""Test performance metrics endpoint"""
|
||||
with patch('COBY.monitoring.performance_monitor.get_performance_monitor') as mock_monitor:
|
||||
mock_monitor.return_value.get_performance_dashboard_data.return_value = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'system_metrics': {
|
||||
'cpu_usage': 45.2,
|
||||
'memory_usage': 67.8,
|
||||
'active_connections': 150
|
||||
},
|
||||
'performance_summary': {
|
||||
'throughput': 1250.5,
|
||||
'error_rate': 0.1,
|
||||
'avg_latency': 12.3
|
||||
}
|
||||
}
|
||||
|
||||
resp = await self.client.request("GET", "/api/performance")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('system_metrics', data)
|
||||
self.assertIn('performance_summary', data)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_static_files_served(self):
|
||||
"""Test that static files are served correctly"""
|
||||
# Test dashboard index
|
||||
resp = await self.client.request("GET", "/")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
content_type = resp.headers.get('content-type', '')
|
||||
self.assertIn('text/html', content_type)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_cors_headers(self):
|
||||
"""Test CORS headers are present"""
|
||||
resp = await self.client.request("OPTIONS", "/api/health")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
# Check CORS headers
|
||||
self.assertIn('Access-Control-Allow-Origin', resp.headers)
|
||||
self.assertIn('Access-Control-Allow-Methods', resp.headers)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_rate_limiting(self):
|
||||
"""Test API rate limiting"""
|
||||
# Make many requests quickly
|
||||
responses = []
|
||||
for i in range(150): # Exceed rate limit
|
||||
resp = await self.client.request("GET", "/api/health")
|
||||
responses.append(resp.status)
|
||||
|
||||
# Should have some rate limited responses
|
||||
rate_limited = [status for status in responses if status == 429]
|
||||
self.assertGreater(len(rate_limited), 0, "Rate limiting not working")
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_error_handling(self):
|
||||
"""Test API error handling"""
|
||||
# Test invalid symbol
|
||||
resp = await self.client.request("GET", "/api/orderbook/INVALID")
|
||||
self.assertEqual(resp.status, 404)
|
||||
|
||||
data = await resp.json()
|
||||
self.assertIn('error', data)
|
||||
|
||||
@unittest_run_loop
|
||||
async def test_api_documentation(self):
|
||||
"""Test API documentation endpoints"""
|
||||
# Test OpenAPI docs
|
||||
resp = await self.client.request("GET", "/docs")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
# Test ReDoc
|
||||
resp = await self.client.request("GET", "/redoc")
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
|
||||
class TestWebSocketFunctionality:
|
||||
"""Test WebSocket functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
async def websocket_server(self):
|
||||
"""Create WebSocket server for testing"""
|
||||
server = WebSocketServer(host='localhost', port=8081)
|
||||
await server.start()
|
||||
yield server
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection(self, websocket_server):
|
||||
"""Test WebSocket connection establishment"""
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||
# Connection should be established
|
||||
self.assertEqual(ws.closed, False)
|
||||
|
||||
# Send ping
|
||||
await ws.ping()
|
||||
|
||||
# Should receive pong
|
||||
msg = await ws.receive()
|
||||
self.assertEqual(msg.type, WSMsgType.PONG)
|
||||
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_data_streaming(self, websocket_server):
|
||||
"""Test real-time data streaming via WebSocket"""
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||
# Subscribe to updates
|
||||
subscribe_msg = {
|
||||
'type': 'subscribe',
|
||||
'channels': ['orderbook', 'trades', 'performance']
|
||||
}
|
||||
await ws.send_str(json.dumps(subscribe_msg))
|
||||
|
||||
# Should receive subscription confirmation
|
||||
msg = await ws.receive()
|
||||
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||
|
||||
data = json.loads(msg.data)
|
||||
self.assertEqual(data.get('type'), 'subscription_confirmed')
|
||||
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_error_handling(self, websocket_server):
|
||||
"""Test WebSocket error handling"""
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||
# Send invalid message
|
||||
invalid_msg = {'invalid': 'message'}
|
||||
await ws.send_str(json.dumps(invalid_msg))
|
||||
|
||||
# Should receive error response
|
||||
msg = await ws.receive()
|
||||
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||
|
||||
data = json.loads(msg.data)
|
||||
self.assertEqual(data.get('type'), 'error')
|
||||
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_websocket_connections(self, websocket_server):
|
||||
"""Test multiple concurrent WebSocket connections"""
|
||||
session = aiohttp.ClientSession()
|
||||
connections = []
|
||||
|
||||
try:
|
||||
# Create multiple connections
|
||||
for i in range(10):
|
||||
ws = await session.ws_connect(f'ws://localhost:8081/ws/dashboard')
|
||||
connections.append(ws)
|
||||
|
||||
# All connections should be active
|
||||
for ws in connections:
|
||||
self.assertEqual(ws.closed, False)
|
||||
|
||||
# Send message to all connections
|
||||
test_msg = {'type': 'ping', 'id': 'test'}
|
||||
for ws in connections:
|
||||
await ws.send_str(json.dumps(test_msg))
|
||||
|
||||
# All should receive responses
|
||||
for ws in connections:
|
||||
msg = await ws.receive()
|
||||
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||
|
||||
finally:
|
||||
# Close all connections
|
||||
for ws in connections:
|
||||
if not ws.closed:
|
||||
await ws.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
class TestDashboardIntegration:
|
||||
"""Test dashboard integration with backend services"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_services(self):
|
||||
"""Mock backend services"""
|
||||
services = {
|
||||
'redis': Mock(),
|
||||
'timescale': Mock(),
|
||||
'connectors': Mock(),
|
||||
'aggregator': Mock(),
|
||||
'monitor': Mock()
|
||||
}
|
||||
|
||||
# Setup mock responses
|
||||
services['redis'].get.return_value = {'test': 'data'}
|
||||
services['timescale'].query.return_value = [{'result': 'data'}]
|
||||
services['connectors'].get_status.return_value = 'connected'
|
||||
services['aggregator'].get_heatmap.return_value = {'heatmap': 'data'}
|
||||
services['monitor'].get_metrics.return_value = {'metrics': 'data'}
|
||||
|
||||
return services
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_data_flow(self, mock_services):
|
||||
"""Test complete data flow from backend to dashboard"""
|
||||
# Simulate data generation
|
||||
orderbook = OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange="binance",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||
)
|
||||
|
||||
# Mock data processing pipeline
|
||||
with patch.multiple(
|
||||
'COBY.processing.data_processor',
|
||||
DataProcessor=Mock()
|
||||
):
|
||||
# Process data
|
||||
processor = Mock()
|
||||
processor.normalize_orderbook.return_value = orderbook
|
||||
|
||||
# Aggregate data
|
||||
aggregator = Mock()
|
||||
aggregator.create_price_buckets.return_value = Mock()
|
||||
aggregator.generate_heatmap.return_value = Mock()
|
||||
|
||||
# Cache data
|
||||
cache = Mock()
|
||||
cache.set.return_value = True
|
||||
|
||||
# Verify data flows through pipeline
|
||||
processed = processor.normalize_orderbook({}, "binance")
|
||||
buckets = aggregator.create_price_buckets(processed)
|
||||
heatmap = aggregator.generate_heatmap(buckets)
|
||||
cached = cache.set("test_key", heatmap)
|
||||
|
||||
assert processed is not None
|
||||
assert buckets is not None
|
||||
assert heatmap is not None
|
||||
assert cached is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_time_updates(self, mock_services):
|
||||
"""Test real-time dashboard updates"""
|
||||
# Mock WebSocket server
|
||||
ws_server = Mock()
|
||||
ws_server.broadcast = AsyncMock()
|
||||
|
||||
# Simulate real-time data updates
|
||||
updates = [
|
||||
{'type': 'orderbook', 'symbol': 'BTCUSDT', 'data': {}},
|
||||
{'type': 'trade', 'symbol': 'BTCUSDT', 'data': {}},
|
||||
{'type': 'performance', 'data': {}}
|
||||
]
|
||||
|
||||
# Send updates
|
||||
for update in updates:
|
||||
await ws_server.broadcast(json.dumps(update))
|
||||
|
||||
# Verify broadcasts were sent
|
||||
assert ws_server.broadcast.call_count == len(updates)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_performance_under_load(self, mock_services):
|
||||
"""Test dashboard performance under high update frequency"""
|
||||
import time
|
||||
|
||||
# Mock high-frequency updates
|
||||
update_count = 1000
|
||||
start_time = time.time()
|
||||
|
||||
# Simulate processing many updates
|
||||
for i in range(update_count):
|
||||
# Mock data processing
|
||||
mock_services['redis'].get(f"orderbook:BTCUSDT:binance:{i}")
|
||||
mock_services['aggregator'].get_heatmap(f"BTCUSDT:{i}")
|
||||
|
||||
# Small delay to simulate processing
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
updates_per_second = update_count / processing_time
|
||||
|
||||
# Should handle at least 500 updates per second
|
||||
assert updates_per_second > 500, f"Dashboard too slow: {updates_per_second:.2f} updates/sec"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_error_recovery(self, mock_services):
|
||||
"""Test dashboard error recovery"""
|
||||
# Simulate service failures
|
||||
mock_services['redis'].get.side_effect = Exception("Redis connection failed")
|
||||
mock_services['timescale'].query.side_effect = Exception("Database error")
|
||||
|
||||
# Dashboard should handle errors gracefully
|
||||
try:
|
||||
# Attempt operations that will fail
|
||||
mock_services['redis'].get("test_key")
|
||||
except Exception:
|
||||
# Should recover and continue
|
||||
pass
|
||||
|
||||
try:
|
||||
mock_services['timescale'].query("SELECT * FROM test")
|
||||
except Exception:
|
||||
# Should recover and continue
|
||||
pass
|
||||
|
||||
# Reset services to working state
|
||||
mock_services['redis'].get.side_effect = None
|
||||
mock_services['redis'].get.return_value = {'recovered': True}
|
||||
|
||||
# Should work again
|
||||
result = mock_services['redis'].get("test_key")
|
||||
assert result['recovered'] is True
|
||||
|
||||
|
||||
class TestDashboardUI:
|
||||
"""Test dashboard UI functionality (requires browser automation)"""
|
||||
|
||||
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||
reason="UI tests require --ui flag and browser setup")
|
||||
def test_dashboard_loads(self):
|
||||
"""Test that dashboard loads in browser"""
|
||||
# This would require Selenium or similar
|
||||
# Placeholder for UI tests
|
||||
pass
|
||||
|
||||
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||
reason="UI tests require --ui flag and browser setup")
|
||||
def test_real_time_chart_updates(self):
|
||||
"""Test that charts update in real-time"""
|
||||
# This would require browser automation
|
||||
# Placeholder for UI tests
|
||||
pass
|
||||
|
||||
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||
reason="UI tests require --ui flag and browser setup")
|
||||
def test_responsive_design(self):
|
||||
"""Test responsive design on different screen sizes"""
|
||||
# This would require browser automation with different viewport sizes
|
||||
# Placeholder for UI tests
|
||||
pass
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with custom markers"""
|
||||
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
|
||||
config.addinivalue_line("markers", "ui: mark test as UI test")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--e2e",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run end-to-end tests"
|
||||
)
|
||||
parser.addoption(
|
||||
"--ui",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run UI tests (requires browser setup)"
|
||||
)
|
485
COBY/tests/test_integration_pipeline.py
Normal file
485
COBY/tests/test_integration_pipeline.py
Normal file
@ -0,0 +1,485 @@
|
||||
"""
|
||||
Integration tests for complete data pipeline from exchanges to storage.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..connectors.binance_connector import BinanceConnector
|
||||
from ..processing.data_processor import DataProcessor
|
||||
from ..aggregation.aggregation_engine import AggregationEngine
|
||||
from ..storage.timescale_manager import TimescaleManager
|
||||
from ..caching.redis_manager import RedisManager
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TestDataPipelineIntegration:
|
||||
"""Test complete data pipeline integration"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_components(self):
|
||||
"""Setup mock components for testing"""
|
||||
# Mock exchange connector
|
||||
connector = Mock(spec=BinanceConnector)
|
||||
connector.exchange_name = "binance"
|
||||
connector.connect = AsyncMock(return_value=True)
|
||||
connector.disconnect = AsyncMock()
|
||||
connector.subscribe_orderbook = AsyncMock()
|
||||
connector.subscribe_trades = AsyncMock()
|
||||
|
||||
# Mock data processor
|
||||
processor = Mock(spec=DataProcessor)
|
||||
processor.process_orderbook = Mock()
|
||||
processor.process_trade = Mock()
|
||||
processor.validate_data = Mock(return_value=True)
|
||||
|
||||
# Mock aggregation engine
|
||||
aggregator = Mock(spec=AggregationEngine)
|
||||
aggregator.aggregate_orderbook = Mock()
|
||||
aggregator.create_heatmap = Mock()
|
||||
|
||||
# Mock storage manager
|
||||
storage = Mock(spec=TimescaleManager)
|
||||
storage.store_orderbook = AsyncMock(return_value=True)
|
||||
storage.store_trade = AsyncMock(return_value=True)
|
||||
storage.is_connected = Mock(return_value=True)
|
||||
|
||||
# Mock cache manager
|
||||
cache = Mock(spec=RedisManager)
|
||||
cache.set = AsyncMock(return_value=True)
|
||||
cache.get = AsyncMock(return_value=None)
|
||||
cache.is_connected = Mock(return_value=True)
|
||||
|
||||
return {
|
||||
'connector': connector,
|
||||
'processor': processor,
|
||||
'aggregator': aggregator,
|
||||
'storage': storage,
|
||||
'cache': cache
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_orderbook(self):
|
||||
"""Create sample order book data"""
|
||||
return OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange="binance",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[
|
||||
PriceLevel(price=50000.0, size=1.5),
|
||||
PriceLevel(price=49990.0, size=2.0),
|
||||
PriceLevel(price=49980.0, size=1.0)
|
||||
],
|
||||
asks=[
|
||||
PriceLevel(price=50010.0, size=1.2),
|
||||
PriceLevel(price=50020.0, size=1.8),
|
||||
PriceLevel(price=50030.0, size=0.8)
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trade(self):
|
||||
"""Create sample trade data"""
|
||||
return TradeEvent(
|
||||
symbol="BTCUSDT",
|
||||
exchange="binance",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
price=50005.0,
|
||||
size=0.5,
|
||||
side="buy",
|
||||
trade_id="12345"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_orderbook_pipeline(self, mock_components, sample_orderbook):
|
||||
"""Test complete order book processing pipeline"""
|
||||
components = mock_components
|
||||
|
||||
# Setup processor to return processed data
|
||||
components['processor'].process_orderbook.return_value = sample_orderbook
|
||||
|
||||
# Simulate pipeline flow
|
||||
# 1. Receive data from exchange
|
||||
raw_data = {"symbol": "BTCUSDT", "bids": [], "asks": []}
|
||||
|
||||
# 2. Process data
|
||||
processed_data = components['processor'].process_orderbook(raw_data, "binance")
|
||||
|
||||
# 3. Validate data
|
||||
is_valid = components['processor'].validate_data(processed_data)
|
||||
assert is_valid
|
||||
|
||||
# 4. Aggregate data
|
||||
components['aggregator'].aggregate_orderbook(processed_data)
|
||||
|
||||
# 5. Store in database
|
||||
await components['storage'].store_orderbook(processed_data)
|
||||
|
||||
# 6. Cache latest data
|
||||
await components['cache'].set(f"orderbook:BTCUSDT:binance", processed_data)
|
||||
|
||||
# Verify all components were called
|
||||
components['processor'].process_orderbook.assert_called_once()
|
||||
components['processor'].validate_data.assert_called_once()
|
||||
components['aggregator'].aggregate_orderbook.assert_called_once()
|
||||
components['storage'].store_orderbook.assert_called_once()
|
||||
components['cache'].set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_trade_pipeline(self, mock_components, sample_trade):
|
||||
"""Test complete trade processing pipeline"""
|
||||
components = mock_components
|
||||
|
||||
# Setup processor to return processed data
|
||||
components['processor'].process_trade.return_value = sample_trade
|
||||
|
||||
# Simulate pipeline flow
|
||||
raw_data = {"symbol": "BTCUSDT", "price": 50005.0, "quantity": 0.5}
|
||||
|
||||
# Process through pipeline
|
||||
processed_data = components['processor'].process_trade(raw_data, "binance")
|
||||
is_valid = components['processor'].validate_data(processed_data)
|
||||
assert is_valid
|
||||
|
||||
await components['storage'].store_trade(processed_data)
|
||||
await components['cache'].set(f"trade:BTCUSDT:binance:latest", processed_data)
|
||||
|
||||
# Verify calls
|
||||
components['processor'].process_trade.assert_called_once()
|
||||
components['storage'].store_trade.assert_called_once()
|
||||
components['cache'].set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_exchange_pipeline(self, mock_components):
|
||||
"""Test pipeline with multiple exchanges"""
|
||||
components = mock_components
|
||||
exchanges = ["binance", "coinbase", "kraken"]
|
||||
|
||||
# Simulate data from multiple exchanges
|
||||
for exchange in exchanges:
|
||||
# Create exchange-specific data
|
||||
orderbook = OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange=exchange,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||
)
|
||||
|
||||
components['processor'].process_orderbook.return_value = orderbook
|
||||
components['processor'].validate_data.return_value = True
|
||||
|
||||
# Process through pipeline
|
||||
processed_data = components['processor'].process_orderbook({}, exchange)
|
||||
is_valid = components['processor'].validate_data(processed_data)
|
||||
assert is_valid
|
||||
|
||||
await components['storage'].store_orderbook(processed_data)
|
||||
await components['cache'].set(f"orderbook:BTCUSDT:{exchange}", processed_data)
|
||||
|
||||
# Verify multiple calls
|
||||
assert components['processor'].process_orderbook.call_count == len(exchanges)
|
||||
assert components['storage'].store_orderbook.call_count == len(exchanges)
|
||||
assert components['cache'].set.call_count == len(exchanges)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling(self, mock_components, sample_orderbook):
|
||||
"""Test pipeline error handling and recovery"""
|
||||
components = mock_components
|
||||
|
||||
# Setup storage to fail initially
|
||||
components['storage'].store_orderbook.side_effect = [
|
||||
Exception("Database connection failed"),
|
||||
True # Success on retry
|
||||
]
|
||||
|
||||
components['processor'].process_orderbook.return_value = sample_orderbook
|
||||
components['processor'].validate_data.return_value = True
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(Exception):
|
||||
await components['storage'].store_orderbook(sample_orderbook)
|
||||
|
||||
# Second attempt should succeed
|
||||
result = await components['storage'].store_orderbook(sample_orderbook)
|
||||
assert result is True
|
||||
|
||||
# Verify retry logic
|
||||
assert components['storage'].store_orderbook.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_performance(self, mock_components):
|
||||
"""Test pipeline performance with high throughput"""
|
||||
components = mock_components
|
||||
|
||||
# Setup fast responses
|
||||
components['processor'].process_orderbook.return_value = Mock()
|
||||
components['processor'].validate_data.return_value = True
|
||||
components['storage'].store_orderbook.return_value = True
|
||||
components['cache'].set.return_value = True
|
||||
|
||||
# Process multiple items quickly
|
||||
start_time = time.time()
|
||||
tasks = []
|
||||
|
||||
for i in range(100):
|
||||
# Simulate processing 100 order books
|
||||
task = asyncio.create_task(self._process_single_orderbook(components, i))
|
||||
tasks.append(task)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
throughput = 100 / processing_time
|
||||
|
||||
# Should process at least 50 items per second
|
||||
assert throughput > 50, f"Throughput too low: {throughput:.2f} items/sec"
|
||||
|
||||
# Verify all items were processed
|
||||
assert components['processor'].process_orderbook.call_count == 100
|
||||
assert components['storage'].store_orderbook.call_count == 100
|
||||
|
||||
async def _process_single_orderbook(self, components, index):
|
||||
"""Helper method to process a single order book"""
|
||||
raw_data = {"symbol": "BTCUSDT", "index": index}
|
||||
|
||||
processed_data = components['processor'].process_orderbook(raw_data, "binance")
|
||||
is_valid = components['processor'].validate_data(processed_data)
|
||||
|
||||
if is_valid:
|
||||
await components['storage'].store_orderbook(processed_data)
|
||||
await components['cache'].set(f"orderbook:BTCUSDT:binance:{index}", processed_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_consistency_across_pipeline(self, mock_components, sample_orderbook):
|
||||
"""Test data consistency throughout the pipeline"""
|
||||
components = mock_components
|
||||
|
||||
# Track data transformations
|
||||
original_data = {"symbol": "BTCUSDT", "timestamp": "2024-01-01T00:00:00Z"}
|
||||
|
||||
# Setup processor to modify data
|
||||
modified_orderbook = sample_orderbook
|
||||
modified_orderbook.symbol = "BTCUSDT" # Ensure consistency
|
||||
components['processor'].process_orderbook.return_value = modified_orderbook
|
||||
components['processor'].validate_data.return_value = True
|
||||
|
||||
# Process data
|
||||
processed_data = components['processor'].process_orderbook(original_data, "binance")
|
||||
|
||||
# Verify data consistency
|
||||
assert processed_data.symbol == "BTCUSDT"
|
||||
assert processed_data.exchange == "binance"
|
||||
assert len(processed_data.bids) > 0
|
||||
assert len(processed_data.asks) > 0
|
||||
|
||||
# Verify all price levels are valid
|
||||
for bid in processed_data.bids:
|
||||
assert bid.price > 0
|
||||
assert bid.size > 0
|
||||
|
||||
for ask in processed_data.asks:
|
||||
assert ask.price > 0
|
||||
assert ask.size > 0
|
||||
|
||||
# Verify bid/ask ordering
|
||||
bid_prices = [bid.price for bid in processed_data.bids]
|
||||
ask_prices = [ask.price for ask in processed_data.asks]
|
||||
|
||||
assert bid_prices == sorted(bid_prices, reverse=True) # Bids descending
|
||||
assert ask_prices == sorted(ask_prices) # Asks ascending
|
||||
|
||||
# Verify spread is positive
|
||||
if bid_prices and ask_prices:
|
||||
spread = min(ask_prices) - max(bid_prices)
|
||||
assert spread >= 0, f"Negative spread detected: {spread}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_memory_usage(self, mock_components):
|
||||
"""Test pipeline memory usage under load"""
|
||||
import psutil
|
||||
import gc
|
||||
|
||||
components = mock_components
|
||||
process = psutil.Process()
|
||||
|
||||
# Get initial memory usage
|
||||
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
# Process large amount of data
|
||||
for i in range(1000):
|
||||
orderbook = OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange="binance",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[PriceLevel(price=50000.0 + i, size=1.0)],
|
||||
asks=[PriceLevel(price=50010.0 + i, size=1.0)]
|
||||
)
|
||||
|
||||
components['processor'].process_orderbook.return_value = orderbook
|
||||
components['processor'].validate_data.return_value = True
|
||||
|
||||
# Process data
|
||||
processed_data = components['processor'].process_orderbook({}, "binance")
|
||||
await components['storage'].store_orderbook(processed_data)
|
||||
|
||||
# Force garbage collection every 100 items
|
||||
if i % 100 == 0:
|
||||
gc.collect()
|
||||
|
||||
# Get final memory usage
|
||||
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# Memory increase should be reasonable (less than 100MB for 1000 items)
|
||||
assert memory_increase < 100, f"Memory usage increased by {memory_increase:.2f}MB"
|
||||
|
||||
logger.info(f"Memory usage: {initial_memory:.2f}MB -> {final_memory:.2f}MB (+{memory_increase:.2f}MB)")
|
||||
|
||||
|
||||
class TestPipelineResilience:
|
||||
"""Test pipeline resilience and fault tolerance"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_reconnection(self):
|
||||
"""Test database reconnection handling"""
|
||||
storage = Mock(spec=TimescaleManager)
|
||||
|
||||
# Simulate connection failure then recovery
|
||||
storage.is_connected.side_effect = [False, False, True]
|
||||
storage.connect.return_value = True
|
||||
storage.store_orderbook.return_value = True
|
||||
|
||||
# Should attempt reconnection
|
||||
for attempt in range(3):
|
||||
if not storage.is_connected():
|
||||
storage.connect()
|
||||
else:
|
||||
break
|
||||
|
||||
assert storage.connect.call_count == 1
|
||||
assert storage.is_connected.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_fallback(self):
|
||||
"""Test cache fallback when Redis is unavailable"""
|
||||
cache = Mock(spec=RedisManager)
|
||||
|
||||
# Simulate cache failure
|
||||
cache.is_connected.return_value = False
|
||||
cache.set.side_effect = Exception("Redis connection failed")
|
||||
|
||||
# Should handle cache failure gracefully
|
||||
try:
|
||||
await cache.set("test_key", "test_value")
|
||||
except Exception:
|
||||
# Should continue processing even if cache fails
|
||||
pass
|
||||
|
||||
assert not cache.is_connected()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_failover(self):
|
||||
"""Test exchange failover when one exchange fails"""
|
||||
exchanges = ["binance", "coinbase", "kraken"]
|
||||
failed_exchange = "binance"
|
||||
|
||||
# Simulate one exchange failing
|
||||
for exchange in exchanges:
|
||||
if exchange == failed_exchange:
|
||||
# This exchange fails
|
||||
assert exchange == failed_exchange
|
||||
else:
|
||||
# Other exchanges continue working
|
||||
assert exchange != failed_exchange
|
||||
|
||||
# Should continue with remaining exchanges
|
||||
working_exchanges = [ex for ex in exchanges if ex != failed_exchange]
|
||||
assert len(working_exchanges) == 2
|
||||
assert "coinbase" in working_exchanges
|
||||
assert "kraken" in working_exchanges
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRealDataPipeline:
|
||||
"""Integration tests with real components (requires running services)"""
|
||||
|
||||
@pytest.mark.skipif(not pytest.config.getoption("--integration"),
|
||||
reason="Integration tests require --integration flag")
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_database_integration(self):
|
||||
"""Test with real TimescaleDB instance"""
|
||||
# This test requires a running TimescaleDB instance
|
||||
# Skip if not available
|
||||
try:
|
||||
from ..storage.timescale_manager import TimescaleManager
|
||||
|
||||
storage = TimescaleManager()
|
||||
await storage.connect()
|
||||
|
||||
# Test basic operations
|
||||
assert storage.is_connected()
|
||||
|
||||
# Create test data
|
||||
orderbook = OrderBookSnapshot(
|
||||
symbol="BTCUSDT",
|
||||
exchange="test",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||
)
|
||||
|
||||
# Store and verify
|
||||
result = await storage.store_orderbook(orderbook)
|
||||
assert result is True
|
||||
|
||||
await storage.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Real database not available: {e}")
|
||||
|
||||
@pytest.mark.skipif(not pytest.config.getoption("--integration"),
|
||||
reason="Integration tests require --integration flag")
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_cache_integration(self):
|
||||
"""Test with real Redis instance"""
|
||||
try:
|
||||
from ..caching.redis_manager import RedisManager
|
||||
|
||||
cache = RedisManager()
|
||||
await cache.connect()
|
||||
|
||||
assert cache.is_connected()
|
||||
|
||||
# Test basic operations
|
||||
await cache.set("test_key", {"test": "data"})
|
||||
result = await cache.get("test_key")
|
||||
assert result is not None
|
||||
|
||||
await cache.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Real cache not available: {e}")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with custom markers"""
|
||||
config.addinivalue_line("markers", "integration: mark test as integration test")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run integration tests with real services"
|
||||
)
|
590
COBY/tests/test_load_performance.py
Normal file
590
COBY/tests/test_load_performance.py
Normal file
@ -0,0 +1,590 @@
|
||||
"""
|
||||
Load testing and performance benchmarks for high-frequency data scenarios.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
import statistics
|
||||
from datetime import datetime, timezone
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Dict, Any
|
||||
import psutil
|
||||
import gc
|
||||
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||
from ..connectors.binance_connector import BinanceConnector
|
||||
from ..processing.data_processor import DataProcessor
|
||||
from ..aggregation.aggregation_engine import AggregationEngine
|
||||
from ..monitoring.metrics_collector import MetricsCollector
|
||||
from ..monitoring.latency_tracker import LatencyTracker
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LoadTestConfig:
|
||||
"""Configuration for load tests"""
|
||||
|
||||
# Test parameters
|
||||
DURATION_SECONDS = 60
|
||||
TARGET_TPS = 1000 # Transactions per second
|
||||
RAMP_UP_SECONDS = 10
|
||||
|
||||
# Performance thresholds
|
||||
MAX_LATENCY_MS = 100
|
||||
MAX_MEMORY_MB = 500
|
||||
MIN_SUCCESS_RATE = 99.0
|
||||
|
||||
# Data generation
|
||||
SYMBOLS = ["BTCUSDT", "ETHUSDT", "ADAUSDT", "DOTUSDT"]
|
||||
EXCHANGES = ["binance", "coinbase", "kraken", "bybit"]
|
||||
|
||||
|
||||
class DataGenerator:
|
||||
"""Generate realistic test data for load testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_prices = {
|
||||
"BTCUSDT": 50000.0,
|
||||
"ETHUSDT": 3000.0,
|
||||
"ADAUSDT": 1.0,
|
||||
"DOTUSDT": 25.0
|
||||
}
|
||||
self.counter = 0
|
||||
|
||||
def generate_orderbook(self, symbol: str, exchange: str) -> OrderBookSnapshot:
|
||||
"""Generate realistic order book data"""
|
||||
base_price = self.base_prices.get(symbol, 100.0)
|
||||
|
||||
# Add some randomness
|
||||
price_variation = (self.counter % 100) * 0.01
|
||||
mid_price = base_price + price_variation
|
||||
|
||||
# Generate bids (below mid price)
|
||||
bids = []
|
||||
for i in range(10):
|
||||
price = mid_price - (i + 1) * 0.1
|
||||
size = 1.0 + (i * 0.1)
|
||||
bids.append(PriceLevel(price=price, size=size))
|
||||
|
||||
# Generate asks (above mid price)
|
||||
asks = []
|
||||
for i in range(10):
|
||||
price = mid_price + (i + 1) * 0.1
|
||||
size = 1.0 + (i * 0.1)
|
||||
asks.append(PriceLevel(price=price, size=size))
|
||||
|
||||
self.counter += 1
|
||||
|
||||
return OrderBookSnapshot(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
bids=bids,
|
||||
asks=asks
|
||||
)
|
||||
|
||||
def generate_trade(self, symbol: str, exchange: str) -> TradeEvent:
|
||||
"""Generate realistic trade data"""
|
||||
base_price = self.base_prices.get(symbol, 100.0)
|
||||
price_variation = (self.counter % 50) * 0.01
|
||||
price = base_price + price_variation
|
||||
|
||||
self.counter += 1
|
||||
|
||||
return TradeEvent(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
price=price,
|
||||
size=0.1 + (self.counter % 10) * 0.01,
|
||||
side="buy" if self.counter % 2 == 0 else "sell",
|
||||
trade_id=str(self.counter)
|
||||
)
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Monitor performance during load tests"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.latencies = []
|
||||
self.errors = []
|
||||
self.memory_samples = []
|
||||
self.cpu_samples = []
|
||||
self.process = psutil.Process()
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring"""
|
||||
self.start_time = time.time()
|
||||
self.latencies.clear()
|
||||
self.errors.clear()
|
||||
self.memory_samples.clear()
|
||||
self.cpu_samples.clear()
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
self.end_time = time.time()
|
||||
|
||||
def record_latency(self, latency_ms: float):
|
||||
"""Record operation latency"""
|
||||
self.latencies.append(latency_ms)
|
||||
|
||||
def record_error(self, error: Exception):
|
||||
"""Record error"""
|
||||
self.errors.append(str(error))
|
||||
|
||||
def sample_system_metrics(self):
|
||||
"""Sample system metrics"""
|
||||
try:
|
||||
memory_mb = self.process.memory_info().rss / 1024 / 1024
|
||||
cpu_percent = self.process.cpu_percent()
|
||||
|
||||
self.memory_samples.append(memory_mb)
|
||||
self.cpu_samples.append(cpu_percent)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sampling system metrics: {e}")
|
||||
|
||||
def get_results(self) -> Dict[str, Any]:
|
||||
"""Get performance test results"""
|
||||
duration = self.end_time - self.start_time if self.end_time else 0
|
||||
total_operations = len(self.latencies)
|
||||
|
||||
results = {
|
||||
'duration_seconds': duration,
|
||||
'total_operations': total_operations,
|
||||
'operations_per_second': total_operations / duration if duration > 0 else 0,
|
||||
'error_count': len(self.errors),
|
||||
'success_rate': ((total_operations - len(self.errors)) / total_operations * 100) if total_operations > 0 else 0,
|
||||
'latency': {
|
||||
'min_ms': min(self.latencies) if self.latencies else 0,
|
||||
'max_ms': max(self.latencies) if self.latencies else 0,
|
||||
'avg_ms': statistics.mean(self.latencies) if self.latencies else 0,
|
||||
'p50_ms': statistics.median(self.latencies) if self.latencies else 0,
|
||||
'p95_ms': self._percentile(self.latencies, 95) if self.latencies else 0,
|
||||
'p99_ms': self._percentile(self.latencies, 99) if self.latencies else 0
|
||||
},
|
||||
'memory': {
|
||||
'min_mb': min(self.memory_samples) if self.memory_samples else 0,
|
||||
'max_mb': max(self.memory_samples) if self.memory_samples else 0,
|
||||
'avg_mb': statistics.mean(self.memory_samples) if self.memory_samples else 0
|
||||
},
|
||||
'cpu': {
|
||||
'min_percent': min(self.cpu_samples) if self.cpu_samples else 0,
|
||||
'max_percent': max(self.cpu_samples) if self.cpu_samples else 0,
|
||||
'avg_percent': statistics.mean(self.cpu_samples) if self.cpu_samples else 0
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def _percentile(self, data: List[float], percentile: int) -> float:
|
||||
"""Calculate percentile"""
|
||||
if not data:
|
||||
return 0.0
|
||||
sorted_data = sorted(data)
|
||||
index = int((percentile / 100.0) * len(sorted_data))
|
||||
index = min(index, len(sorted_data) - 1)
|
||||
return sorted_data[index]
|
||||
|
||||
|
||||
@pytest.mark.load
|
||||
class TestLoadPerformance:
|
||||
"""Load testing and performance benchmarks"""
|
||||
|
||||
@pytest.fixture
|
||||
def data_generator(self):
|
||||
"""Create data generator"""
|
||||
return DataGenerator()
|
||||
|
||||
@pytest.fixture
|
||||
def performance_monitor(self):
|
||||
"""Create performance monitor"""
|
||||
return PerformanceMonitor()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orderbook_processing_load(self, data_generator, performance_monitor):
|
||||
"""Test order book processing under high load"""
|
||||
processor = DataProcessor()
|
||||
monitor = performance_monitor
|
||||
|
||||
monitor.start()
|
||||
|
||||
# Generate load
|
||||
tasks = []
|
||||
for i in range(LoadTestConfig.TARGET_TPS):
|
||||
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._process_orderbook_with_timing(
|
||||
processor, data_generator, symbol, exchange, monitor
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Add small delay to simulate realistic load
|
||||
if i % 100 == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
monitor.stop()
|
||||
results = monitor.get_results()
|
||||
|
||||
# Verify performance requirements
|
||||
assert results['operations_per_second'] >= LoadTestConfig.TARGET_TPS * 0.8, \
|
||||
f"Throughput too low: {results['operations_per_second']:.2f} ops/sec"
|
||||
|
||||
assert results['latency']['p95_ms'] <= LoadTestConfig.MAX_LATENCY_MS, \
|
||||
f"P95 latency too high: {results['latency']['p95_ms']:.2f}ms"
|
||||
|
||||
assert results['success_rate'] >= LoadTestConfig.MIN_SUCCESS_RATE, \
|
||||
f"Success rate too low: {results['success_rate']:.2f}%"
|
||||
|
||||
logger.info(f"Load test results: {results}")
|
||||
|
||||
async def _process_orderbook_with_timing(self, processor, data_generator,
|
||||
symbol, exchange, monitor):
|
||||
"""Process order book with timing measurement"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Generate and process order book
|
||||
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||
processed = processor.normalize_orderbook(orderbook.__dict__, exchange)
|
||||
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
|
||||
monitor.record_latency(latency_ms)
|
||||
monitor.sample_system_metrics()
|
||||
|
||||
except Exception as e:
|
||||
monitor.record_error(e)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_processing_load(self, data_generator, performance_monitor):
|
||||
"""Test trade processing under high load"""
|
||||
processor = DataProcessor()
|
||||
monitor = performance_monitor
|
||||
|
||||
monitor.start()
|
||||
|
||||
# Generate sustained load for specified duration
|
||||
end_time = time.time() + LoadTestConfig.DURATION_SECONDS
|
||||
operation_count = 0
|
||||
|
||||
while time.time() < end_time:
|
||||
symbol = LoadTestConfig.SYMBOLS[operation_count % len(LoadTestConfig.SYMBOLS)]
|
||||
exchange = LoadTestConfig.EXCHANGES[operation_count % len(LoadTestConfig.EXCHANGES)]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Generate and process trade
|
||||
trade = data_generator.generate_trade(symbol, exchange)
|
||||
processed = processor.normalize_trade(trade.__dict__, exchange)
|
||||
|
||||
process_time = time.time()
|
||||
latency_ms = (process_time - start_time) * 1000
|
||||
|
||||
monitor.record_latency(latency_ms)
|
||||
|
||||
if operation_count % 100 == 0:
|
||||
monitor.sample_system_metrics()
|
||||
|
||||
operation_count += 1
|
||||
|
||||
# Control rate to avoid overwhelming
|
||||
await asyncio.sleep(0.001) # 1ms delay
|
||||
|
||||
except Exception as e:
|
||||
monitor.record_error(e)
|
||||
|
||||
monitor.stop()
|
||||
results = monitor.get_results()
|
||||
|
||||
# Verify performance
|
||||
assert results['latency']['avg_ms'] <= LoadTestConfig.MAX_LATENCY_MS, \
|
||||
f"Average latency too high: {results['latency']['avg_ms']:.2f}ms"
|
||||
|
||||
assert results['memory']['max_mb'] <= LoadTestConfig.MAX_MEMORY_MB, \
|
||||
f"Memory usage too high: {results['memory']['max_mb']:.2f}MB"
|
||||
|
||||
logger.info(f"Trade processing results: {results}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_performance(self, data_generator, performance_monitor):
|
||||
"""Test aggregation engine performance"""
|
||||
aggregator = AggregationEngine()
|
||||
monitor = performance_monitor
|
||||
|
||||
monitor.start()
|
||||
|
||||
# Generate multiple order books for aggregation
|
||||
orderbooks = []
|
||||
for i in range(100):
|
||||
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||
orderbooks.append(orderbook)
|
||||
|
||||
# Test aggregation performance
|
||||
start_time = time.time()
|
||||
|
||||
for orderbook in orderbooks:
|
||||
try:
|
||||
# Test price bucketing
|
||||
buckets = aggregator.create_price_buckets(orderbook)
|
||||
|
||||
# Test heatmap generation
|
||||
heatmap = aggregator.generate_heatmap(buckets)
|
||||
|
||||
# Test metrics calculation
|
||||
metrics = aggregator.calculate_metrics(orderbook)
|
||||
|
||||
process_time = time.time()
|
||||
latency_ms = (process_time - start_time) * 1000
|
||||
monitor.record_latency(latency_ms)
|
||||
|
||||
start_time = process_time
|
||||
|
||||
except Exception as e:
|
||||
monitor.record_error(e)
|
||||
|
||||
monitor.stop()
|
||||
results = monitor.get_results()
|
||||
|
||||
# Verify aggregation performance
|
||||
assert results['latency']['p95_ms'] <= 50, \
|
||||
f"Aggregation P95 latency too high: {results['latency']['p95_ms']:.2f}ms"
|
||||
|
||||
logger.info(f"Aggregation performance results: {results}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_exchange_processing(self, data_generator, performance_monitor):
|
||||
"""Test concurrent processing from multiple exchanges"""
|
||||
processor = DataProcessor()
|
||||
monitor = performance_monitor
|
||||
|
||||
monitor.start()
|
||||
|
||||
# Create concurrent tasks for each exchange
|
||||
tasks = []
|
||||
for exchange in LoadTestConfig.EXCHANGES:
|
||||
task = asyncio.create_task(
|
||||
self._simulate_exchange_load(processor, data_generator, exchange, monitor)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Run all exchanges concurrently
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
monitor.stop()
|
||||
results = monitor.get_results()
|
||||
|
||||
# Verify concurrent processing performance
|
||||
expected_total_ops = len(LoadTestConfig.EXCHANGES) * 100 # 100 ops per exchange
|
||||
assert results['total_operations'] >= expected_total_ops * 0.9, \
|
||||
f"Not enough operations completed: {results['total_operations']}"
|
||||
|
||||
assert results['success_rate'] >= 95.0, \
|
||||
f"Success rate too low under concurrent load: {results['success_rate']:.2f}%"
|
||||
|
||||
logger.info(f"Concurrent processing results: {results}")
|
||||
|
||||
async def _simulate_exchange_load(self, processor, data_generator, exchange, monitor):
|
||||
"""Simulate load from a single exchange"""
|
||||
for i in range(100):
|
||||
try:
|
||||
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Alternate between order books and trades
|
||||
if i % 2 == 0:
|
||||
data = data_generator.generate_orderbook(symbol, exchange)
|
||||
processed = processor.normalize_orderbook(data.__dict__, exchange)
|
||||
else:
|
||||
data = data_generator.generate_trade(symbol, exchange)
|
||||
processed = processor.normalize_trade(data.__dict__, exchange)
|
||||
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
|
||||
monitor.record_latency(latency_ms)
|
||||
|
||||
# Small delay to simulate realistic timing
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
except Exception as e:
|
||||
monitor.record_error(e)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_under_load(self, data_generator):
|
||||
"""Test memory usage patterns under sustained load"""
|
||||
processor = DataProcessor()
|
||||
process = psutil.Process()
|
||||
|
||||
# Get baseline memory
|
||||
gc.collect() # Force garbage collection
|
||||
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
memory_samples = []
|
||||
|
||||
# Generate sustained load
|
||||
for i in range(1000):
|
||||
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||
|
||||
# Generate data
|
||||
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||
trade = data_generator.generate_trade(symbol, exchange)
|
||||
|
||||
# Process data
|
||||
processor.normalize_orderbook(orderbook.__dict__, exchange)
|
||||
processor.normalize_trade(trade.__dict__, exchange)
|
||||
|
||||
# Sample memory every 100 operations
|
||||
if i % 100 == 0:
|
||||
current_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_samples.append(current_memory)
|
||||
|
||||
# Force garbage collection periodically
|
||||
if i % 500 == 0:
|
||||
gc.collect()
|
||||
|
||||
# Final memory check
|
||||
gc.collect()
|
||||
final_memory = process.memory_info().rss / 1024 / 1024
|
||||
|
||||
# Calculate memory statistics
|
||||
max_memory = max(memory_samples) if memory_samples else final_memory
|
||||
memory_growth = final_memory - baseline_memory
|
||||
|
||||
# Verify memory usage is reasonable
|
||||
assert memory_growth < 100, \
|
||||
f"Memory growth too high: {memory_growth:.2f}MB"
|
||||
|
||||
assert max_memory < baseline_memory + 200, \
|
||||
f"Peak memory usage too high: {max_memory:.2f}MB"
|
||||
|
||||
logger.info(f"Memory usage: baseline={baseline_memory:.2f}MB, "
|
||||
f"final={final_memory:.2f}MB, growth={memory_growth:.2f}MB, "
|
||||
f"peak={max_memory:.2f}MB")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stress_test_extreme_load(self, data_generator, performance_monitor):
|
||||
"""Stress test with extreme load conditions"""
|
||||
processor = DataProcessor()
|
||||
monitor = performance_monitor
|
||||
|
||||
# Extreme load parameters
|
||||
EXTREME_TPS = 5000
|
||||
STRESS_DURATION = 30 # seconds
|
||||
|
||||
monitor.start()
|
||||
|
||||
# Generate extreme load
|
||||
tasks = []
|
||||
operations_per_batch = 100
|
||||
batches = EXTREME_TPS // operations_per_batch
|
||||
|
||||
for batch in range(batches):
|
||||
batch_tasks = []
|
||||
for i in range(operations_per_batch):
|
||||
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._process_orderbook_with_timing(
|
||||
processor, data_generator, symbol, exchange, monitor
|
||||
)
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Process batch
|
||||
await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Small delay between batches
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
monitor.stop()
|
||||
results = monitor.get_results()
|
||||
|
||||
# Under extreme load, we accept lower performance but system should remain stable
|
||||
assert results['success_rate'] >= 80.0, \
|
||||
f"System failed under stress: {results['success_rate']:.2f}% success rate"
|
||||
|
||||
assert results['latency']['p99_ms'] <= 500, \
|
||||
f"P99 latency too high under stress: {results['latency']['p99_ms']:.2f}ms"
|
||||
|
||||
logger.info(f"Stress test results: {results}")
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestPerformanceBenchmarks:
|
||||
"""Performance benchmarks for regression testing"""
|
||||
|
||||
def test_orderbook_processing_benchmark(self, benchmark):
|
||||
"""Benchmark order book processing speed"""
|
||||
processor = DataProcessor()
|
||||
generator = DataGenerator()
|
||||
|
||||
def process_orderbook():
|
||||
orderbook = generator.generate_orderbook("BTCUSDT", "binance")
|
||||
return processor.normalize_orderbook(orderbook.__dict__, "binance")
|
||||
|
||||
result = benchmark(process_orderbook)
|
||||
assert result is not None
|
||||
|
||||
def test_trade_processing_benchmark(self, benchmark):
|
||||
"""Benchmark trade processing speed"""
|
||||
processor = DataProcessor()
|
||||
generator = DataGenerator()
|
||||
|
||||
def process_trade():
|
||||
trade = generator.generate_trade("BTCUSDT", "binance")
|
||||
return processor.normalize_trade(trade.__dict__, "binance")
|
||||
|
||||
result = benchmark(process_trade)
|
||||
assert result is not None
|
||||
|
||||
def test_aggregation_benchmark(self, benchmark):
|
||||
"""Benchmark aggregation engine performance"""
|
||||
aggregator = AggregationEngine()
|
||||
generator = DataGenerator()
|
||||
|
||||
def aggregate_data():
|
||||
orderbook = generator.generate_orderbook("BTCUSDT", "binance")
|
||||
buckets = aggregator.create_price_buckets(orderbook)
|
||||
return aggregator.generate_heatmap(buckets)
|
||||
|
||||
result = benchmark(aggregate_data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with custom markers"""
|
||||
config.addinivalue_line("markers", "load: mark test as load test")
|
||||
config.addinivalue_line("markers", "benchmark: mark test as benchmark")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--load",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run load tests"
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run benchmark tests"
|
||||
)
|
108
COBY/tests/test_performance_benchmarks.py
Normal file
108
COBY/tests/test_performance_benchmarks.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""
|
||||
Performance benchmarks and regression tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import statistics
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||
from ..processing.data_processor import DataProcessor
|
||||
from ..aggregation.aggregation_engine import AggregationEngine
|
||||
from ..connectors.binance_connector import BinanceConnector
|
||||
from ..storage.timescale_manager import TimescaleManager
|
||||
from ..caching.redis_manager import RedisManager
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Benchmark result data structure"""
|
||||
name: str
|
||||
duration_ms: float
|
||||
operations_per_second: float
|
||||
memory_usage_mb: float
|
||||
cpu_usage_percent: float
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
"""Benchmark execution and result management"""
|
||||
|
||||
def __init__(self, results_file: str = "benchmark_results.json"):
|
||||
self.results_file = Path(results_file)
|
||||
self.results: List[BenchmarkResult] = []
|
||||
self.load_historical_results()
|
||||
|
||||
def load_historical_results(self):
|
||||
"""Load historical benchmark results"""
|
||||
if self.results_file.exists():
|
||||
try:
|
||||
with open(self.results_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
for item in data:
|
||||
result = BenchmarkResult(
|
||||
name=item['name'],
|
||||
duration_ms=item['duration_ms'],
|
||||
operations_per_second=item['operations_per_second'],
|
||||
memory_usage_mb=item['memory_usage_mb'],
|
||||
cpu_usage_percent=item['cpu_usage_percent'],
|
||||
timestamp=datetime.fromisoformat(item['timestamp']),
|
||||
metadata=item.get('metadata', {})
|
||||
)
|
||||
self.results.append(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load historical results: {e}")
|
||||
|
||||
def save_results(self):
|
||||
"""Save benchmark results to file"""
|
||||
try:
|
||||
data = []
|
||||
for result in self.results:
|
||||
data.append({
|
||||
'name': result.name,
|
||||
'duration_ms': result.duration_ms,
|
||||
'operations_per_second': result.operations_per_second,
|
||||
'memory_usage_mb': result.memory_usage_mb,
|
||||
'cpu_usage_percent': result.cpu_usage_percent,
|
||||
'timestamp': result.timestamp.isoformat(),
|
||||
'metadata': result.metadata or {}
|
||||
})
|
||||
|
||||
with open(self.results_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save benchmark results: {e}")
|
||||
|
||||
def run_benchmark(self, name: str, func, iterations: int = 1000,
|
||||
warmup: int = 100) -> BenchmarkResult:
|
||||
"""Run a benchmark function"""
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
func()
|
||||
|
||||
# Collect baseline metrics
|
||||
initial_memory = process.memory_info().rss / 1024 / 1024
|
||||
initial_cpu = process.cpu_percent()
|
||||
|
||||
# Run benchmark
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(iterations):
|
||||
func()
|
||||
|
||||
end_time =
|
@ -1,57 +1,23 @@
|
||||
"""
|
||||
Custom exceptions for the COBY system.
|
||||
Custom exceptions for COBY system.
|
||||
"""
|
||||
|
||||
|
||||
class COBYException(Exception):
|
||||
"""Base exception for COBY system"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = None, details: dict = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert exception to dictionary"""
|
||||
return {
|
||||
'error': self.__class__.__name__,
|
||||
'message': self.message,
|
||||
'error_code': self.error_code,
|
||||
'details': self.details
|
||||
}
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(COBYException):
|
||||
"""Exception raised for connection-related errors"""
|
||||
"""Connection related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(COBYException):
|
||||
"""Exception raised for data validation errors"""
|
||||
"""Data validation errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ProcessingError(COBYException):
|
||||
"""Exception raised for data processing errors"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageError(COBYException):
|
||||
"""Exception raised for storage-related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(COBYException):
|
||||
"""Exception raised for configuration errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ReplayError(COBYException):
|
||||
"""Exception raised for replay-related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AggregationError(COBYException):
|
||||
"""Exception raised for aggregation errors"""
|
||||
"""Data processing errors"""
|
||||
pass
|
@ -1,206 +1,15 @@
|
||||
"""
|
||||
Timing utilities for the COBY system.
|
||||
Timing utilities for COBY system.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_current_timestamp() -> datetime:
|
||||
"""
|
||||
Get current UTC timestamp.
|
||||
|
||||
Returns:
|
||||
datetime: Current UTC timestamp
|
||||
"""
|
||||
"""Get current UTC timestamp"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def format_timestamp(timestamp: datetime, format_str: str = "%Y-%m-%d %H:%M:%S.%f") -> str:
|
||||
"""
|
||||
Format timestamp to string.
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp to format
|
||||
format_str: Format string
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string
|
||||
"""
|
||||
return timestamp.strftime(format_str)
|
||||
|
||||
|
||||
def parse_timestamp(timestamp_str: str, format_str: str = "%Y-%m-%d %H:%M:%S.%f") -> datetime:
|
||||
"""
|
||||
Parse timestamp string to datetime.
|
||||
|
||||
Args:
|
||||
timestamp_str: Timestamp string to parse
|
||||
format_str: Format string
|
||||
|
||||
Returns:
|
||||
datetime: Parsed timestamp
|
||||
"""
|
||||
dt = datetime.strptime(timestamp_str, format_str)
|
||||
# Ensure timezone awareness
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def timestamp_to_unix(timestamp: datetime) -> float:
|
||||
"""
|
||||
Convert datetime to Unix timestamp.
|
||||
|
||||
Args:
|
||||
timestamp: Datetime to convert
|
||||
|
||||
Returns:
|
||||
float: Unix timestamp
|
||||
"""
|
||||
return timestamp.timestamp()
|
||||
|
||||
|
||||
def unix_to_timestamp(unix_time: float) -> datetime:
|
||||
"""
|
||||
Convert Unix timestamp to datetime.
|
||||
|
||||
Args:
|
||||
unix_time: Unix timestamp
|
||||
|
||||
Returns:
|
||||
datetime: Converted datetime (UTC)
|
||||
"""
|
||||
return datetime.fromtimestamp(unix_time, tz=timezone.utc)
|
||||
|
||||
|
||||
def calculate_time_diff(start: datetime, end: datetime) -> float:
|
||||
"""
|
||||
Calculate time difference in seconds.
|
||||
|
||||
Args:
|
||||
start: Start timestamp
|
||||
end: End timestamp
|
||||
|
||||
Returns:
|
||||
float: Time difference in seconds
|
||||
"""
|
||||
return (end - start).total_seconds()
|
||||
|
||||
|
||||
def is_timestamp_recent(timestamp: datetime, max_age_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Check if timestamp is recent (within max_age_seconds).
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp to check
|
||||
max_age_seconds: Maximum age in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if recent, False otherwise
|
||||
"""
|
||||
now = get_current_timestamp()
|
||||
age = calculate_time_diff(timestamp, now)
|
||||
return age <= max_age_seconds
|
||||
|
||||
|
||||
def sleep_until(target_time: datetime) -> None:
|
||||
"""
|
||||
Sleep until target time.
|
||||
|
||||
Args:
|
||||
target_time: Target timestamp to sleep until
|
||||
"""
|
||||
now = get_current_timestamp()
|
||||
sleep_seconds = calculate_time_diff(now, target_time)
|
||||
|
||||
if sleep_seconds > 0:
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
|
||||
def get_milliseconds() -> int:
|
||||
"""
|
||||
Get current timestamp in milliseconds.
|
||||
|
||||
Returns:
|
||||
int: Current timestamp in milliseconds
|
||||
"""
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def milliseconds_to_timestamp(ms: int) -> datetime:
|
||||
"""
|
||||
Convert milliseconds to datetime.
|
||||
|
||||
Args:
|
||||
ms: Milliseconds timestamp
|
||||
|
||||
Returns:
|
||||
datetime: Converted datetime (UTC)
|
||||
"""
|
||||
return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
|
||||
|
||||
|
||||
def round_timestamp(timestamp: datetime, seconds: int) -> datetime:
|
||||
"""
|
||||
Round timestamp to nearest interval.
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp to round
|
||||
seconds: Interval in seconds
|
||||
|
||||
Returns:
|
||||
datetime: Rounded timestamp
|
||||
"""
|
||||
unix_time = timestamp_to_unix(timestamp)
|
||||
rounded_unix = round(unix_time / seconds) * seconds
|
||||
return unix_to_timestamp(rounded_unix)
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Simple timer for measuring execution time"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time: Optional[float] = None
|
||||
self.end_time: Optional[float] = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the timer"""
|
||||
self.start_time = time.perf_counter()
|
||||
self.end_time = None
|
||||
|
||||
def stop(self) -> float:
|
||||
"""
|
||||
Stop the timer and return elapsed time.
|
||||
|
||||
Returns:
|
||||
float: Elapsed time in seconds
|
||||
"""
|
||||
if self.start_time is None:
|
||||
raise ValueError("Timer not started")
|
||||
|
||||
self.end_time = time.perf_counter()
|
||||
return self.elapsed()
|
||||
|
||||
def elapsed(self) -> float:
|
||||
"""
|
||||
Get elapsed time.
|
||||
|
||||
Returns:
|
||||
float: Elapsed time in seconds
|
||||
"""
|
||||
if self.start_time is None:
|
||||
return 0.0
|
||||
|
||||
end = self.end_time or time.perf_counter()
|
||||
return end - self.start_time
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.stop()
|
||||
def format_timestamp(timestamp: datetime) -> str:
|
||||
"""Format timestamp as ISO string"""
|
||||
return timestamp.isoformat()
|
@ -1,217 +1,31 @@
|
||||
"""
|
||||
Data validation utilities for the COBY system.
|
||||
Validation utilities for COBY system.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any
|
||||
|
||||
|
||||
def validate_symbol(symbol: str) -> bool:
|
||||
"""
|
||||
Validate trading symbol format.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
"""Validate trading symbol format"""
|
||||
if not symbol or not isinstance(symbol, str):
|
||||
return False
|
||||
|
||||
# Basic symbol format validation (e.g., BTCUSDT, ETH-USD)
|
||||
pattern = r'^[A-Z0-9]{2,10}[-/]?[A-Z0-9]{2,10}$'
|
||||
# Basic symbol validation (letters and numbers, 3-12 chars)
|
||||
pattern = r'^[A-Z0-9]{3,12}$'
|
||||
return bool(re.match(pattern, symbol.upper()))
|
||||
|
||||
|
||||
def validate_price(price: float) -> bool:
|
||||
"""
|
||||
Validate price value.
|
||||
|
||||
Args:
|
||||
price: Price to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not isinstance(price, (int, float, Decimal)):
|
||||
return False
|
||||
|
||||
try:
|
||||
price_decimal = Decimal(str(price))
|
||||
return price_decimal > 0 and price_decimal < Decimal('1e10') # Reasonable upper bound
|
||||
except (InvalidOperation, ValueError):
|
||||
return False
|
||||
"""Validate price value"""
|
||||
return isinstance(price, (int, float)) and price > 0
|
||||
|
||||
|
||||
def validate_volume(volume: float) -> bool:
|
||||
"""
|
||||
Validate volume value.
|
||||
|
||||
Args:
|
||||
volume: Volume to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not isinstance(volume, (int, float, Decimal)):
|
||||
return False
|
||||
|
||||
try:
|
||||
volume_decimal = Decimal(str(volume))
|
||||
return volume_decimal >= 0 and volume_decimal < Decimal('1e15') # Reasonable upper bound
|
||||
except (InvalidOperation, ValueError):
|
||||
return False
|
||||
"""Validate volume value"""
|
||||
return isinstance(volume, (int, float)) and volume >= 0
|
||||
|
||||
|
||||
def validate_exchange_name(exchange: str) -> bool:
|
||||
"""
|
||||
Validate exchange name.
|
||||
|
||||
Args:
|
||||
exchange: Exchange name to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not exchange or not isinstance(exchange, str):
|
||||
return False
|
||||
|
||||
# Exchange name should be alphanumeric with possible underscores/hyphens
|
||||
pattern = r'^[a-zA-Z0-9_-]{2,20}$'
|
||||
return bool(re.match(pattern, exchange))
|
||||
|
||||
|
||||
def validate_timestamp_range(start_time, end_time) -> List[str]:
|
||||
"""
|
||||
Validate timestamp range.
|
||||
|
||||
Args:
|
||||
start_time: Start timestamp
|
||||
end_time: End timestamp
|
||||
|
||||
Returns:
|
||||
List[str]: List of validation errors (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if start_time is None:
|
||||
errors.append("Start time cannot be None")
|
||||
|
||||
if end_time is None:
|
||||
errors.append("End time cannot be None")
|
||||
|
||||
if start_time and end_time and start_time >= end_time:
|
||||
errors.append("Start time must be before end time")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_bucket_size(bucket_size: float) -> bool:
|
||||
"""
|
||||
Validate price bucket size.
|
||||
|
||||
Args:
|
||||
bucket_size: Bucket size to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not isinstance(bucket_size, (int, float, Decimal)):
|
||||
return False
|
||||
|
||||
try:
|
||||
size_decimal = Decimal(str(bucket_size))
|
||||
return size_decimal > 0 and size_decimal <= Decimal('1000') # Reasonable upper bound
|
||||
except (InvalidOperation, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def validate_speed_multiplier(speed: float) -> bool:
|
||||
"""
|
||||
Validate replay speed multiplier.
|
||||
|
||||
Args:
|
||||
speed: Speed multiplier to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not isinstance(speed, (int, float)):
|
||||
return False
|
||||
|
||||
return 0.01 <= speed <= 100.0 # 1% to 100x speed
|
||||
|
||||
|
||||
def sanitize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
Sanitize and normalize symbol format.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to sanitize
|
||||
|
||||
Returns:
|
||||
str: Sanitized symbol
|
||||
"""
|
||||
if not symbol:
|
||||
return ""
|
||||
|
||||
# Remove whitespace and convert to uppercase
|
||||
sanitized = symbol.strip().upper()
|
||||
|
||||
# Remove invalid characters
|
||||
sanitized = re.sub(r'[^A-Z0-9/-]', '', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_percentage(value: float, min_val: float = 0.0, max_val: float = 100.0) -> bool:
|
||||
"""
|
||||
Validate percentage value.
|
||||
|
||||
Args:
|
||||
value: Percentage value to validate
|
||||
min_val: Minimum allowed value
|
||||
max_val: Maximum allowed value
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
if not isinstance(value, (int, float)):
|
||||
return False
|
||||
|
||||
return min_val <= value <= max_val
|
||||
|
||||
|
||||
def validate_connection_config(config: dict) -> List[str]:
|
||||
"""
|
||||
Validate connection configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
List[str]: List of validation errors (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Required fields
|
||||
required_fields = ['host', 'port']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate host
|
||||
if 'host' in config:
|
||||
host = config['host']
|
||||
if not isinstance(host, str) or not host.strip():
|
||||
errors.append("Host must be a non-empty string")
|
||||
|
||||
# Validate port
|
||||
if 'port' in config:
|
||||
port = config['port']
|
||||
if not isinstance(port, int) or not (1 <= port <= 65535):
|
||||
errors.append("Port must be an integer between 1 and 65535")
|
||||
|
||||
return errors
|
||||
def validate_timestamp(timestamp: Any) -> bool:
|
||||
"""Validate timestamp"""
|
||||
return timestamp is not None
|
Reference in New Issue
Block a user