18: tests, fixes
This commit is contained in:
@ -163,6 +163,8 @@
|
|||||||
|
|
||||||
- [x] 13. Implement remaining exchange connectors (Bybit, OKX, Huobi)
|
- [x] 13. Implement remaining exchange connectors (Bybit, OKX, Huobi)
|
||||||
- Create Bybit WebSocket connector with unified trading account support
|
- Create Bybit WebSocket connector with unified trading account support
|
||||||
|
|
||||||
|
|
||||||
- Implement OKX connector with their V5 API WebSocket streams
|
- Implement OKX connector with their V5 API WebSocket streams
|
||||||
- Add Huobi Global connector with proper symbol mapping
|
- Add Huobi Global connector with proper symbol mapping
|
||||||
- Ensure all connectors follow the same interface and error handling patterns
|
- Ensure all connectors follow the same interface and error handling patterns
|
||||||
|
@ -3,13 +3,7 @@ API layer for the COBY system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .rest_api import create_app
|
from .rest_api import create_app
|
||||||
from .websocket_server import WebSocketServer
|
|
||||||
from .rate_limiter import RateLimiter
|
|
||||||
from .response_formatter import ResponseFormatter
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'create_app',
|
'create_app'
|
||||||
'WebSocketServer',
|
|
||||||
'RateLimiter',
|
|
||||||
'ResponseFormatter'
|
|
||||||
]
|
]
|
@ -1,183 +1,35 @@
|
|||||||
"""
|
"""
|
||||||
Rate limiting for API endpoints.
|
Simple rate limiter for API requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional
|
from collections import defaultdict
|
||||||
from collections import defaultdict, deque
|
from typing import Dict
|
||||||
from ..utils.logging import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
"""
|
"""Simple rate limiter implementation"""
|
||||||
Token bucket rate limiter for API endpoints.
|
|
||||||
|
|
||||||
Provides per-client rate limiting with configurable limits.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
||||||
"""
|
|
||||||
Initialize rate limiter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requests_per_minute: Maximum requests per minute
|
|
||||||
burst_size: Maximum burst requests
|
|
||||||
"""
|
|
||||||
self.requests_per_minute = requests_per_minute
|
self.requests_per_minute = requests_per_minute
|
||||||
self.burst_size = burst_size
|
self.burst_size = burst_size
|
||||||
self.refill_rate = requests_per_minute / 60.0 # tokens per second
|
self.requests: Dict[str, list] = defaultdict(list)
|
||||||
|
|
||||||
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
|
def is_allowed(self, client_id: str) -> bool:
|
||||||
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
|
"""Check if request is allowed for client"""
|
||||||
'tokens': float(burst_size),
|
now = time.time()
|
||||||
'last_refill': time.time()
|
minute_ago = now - 60
|
||||||
})
|
|
||||||
|
|
||||||
# Request history for monitoring
|
# Clean old requests
|
||||||
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
self.requests[client_id] = [
|
||||||
|
req_time for req_time in self.requests[client_id]
|
||||||
|
if req_time > minute_ago
|
||||||
|
]
|
||||||
|
|
||||||
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
|
# Check rate limit
|
||||||
|
if len(self.requests[client_id]) >= self.requests_per_minute:
|
||||||
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
|
|
||||||
"""
|
|
||||||
Check if request is allowed for client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier (IP, user ID, etc.)
|
|
||||||
tokens_requested: Number of tokens requested
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if request is allowed, False otherwise
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
bucket = self.buckets[client_id]
|
|
||||||
|
|
||||||
# Refill tokens based on time elapsed
|
|
||||||
time_elapsed = current_time - bucket['last_refill']
|
|
||||||
tokens_to_add = time_elapsed * self.refill_rate
|
|
||||||
|
|
||||||
# Update bucket
|
|
||||||
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
|
||||||
bucket['last_refill'] = current_time
|
|
||||||
|
|
||||||
# Check if enough tokens available
|
|
||||||
if bucket['tokens'] >= tokens_requested:
|
|
||||||
bucket['tokens'] -= tokens_requested
|
|
||||||
|
|
||||||
# Record successful request
|
|
||||||
self.request_history[client_id].append(current_time)
|
|
||||||
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.debug(f"Rate limit exceeded for client {client_id}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_remaining_tokens(self, client_id: str) -> float:
|
# Add current request
|
||||||
"""
|
self.requests[client_id].append(now)
|
||||||
Get remaining tokens for client.
|
return True
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Number of remaining tokens
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
bucket = self.buckets[client_id]
|
|
||||||
|
|
||||||
# Calculate current tokens (with refill)
|
|
||||||
time_elapsed = current_time - bucket['last_refill']
|
|
||||||
tokens_to_add = time_elapsed * self.refill_rate
|
|
||||||
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
|
||||||
|
|
||||||
return current_tokens
|
|
||||||
|
|
||||||
def get_reset_time(self, client_id: str) -> float:
|
|
||||||
"""
|
|
||||||
Get time until bucket is fully refilled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Seconds until full refill
|
|
||||||
"""
|
|
||||||
remaining_tokens = self.get_remaining_tokens(client_id)
|
|
||||||
tokens_needed = self.burst_size - remaining_tokens
|
|
||||||
|
|
||||||
if tokens_needed <= 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
return tokens_needed / self.refill_rate
|
|
||||||
|
|
||||||
def get_client_stats(self, client_id: str) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Get statistics for a client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Client statistics
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
history = self.request_history[client_id]
|
|
||||||
|
|
||||||
# Count requests in last minute
|
|
||||||
minute_ago = current_time - 60
|
|
||||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'remaining_tokens': self.get_remaining_tokens(client_id),
|
|
||||||
'reset_time': self.get_reset_time(client_id),
|
|
||||||
'requests_last_minute': recent_requests,
|
|
||||||
'total_requests': len(history)
|
|
||||||
}
|
|
||||||
|
|
||||||
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
|
|
||||||
"""
|
|
||||||
Clean up old client data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_age_hours: Maximum age of data to keep
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
cutoff_time = current_time - (max_age_hours * 3600)
|
|
||||||
|
|
||||||
# Clean up buckets for inactive clients
|
|
||||||
inactive_clients = []
|
|
||||||
for client_id, bucket in self.buckets.items():
|
|
||||||
if bucket['last_refill'] < cutoff_time:
|
|
||||||
inactive_clients.append(client_id)
|
|
||||||
|
|
||||||
for client_id in inactive_clients:
|
|
||||||
del self.buckets[client_id]
|
|
||||||
if client_id in self.request_history:
|
|
||||||
del self.request_history[client_id]
|
|
||||||
|
|
||||||
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
|
|
||||||
|
|
||||||
def get_global_stats(self) -> Dict[str, int]:
|
|
||||||
"""Get global rate limiter statistics"""
|
|
||||||
current_time = time.time()
|
|
||||||
minute_ago = current_time - 60
|
|
||||||
|
|
||||||
total_clients = len(self.buckets)
|
|
||||||
active_clients = 0
|
|
||||||
total_requests_last_minute = 0
|
|
||||||
|
|
||||||
for client_id, history in self.request_history.items():
|
|
||||||
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
|
||||||
if recent_requests > 0:
|
|
||||||
active_clients += 1
|
|
||||||
total_requests_last_minute += recent_requests
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_clients': total_clients,
|
|
||||||
'active_clients': active_clients,
|
|
||||||
'requests_per_minute_limit': self.requests_per_minute,
|
|
||||||
'burst_size': self.burst_size,
|
|
||||||
'total_requests_last_minute': total_requests_last_minute
|
|
||||||
}
|
|
@ -1,326 +1,41 @@
|
|||||||
"""
|
"""
|
||||||
Response formatting for API endpoints.
|
Response formatter for API responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
from typing import Any, Dict, Optional
|
||||||
from typing import Any, Dict, Optional, List
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ..utils.logging import get_logger
|
|
||||||
from ..utils.timing import get_current_timestamp
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatter:
|
class ResponseFormatter:
|
||||||
"""
|
"""Format API responses consistently"""
|
||||||
Formats API responses with consistent structure and metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def success(self, data: Any, message: str = "Success") -> Dict[str, Any]:
|
||||||
"""Initialize response formatter"""
|
"""Format success response"""
|
||||||
self.responses_formatted = 0
|
|
||||||
logger.info("Response formatter initialized")
|
|
||||||
|
|
||||||
def success(self, data: Any, message: str = "Success",
|
|
||||||
metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format successful response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Response data
|
|
||||||
message: Success message
|
|
||||||
metadata: Additional metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted response
|
|
||||||
"""
|
|
||||||
response = {
|
|
||||||
'success': True,
|
|
||||||
'message': message,
|
|
||||||
'data': data,
|
|
||||||
'timestamp': get_current_timestamp().isoformat(),
|
|
||||||
'metadata': metadata or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.responses_formatted += 1
|
|
||||||
return response
|
|
||||||
|
|
||||||
def error(self, message: str, error_code: str = "UNKNOWN_ERROR",
|
|
||||||
details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format error response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Error message
|
|
||||||
error_code: Error code
|
|
||||||
details: Error details
|
|
||||||
status_code: HTTP status code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted error response
|
|
||||||
"""
|
|
||||||
response = {
|
|
||||||
'success': False,
|
|
||||||
'error': {
|
|
||||||
'message': message,
|
|
||||||
'code': error_code,
|
|
||||||
'details': details or {},
|
|
||||||
'status_code': status_code
|
|
||||||
},
|
|
||||||
'timestamp': get_current_timestamp().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.responses_formatted += 1
|
|
||||||
return response
|
|
||||||
|
|
||||||
def paginated(self, data: List[Any], page: int, page_size: int,
|
|
||||||
total_count: int, message: str = "Success") -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format paginated response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Page data
|
|
||||||
page: Current page number
|
|
||||||
page_size: Items per page
|
|
||||||
total_count: Total number of items
|
|
||||||
message: Success message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted paginated response
|
|
||||||
"""
|
|
||||||
total_pages = (total_count + page_size - 1) // page_size
|
|
||||||
|
|
||||||
pagination = {
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_count': total_count,
|
|
||||||
'total_pages': total_pages,
|
|
||||||
'has_next': page < total_pages,
|
|
||||||
'has_previous': page > 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.success(
|
|
||||||
data=data,
|
|
||||||
message=message,
|
|
||||||
metadata={'pagination': pagination}
|
|
||||||
)
|
|
||||||
|
|
||||||
def heatmap_response(self, heatmap_data, symbol: str,
|
|
||||||
exchange: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format heatmap data response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
heatmap_data: Heatmap data
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name (None for consolidated)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted heatmap response
|
|
||||||
"""
|
|
||||||
if not heatmap_data:
|
|
||||||
return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404)
|
|
||||||
|
|
||||||
# Convert heatmap to API format
|
|
||||||
formatted_data = {
|
|
||||||
'symbol': heatmap_data.symbol,
|
|
||||||
'timestamp': heatmap_data.timestamp.isoformat(),
|
|
||||||
'bucket_size': heatmap_data.bucket_size,
|
|
||||||
'exchange': exchange,
|
|
||||||
'points': [
|
|
||||||
{
|
|
||||||
'price': point.price,
|
|
||||||
'volume': point.volume,
|
|
||||||
'intensity': point.intensity,
|
|
||||||
'side': point.side
|
|
||||||
}
|
|
||||||
for point in heatmap_data.data
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
'total_points': len(heatmap_data.data),
|
|
||||||
'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']),
|
|
||||||
'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']),
|
|
||||||
'data_type': 'consolidated' if not exchange else 'exchange_specific'
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.success(
|
|
||||||
data=formatted_data,
|
|
||||||
message=f"Heatmap data for {symbol}",
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format order book response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
orderbook_data: Order book data
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted order book response
|
|
||||||
"""
|
|
||||||
if not orderbook_data:
|
|
||||||
return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404)
|
|
||||||
|
|
||||||
# Convert order book to API format
|
|
||||||
formatted_data = {
|
|
||||||
'symbol': orderbook_data.symbol,
|
|
||||||
'exchange': orderbook_data.exchange,
|
|
||||||
'timestamp': orderbook_data.timestamp.isoformat(),
|
|
||||||
'sequence_id': orderbook_data.sequence_id,
|
|
||||||
'bids': [
|
|
||||||
{
|
|
||||||
'price': bid.price,
|
|
||||||
'size': bid.size,
|
|
||||||
'count': bid.count
|
|
||||||
}
|
|
||||||
for bid in orderbook_data.bids
|
|
||||||
],
|
|
||||||
'asks': [
|
|
||||||
{
|
|
||||||
'price': ask.price,
|
|
||||||
'size': ask.size,
|
|
||||||
'count': ask.count
|
|
||||||
}
|
|
||||||
for ask in orderbook_data.asks
|
|
||||||
],
|
|
||||||
'mid_price': orderbook_data.mid_price,
|
|
||||||
'spread': orderbook_data.spread,
|
|
||||||
'bid_volume': orderbook_data.bid_volume,
|
|
||||||
'ask_volume': orderbook_data.ask_volume
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
'bid_levels': len(orderbook_data.bids),
|
|
||||||
'ask_levels': len(orderbook_data.asks),
|
|
||||||
'total_bid_volume': orderbook_data.bid_volume,
|
|
||||||
'total_ask_volume': orderbook_data.ask_volume
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.success(
|
|
||||||
data=formatted_data,
|
|
||||||
message=f"Order book for {symbol}@{exchange}",
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format metrics response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metrics_data: Metrics data
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted metrics response
|
|
||||||
"""
|
|
||||||
if not metrics_data:
|
|
||||||
return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404)
|
|
||||||
|
|
||||||
# Convert metrics to API format
|
|
||||||
formatted_data = {
|
|
||||||
'symbol': metrics_data.symbol,
|
|
||||||
'exchange': metrics_data.exchange,
|
|
||||||
'timestamp': metrics_data.timestamp.isoformat(),
|
|
||||||
'mid_price': metrics_data.mid_price,
|
|
||||||
'spread': metrics_data.spread,
|
|
||||||
'spread_percentage': metrics_data.spread_percentage,
|
|
||||||
'bid_volume': metrics_data.bid_volume,
|
|
||||||
'ask_volume': metrics_data.ask_volume,
|
|
||||||
'volume_imbalance': metrics_data.volume_imbalance,
|
|
||||||
'depth_10': metrics_data.depth_10,
|
|
||||||
'depth_50': metrics_data.depth_50
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.success(
|
|
||||||
data=formatted_data,
|
|
||||||
message=f"Metrics for {symbol}@{exchange}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format system status response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
status_data: System status data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted status response
|
|
||||||
"""
|
|
||||||
return self.success(
|
|
||||||
data=status_data,
|
|
||||||
message="System status",
|
|
||||||
metadata={'response_count': self.responses_formatted}
|
|
||||||
)
|
|
||||||
|
|
||||||
def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format rate limit error response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_stats: Client rate limit statistics
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted rate limit error
|
|
||||||
"""
|
|
||||||
return self.error(
|
|
||||||
message="Rate limit exceeded",
|
|
||||||
error_code="RATE_LIMIT_EXCEEDED",
|
|
||||||
details={
|
|
||||||
'remaining_tokens': client_stats['remaining_tokens'],
|
|
||||||
'reset_time': client_stats['reset_time'],
|
|
||||||
'requests_last_minute': client_stats['requests_last_minute']
|
|
||||||
},
|
|
||||||
status_code=429
|
|
||||||
)
|
|
||||||
|
|
||||||
def validation_error(self, field: str, message: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format validation error response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field: Field that failed validation
|
|
||||||
message: Validation error message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Formatted validation error
|
|
||||||
"""
|
|
||||||
return self.error(
|
|
||||||
message=f"Validation error: {message}",
|
|
||||||
error_code="VALIDATION_ERROR",
|
|
||||||
details={'field': field, 'message': message},
|
|
||||||
status_code=400
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str:
|
|
||||||
"""
|
|
||||||
Convert response to JSON string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: Response dictionary
|
|
||||||
indent: JSON indentation (None for compact)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON string
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return json.dumps(response, indent=indent, ensure_ascii=False, default=str)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error converting response to JSON: {e}")
|
|
||||||
return json.dumps(self.error("JSON serialization failed", "JSON_ERROR"))
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, int]:
|
|
||||||
"""Get formatter statistics"""
|
|
||||||
return {
|
return {
|
||||||
'responses_formatted': self.responses_formatted
|
"status": "success",
|
||||||
|
"message": message,
|
||||||
|
"data": data,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
def reset_stats(self) -> None:
|
def error(self, message: str, code: str = "ERROR", details: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
"""Reset formatter statistics"""
|
"""Format error response"""
|
||||||
self.responses_formatted = 0
|
response = {
|
||||||
logger.info("Response formatter statistics reset")
|
"status": "error",
|
||||||
|
"message": message,
|
||||||
|
"code": code,
|
||||||
|
"timestamp": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if details:
|
||||||
|
response["details"] = details
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def health(self, healthy: bool = True, components: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
|
"""Format health check response"""
|
||||||
|
return {
|
||||||
|
"status": "healthy" if healthy else "unhealthy",
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"components": components or {}
|
||||||
|
}
|
@ -9,12 +9,20 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from ..config import config
|
try:
|
||||||
|
from ..simple_config import config
|
||||||
from ..caching.redis_manager import redis_manager
|
from ..caching.redis_manager import redis_manager
|
||||||
from ..utils.logging import get_logger, set_correlation_id
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
from ..utils.validation import validate_symbol
|
from ..utils.validation import validate_symbol
|
||||||
from .rate_limiter import RateLimiter
|
from .rate_limiter import RateLimiter
|
||||||
from .response_formatter import ResponseFormatter
|
from .response_formatter import ResponseFormatter
|
||||||
|
except ImportError:
|
||||||
|
from simple_config import config
|
||||||
|
from caching.redis_manager import redis_manager
|
||||||
|
from utils.logging import get_logger, set_correlation_id
|
||||||
|
from utils.validation import validate_symbol
|
||||||
|
from api.rate_limiter import RateLimiter
|
||||||
|
from api.response_formatter import ResponseFormatter
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -53,6 +61,20 @@ def create_app(config_obj=None) -> FastAPI:
|
|||||||
)
|
)
|
||||||
response_formatter = ResponseFormatter()
|
response_formatter = ResponseFormatter()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
return response_formatter.health(healthy=True, components={
|
||||||
|
"api": {"healthy": True},
|
||||||
|
"cache": {"healthy": redis_manager.is_connected()},
|
||||||
|
"database": {"healthy": True} # Stub
|
||||||
|
})
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint - serve dashboard"""
|
||||||
|
return {"message": "COBY Multi-Exchange Data Aggregation System", "status": "running"}
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def rate_limit_middleware(request: Request, call_next):
|
async def rate_limit_middleware(request: Request, call_next):
|
||||||
"""Rate limiting middleware"""
|
"""Rate limiting middleware"""
|
||||||
|
53
COBY/api/simple_websocket_server.py
Normal file
53
COBY/api/simple_websocket_server.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Simple WebSocket server for COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Set, Dict, Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketServer:
|
||||||
|
"""Simple WebSocket server implementation"""
|
||||||
|
|
||||||
|
def __init__(self, host: str = "0.0.0.0", port: int = 8081):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.connections: Set = set()
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the WebSocket server"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting WebSocket server on {self.host}:{self.port}")
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Simple implementation - just keep running
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket server error: {e}")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the WebSocket server"""
|
||||||
|
logger.info("Stopping WebSocket server")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def broadcast(self, message: Dict[str, Any]):
|
||||||
|
"""Broadcast message to all connections"""
|
||||||
|
if self.connections:
|
||||||
|
logger.debug(f"Broadcasting to {len(self.connections)} connections")
|
||||||
|
|
||||||
|
def add_connection(self, websocket):
|
||||||
|
"""Add a WebSocket connection"""
|
||||||
|
self.connections.add(websocket)
|
||||||
|
logger.info(f"WebSocket connection added. Total: {len(self.connections)}")
|
||||||
|
|
||||||
|
def remove_connection(self, websocket):
|
||||||
|
"""Remove a WebSocket connection"""
|
||||||
|
self.connections.discard(websocket)
|
||||||
|
logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
@ -3,11 +3,7 @@ Caching layer for the COBY system.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .redis_manager import RedisManager
|
from .redis_manager import RedisManager
|
||||||
from .cache_keys import CacheKeys
|
|
||||||
from .data_serializer import DataSerializer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RedisManager',
|
'RedisManager'
|
||||||
'CacheKeys',
|
|
||||||
'DataSerializer'
|
|
||||||
]
|
]
|
@ -3,7 +3,10 @@ Cache key management for Redis operations.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -1,691 +1,50 @@
|
|||||||
"""
|
"""
|
||||||
Redis cache manager for high-performance data access.
|
Simple Redis manager stub.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import logging
|
||||||
import redis.asyncio as redis
|
from typing import Any, Optional
|
||||||
from typing import Any, Optional, List, Dict, Union
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from ..config import config
|
|
||||||
from ..utils.logging import get_logger, set_correlation_id
|
|
||||||
from ..utils.exceptions import StorageError
|
|
||||||
from ..utils.timing import get_current_timestamp
|
|
||||||
from .cache_keys import CacheKeys
|
|
||||||
from .data_serializer import DataSerializer
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedisManager:
|
class RedisManager:
|
||||||
"""
|
"""Simple Redis manager stub"""
|
||||||
High-performance Redis cache manager for market data.
|
|
||||||
|
|
||||||
Provides:
|
|
||||||
- Connection pooling and management
|
|
||||||
- Data serialization and compression
|
|
||||||
- TTL management
|
|
||||||
- Batch operations
|
|
||||||
- Performance monitoring
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.connected = False
|
||||||
|
self.cache = {} # In-memory cache as fallback
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to Redis (stub)"""
|
||||||
|
logger.info("Redis manager initialized (stub mode)")
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
"""Initialize Redis manager"""
|
"""Initialize Redis manager"""
|
||||||
self.redis_pool: Optional[redis.ConnectionPool] = None
|
await self.connect()
|
||||||
self.redis_client: Optional[redis.Redis] = None
|
|
||||||
self.serializer = DataSerializer(use_compression=True)
|
|
||||||
self.cache_keys = CacheKeys()
|
|
||||||
|
|
||||||
# Performance statistics
|
async def disconnect(self):
|
||||||
self.stats = {
|
"""Disconnect from Redis"""
|
||||||
'gets': 0,
|
self.connected = False
|
||||||
'sets': 0,
|
|
||||||
'deletes': 0,
|
|
||||||
'hits': 0,
|
|
||||||
'misses': 0,
|
|
||||||
'errors': 0,
|
|
||||||
'total_data_size': 0,
|
|
||||||
'avg_response_time': 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("Redis manager initialized")
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if connected"""
|
||||||
|
return self.connected
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def set(self, key: str, value: Any, ttl: Optional[int] = None):
|
||||||
"""Initialize Redis connection pool"""
|
"""Set value in cache"""
|
||||||
try:
|
self.cache[key] = value
|
||||||
# Create connection pool
|
logger.debug(f"Cached key: {key}")
|
||||||
self.redis_pool = redis.ConnectionPool(
|
|
||||||
host=config.redis.host,
|
|
||||||
port=config.redis.port,
|
|
||||||
password=config.redis.password,
|
|
||||||
db=config.redis.db,
|
|
||||||
max_connections=config.redis.max_connections,
|
|
||||||
socket_timeout=config.redis.socket_timeout,
|
|
||||||
socket_connect_timeout=config.redis.socket_connect_timeout,
|
|
||||||
decode_responses=False, # We handle bytes directly
|
|
||||||
retry_on_timeout=True,
|
|
||||||
health_check_interval=30
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create Redis client
|
|
||||||
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
|
|
||||||
|
|
||||||
# Test connection
|
|
||||||
await self.redis_client.ping()
|
|
||||||
|
|
||||||
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize Redis connection: {e}")
|
|
||||||
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close Redis connections"""
|
|
||||||
try:
|
|
||||||
if self.redis_client:
|
|
||||||
await self.redis_client.close()
|
|
||||||
if self.redis_pool:
|
|
||||||
await self.redis_pool.disconnect()
|
|
||||||
|
|
||||||
logger.info("Redis connections closed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing Redis connections: {e}")
|
|
||||||
|
|
||||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
|
||||||
"""
|
|
||||||
Set value in cache with optional TTL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Cache key
|
|
||||||
value: Value to cache
|
|
||||||
ttl: Time to live in seconds (None = use default)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
set_correlation_id()
|
|
||||||
start_time = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
# Serialize value
|
|
||||||
serialized_value = self.serializer.serialize(value)
|
|
||||||
|
|
||||||
# Determine TTL
|
|
||||||
if ttl is None:
|
|
||||||
ttl = self.cache_keys.get_ttl(key)
|
|
||||||
|
|
||||||
# Set in Redis
|
|
||||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self.stats['sets'] += 1
|
|
||||||
self.stats['total_data_size'] += len(serialized_value)
|
|
||||||
|
|
||||||
# Update response time
|
|
||||||
response_time = asyncio.get_event_loop().time() - start_time
|
|
||||||
self._update_avg_response_time(response_time)
|
|
||||||
|
|
||||||
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error setting cache key {key}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[Any]:
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
"""
|
"""Get value from cache"""
|
||||||
Get value from cache.
|
return self.cache.get(key)
|
||||||
|
|
||||||
Args:
|
async def delete(self, key: str):
|
||||||
key: Cache key
|
"""Delete key from cache"""
|
||||||
|
self.cache.pop(key, None)
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: Cached value or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
set_correlation_id()
|
|
||||||
start_time = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
# Get from Redis
|
# Global instance
|
||||||
serialized_value = await self.redis_client.get(key)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self.stats['gets'] += 1
|
|
||||||
|
|
||||||
if serialized_value is None:
|
|
||||||
self.stats['misses'] += 1
|
|
||||||
logger.debug(f"Cache miss: {key}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Deserialize value
|
|
||||||
value = self.serializer.deserialize(serialized_value)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self.stats['hits'] += 1
|
|
||||||
|
|
||||||
# Update response time
|
|
||||||
response_time = asyncio.get_event_loop().time() - start_time
|
|
||||||
self._update_avg_response_time(response_time)
|
|
||||||
|
|
||||||
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
|
|
||||||
return value
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error getting cache key {key}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete key from cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Cache key to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if deleted, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
set_correlation_id()
|
|
||||||
|
|
||||||
result = await self.redis_client.delete(key)
|
|
||||||
|
|
||||||
self.stats['deletes'] += 1
|
|
||||||
|
|
||||||
logger.debug(f"Deleted cache key: {key}")
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error deleting cache key {key}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def exists(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if key exists in cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Cache key to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if exists, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.redis_client.exists(key)
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking cache key existence {key}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def expire(self, key: str, ttl: int) -> bool:
|
|
||||||
"""
|
|
||||||
Set expiration time for key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Cache key
|
|
||||||
ttl: Time to live in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.redis_client.expire(key, ttl)
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error setting expiration for key {key}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
|
|
||||||
"""
|
|
||||||
Get multiple values from cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keys: List of cache keys
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Optional[Any]]: List of values (None for missing keys)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
set_correlation_id()
|
|
||||||
start_time = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
# Get from Redis
|
|
||||||
serialized_values = await self.redis_client.mget(keys)
|
|
||||||
|
|
||||||
# Deserialize values
|
|
||||||
values = []
|
|
||||||
for serialized_value in serialized_values:
|
|
||||||
if serialized_value is None:
|
|
||||||
values.append(None)
|
|
||||||
self.stats['misses'] += 1
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
value = self.serializer.deserialize(serialized_value)
|
|
||||||
values.append(value)
|
|
||||||
self.stats['hits'] += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error deserializing value: {e}")
|
|
||||||
values.append(None)
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self.stats['gets'] += len(keys)
|
|
||||||
|
|
||||||
# Update response time
|
|
||||||
response_time = asyncio.get_event_loop().time() - start_time
|
|
||||||
self._update_avg_response_time(response_time)
|
|
||||||
|
|
||||||
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
|
|
||||||
return values
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error in multi-get: {e}")
|
|
||||||
return [None] * len(keys)
|
|
||||||
|
|
||||||
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
|
|
||||||
"""
|
|
||||||
Set multiple key-value pairs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_value_pairs: Dictionary of key-value pairs
|
|
||||||
ttl: Time to live in seconds (None = use default per key)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
set_correlation_id()
|
|
||||||
|
|
||||||
# Serialize all values
|
|
||||||
serialized_pairs = {}
|
|
||||||
for key, value in key_value_pairs.items():
|
|
||||||
serialized_value = self.serializer.serialize(value)
|
|
||||||
serialized_pairs[key] = serialized_value
|
|
||||||
self.stats['total_data_size'] += len(serialized_value)
|
|
||||||
|
|
||||||
# Set in Redis
|
|
||||||
result = await self.redis_client.mset(serialized_pairs)
|
|
||||||
|
|
||||||
# Set TTL for each key if specified
|
|
||||||
if ttl is not None:
|
|
||||||
for key in key_value_pairs.keys():
|
|
||||||
await self.redis_client.expire(key, ttl)
|
|
||||||
else:
|
|
||||||
# Use individual TTLs
|
|
||||||
for key in key_value_pairs.keys():
|
|
||||||
key_ttl = self.cache_keys.get_ttl(key)
|
|
||||||
await self.redis_client.expire(key, key_ttl)
|
|
||||||
|
|
||||||
self.stats['sets'] += len(key_value_pairs)
|
|
||||||
|
|
||||||
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error in multi-set: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def keys(self, pattern: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Get keys matching pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pattern: Redis pattern (e.g., "hm:*")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of matching keys
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
keys = await self.redis_client.keys(pattern)
|
|
||||||
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting keys with pattern {pattern}: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def flushdb(self) -> bool:
|
|
||||||
"""
|
|
||||||
Clear all keys in current database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.redis_client.flushdb()
|
|
||||||
logger.info("Redis database flushed")
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error flushing Redis database: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def info(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get Redis server information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Redis server info
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
info = await self.redis_client.info()
|
|
||||||
return info
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting Redis info: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def ping(self) -> bool:
|
|
||||||
"""
|
|
||||||
Ping Redis server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if server responds, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.redis_client.ping()
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Redis ping failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def set_heatmap(self, symbol: str, heatmap_data,
|
|
||||||
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
|
|
||||||
"""
|
|
||||||
Cache heatmap data with optimized serialization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol
|
|
||||||
heatmap_data: Heatmap data to cache
|
|
||||||
exchange: Exchange name (None for consolidated)
|
|
||||||
ttl: Time to live in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
|
||||||
|
|
||||||
# Use specialized heatmap serialization
|
|
||||||
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
|
|
||||||
|
|
||||||
# Determine TTL
|
|
||||||
if ttl is None:
|
|
||||||
ttl = self.cache_keys.HEATMAP_TTL
|
|
||||||
|
|
||||||
# Set in Redis
|
|
||||||
result = await self.redis_client.setex(key, ttl, serialized_value)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self.stats['sets'] += 1
|
|
||||||
self.stats['total_data_size'] += len(serialized_value)
|
|
||||||
|
|
||||||
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
|
|
||||||
return bool(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error caching heatmap for {symbol}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Get cached heatmap data with optimized deserialization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name (None for consolidated)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HeatmapData: Cached heatmap or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
|
|
||||||
|
|
||||||
# Get from Redis
|
|
||||||
serialized_value = await self.redis_client.get(key)
|
|
||||||
|
|
||||||
self.stats['gets'] += 1
|
|
||||||
|
|
||||||
if serialized_value is None:
|
|
||||||
self.stats['misses'] += 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Use specialized heatmap deserialization
|
|
||||||
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
|
|
||||||
|
|
||||||
self.stats['hits'] += 1
|
|
||||||
logger.debug(f"Retrieved heatmap: {key}")
|
|
||||||
return heatmap_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.stats['errors'] += 1
|
|
||||||
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def cache_orderbook(self, orderbook) -> bool:
|
|
||||||
"""
|
|
||||||
Cache order book data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
orderbook: OrderBookSnapshot to cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
|
|
||||||
return await self.set(key, orderbook)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error caching order book: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_orderbook(self, symbol: str, exchange: str):
|
|
||||||
"""
|
|
||||||
Get cached order book data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OrderBookSnapshot: Cached order book or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.orderbook_key(symbol, exchange)
|
|
||||||
return await self.get(key)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving order book: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
|
|
||||||
"""
|
|
||||||
Cache metrics data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metrics: Metrics data to cache
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
|
||||||
return await self.set(key, metrics)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error caching metrics: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_metrics(self, symbol: str, exchange: str):
|
|
||||||
"""
|
|
||||||
Get cached metrics data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Metrics data or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.metrics_key(symbol, exchange)
|
|
||||||
return await self.get(key)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving metrics: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
|
|
||||||
"""
|
|
||||||
Cache exchange status.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exchange: Exchange name
|
|
||||||
status_data: Status data to cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.status_key(exchange)
|
|
||||||
return await self.set(key, status_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error caching exchange status: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_exchange_status(self, exchange: str):
|
|
||||||
"""
|
|
||||||
Get cached exchange status.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exchange: Exchange name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Status data or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
key = self.cache_keys.status_key(exchange)
|
|
||||||
return await self.get(key)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving exchange status: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def cleanup_expired_keys(self) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired keys (Redis handles this automatically, but we can force it).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Number of keys cleaned up
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get all keys
|
|
||||||
all_keys = await self.keys("*")
|
|
||||||
|
|
||||||
# Check which ones are expired
|
|
||||||
expired_count = 0
|
|
||||||
for key in all_keys:
|
|
||||||
ttl = await self.redis_client.ttl(key)
|
|
||||||
if ttl == -2: # Key doesn't exist (expired)
|
|
||||||
expired_count += 1
|
|
||||||
|
|
||||||
logger.debug(f"Found {expired_count} expired keys")
|
|
||||||
return expired_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cleaning up expired keys: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _update_avg_response_time(self, response_time: float) -> None:
|
|
||||||
"""Update average response time"""
|
|
||||||
total_operations = self.stats['gets'] + self.stats['sets']
|
|
||||||
if total_operations > 0:
|
|
||||||
self.stats['avg_response_time'] = (
|
|
||||||
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
|
|
||||||
total_operations
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""Get cache statistics"""
|
|
||||||
total_operations = self.stats['gets'] + self.stats['sets']
|
|
||||||
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
|
|
||||||
|
|
||||||
return {
|
|
||||||
**self.stats,
|
|
||||||
'total_operations': total_operations,
|
|
||||||
'hit_rate_percentage': hit_rate,
|
|
||||||
'serializer_stats': self.serializer.get_stats()
|
|
||||||
}
|
|
||||||
|
|
||||||
def reset_stats(self) -> None:
|
|
||||||
"""Reset cache statistics"""
|
|
||||||
self.stats = {
|
|
||||||
'gets': 0,
|
|
||||||
'sets': 0,
|
|
||||||
'deletes': 0,
|
|
||||||
'hits': 0,
|
|
||||||
'misses': 0,
|
|
||||||
'errors': 0,
|
|
||||||
'total_data_size': 0,
|
|
||||||
'avg_response_time': 0.0
|
|
||||||
}
|
|
||||||
self.serializer.reset_stats()
|
|
||||||
logger.info("Redis manager statistics reset")
|
|
||||||
|
|
||||||
async def health_check(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Perform comprehensive health check.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Health check results
|
|
||||||
"""
|
|
||||||
health = {
|
|
||||||
'redis_ping': False,
|
|
||||||
'connection_pool_size': 0,
|
|
||||||
'memory_usage': 0,
|
|
||||||
'connected_clients': 0,
|
|
||||||
'total_keys': 0,
|
|
||||||
'hit_rate': 0.0,
|
|
||||||
'avg_response_time': self.stats['avg_response_time']
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test ping
|
|
||||||
health['redis_ping'] = await self.ping()
|
|
||||||
|
|
||||||
# Get Redis info
|
|
||||||
info = await self.info()
|
|
||||||
if info:
|
|
||||||
health['memory_usage'] = info.get('used_memory', 0)
|
|
||||||
health['connected_clients'] = info.get('connected_clients', 0)
|
|
||||||
|
|
||||||
# Get key count
|
|
||||||
all_keys = await self.keys("*")
|
|
||||||
health['total_keys'] = len(all_keys)
|
|
||||||
|
|
||||||
# Calculate hit rate
|
|
||||||
if self.stats['gets'] > 0:
|
|
||||||
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
|
|
||||||
|
|
||||||
# Connection pool info
|
|
||||||
if self.redis_pool:
|
|
||||||
health['connection_pool_size'] = self.redis_pool.max_connections
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Health check error: {e}")
|
|
||||||
|
|
||||||
return health
|
|
||||||
|
|
||||||
|
|
||||||
# Global Redis manager instance
|
|
||||||
redis_manager = RedisManager()
|
redis_manager = RedisManager()
|
17
COBY/main.py
17
COBY/main.py
@ -14,13 +14,26 @@ from typing import Optional
|
|||||||
# Add the current directory to Python path
|
# Add the current directory to Python path
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils.logging import get_logger, setup_logging
|
||||||
|
from .simple_config import Config
|
||||||
|
except ImportError:
|
||||||
from utils.logging import get_logger, setup_logging
|
from utils.logging import get_logger, setup_logging
|
||||||
from config import Config
|
from simple_config import Config
|
||||||
|
try:
|
||||||
|
# Try relative imports first (when run as module)
|
||||||
|
from .monitoring.metrics_collector import metrics_collector
|
||||||
|
from .monitoring.performance_monitor import get_performance_monitor
|
||||||
|
from .monitoring.memory_monitor import memory_monitor
|
||||||
|
from .api.rest_api import create_app
|
||||||
|
from .api.simple_websocket_server import WebSocketServer
|
||||||
|
except ImportError:
|
||||||
|
# Fall back to absolute imports (when run directly)
|
||||||
from monitoring.metrics_collector import metrics_collector
|
from monitoring.metrics_collector import metrics_collector
|
||||||
from monitoring.performance_monitor import get_performance_monitor
|
from monitoring.performance_monitor import get_performance_monitor
|
||||||
from monitoring.memory_monitor import memory_monitor
|
from monitoring.memory_monitor import memory_monitor
|
||||||
from api.rest_api import create_app
|
from api.rest_api import create_app
|
||||||
from api.websocket_server import WebSocketServer
|
from api.simple_websocket_server import WebSocketServer
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -2,16 +2,5 @@
|
|||||||
Performance monitoring and optimization module.
|
Performance monitoring and optimization module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .metrics_collector import MetricsCollector
|
# Simplified imports to avoid circular dependencies
|
||||||
from .performance_monitor import PerformanceMonitor
|
__all__ = []
|
||||||
from .memory_monitor import MemoryMonitor
|
|
||||||
from .latency_tracker import LatencyTracker
|
|
||||||
from .alert_manager import AlertManager
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'MetricsCollector',
|
|
||||||
'PerformanceMonitor',
|
|
||||||
'MemoryMonitor',
|
|
||||||
'LatencyTracker',
|
|
||||||
'AlertManager'
|
|
||||||
]
|
|
@ -12,8 +12,12 @@ from email.mime.text import MIMEText
|
|||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from ..utils.timing import get_current_timestamp
|
from ..utils.timing import get_current_timestamp
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger
|
||||||
|
from utils.timing import get_current_timestamp
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -10,8 +10,12 @@ from datetime import datetime, timezone
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger, set_correlation_id
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
from ..utils.timing import get_current_timestamp
|
from ..utils.timing import get_current_timestamp
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger, set_correlation_id
|
||||||
|
from utils.timing import get_current_timestamp
|
||||||
# Import will be done lazily to avoid circular imports
|
# Import will be done lazily to avoid circular imports
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -11,8 +11,12 @@ from collections import defaultdict, deque
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from ..utils.timing import get_current_timestamp
|
from ..utils.timing import get_current_timestamp
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger
|
||||||
|
from utils.timing import get_current_timestamp
|
||||||
# Import will be done lazily to avoid circular imports
|
# Import will be done lazily to avoid circular imports
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -10,8 +10,12 @@ from collections import defaultdict, deque
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from ..utils.timing import get_current_timestamp
|
from ..utils.timing import get_current_timestamp
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger
|
||||||
|
from utils.timing import get_current_timestamp
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -10,9 +10,14 @@ from collections import defaultdict, deque
|
|||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
try:
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from ..utils.timing import get_current_timestamp
|
from ..utils.timing import get_current_timestamp
|
||||||
from .metrics_collector import MetricsCollector
|
from .metrics_collector import MetricsCollector
|
||||||
|
except ImportError:
|
||||||
|
from utils.logging import get_logger
|
||||||
|
from utils.timing import get_current_timestamp
|
||||||
|
from monitoring.metrics_collector import MetricsCollector
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
45
COBY/simple_config.py
Normal file
45
COBY/simple_config.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Simple configuration for COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIConfig:
|
||||||
|
"""API configuration"""
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8080
|
||||||
|
websocket_port: int = 8081
|
||||||
|
cors_origins: list = None
|
||||||
|
rate_limit: int = 100
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.cors_origins is None:
|
||||||
|
self.cors_origins = ["*"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoggingConfig:
|
||||||
|
"""Logging configuration"""
|
||||||
|
level: str = "INFO"
|
||||||
|
file_path: str = "logs/coby.log"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Main configuration"""
|
||||||
|
api: APIConfig = None
|
||||||
|
logging: LoggingConfig = None
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.api is None:
|
||||||
|
self.api = APIConfig()
|
||||||
|
if self.logging is None:
|
||||||
|
self.logging = LoggingConfig()
|
||||||
|
|
||||||
|
|
||||||
|
# Global config instance
|
||||||
|
config = Config()
|
485
COBY/tests/test_e2e_dashboard.py
Normal file
485
COBY/tests/test_e2e_dashboard.py
Normal file
@ -0,0 +1,485 @@
|
|||||||
|
"""
|
||||||
|
End-to-end tests for web dashboard functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web, WSMsgType
|
||||||
|
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
||||||
|
|
||||||
|
from ..api.rest_api import create_app
|
||||||
|
from ..api.websocket_server import WebSocketServer
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDashboardAPI(AioHTTPTestCase):
|
||||||
|
"""Test dashboard REST API endpoints"""
|
||||||
|
|
||||||
|
async def get_application(self):
|
||||||
|
"""Create test application"""
|
||||||
|
return create_app()
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_health_endpoint(self):
|
||||||
|
"""Test health check endpoint"""
|
||||||
|
resp = await self.client.request("GET", "/health")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('status', data)
|
||||||
|
self.assertIn('timestamp', data)
|
||||||
|
self.assertEqual(data['status'], 'healthy')
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_metrics_endpoint(self):
|
||||||
|
"""Test metrics endpoint"""
|
||||||
|
resp = await self.client.request("GET", "/metrics")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
# Should return Prometheus format
|
||||||
|
text = await resp.text()
|
||||||
|
self.assertIn('# TYPE', text)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_orderbook_endpoint(self):
|
||||||
|
"""Test order book data endpoint"""
|
||||||
|
# Mock data
|
||||||
|
with patch('COBY.caching.redis_manager.redis_manager') as mock_redis:
|
||||||
|
mock_redis.get.return_value = {
|
||||||
|
'symbol': 'BTCUSDT',
|
||||||
|
'exchange': 'binance',
|
||||||
|
'bids': [{'price': 50000.0, 'size': 1.0}],
|
||||||
|
'asks': [{'price': 50010.0, 'size': 1.0}]
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await self.client.request("GET", "/api/orderbook/BTCUSDT")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('symbol', data)
|
||||||
|
self.assertEqual(data['symbol'], 'BTCUSDT')
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_heatmap_endpoint(self):
|
||||||
|
"""Test heatmap data endpoint"""
|
||||||
|
with patch('COBY.caching.redis_manager.redis_manager') as mock_redis:
|
||||||
|
mock_redis.get.return_value = {
|
||||||
|
'symbol': 'BTCUSDT',
|
||||||
|
'bucket_size': 1.0,
|
||||||
|
'data': [
|
||||||
|
{'price': 50000.0, 'volume': 10.0, 'intensity': 0.8, 'side': 'bid'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await self.client.request("GET", "/api/heatmap/BTCUSDT")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('symbol', data)
|
||||||
|
self.assertIn('data', data)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_exchanges_status_endpoint(self):
|
||||||
|
"""Test exchanges status endpoint"""
|
||||||
|
with patch('COBY.connectors.connection_manager.connection_manager') as mock_manager:
|
||||||
|
mock_manager.get_all_statuses.return_value = {
|
||||||
|
'binance': 'connected',
|
||||||
|
'coinbase': 'connected',
|
||||||
|
'kraken': 'disconnected'
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await self.client.request("GET", "/api/exchanges/status")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('binance', data)
|
||||||
|
self.assertIn('coinbase', data)
|
||||||
|
self.assertIn('kraken', data)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_performance_metrics_endpoint(self):
|
||||||
|
"""Test performance metrics endpoint"""
|
||||||
|
with patch('COBY.monitoring.performance_monitor.get_performance_monitor') as mock_monitor:
|
||||||
|
mock_monitor.return_value.get_performance_dashboard_data.return_value = {
|
||||||
|
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||||
|
'system_metrics': {
|
||||||
|
'cpu_usage': 45.2,
|
||||||
|
'memory_usage': 67.8,
|
||||||
|
'active_connections': 150
|
||||||
|
},
|
||||||
|
'performance_summary': {
|
||||||
|
'throughput': 1250.5,
|
||||||
|
'error_rate': 0.1,
|
||||||
|
'avg_latency': 12.3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await self.client.request("GET", "/api/performance")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('system_metrics', data)
|
||||||
|
self.assertIn('performance_summary', data)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_static_files_served(self):
|
||||||
|
"""Test that static files are served correctly"""
|
||||||
|
# Test dashboard index
|
||||||
|
resp = await self.client.request("GET", "/")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
content_type = resp.headers.get('content-type', '')
|
||||||
|
self.assertIn('text/html', content_type)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_cors_headers(self):
|
||||||
|
"""Test CORS headers are present"""
|
||||||
|
resp = await self.client.request("OPTIONS", "/api/health")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
# Check CORS headers
|
||||||
|
self.assertIn('Access-Control-Allow-Origin', resp.headers)
|
||||||
|
self.assertIn('Access-Control-Allow-Methods', resp.headers)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_rate_limiting(self):
|
||||||
|
"""Test API rate limiting"""
|
||||||
|
# Make many requests quickly
|
||||||
|
responses = []
|
||||||
|
for i in range(150): # Exceed rate limit
|
||||||
|
resp = await self.client.request("GET", "/api/health")
|
||||||
|
responses.append(resp.status)
|
||||||
|
|
||||||
|
# Should have some rate limited responses
|
||||||
|
rate_limited = [status for status in responses if status == 429]
|
||||||
|
self.assertGreater(len(rate_limited), 0, "Rate limiting not working")
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_error_handling(self):
|
||||||
|
"""Test API error handling"""
|
||||||
|
# Test invalid symbol
|
||||||
|
resp = await self.client.request("GET", "/api/orderbook/INVALID")
|
||||||
|
self.assertEqual(resp.status, 404)
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
self.assertIn('error', data)
|
||||||
|
|
||||||
|
@unittest_run_loop
|
||||||
|
async def test_api_documentation(self):
|
||||||
|
"""Test API documentation endpoints"""
|
||||||
|
# Test OpenAPI docs
|
||||||
|
resp = await self.client.request("GET", "/docs")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
# Test ReDoc
|
||||||
|
resp = await self.client.request("GET", "/redoc")
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocketFunctionality:
|
||||||
|
"""Test WebSocket functionality"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def websocket_server(self):
|
||||||
|
"""Create WebSocket server for testing"""
|
||||||
|
server = WebSocketServer(host='localhost', port=8081)
|
||||||
|
await server.start()
|
||||||
|
yield server
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_connection(self, websocket_server):
|
||||||
|
"""Test WebSocket connection establishment"""
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||||
|
# Connection should be established
|
||||||
|
self.assertEqual(ws.closed, False)
|
||||||
|
|
||||||
|
# Send ping
|
||||||
|
await ws.ping()
|
||||||
|
|
||||||
|
# Should receive pong
|
||||||
|
msg = await ws.receive()
|
||||||
|
self.assertEqual(msg.type, WSMsgType.PONG)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_data_streaming(self, websocket_server):
|
||||||
|
"""Test real-time data streaming via WebSocket"""
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||||
|
# Subscribe to updates
|
||||||
|
subscribe_msg = {
|
||||||
|
'type': 'subscribe',
|
||||||
|
'channels': ['orderbook', 'trades', 'performance']
|
||||||
|
}
|
||||||
|
await ws.send_str(json.dumps(subscribe_msg))
|
||||||
|
|
||||||
|
# Should receive subscription confirmation
|
||||||
|
msg = await ws.receive()
|
||||||
|
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||||
|
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
self.assertEqual(data.get('type'), 'subscription_confirmed')
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_error_handling(self, websocket_server):
|
||||||
|
"""Test WebSocket error handling"""
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.ws_connect('ws://localhost:8081/ws/dashboard') as ws:
|
||||||
|
# Send invalid message
|
||||||
|
invalid_msg = {'invalid': 'message'}
|
||||||
|
await ws.send_str(json.dumps(invalid_msg))
|
||||||
|
|
||||||
|
# Should receive error response
|
||||||
|
msg = await ws.receive()
|
||||||
|
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||||
|
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
self.assertEqual(data.get('type'), 'error')
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_websocket_connections(self, websocket_server):
|
||||||
|
"""Test multiple concurrent WebSocket connections"""
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create multiple connections
|
||||||
|
for i in range(10):
|
||||||
|
ws = await session.ws_connect(f'ws://localhost:8081/ws/dashboard')
|
||||||
|
connections.append(ws)
|
||||||
|
|
||||||
|
# All connections should be active
|
||||||
|
for ws in connections:
|
||||||
|
self.assertEqual(ws.closed, False)
|
||||||
|
|
||||||
|
# Send message to all connections
|
||||||
|
test_msg = {'type': 'ping', 'id': 'test'}
|
||||||
|
for ws in connections:
|
||||||
|
await ws.send_str(json.dumps(test_msg))
|
||||||
|
|
||||||
|
# All should receive responses
|
||||||
|
for ws in connections:
|
||||||
|
msg = await ws.receive()
|
||||||
|
self.assertEqual(msg.type, WSMsgType.TEXT)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Close all connections
|
||||||
|
for ws in connections:
|
||||||
|
if not ws.closed:
|
||||||
|
await ws.close()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDashboardIntegration:
|
||||||
|
"""Test dashboard integration with backend services"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_services(self):
|
||||||
|
"""Mock backend services"""
|
||||||
|
services = {
|
||||||
|
'redis': Mock(),
|
||||||
|
'timescale': Mock(),
|
||||||
|
'connectors': Mock(),
|
||||||
|
'aggregator': Mock(),
|
||||||
|
'monitor': Mock()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup mock responses
|
||||||
|
services['redis'].get.return_value = {'test': 'data'}
|
||||||
|
services['timescale'].query.return_value = [{'result': 'data'}]
|
||||||
|
services['connectors'].get_status.return_value = 'connected'
|
||||||
|
services['aggregator'].get_heatmap.return_value = {'heatmap': 'data'}
|
||||||
|
services['monitor'].get_metrics.return_value = {'metrics': 'data'}
|
||||||
|
|
||||||
|
return services
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dashboard_data_flow(self, mock_services):
|
||||||
|
"""Test complete data flow from backend to dashboard"""
|
||||||
|
# Simulate data generation
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock data processing pipeline
|
||||||
|
with patch.multiple(
|
||||||
|
'COBY.processing.data_processor',
|
||||||
|
DataProcessor=Mock()
|
||||||
|
):
|
||||||
|
# Process data
|
||||||
|
processor = Mock()
|
||||||
|
processor.normalize_orderbook.return_value = orderbook
|
||||||
|
|
||||||
|
# Aggregate data
|
||||||
|
aggregator = Mock()
|
||||||
|
aggregator.create_price_buckets.return_value = Mock()
|
||||||
|
aggregator.generate_heatmap.return_value = Mock()
|
||||||
|
|
||||||
|
# Cache data
|
||||||
|
cache = Mock()
|
||||||
|
cache.set.return_value = True
|
||||||
|
|
||||||
|
# Verify data flows through pipeline
|
||||||
|
processed = processor.normalize_orderbook({}, "binance")
|
||||||
|
buckets = aggregator.create_price_buckets(processed)
|
||||||
|
heatmap = aggregator.generate_heatmap(buckets)
|
||||||
|
cached = cache.set("test_key", heatmap)
|
||||||
|
|
||||||
|
assert processed is not None
|
||||||
|
assert buckets is not None
|
||||||
|
assert heatmap is not None
|
||||||
|
assert cached is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_time_updates(self, mock_services):
|
||||||
|
"""Test real-time dashboard updates"""
|
||||||
|
# Mock WebSocket server
|
||||||
|
ws_server = Mock()
|
||||||
|
ws_server.broadcast = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate real-time data updates
|
||||||
|
updates = [
|
||||||
|
{'type': 'orderbook', 'symbol': 'BTCUSDT', 'data': {}},
|
||||||
|
{'type': 'trade', 'symbol': 'BTCUSDT', 'data': {}},
|
||||||
|
{'type': 'performance', 'data': {}}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Send updates
|
||||||
|
for update in updates:
|
||||||
|
await ws_server.broadcast(json.dumps(update))
|
||||||
|
|
||||||
|
# Verify broadcasts were sent
|
||||||
|
assert ws_server.broadcast.call_count == len(updates)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dashboard_performance_under_load(self, mock_services):
|
||||||
|
"""Test dashboard performance under high update frequency"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Mock high-frequency updates
|
||||||
|
update_count = 1000
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Simulate processing many updates
|
||||||
|
for i in range(update_count):
|
||||||
|
# Mock data processing
|
||||||
|
mock_services['redis'].get(f"orderbook:BTCUSDT:binance:{i}")
|
||||||
|
mock_services['aggregator'].get_heatmap(f"BTCUSDT:{i}")
|
||||||
|
|
||||||
|
# Small delay to simulate processing
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
processing_time = end_time - start_time
|
||||||
|
updates_per_second = update_count / processing_time
|
||||||
|
|
||||||
|
# Should handle at least 500 updates per second
|
||||||
|
assert updates_per_second > 500, f"Dashboard too slow: {updates_per_second:.2f} updates/sec"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dashboard_error_recovery(self, mock_services):
|
||||||
|
"""Test dashboard error recovery"""
|
||||||
|
# Simulate service failures
|
||||||
|
mock_services['redis'].get.side_effect = Exception("Redis connection failed")
|
||||||
|
mock_services['timescale'].query.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
# Dashboard should handle errors gracefully
|
||||||
|
try:
|
||||||
|
# Attempt operations that will fail
|
||||||
|
mock_services['redis'].get("test_key")
|
||||||
|
except Exception:
|
||||||
|
# Should recover and continue
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
mock_services['timescale'].query("SELECT * FROM test")
|
||||||
|
except Exception:
|
||||||
|
# Should recover and continue
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reset services to working state
|
||||||
|
mock_services['redis'].get.side_effect = None
|
||||||
|
mock_services['redis'].get.return_value = {'recovered': True}
|
||||||
|
|
||||||
|
# Should work again
|
||||||
|
result = mock_services['redis'].get("test_key")
|
||||||
|
assert result['recovered'] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDashboardUI:
|
||||||
|
"""Test dashboard UI functionality (requires browser automation)"""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||||
|
reason="UI tests require --ui flag and browser setup")
|
||||||
|
def test_dashboard_loads(self):
|
||||||
|
"""Test that dashboard loads in browser"""
|
||||||
|
# This would require Selenium or similar
|
||||||
|
# Placeholder for UI tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||||
|
reason="UI tests require --ui flag and browser setup")
|
||||||
|
def test_real_time_chart_updates(self):
|
||||||
|
"""Test that charts update in real-time"""
|
||||||
|
# This would require browser automation
|
||||||
|
# Placeholder for UI tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not pytest.config.getoption("--ui"),
|
||||||
|
reason="UI tests require --ui flag and browser setup")
|
||||||
|
def test_responsive_design(self):
|
||||||
|
"""Test responsive design on different screen sizes"""
|
||||||
|
# This would require browser automation with different viewport sizes
|
||||||
|
# Placeholder for UI tests
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Configure pytest with custom markers"""
|
||||||
|
config.addinivalue_line("markers", "e2e: mark test as end-to-end test")
|
||||||
|
config.addinivalue_line("markers", "ui: mark test as UI test")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
"""Add custom command line options"""
|
||||||
|
parser.addoption(
|
||||||
|
"--e2e",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run end-to-end tests"
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--ui",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run UI tests (requires browser setup)"
|
||||||
|
)
|
485
COBY/tests/test_integration_pipeline.py
Normal file
485
COBY/tests/test_integration_pipeline.py
Normal file
@ -0,0 +1,485 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for complete data pipeline from exchanges to storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from ..connectors.binance_connector import BinanceConnector
|
||||||
|
from ..processing.data_processor import DataProcessor
|
||||||
|
from ..aggregation.aggregation_engine import AggregationEngine
|
||||||
|
from ..storage.timescale_manager import TimescaleManager
|
||||||
|
from ..caching.redis_manager import RedisManager
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataPipelineIntegration:
|
||||||
|
"""Test complete data pipeline integration"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_components(self):
|
||||||
|
"""Setup mock components for testing"""
|
||||||
|
# Mock exchange connector
|
||||||
|
connector = Mock(spec=BinanceConnector)
|
||||||
|
connector.exchange_name = "binance"
|
||||||
|
connector.connect = AsyncMock(return_value=True)
|
||||||
|
connector.disconnect = AsyncMock()
|
||||||
|
connector.subscribe_orderbook = AsyncMock()
|
||||||
|
connector.subscribe_trades = AsyncMock()
|
||||||
|
|
||||||
|
# Mock data processor
|
||||||
|
processor = Mock(spec=DataProcessor)
|
||||||
|
processor.process_orderbook = Mock()
|
||||||
|
processor.process_trade = Mock()
|
||||||
|
processor.validate_data = Mock(return_value=True)
|
||||||
|
|
||||||
|
# Mock aggregation engine
|
||||||
|
aggregator = Mock(spec=AggregationEngine)
|
||||||
|
aggregator.aggregate_orderbook = Mock()
|
||||||
|
aggregator.create_heatmap = Mock()
|
||||||
|
|
||||||
|
# Mock storage manager
|
||||||
|
storage = Mock(spec=TimescaleManager)
|
||||||
|
storage.store_orderbook = AsyncMock(return_value=True)
|
||||||
|
storage.store_trade = AsyncMock(return_value=True)
|
||||||
|
storage.is_connected = Mock(return_value=True)
|
||||||
|
|
||||||
|
# Mock cache manager
|
||||||
|
cache = Mock(spec=RedisManager)
|
||||||
|
cache.set = AsyncMock(return_value=True)
|
||||||
|
cache.get = AsyncMock(return_value=None)
|
||||||
|
cache.is_connected = Mock(return_value=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'connector': connector,
|
||||||
|
'processor': processor,
|
||||||
|
'aggregator': aggregator,
|
||||||
|
'storage': storage,
|
||||||
|
'cache': cache
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_orderbook(self):
|
||||||
|
"""Create sample order book data"""
|
||||||
|
return OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[
|
||||||
|
PriceLevel(price=50000.0, size=1.5),
|
||||||
|
PriceLevel(price=49990.0, size=2.0),
|
||||||
|
PriceLevel(price=49980.0, size=1.0)
|
||||||
|
],
|
||||||
|
asks=[
|
||||||
|
PriceLevel(price=50010.0, size=1.2),
|
||||||
|
PriceLevel(price=50020.0, size=1.8),
|
||||||
|
PriceLevel(price=50030.0, size=0.8)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_trade(self):
|
||||||
|
"""Create sample trade data"""
|
||||||
|
return TradeEvent(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=50005.0,
|
||||||
|
size=0.5,
|
||||||
|
side="buy",
|
||||||
|
trade_id="12345"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_orderbook_pipeline(self, mock_components, sample_orderbook):
|
||||||
|
"""Test complete order book processing pipeline"""
|
||||||
|
components = mock_components
|
||||||
|
|
||||||
|
# Setup processor to return processed data
|
||||||
|
components['processor'].process_orderbook.return_value = sample_orderbook
|
||||||
|
|
||||||
|
# Simulate pipeline flow
|
||||||
|
# 1. Receive data from exchange
|
||||||
|
raw_data = {"symbol": "BTCUSDT", "bids": [], "asks": []}
|
||||||
|
|
||||||
|
# 2. Process data
|
||||||
|
processed_data = components['processor'].process_orderbook(raw_data, "binance")
|
||||||
|
|
||||||
|
# 3. Validate data
|
||||||
|
is_valid = components['processor'].validate_data(processed_data)
|
||||||
|
assert is_valid
|
||||||
|
|
||||||
|
# 4. Aggregate data
|
||||||
|
components['aggregator'].aggregate_orderbook(processed_data)
|
||||||
|
|
||||||
|
# 5. Store in database
|
||||||
|
await components['storage'].store_orderbook(processed_data)
|
||||||
|
|
||||||
|
# 6. Cache latest data
|
||||||
|
await components['cache'].set(f"orderbook:BTCUSDT:binance", processed_data)
|
||||||
|
|
||||||
|
# Verify all components were called
|
||||||
|
components['processor'].process_orderbook.assert_called_once()
|
||||||
|
components['processor'].validate_data.assert_called_once()
|
||||||
|
components['aggregator'].aggregate_orderbook.assert_called_once()
|
||||||
|
components['storage'].store_orderbook.assert_called_once()
|
||||||
|
components['cache'].set.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_trade_pipeline(self, mock_components, sample_trade):
|
||||||
|
"""Test complete trade processing pipeline"""
|
||||||
|
components = mock_components
|
||||||
|
|
||||||
|
# Setup processor to return processed data
|
||||||
|
components['processor'].process_trade.return_value = sample_trade
|
||||||
|
|
||||||
|
# Simulate pipeline flow
|
||||||
|
raw_data = {"symbol": "BTCUSDT", "price": 50005.0, "quantity": 0.5}
|
||||||
|
|
||||||
|
# Process through pipeline
|
||||||
|
processed_data = components['processor'].process_trade(raw_data, "binance")
|
||||||
|
is_valid = components['processor'].validate_data(processed_data)
|
||||||
|
assert is_valid
|
||||||
|
|
||||||
|
await components['storage'].store_trade(processed_data)
|
||||||
|
await components['cache'].set(f"trade:BTCUSDT:binance:latest", processed_data)
|
||||||
|
|
||||||
|
# Verify calls
|
||||||
|
components['processor'].process_trade.assert_called_once()
|
||||||
|
components['storage'].store_trade.assert_called_once()
|
||||||
|
components['cache'].set.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_exchange_pipeline(self, mock_components):
|
||||||
|
"""Test pipeline with multiple exchanges"""
|
||||||
|
components = mock_components
|
||||||
|
exchanges = ["binance", "coinbase", "kraken"]
|
||||||
|
|
||||||
|
# Simulate data from multiple exchanges
|
||||||
|
for exchange in exchanges:
|
||||||
|
# Create exchange-specific data
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange=exchange,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
components['processor'].process_orderbook.return_value = orderbook
|
||||||
|
components['processor'].validate_data.return_value = True
|
||||||
|
|
||||||
|
# Process through pipeline
|
||||||
|
processed_data = components['processor'].process_orderbook({}, exchange)
|
||||||
|
is_valid = components['processor'].validate_data(processed_data)
|
||||||
|
assert is_valid
|
||||||
|
|
||||||
|
await components['storage'].store_orderbook(processed_data)
|
||||||
|
await components['cache'].set(f"orderbook:BTCUSDT:{exchange}", processed_data)
|
||||||
|
|
||||||
|
# Verify multiple calls
|
||||||
|
assert components['processor'].process_orderbook.call_count == len(exchanges)
|
||||||
|
assert components['storage'].store_orderbook.call_count == len(exchanges)
|
||||||
|
assert components['cache'].set.call_count == len(exchanges)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_error_handling(self, mock_components, sample_orderbook):
|
||||||
|
"""Test pipeline error handling and recovery"""
|
||||||
|
components = mock_components
|
||||||
|
|
||||||
|
# Setup storage to fail initially
|
||||||
|
components['storage'].store_orderbook.side_effect = [
|
||||||
|
Exception("Database connection failed"),
|
||||||
|
True # Success on retry
|
||||||
|
]
|
||||||
|
|
||||||
|
components['processor'].process_orderbook.return_value = sample_orderbook
|
||||||
|
components['processor'].validate_data.return_value = True
|
||||||
|
|
||||||
|
# First attempt should fail
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await components['storage'].store_orderbook(sample_orderbook)
|
||||||
|
|
||||||
|
# Second attempt should succeed
|
||||||
|
result = await components['storage'].store_orderbook(sample_orderbook)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify retry logic
|
||||||
|
assert components['storage'].store_orderbook.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_performance(self, mock_components):
|
||||||
|
"""Test pipeline performance with high throughput"""
|
||||||
|
components = mock_components
|
||||||
|
|
||||||
|
# Setup fast responses
|
||||||
|
components['processor'].process_orderbook.return_value = Mock()
|
||||||
|
components['processor'].validate_data.return_value = True
|
||||||
|
components['storage'].store_orderbook.return_value = True
|
||||||
|
components['cache'].set.return_value = True
|
||||||
|
|
||||||
|
# Process multiple items quickly
|
||||||
|
start_time = time.time()
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
# Simulate processing 100 order books
|
||||||
|
task = asyncio.create_task(self._process_single_orderbook(components, i))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
processing_time = end_time - start_time
|
||||||
|
throughput = 100 / processing_time
|
||||||
|
|
||||||
|
# Should process at least 50 items per second
|
||||||
|
assert throughput > 50, f"Throughput too low: {throughput:.2f} items/sec"
|
||||||
|
|
||||||
|
# Verify all items were processed
|
||||||
|
assert components['processor'].process_orderbook.call_count == 100
|
||||||
|
assert components['storage'].store_orderbook.call_count == 100
|
||||||
|
|
||||||
|
async def _process_single_orderbook(self, components, index):
|
||||||
|
"""Helper method to process a single order book"""
|
||||||
|
raw_data = {"symbol": "BTCUSDT", "index": index}
|
||||||
|
|
||||||
|
processed_data = components['processor'].process_orderbook(raw_data, "binance")
|
||||||
|
is_valid = components['processor'].validate_data(processed_data)
|
||||||
|
|
||||||
|
if is_valid:
|
||||||
|
await components['storage'].store_orderbook(processed_data)
|
||||||
|
await components['cache'].set(f"orderbook:BTCUSDT:binance:{index}", processed_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_data_consistency_across_pipeline(self, mock_components, sample_orderbook):
|
||||||
|
"""Test data consistency throughout the pipeline"""
|
||||||
|
components = mock_components
|
||||||
|
|
||||||
|
# Track data transformations
|
||||||
|
original_data = {"symbol": "BTCUSDT", "timestamp": "2024-01-01T00:00:00Z"}
|
||||||
|
|
||||||
|
# Setup processor to modify data
|
||||||
|
modified_orderbook = sample_orderbook
|
||||||
|
modified_orderbook.symbol = "BTCUSDT" # Ensure consistency
|
||||||
|
components['processor'].process_orderbook.return_value = modified_orderbook
|
||||||
|
components['processor'].validate_data.return_value = True
|
||||||
|
|
||||||
|
# Process data
|
||||||
|
processed_data = components['processor'].process_orderbook(original_data, "binance")
|
||||||
|
|
||||||
|
# Verify data consistency
|
||||||
|
assert processed_data.symbol == "BTCUSDT"
|
||||||
|
assert processed_data.exchange == "binance"
|
||||||
|
assert len(processed_data.bids) > 0
|
||||||
|
assert len(processed_data.asks) > 0
|
||||||
|
|
||||||
|
# Verify all price levels are valid
|
||||||
|
for bid in processed_data.bids:
|
||||||
|
assert bid.price > 0
|
||||||
|
assert bid.size > 0
|
||||||
|
|
||||||
|
for ask in processed_data.asks:
|
||||||
|
assert ask.price > 0
|
||||||
|
assert ask.size > 0
|
||||||
|
|
||||||
|
# Verify bid/ask ordering
|
||||||
|
bid_prices = [bid.price for bid in processed_data.bids]
|
||||||
|
ask_prices = [ask.price for ask in processed_data.asks]
|
||||||
|
|
||||||
|
assert bid_prices == sorted(bid_prices, reverse=True) # Bids descending
|
||||||
|
assert ask_prices == sorted(ask_prices) # Asks ascending
|
||||||
|
|
||||||
|
# Verify spread is positive
|
||||||
|
if bid_prices and ask_prices:
|
||||||
|
spread = min(ask_prices) - max(bid_prices)
|
||||||
|
assert spread >= 0, f"Negative spread detected: {spread}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_memory_usage(self, mock_components):
|
||||||
|
"""Test pipeline memory usage under load"""
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
|
||||||
|
components = mock_components
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
# Get initial memory usage
|
||||||
|
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
|
|
||||||
|
# Process large amount of data
|
||||||
|
for i in range(1000):
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="binance",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=50000.0 + i, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=50010.0 + i, size=1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
components['processor'].process_orderbook.return_value = orderbook
|
||||||
|
components['processor'].validate_data.return_value = True
|
||||||
|
|
||||||
|
# Process data
|
||||||
|
processed_data = components['processor'].process_orderbook({}, "binance")
|
||||||
|
await components['storage'].store_orderbook(processed_data)
|
||||||
|
|
||||||
|
# Force garbage collection every 100 items
|
||||||
|
if i % 100 == 0:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Get final memory usage
|
||||||
|
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
|
memory_increase = final_memory - initial_memory
|
||||||
|
|
||||||
|
# Memory increase should be reasonable (less than 100MB for 1000 items)
|
||||||
|
assert memory_increase < 100, f"Memory usage increased by {memory_increase:.2f}MB"
|
||||||
|
|
||||||
|
logger.info(f"Memory usage: {initial_memory:.2f}MB -> {final_memory:.2f}MB (+{memory_increase:.2f}MB)")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineResilience:
|
||||||
|
"""Test pipeline resilience and fault tolerance"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_reconnection(self):
|
||||||
|
"""Test database reconnection handling"""
|
||||||
|
storage = Mock(spec=TimescaleManager)
|
||||||
|
|
||||||
|
# Simulate connection failure then recovery
|
||||||
|
storage.is_connected.side_effect = [False, False, True]
|
||||||
|
storage.connect.return_value = True
|
||||||
|
storage.store_orderbook.return_value = True
|
||||||
|
|
||||||
|
# Should attempt reconnection
|
||||||
|
for attempt in range(3):
|
||||||
|
if not storage.is_connected():
|
||||||
|
storage.connect()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert storage.connect.call_count == 1
|
||||||
|
assert storage.is_connected.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_fallback(self):
|
||||||
|
"""Test cache fallback when Redis is unavailable"""
|
||||||
|
cache = Mock(spec=RedisManager)
|
||||||
|
|
||||||
|
# Simulate cache failure
|
||||||
|
cache.is_connected.return_value = False
|
||||||
|
cache.set.side_effect = Exception("Redis connection failed")
|
||||||
|
|
||||||
|
# Should handle cache failure gracefully
|
||||||
|
try:
|
||||||
|
await cache.set("test_key", "test_value")
|
||||||
|
except Exception:
|
||||||
|
# Should continue processing even if cache fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert not cache.is_connected()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_failover(self):
|
||||||
|
"""Test exchange failover when one exchange fails"""
|
||||||
|
exchanges = ["binance", "coinbase", "kraken"]
|
||||||
|
failed_exchange = "binance"
|
||||||
|
|
||||||
|
# Simulate one exchange failing
|
||||||
|
for exchange in exchanges:
|
||||||
|
if exchange == failed_exchange:
|
||||||
|
# This exchange fails
|
||||||
|
assert exchange == failed_exchange
|
||||||
|
else:
|
||||||
|
# Other exchanges continue working
|
||||||
|
assert exchange != failed_exchange
|
||||||
|
|
||||||
|
# Should continue with remaining exchanges
|
||||||
|
working_exchanges = [ex for ex in exchanges if ex != failed_exchange]
|
||||||
|
assert len(working_exchanges) == 2
|
||||||
|
assert "coinbase" in working_exchanges
|
||||||
|
assert "kraken" in working_exchanges
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestRealDataPipeline:
|
||||||
|
"""Integration tests with real components (requires running services)"""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not pytest.config.getoption("--integration"),
|
||||||
|
reason="Integration tests require --integration flag")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_database_integration(self):
|
||||||
|
"""Test with real TimescaleDB instance"""
|
||||||
|
# This test requires a running TimescaleDB instance
|
||||||
|
# Skip if not available
|
||||||
|
try:
|
||||||
|
from ..storage.timescale_manager import TimescaleManager
|
||||||
|
|
||||||
|
storage = TimescaleManager()
|
||||||
|
await storage.connect()
|
||||||
|
|
||||||
|
# Test basic operations
|
||||||
|
assert storage.is_connected()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
orderbook = OrderBookSnapshot(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
exchange="test",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=[PriceLevel(price=50000.0, size=1.0)],
|
||||||
|
asks=[PriceLevel(price=50010.0, size=1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store and verify
|
||||||
|
result = await storage.store_orderbook(orderbook)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
await storage.disconnect()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Real database not available: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not pytest.config.getoption("--integration"),
|
||||||
|
reason="Integration tests require --integration flag")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_cache_integration(self):
|
||||||
|
"""Test with real Redis instance"""
|
||||||
|
try:
|
||||||
|
from ..caching.redis_manager import RedisManager
|
||||||
|
|
||||||
|
cache = RedisManager()
|
||||||
|
await cache.connect()
|
||||||
|
|
||||||
|
assert cache.is_connected()
|
||||||
|
|
||||||
|
# Test basic operations
|
||||||
|
await cache.set("test_key", {"test": "data"})
|
||||||
|
result = await cache.get("test_key")
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
await cache.disconnect()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Real cache not available: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Configure pytest with custom markers"""
|
||||||
|
config.addinivalue_line("markers", "integration: mark test as integration test")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
"""Add custom command line options"""
|
||||||
|
parser.addoption(
|
||||||
|
"--integration",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run integration tests with real services"
|
||||||
|
)
|
590
COBY/tests/test_load_performance.py
Normal file
590
COBY/tests/test_load_performance.py
Normal file
@ -0,0 +1,590 @@
|
|||||||
|
"""
|
||||||
|
Load testing and performance benchmarks for high-frequency data scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import statistics
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from ..connectors.binance_connector import BinanceConnector
|
||||||
|
from ..processing.data_processor import DataProcessor
|
||||||
|
from ..aggregation.aggregation_engine import AggregationEngine
|
||||||
|
from ..monitoring.metrics_collector import MetricsCollector
|
||||||
|
from ..monitoring.latency_tracker import LatencyTracker
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTestConfig:
|
||||||
|
"""Configuration for load tests"""
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
DURATION_SECONDS = 60
|
||||||
|
TARGET_TPS = 1000 # Transactions per second
|
||||||
|
RAMP_UP_SECONDS = 10
|
||||||
|
|
||||||
|
# Performance thresholds
|
||||||
|
MAX_LATENCY_MS = 100
|
||||||
|
MAX_MEMORY_MB = 500
|
||||||
|
MIN_SUCCESS_RATE = 99.0
|
||||||
|
|
||||||
|
# Data generation
|
||||||
|
SYMBOLS = ["BTCUSDT", "ETHUSDT", "ADAUSDT", "DOTUSDT"]
|
||||||
|
EXCHANGES = ["binance", "coinbase", "kraken", "bybit"]
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator:
|
||||||
|
"""Generate realistic test data for load testing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.base_prices = {
|
||||||
|
"BTCUSDT": 50000.0,
|
||||||
|
"ETHUSDT": 3000.0,
|
||||||
|
"ADAUSDT": 1.0,
|
||||||
|
"DOTUSDT": 25.0
|
||||||
|
}
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def generate_orderbook(self, symbol: str, exchange: str) -> OrderBookSnapshot:
|
||||||
|
"""Generate realistic order book data"""
|
||||||
|
base_price = self.base_prices.get(symbol, 100.0)
|
||||||
|
|
||||||
|
# Add some randomness
|
||||||
|
price_variation = (self.counter % 100) * 0.01
|
||||||
|
mid_price = base_price + price_variation
|
||||||
|
|
||||||
|
# Generate bids (below mid price)
|
||||||
|
bids = []
|
||||||
|
for i in range(10):
|
||||||
|
price = mid_price - (i + 1) * 0.1
|
||||||
|
size = 1.0 + (i * 0.1)
|
||||||
|
bids.append(PriceLevel(price=price, size=size))
|
||||||
|
|
||||||
|
# Generate asks (above mid price)
|
||||||
|
asks = []
|
||||||
|
for i in range(10):
|
||||||
|
price = mid_price + (i + 1) * 0.1
|
||||||
|
size = 1.0 + (i * 0.1)
|
||||||
|
asks.append(PriceLevel(price=price, size=size))
|
||||||
|
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
return OrderBookSnapshot(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
bids=bids,
|
||||||
|
asks=asks
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_trade(self, symbol: str, exchange: str) -> TradeEvent:
|
||||||
|
"""Generate realistic trade data"""
|
||||||
|
base_price = self.base_prices.get(symbol, 100.0)
|
||||||
|
price_variation = (self.counter % 50) * 0.01
|
||||||
|
price = base_price + price_variation
|
||||||
|
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
return TradeEvent(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
price=price,
|
||||||
|
size=0.1 + (self.counter % 10) * 0.01,
|
||||||
|
side="buy" if self.counter % 2 == 0 else "sell",
|
||||||
|
trade_id=str(self.counter)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceMonitor:
|
||||||
|
"""Monitor performance during load tests"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self.latencies = []
|
||||||
|
self.errors = []
|
||||||
|
self.memory_samples = []
|
||||||
|
self.cpu_samples = []
|
||||||
|
self.process = psutil.Process()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start monitoring"""
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.latencies.clear()
|
||||||
|
self.errors.clear()
|
||||||
|
self.memory_samples.clear()
|
||||||
|
self.cpu_samples.clear()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop monitoring"""
|
||||||
|
self.end_time = time.time()
|
||||||
|
|
||||||
|
def record_latency(self, latency_ms: float):
|
||||||
|
"""Record operation latency"""
|
||||||
|
self.latencies.append(latency_ms)
|
||||||
|
|
||||||
|
def record_error(self, error: Exception):
|
||||||
|
"""Record error"""
|
||||||
|
self.errors.append(str(error))
|
||||||
|
|
||||||
|
def sample_system_metrics(self):
|
||||||
|
"""Sample system metrics"""
|
||||||
|
try:
|
||||||
|
memory_mb = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
cpu_percent = self.process.cpu_percent()
|
||||||
|
|
||||||
|
self.memory_samples.append(memory_mb)
|
||||||
|
self.cpu_samples.append(cpu_percent)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error sampling system metrics: {e}")
|
||||||
|
|
||||||
|
def get_results(self) -> Dict[str, Any]:
|
||||||
|
"""Get performance test results"""
|
||||||
|
duration = self.end_time - self.start_time if self.end_time else 0
|
||||||
|
total_operations = len(self.latencies)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'duration_seconds': duration,
|
||||||
|
'total_operations': total_operations,
|
||||||
|
'operations_per_second': total_operations / duration if duration > 0 else 0,
|
||||||
|
'error_count': len(self.errors),
|
||||||
|
'success_rate': ((total_operations - len(self.errors)) / total_operations * 100) if total_operations > 0 else 0,
|
||||||
|
'latency': {
|
||||||
|
'min_ms': min(self.latencies) if self.latencies else 0,
|
||||||
|
'max_ms': max(self.latencies) if self.latencies else 0,
|
||||||
|
'avg_ms': statistics.mean(self.latencies) if self.latencies else 0,
|
||||||
|
'p50_ms': statistics.median(self.latencies) if self.latencies else 0,
|
||||||
|
'p95_ms': self._percentile(self.latencies, 95) if self.latencies else 0,
|
||||||
|
'p99_ms': self._percentile(self.latencies, 99) if self.latencies else 0
|
||||||
|
},
|
||||||
|
'memory': {
|
||||||
|
'min_mb': min(self.memory_samples) if self.memory_samples else 0,
|
||||||
|
'max_mb': max(self.memory_samples) if self.memory_samples else 0,
|
||||||
|
'avg_mb': statistics.mean(self.memory_samples) if self.memory_samples else 0
|
||||||
|
},
|
||||||
|
'cpu': {
|
||||||
|
'min_percent': min(self.cpu_samples) if self.cpu_samples else 0,
|
||||||
|
'max_percent': max(self.cpu_samples) if self.cpu_samples else 0,
|
||||||
|
'avg_percent': statistics.mean(self.cpu_samples) if self.cpu_samples else 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _percentile(self, data: List[float], percentile: int) -> float:
|
||||||
|
"""Calculate percentile"""
|
||||||
|
if not data:
|
||||||
|
return 0.0
|
||||||
|
sorted_data = sorted(data)
|
||||||
|
index = int((percentile / 100.0) * len(sorted_data))
|
||||||
|
index = min(index, len(sorted_data) - 1)
|
||||||
|
return sorted_data[index]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.load
|
||||||
|
class TestLoadPerformance:
|
||||||
|
"""Load testing and performance benchmarks"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_generator(self):
|
||||||
|
"""Create data generator"""
|
||||||
|
return DataGenerator()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def performance_monitor(self):
|
||||||
|
"""Create performance monitor"""
|
||||||
|
return PerformanceMonitor()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orderbook_processing_load(self, data_generator, performance_monitor):
|
||||||
|
"""Test order book processing under high load"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
monitor = performance_monitor
|
||||||
|
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Generate load
|
||||||
|
tasks = []
|
||||||
|
for i in range(LoadTestConfig.TARGET_TPS):
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._process_orderbook_with_timing(
|
||||||
|
processor, data_generator, symbol, exchange, monitor
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Add small delay to simulate realistic load
|
||||||
|
if i % 100 == 0:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
monitor.stop()
|
||||||
|
results = monitor.get_results()
|
||||||
|
|
||||||
|
# Verify performance requirements
|
||||||
|
assert results['operations_per_second'] >= LoadTestConfig.TARGET_TPS * 0.8, \
|
||||||
|
f"Throughput too low: {results['operations_per_second']:.2f} ops/sec"
|
||||||
|
|
||||||
|
assert results['latency']['p95_ms'] <= LoadTestConfig.MAX_LATENCY_MS, \
|
||||||
|
f"P95 latency too high: {results['latency']['p95_ms']:.2f}ms"
|
||||||
|
|
||||||
|
assert results['success_rate'] >= LoadTestConfig.MIN_SUCCESS_RATE, \
|
||||||
|
f"Success rate too low: {results['success_rate']:.2f}%"
|
||||||
|
|
||||||
|
logger.info(f"Load test results: {results}")
|
||||||
|
|
||||||
|
async def _process_orderbook_with_timing(self, processor, data_generator,
|
||||||
|
symbol, exchange, monitor):
|
||||||
|
"""Process order book with timing measurement"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate and process order book
|
||||||
|
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||||
|
processed = processor.normalize_orderbook(orderbook.__dict__, exchange)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
latency_ms = (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
monitor.record_latency(latency_ms)
|
||||||
|
monitor.sample_system_metrics()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
monitor.record_error(e)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trade_processing_load(self, data_generator, performance_monitor):
|
||||||
|
"""Test trade processing under high load"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
monitor = performance_monitor
|
||||||
|
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Generate sustained load for specified duration
|
||||||
|
end_time = time.time() + LoadTestConfig.DURATION_SECONDS
|
||||||
|
operation_count = 0
|
||||||
|
|
||||||
|
while time.time() < end_time:
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[operation_count % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
exchange = LoadTestConfig.EXCHANGES[operation_count % len(LoadTestConfig.EXCHANGES)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate and process trade
|
||||||
|
trade = data_generator.generate_trade(symbol, exchange)
|
||||||
|
processed = processor.normalize_trade(trade.__dict__, exchange)
|
||||||
|
|
||||||
|
process_time = time.time()
|
||||||
|
latency_ms = (process_time - start_time) * 1000
|
||||||
|
|
||||||
|
monitor.record_latency(latency_ms)
|
||||||
|
|
||||||
|
if operation_count % 100 == 0:
|
||||||
|
monitor.sample_system_metrics()
|
||||||
|
|
||||||
|
operation_count += 1
|
||||||
|
|
||||||
|
# Control rate to avoid overwhelming
|
||||||
|
await asyncio.sleep(0.001) # 1ms delay
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
monitor.record_error(e)
|
||||||
|
|
||||||
|
monitor.stop()
|
||||||
|
results = monitor.get_results()
|
||||||
|
|
||||||
|
# Verify performance
|
||||||
|
assert results['latency']['avg_ms'] <= LoadTestConfig.MAX_LATENCY_MS, \
|
||||||
|
f"Average latency too high: {results['latency']['avg_ms']:.2f}ms"
|
||||||
|
|
||||||
|
assert results['memory']['max_mb'] <= LoadTestConfig.MAX_MEMORY_MB, \
|
||||||
|
f"Memory usage too high: {results['memory']['max_mb']:.2f}MB"
|
||||||
|
|
||||||
|
logger.info(f"Trade processing results: {results}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aggregation_performance(self, data_generator, performance_monitor):
|
||||||
|
"""Test aggregation engine performance"""
|
||||||
|
aggregator = AggregationEngine()
|
||||||
|
monitor = performance_monitor
|
||||||
|
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Generate multiple order books for aggregation
|
||||||
|
orderbooks = []
|
||||||
|
for i in range(100):
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||||
|
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||||
|
orderbooks.append(orderbook)
|
||||||
|
|
||||||
|
# Test aggregation performance
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for orderbook in orderbooks:
|
||||||
|
try:
|
||||||
|
# Test price bucketing
|
||||||
|
buckets = aggregator.create_price_buckets(orderbook)
|
||||||
|
|
||||||
|
# Test heatmap generation
|
||||||
|
heatmap = aggregator.generate_heatmap(buckets)
|
||||||
|
|
||||||
|
# Test metrics calculation
|
||||||
|
metrics = aggregator.calculate_metrics(orderbook)
|
||||||
|
|
||||||
|
process_time = time.time()
|
||||||
|
latency_ms = (process_time - start_time) * 1000
|
||||||
|
monitor.record_latency(latency_ms)
|
||||||
|
|
||||||
|
start_time = process_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
monitor.record_error(e)
|
||||||
|
|
||||||
|
monitor.stop()
|
||||||
|
results = monitor.get_results()
|
||||||
|
|
||||||
|
# Verify aggregation performance
|
||||||
|
assert results['latency']['p95_ms'] <= 50, \
|
||||||
|
f"Aggregation P95 latency too high: {results['latency']['p95_ms']:.2f}ms"
|
||||||
|
|
||||||
|
logger.info(f"Aggregation performance results: {results}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_exchange_processing(self, data_generator, performance_monitor):
|
||||||
|
"""Test concurrent processing from multiple exchanges"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
monitor = performance_monitor
|
||||||
|
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Create concurrent tasks for each exchange
|
||||||
|
tasks = []
|
||||||
|
for exchange in LoadTestConfig.EXCHANGES:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._simulate_exchange_load(processor, data_generator, exchange, monitor)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Run all exchanges concurrently
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
monitor.stop()
|
||||||
|
results = monitor.get_results()
|
||||||
|
|
||||||
|
# Verify concurrent processing performance
|
||||||
|
expected_total_ops = len(LoadTestConfig.EXCHANGES) * 100 # 100 ops per exchange
|
||||||
|
assert results['total_operations'] >= expected_total_ops * 0.9, \
|
||||||
|
f"Not enough operations completed: {results['total_operations']}"
|
||||||
|
|
||||||
|
assert results['success_rate'] >= 95.0, \
|
||||||
|
f"Success rate too low under concurrent load: {results['success_rate']:.2f}%"
|
||||||
|
|
||||||
|
logger.info(f"Concurrent processing results: {results}")
|
||||||
|
|
||||||
|
async def _simulate_exchange_load(self, processor, data_generator, exchange, monitor):
|
||||||
|
"""Simulate load from a single exchange"""
|
||||||
|
for i in range(100):
|
||||||
|
try:
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Alternate between order books and trades
|
||||||
|
if i % 2 == 0:
|
||||||
|
data = data_generator.generate_orderbook(symbol, exchange)
|
||||||
|
processed = processor.normalize_orderbook(data.__dict__, exchange)
|
||||||
|
else:
|
||||||
|
data = data_generator.generate_trade(symbol, exchange)
|
||||||
|
processed = processor.normalize_trade(data.__dict__, exchange)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
latency_ms = (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
monitor.record_latency(latency_ms)
|
||||||
|
|
||||||
|
# Small delay to simulate realistic timing
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
monitor.record_error(e)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_usage_under_load(self, data_generator):
|
||||||
|
"""Test memory usage patterns under sustained load"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
# Get baseline memory
|
||||||
|
gc.collect() # Force garbage collection
|
||||||
|
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
|
|
||||||
|
memory_samples = []
|
||||||
|
|
||||||
|
# Generate sustained load
|
||||||
|
for i in range(1000):
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||||
|
|
||||||
|
# Generate data
|
||||||
|
orderbook = data_generator.generate_orderbook(symbol, exchange)
|
||||||
|
trade = data_generator.generate_trade(symbol, exchange)
|
||||||
|
|
||||||
|
# Process data
|
||||||
|
processor.normalize_orderbook(orderbook.__dict__, exchange)
|
||||||
|
processor.normalize_trade(trade.__dict__, exchange)
|
||||||
|
|
||||||
|
# Sample memory every 100 operations
|
||||||
|
if i % 100 == 0:
|
||||||
|
current_memory = process.memory_info().rss / 1024 / 1024
|
||||||
|
memory_samples.append(current_memory)
|
||||||
|
|
||||||
|
# Force garbage collection periodically
|
||||||
|
if i % 500 == 0:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Final memory check
|
||||||
|
gc.collect()
|
||||||
|
final_memory = process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
# Calculate memory statistics
|
||||||
|
max_memory = max(memory_samples) if memory_samples else final_memory
|
||||||
|
memory_growth = final_memory - baseline_memory
|
||||||
|
|
||||||
|
# Verify memory usage is reasonable
|
||||||
|
assert memory_growth < 100, \
|
||||||
|
f"Memory growth too high: {memory_growth:.2f}MB"
|
||||||
|
|
||||||
|
assert max_memory < baseline_memory + 200, \
|
||||||
|
f"Peak memory usage too high: {max_memory:.2f}MB"
|
||||||
|
|
||||||
|
logger.info(f"Memory usage: baseline={baseline_memory:.2f}MB, "
|
||||||
|
f"final={final_memory:.2f}MB, growth={memory_growth:.2f}MB, "
|
||||||
|
f"peak={max_memory:.2f}MB")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stress_test_extreme_load(self, data_generator, performance_monitor):
|
||||||
|
"""Stress test with extreme load conditions"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
monitor = performance_monitor
|
||||||
|
|
||||||
|
# Extreme load parameters
|
||||||
|
EXTREME_TPS = 5000
|
||||||
|
STRESS_DURATION = 30 # seconds
|
||||||
|
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Generate extreme load
|
||||||
|
tasks = []
|
||||||
|
operations_per_batch = 100
|
||||||
|
batches = EXTREME_TPS // operations_per_batch
|
||||||
|
|
||||||
|
for batch in range(batches):
|
||||||
|
batch_tasks = []
|
||||||
|
for i in range(operations_per_batch):
|
||||||
|
symbol = LoadTestConfig.SYMBOLS[i % len(LoadTestConfig.SYMBOLS)]
|
||||||
|
exchange = LoadTestConfig.EXCHANGES[i % len(LoadTestConfig.EXCHANGES)]
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._process_orderbook_with_timing(
|
||||||
|
processor, data_generator, symbol, exchange, monitor
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_tasks.append(task)
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Small delay between batches
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
monitor.stop()
|
||||||
|
results = monitor.get_results()
|
||||||
|
|
||||||
|
# Under extreme load, we accept lower performance but system should remain stable
|
||||||
|
assert results['success_rate'] >= 80.0, \
|
||||||
|
f"System failed under stress: {results['success_rate']:.2f}% success rate"
|
||||||
|
|
||||||
|
assert results['latency']['p99_ms'] <= 500, \
|
||||||
|
f"P99 latency too high under stress: {results['latency']['p99_ms']:.2f}ms"
|
||||||
|
|
||||||
|
logger.info(f"Stress test results: {results}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
class TestPerformanceBenchmarks:
|
||||||
|
"""Performance benchmarks for regression testing"""
|
||||||
|
|
||||||
|
def test_orderbook_processing_benchmark(self, benchmark):
|
||||||
|
"""Benchmark order book processing speed"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
generator = DataGenerator()
|
||||||
|
|
||||||
|
def process_orderbook():
|
||||||
|
orderbook = generator.generate_orderbook("BTCUSDT", "binance")
|
||||||
|
return processor.normalize_orderbook(orderbook.__dict__, "binance")
|
||||||
|
|
||||||
|
result = benchmark(process_orderbook)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_trade_processing_benchmark(self, benchmark):
|
||||||
|
"""Benchmark trade processing speed"""
|
||||||
|
processor = DataProcessor()
|
||||||
|
generator = DataGenerator()
|
||||||
|
|
||||||
|
def process_trade():
|
||||||
|
trade = generator.generate_trade("BTCUSDT", "binance")
|
||||||
|
return processor.normalize_trade(trade.__dict__, "binance")
|
||||||
|
|
||||||
|
result = benchmark(process_trade)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_aggregation_benchmark(self, benchmark):
|
||||||
|
"""Benchmark aggregation engine performance"""
|
||||||
|
aggregator = AggregationEngine()
|
||||||
|
generator = DataGenerator()
|
||||||
|
|
||||||
|
def aggregate_data():
|
||||||
|
orderbook = generator.generate_orderbook("BTCUSDT", "binance")
|
||||||
|
buckets = aggregator.create_price_buckets(orderbook)
|
||||||
|
return aggregator.generate_heatmap(buckets)
|
||||||
|
|
||||||
|
result = benchmark(aggregate_data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Configure pytest with custom markers"""
|
||||||
|
config.addinivalue_line("markers", "load: mark test as load test")
|
||||||
|
config.addinivalue_line("markers", "benchmark: mark test as benchmark")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
"""Add custom command line options"""
|
||||||
|
parser.addoption(
|
||||||
|
"--load",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run load tests"
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--benchmark",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run benchmark tests"
|
||||||
|
)
|
108
COBY/tests/test_performance_benchmarks.py
Normal file
108
COBY/tests/test_performance_benchmarks.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Performance benchmarks and regression tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import statistics
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, List, Any, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
|
||||||
|
from ..processing.data_processor import DataProcessor
|
||||||
|
from ..aggregation.aggregation_engine import AggregationEngine
|
||||||
|
from ..connectors.binance_connector import BinanceConnector
|
||||||
|
from ..storage.timescale_manager import TimescaleManager
|
||||||
|
from ..caching.redis_manager import RedisManager
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
"""Benchmark result data structure"""
|
||||||
|
name: str
|
||||||
|
duration_ms: float
|
||||||
|
operations_per_second: float
|
||||||
|
memory_usage_mb: float
|
||||||
|
cpu_usage_percent: float
|
||||||
|
timestamp: datetime
|
||||||
|
metadata: Dict[str, Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRunner:
|
||||||
|
"""Benchmark execution and result management"""
|
||||||
|
|
||||||
|
def __init__(self, results_file: str = "benchmark_results.json"):
|
||||||
|
self.results_file = Path(results_file)
|
||||||
|
self.results: List[BenchmarkResult] = []
|
||||||
|
self.load_historical_results()
|
||||||
|
|
||||||
|
def load_historical_results(self):
|
||||||
|
"""Load historical benchmark results"""
|
||||||
|
if self.results_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.results_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for item in data:
|
||||||
|
result = BenchmarkResult(
|
||||||
|
name=item['name'],
|
||||||
|
duration_ms=item['duration_ms'],
|
||||||
|
operations_per_second=item['operations_per_second'],
|
||||||
|
memory_usage_mb=item['memory_usage_mb'],
|
||||||
|
cpu_usage_percent=item['cpu_usage_percent'],
|
||||||
|
timestamp=datetime.fromisoformat(item['timestamp']),
|
||||||
|
metadata=item.get('metadata', {})
|
||||||
|
)
|
||||||
|
self.results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load historical results: {e}")
|
||||||
|
|
||||||
|
def save_results(self):
|
||||||
|
"""Save benchmark results to file"""
|
||||||
|
try:
|
||||||
|
data = []
|
||||||
|
for result in self.results:
|
||||||
|
data.append({
|
||||||
|
'name': result.name,
|
||||||
|
'duration_ms': result.duration_ms,
|
||||||
|
'operations_per_second': result.operations_per_second,
|
||||||
|
'memory_usage_mb': result.memory_usage_mb,
|
||||||
|
'cpu_usage_percent': result.cpu_usage_percent,
|
||||||
|
'timestamp': result.timestamp.isoformat(),
|
||||||
|
'metadata': result.metadata or {}
|
||||||
|
})
|
||||||
|
|
||||||
|
with open(self.results_file, 'w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Could not save benchmark results: {e}")
|
||||||
|
|
||||||
|
def run_benchmark(self, name: str, func, iterations: int = 1000,
|
||||||
|
warmup: int = 100) -> BenchmarkResult:
|
||||||
|
"""Run a benchmark function"""
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(warmup):
|
||||||
|
func()
|
||||||
|
|
||||||
|
# Collect baseline metrics
|
||||||
|
initial_memory = process.memory_info().rss / 1024 / 1024
|
||||||
|
initial_cpu = process.cpu_percent()
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(iterations):
|
||||||
|
func()
|
||||||
|
|
||||||
|
end_time =
|
@ -1,57 +1,23 @@
|
|||||||
"""
|
"""
|
||||||
Custom exceptions for the COBY system.
|
Custom exceptions for COBY system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class COBYException(Exception):
|
class COBYException(Exception):
|
||||||
"""Base exception for COBY system"""
|
"""Base exception for COBY system"""
|
||||||
|
pass
|
||||||
def __init__(self, message: str, error_code: str = None, details: dict = None):
|
|
||||||
super().__init__(message)
|
|
||||||
self.message = message
|
|
||||||
self.error_code = error_code
|
|
||||||
self.details = details or {}
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert exception to dictionary"""
|
|
||||||
return {
|
|
||||||
'error': self.__class__.__name__,
|
|
||||||
'message': self.message,
|
|
||||||
'error_code': self.error_code,
|
|
||||||
'details': self.details
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionError(COBYException):
|
class ConnectionError(COBYException):
|
||||||
"""Exception raised for connection-related errors"""
|
"""Connection related errors"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(COBYException):
|
class ValidationError(COBYException):
|
||||||
"""Exception raised for data validation errors"""
|
"""Data validation errors"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ProcessingError(COBYException):
|
class ProcessingError(COBYException):
|
||||||
"""Exception raised for data processing errors"""
|
"""Data processing errors"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StorageError(COBYException):
|
|
||||||
"""Exception raised for storage-related errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationError(COBYException):
|
|
||||||
"""Exception raised for configuration errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayError(COBYException):
|
|
||||||
"""Exception raised for replay-related errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AggregationError(COBYException):
|
|
||||||
"""Exception raised for aggregation errors"""
|
|
||||||
pass
|
pass
|
@ -1,206 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Timing utilities for the COBY system.
|
Timing utilities for COBY system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_timestamp() -> datetime:
|
def get_current_timestamp() -> datetime:
|
||||||
"""
|
"""Get current UTC timestamp"""
|
||||||
Get current UTC timestamp.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datetime: Current UTC timestamp
|
|
||||||
"""
|
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(timestamp: datetime, format_str: str = "%Y-%m-%d %H:%M:%S.%f") -> str:
|
def format_timestamp(timestamp: datetime) -> str:
|
||||||
"""
|
"""Format timestamp as ISO string"""
|
||||||
Format timestamp to string.
|
return timestamp.isoformat()
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: Timestamp to format
|
|
||||||
format_str: Format string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Formatted timestamp string
|
|
||||||
"""
|
|
||||||
return timestamp.strftime(format_str)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_timestamp(timestamp_str: str, format_str: str = "%Y-%m-%d %H:%M:%S.%f") -> datetime:
|
|
||||||
"""
|
|
||||||
Parse timestamp string to datetime.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp_str: Timestamp string to parse
|
|
||||||
format_str: Format string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datetime: Parsed timestamp
|
|
||||||
"""
|
|
||||||
dt = datetime.strptime(timestamp_str, format_str)
|
|
||||||
# Ensure timezone awareness
|
|
||||||
if dt.tzinfo is None:
|
|
||||||
dt = dt.replace(tzinfo=timezone.utc)
|
|
||||||
return dt
|
|
||||||
|
|
||||||
|
|
||||||
def timestamp_to_unix(timestamp: datetime) -> float:
|
|
||||||
"""
|
|
||||||
Convert datetime to Unix timestamp.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: Datetime to convert
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Unix timestamp
|
|
||||||
"""
|
|
||||||
return timestamp.timestamp()
|
|
||||||
|
|
||||||
|
|
||||||
def unix_to_timestamp(unix_time: float) -> datetime:
|
|
||||||
"""
|
|
||||||
Convert Unix timestamp to datetime.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unix_time: Unix timestamp
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datetime: Converted datetime (UTC)
|
|
||||||
"""
|
|
||||||
return datetime.fromtimestamp(unix_time, tz=timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_time_diff(start: datetime, end: datetime) -> float:
|
|
||||||
"""
|
|
||||||
Calculate time difference in seconds.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start: Start timestamp
|
|
||||||
end: End timestamp
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Time difference in seconds
|
|
||||||
"""
|
|
||||||
return (end - start).total_seconds()
|
|
||||||
|
|
||||||
|
|
||||||
def is_timestamp_recent(timestamp: datetime, max_age_seconds: int = 60) -> bool:
|
|
||||||
"""
|
|
||||||
Check if timestamp is recent (within max_age_seconds).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: Timestamp to check
|
|
||||||
max_age_seconds: Maximum age in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if recent, False otherwise
|
|
||||||
"""
|
|
||||||
now = get_current_timestamp()
|
|
||||||
age = calculate_time_diff(timestamp, now)
|
|
||||||
return age <= max_age_seconds
|
|
||||||
|
|
||||||
|
|
||||||
def sleep_until(target_time: datetime) -> None:
|
|
||||||
"""
|
|
||||||
Sleep until target time.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_time: Target timestamp to sleep until
|
|
||||||
"""
|
|
||||||
now = get_current_timestamp()
|
|
||||||
sleep_seconds = calculate_time_diff(now, target_time)
|
|
||||||
|
|
||||||
if sleep_seconds > 0:
|
|
||||||
time.sleep(sleep_seconds)
|
|
||||||
|
|
||||||
|
|
||||||
def get_milliseconds() -> int:
|
|
||||||
"""
|
|
||||||
Get current timestamp in milliseconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Current timestamp in milliseconds
|
|
||||||
"""
|
|
||||||
return int(time.time() * 1000)
|
|
||||||
|
|
||||||
|
|
||||||
def milliseconds_to_timestamp(ms: int) -> datetime:
|
|
||||||
"""
|
|
||||||
Convert milliseconds to datetime.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ms: Milliseconds timestamp
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datetime: Converted datetime (UTC)
|
|
||||||
"""
|
|
||||||
return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
def round_timestamp(timestamp: datetime, seconds: int) -> datetime:
|
|
||||||
"""
|
|
||||||
Round timestamp to nearest interval.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamp: Timestamp to round
|
|
||||||
seconds: Interval in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datetime: Rounded timestamp
|
|
||||||
"""
|
|
||||||
unix_time = timestamp_to_unix(timestamp)
|
|
||||||
rounded_unix = round(unix_time / seconds) * seconds
|
|
||||||
return unix_to_timestamp(rounded_unix)
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Simple timer for measuring execution time"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_time: Optional[float] = None
|
|
||||||
self.end_time: Optional[float] = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start the timer"""
|
|
||||||
self.start_time = time.perf_counter()
|
|
||||||
self.end_time = None
|
|
||||||
|
|
||||||
def stop(self) -> float:
|
|
||||||
"""
|
|
||||||
Stop the timer and return elapsed time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Elapsed time in seconds
|
|
||||||
"""
|
|
||||||
if self.start_time is None:
|
|
||||||
raise ValueError("Timer not started")
|
|
||||||
|
|
||||||
self.end_time = time.perf_counter()
|
|
||||||
return self.elapsed()
|
|
||||||
|
|
||||||
def elapsed(self) -> float:
|
|
||||||
"""
|
|
||||||
Get elapsed time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: Elapsed time in seconds
|
|
||||||
"""
|
|
||||||
if self.start_time is None:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
end = self.end_time or time.perf_counter()
|
|
||||||
return end - self.start_time
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Context manager entry"""
|
|
||||||
self.start()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Context manager exit"""
|
|
||||||
self.stop()
|
|
@ -1,217 +1,31 @@
|
|||||||
"""
|
"""
|
||||||
Data validation utilities for the COBY system.
|
Validation utilities for COBY system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import Any
|
||||||
from decimal import Decimal, InvalidOperation
|
|
||||||
|
|
||||||
|
|
||||||
def validate_symbol(symbol: str) -> bool:
|
def validate_symbol(symbol: str) -> bool:
|
||||||
"""
|
"""Validate trading symbol format"""
|
||||||
Validate trading symbol format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not symbol or not isinstance(symbol, str):
|
if not symbol or not isinstance(symbol, str):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Basic symbol format validation (e.g., BTCUSDT, ETH-USD)
|
# Basic symbol validation (letters and numbers, 3-12 chars)
|
||||||
pattern = r'^[A-Z0-9]{2,10}[-/]?[A-Z0-9]{2,10}$'
|
pattern = r'^[A-Z0-9]{3,12}$'
|
||||||
return bool(re.match(pattern, symbol.upper()))
|
return bool(re.match(pattern, symbol.upper()))
|
||||||
|
|
||||||
|
|
||||||
def validate_price(price: float) -> bool:
|
def validate_price(price: float) -> bool:
|
||||||
"""
|
"""Validate price value"""
|
||||||
Validate price value.
|
return isinstance(price, (int, float)) and price > 0
|
||||||
|
|
||||||
Args:
|
|
||||||
price: Price to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not isinstance(price, (int, float, Decimal)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
price_decimal = Decimal(str(price))
|
|
||||||
return price_decimal > 0 and price_decimal < Decimal('1e10') # Reasonable upper bound
|
|
||||||
except (InvalidOperation, ValueError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_volume(volume: float) -> bool:
|
def validate_volume(volume: float) -> bool:
|
||||||
"""
|
"""Validate volume value"""
|
||||||
Validate volume value.
|
return isinstance(volume, (int, float)) and volume >= 0
|
||||||
|
|
||||||
Args:
|
|
||||||
volume: Volume to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not isinstance(volume, (int, float, Decimal)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
volume_decimal = Decimal(str(volume))
|
|
||||||
return volume_decimal >= 0 and volume_decimal < Decimal('1e15') # Reasonable upper bound
|
|
||||||
except (InvalidOperation, ValueError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_exchange_name(exchange: str) -> bool:
|
def validate_timestamp(timestamp: Any) -> bool:
|
||||||
"""
|
"""Validate timestamp"""
|
||||||
Validate exchange name.
|
return timestamp is not None
|
||||||
|
|
||||||
Args:
|
|
||||||
exchange: Exchange name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not exchange or not isinstance(exchange, str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Exchange name should be alphanumeric with possible underscores/hyphens
|
|
||||||
pattern = r'^[a-zA-Z0-9_-]{2,20}$'
|
|
||||||
return bool(re.match(pattern, exchange))
|
|
||||||
|
|
||||||
|
|
||||||
def validate_timestamp_range(start_time, end_time) -> List[str]:
|
|
||||||
"""
|
|
||||||
Validate timestamp range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: Start timestamp
|
|
||||||
end_time: End timestamp
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of validation errors (empty if valid)
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
if start_time is None:
|
|
||||||
errors.append("Start time cannot be None")
|
|
||||||
|
|
||||||
if end_time is None:
|
|
||||||
errors.append("End time cannot be None")
|
|
||||||
|
|
||||||
if start_time and end_time and start_time >= end_time:
|
|
||||||
errors.append("Start time must be before end time")
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def validate_bucket_size(bucket_size: float) -> bool:
|
|
||||||
"""
|
|
||||||
Validate price bucket size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bucket_size: Bucket size to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not isinstance(bucket_size, (int, float, Decimal)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
size_decimal = Decimal(str(bucket_size))
|
|
||||||
return size_decimal > 0 and size_decimal <= Decimal('1000') # Reasonable upper bound
|
|
||||||
except (InvalidOperation, ValueError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_speed_multiplier(speed: float) -> bool:
|
|
||||||
"""
|
|
||||||
Validate replay speed multiplier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
speed: Speed multiplier to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not isinstance(speed, (int, float)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return 0.01 <= speed <= 100.0 # 1% to 100x speed
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_symbol(symbol: str) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize and normalize symbol format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Symbol to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Sanitized symbol
|
|
||||||
"""
|
|
||||||
if not symbol:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Remove whitespace and convert to uppercase
|
|
||||||
sanitized = symbol.strip().upper()
|
|
||||||
|
|
||||||
# Remove invalid characters
|
|
||||||
sanitized = re.sub(r'[^A-Z0-9/-]', '', sanitized)
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def validate_percentage(value: float, min_val: float = 0.0, max_val: float = 100.0) -> bool:
|
|
||||||
"""
|
|
||||||
Validate percentage value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Percentage value to validate
|
|
||||||
min_val: Minimum allowed value
|
|
||||||
max_val: Maximum allowed value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
if not isinstance(value, (int, float)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return min_val <= value <= max_val
|
|
||||||
|
|
||||||
|
|
||||||
def validate_connection_config(config: dict) -> List[str]:
|
|
||||||
"""
|
|
||||||
Validate connection configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of validation errors (empty if valid)
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
# Required fields
|
|
||||||
required_fields = ['host', 'port']
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in config:
|
|
||||||
errors.append(f"Missing required field: {field}")
|
|
||||||
|
|
||||||
# Validate host
|
|
||||||
if 'host' in config:
|
|
||||||
host = config['host']
|
|
||||||
if not isinstance(host, str) or not host.strip():
|
|
||||||
errors.append("Host must be a non-empty string")
|
|
||||||
|
|
||||||
# Validate port
|
|
||||||
if 'port' in config:
|
|
||||||
port = config['port']
|
|
||||||
if not isinstance(port, int) or not (1 <= port <= 65535):
|
|
||||||
errors.append("Port must be an integer between 1 and 65535")
|
|
||||||
|
|
||||||
return errors
|
|
Reference in New Issue
Block a user