api
This commit is contained in:
@ -53,6 +53,9 @@
|
|||||||
- Implement data processor for normalizing raw exchange data
|
- Implement data processor for normalizing raw exchange data
|
||||||
- Create validation logic for order book and trade data
|
- Create validation logic for order book and trade data
|
||||||
- Implement data quality checks and filtering
|
- Implement data quality checks and filtering
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- Add metrics calculation for order book statistics
|
- Add metrics calculation for order book statistics
|
||||||
- Write comprehensive unit tests for data processing logic
|
- Write comprehensive unit tests for data processing logic
|
||||||
- _Requirements: 1.4, 6.3, 8.1_
|
- _Requirements: 1.4, 6.3, 8.1_
|
||||||
|
15
COBY/api/__init__.py
Normal file
15
COBY/api/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
API layer for the COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .rest_api import create_app
|
||||||
|
from .websocket_server import WebSocketServer
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
from .response_formatter import ResponseFormatter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'create_app',
|
||||||
|
'WebSocketServer',
|
||||||
|
'RateLimiter',
|
||||||
|
'ResponseFormatter'
|
||||||
|
]
|
183
COBY/api/rate_limiter.py
Normal file
183
COBY/api/rate_limiter.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
"""
|
||||||
|
Rate limiting for API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
Token bucket rate limiter for API endpoints.
|
||||||
|
|
||||||
|
Provides per-client rate limiting with configurable limits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
|
||||||
|
"""
|
||||||
|
Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests_per_minute: Maximum requests per minute
|
||||||
|
burst_size: Maximum burst requests
|
||||||
|
"""
|
||||||
|
self.requests_per_minute = requests_per_minute
|
||||||
|
self.burst_size = burst_size
|
||||||
|
self.refill_rate = requests_per_minute / 60.0 # tokens per second
|
||||||
|
|
||||||
|
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
|
||||||
|
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
|
||||||
|
'tokens': float(burst_size),
|
||||||
|
'last_refill': time.time()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Request history for monitoring
|
||||||
|
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||||
|
|
||||||
|
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
|
||||||
|
|
||||||
|
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Check if request is allowed for client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Client identifier (IP, user ID, etc.)
|
||||||
|
tokens_requested: Number of tokens requested
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if request is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
bucket = self.buckets[client_id]
|
||||||
|
|
||||||
|
# Refill tokens based on time elapsed
|
||||||
|
time_elapsed = current_time - bucket['last_refill']
|
||||||
|
tokens_to_add = time_elapsed * self.refill_rate
|
||||||
|
|
||||||
|
# Update bucket
|
||||||
|
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||||
|
bucket['last_refill'] = current_time
|
||||||
|
|
||||||
|
# Check if enough tokens available
|
||||||
|
if bucket['tokens'] >= tokens_requested:
|
||||||
|
bucket['tokens'] -= tokens_requested
|
||||||
|
|
||||||
|
# Record successful request
|
||||||
|
self.request_history[client_id].append(current_time)
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.debug(f"Rate limit exceeded for client {client_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_remaining_tokens(self, client_id: str) -> float:
|
||||||
|
"""
|
||||||
|
Get remaining tokens for client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Client identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Number of remaining tokens
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
bucket = self.buckets[client_id]
|
||||||
|
|
||||||
|
# Calculate current tokens (with refill)
|
||||||
|
time_elapsed = current_time - bucket['last_refill']
|
||||||
|
tokens_to_add = time_elapsed * self.refill_rate
|
||||||
|
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
|
||||||
|
|
||||||
|
return current_tokens
|
||||||
|
|
||||||
|
def get_reset_time(self, client_id: str) -> float:
|
||||||
|
"""
|
||||||
|
Get time until bucket is fully refilled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Client identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Seconds until full refill
|
||||||
|
"""
|
||||||
|
remaining_tokens = self.get_remaining_tokens(client_id)
|
||||||
|
tokens_needed = self.burst_size - remaining_tokens
|
||||||
|
|
||||||
|
if tokens_needed <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return tokens_needed / self.refill_rate
|
||||||
|
|
||||||
|
def get_client_stats(self, client_id: str) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Get statistics for a client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Client identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Client statistics
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
history = self.request_history[client_id]
|
||||||
|
|
||||||
|
# Count requests in last minute
|
||||||
|
minute_ago = current_time - 60
|
||||||
|
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'remaining_tokens': self.get_remaining_tokens(client_id),
|
||||||
|
'reset_time': self.get_reset_time(client_id),
|
||||||
|
'requests_last_minute': recent_requests,
|
||||||
|
'total_requests': len(history)
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
|
||||||
|
"""
|
||||||
|
Clean up old client data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_hours: Maximum age of data to keep
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
cutoff_time = current_time - (max_age_hours * 3600)
|
||||||
|
|
||||||
|
# Clean up buckets for inactive clients
|
||||||
|
inactive_clients = []
|
||||||
|
for client_id, bucket in self.buckets.items():
|
||||||
|
if bucket['last_refill'] < cutoff_time:
|
||||||
|
inactive_clients.append(client_id)
|
||||||
|
|
||||||
|
for client_id in inactive_clients:
|
||||||
|
del self.buckets[client_id]
|
||||||
|
if client_id in self.request_history:
|
||||||
|
del self.request_history[client_id]
|
||||||
|
|
||||||
|
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
|
||||||
|
|
||||||
|
def get_global_stats(self) -> Dict[str, int]:
|
||||||
|
"""Get global rate limiter statistics"""
|
||||||
|
current_time = time.time()
|
||||||
|
minute_ago = current_time - 60
|
||||||
|
|
||||||
|
total_clients = len(self.buckets)
|
||||||
|
active_clients = 0
|
||||||
|
total_requests_last_minute = 0
|
||||||
|
|
||||||
|
for client_id, history in self.request_history.items():
|
||||||
|
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
|
||||||
|
if recent_requests > 0:
|
||||||
|
active_clients += 1
|
||||||
|
total_requests_last_minute += recent_requests
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_clients': total_clients,
|
||||||
|
'active_clients': active_clients,
|
||||||
|
'requests_per_minute_limit': self.requests_per_minute,
|
||||||
|
'burst_size': self.burst_size,
|
||||||
|
'total_requests_last_minute': total_requests_last_minute
|
||||||
|
}
|
326
COBY/api/response_formatter.py
Normal file
326
COBY/api/response_formatter.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
Response formatting for API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
from ..utils.timing import get_current_timestamp
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatter:
|
||||||
|
"""
|
||||||
|
Formats API responses with consistent structure and metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize response formatter"""
|
||||||
|
self.responses_formatted = 0
|
||||||
|
logger.info("Response formatter initialized")
|
||||||
|
|
||||||
|
def success(self, data: Any, message: str = "Success",
|
||||||
|
metadata: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format successful response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Response data
|
||||||
|
message: Success message
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted response
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
'success': True,
|
||||||
|
'message': message,
|
||||||
|
'data': data,
|
||||||
|
'timestamp': get_current_timestamp().isoformat(),
|
||||||
|
'metadata': metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.responses_formatted += 1
|
||||||
|
return response
|
||||||
|
|
||||||
|
def error(self, message: str, error_code: str = "UNKNOWN_ERROR",
|
||||||
|
details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
error_code: Error code
|
||||||
|
details: Error details
|
||||||
|
status_code: HTTP status code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted error response
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
'success': False,
|
||||||
|
'error': {
|
||||||
|
'message': message,
|
||||||
|
'code': error_code,
|
||||||
|
'details': details or {},
|
||||||
|
'status_code': status_code
|
||||||
|
},
|
||||||
|
'timestamp': get_current_timestamp().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.responses_formatted += 1
|
||||||
|
return response
|
||||||
|
|
||||||
|
def paginated(self, data: List[Any], page: int, page_size: int,
|
||||||
|
total_count: int, message: str = "Success") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format paginated response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Page data
|
||||||
|
page: Current page number
|
||||||
|
page_size: Items per page
|
||||||
|
total_count: Total number of items
|
||||||
|
message: Success message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted paginated response
|
||||||
|
"""
|
||||||
|
total_pages = (total_count + page_size - 1) // page_size
|
||||||
|
|
||||||
|
pagination = {
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_count': total_count,
|
||||||
|
'total_pages': total_pages,
|
||||||
|
'has_next': page < total_pages,
|
||||||
|
'has_previous': page > 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data=data,
|
||||||
|
message=message,
|
||||||
|
metadata={'pagination': pagination}
|
||||||
|
)
|
||||||
|
|
||||||
|
def heatmap_response(self, heatmap_data, symbol: str,
|
||||||
|
exchange: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format heatmap data response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
heatmap_data: Heatmap data
|
||||||
|
symbol: Trading symbol
|
||||||
|
exchange: Exchange name (None for consolidated)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted heatmap response
|
||||||
|
"""
|
||||||
|
if not heatmap_data:
|
||||||
|
return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404)
|
||||||
|
|
||||||
|
# Convert heatmap to API format
|
||||||
|
formatted_data = {
|
||||||
|
'symbol': heatmap_data.symbol,
|
||||||
|
'timestamp': heatmap_data.timestamp.isoformat(),
|
||||||
|
'bucket_size': heatmap_data.bucket_size,
|
||||||
|
'exchange': exchange,
|
||||||
|
'points': [
|
||||||
|
{
|
||||||
|
'price': point.price,
|
||||||
|
'volume': point.volume,
|
||||||
|
'intensity': point.intensity,
|
||||||
|
'side': point.side
|
||||||
|
}
|
||||||
|
for point in heatmap_data.data
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'total_points': len(heatmap_data.data),
|
||||||
|
'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']),
|
||||||
|
'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']),
|
||||||
|
'data_type': 'consolidated' if not exchange else 'exchange_specific'
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data=formatted_data,
|
||||||
|
message=f"Heatmap data for {symbol}",
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format order book response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orderbook_data: Order book data
|
||||||
|
symbol: Trading symbol
|
||||||
|
exchange: Exchange name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted order book response
|
||||||
|
"""
|
||||||
|
if not orderbook_data:
|
||||||
|
return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404)
|
||||||
|
|
||||||
|
# Convert order book to API format
|
||||||
|
formatted_data = {
|
||||||
|
'symbol': orderbook_data.symbol,
|
||||||
|
'exchange': orderbook_data.exchange,
|
||||||
|
'timestamp': orderbook_data.timestamp.isoformat(),
|
||||||
|
'sequence_id': orderbook_data.sequence_id,
|
||||||
|
'bids': [
|
||||||
|
{
|
||||||
|
'price': bid.price,
|
||||||
|
'size': bid.size,
|
||||||
|
'count': bid.count
|
||||||
|
}
|
||||||
|
for bid in orderbook_data.bids
|
||||||
|
],
|
||||||
|
'asks': [
|
||||||
|
{
|
||||||
|
'price': ask.price,
|
||||||
|
'size': ask.size,
|
||||||
|
'count': ask.count
|
||||||
|
}
|
||||||
|
for ask in orderbook_data.asks
|
||||||
|
],
|
||||||
|
'mid_price': orderbook_data.mid_price,
|
||||||
|
'spread': orderbook_data.spread,
|
||||||
|
'bid_volume': orderbook_data.bid_volume,
|
||||||
|
'ask_volume': orderbook_data.ask_volume
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'bid_levels': len(orderbook_data.bids),
|
||||||
|
'ask_levels': len(orderbook_data.asks),
|
||||||
|
'total_bid_volume': orderbook_data.bid_volume,
|
||||||
|
'total_ask_volume': orderbook_data.ask_volume
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data=formatted_data,
|
||||||
|
message=f"Order book for {symbol}@{exchange}",
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format metrics response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics_data: Metrics data
|
||||||
|
symbol: Trading symbol
|
||||||
|
exchange: Exchange name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted metrics response
|
||||||
|
"""
|
||||||
|
if not metrics_data:
|
||||||
|
return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404)
|
||||||
|
|
||||||
|
# Convert metrics to API format
|
||||||
|
formatted_data = {
|
||||||
|
'symbol': metrics_data.symbol,
|
||||||
|
'exchange': metrics_data.exchange,
|
||||||
|
'timestamp': metrics_data.timestamp.isoformat(),
|
||||||
|
'mid_price': metrics_data.mid_price,
|
||||||
|
'spread': metrics_data.spread,
|
||||||
|
'spread_percentage': metrics_data.spread_percentage,
|
||||||
|
'bid_volume': metrics_data.bid_volume,
|
||||||
|
'ask_volume': metrics_data.ask_volume,
|
||||||
|
'volume_imbalance': metrics_data.volume_imbalance,
|
||||||
|
'depth_10': metrics_data.depth_10,
|
||||||
|
'depth_50': metrics_data.depth_50
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data=formatted_data,
|
||||||
|
message=f"Metrics for {symbol}@{exchange}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format system status response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status_data: System status data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted status response
|
||||||
|
"""
|
||||||
|
return self.success(
|
||||||
|
data=status_data,
|
||||||
|
message="System status",
|
||||||
|
metadata={'response_count': self.responses_formatted}
|
||||||
|
)
|
||||||
|
|
||||||
|
def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format rate limit error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_stats: Client rate limit statistics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted rate limit error
|
||||||
|
"""
|
||||||
|
return self.error(
|
||||||
|
message="Rate limit exceeded",
|
||||||
|
error_code="RATE_LIMIT_EXCEEDED",
|
||||||
|
details={
|
||||||
|
'remaining_tokens': client_stats['remaining_tokens'],
|
||||||
|
'reset_time': client_stats['reset_time'],
|
||||||
|
'requests_last_minute': client_stats['requests_last_minute']
|
||||||
|
},
|
||||||
|
status_code=429
|
||||||
|
)
|
||||||
|
|
||||||
|
def validation_error(self, field: str, message: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format validation error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: Field that failed validation
|
||||||
|
message: Validation error message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Formatted validation error
|
||||||
|
"""
|
||||||
|
return self.error(
|
||||||
|
message=f"Validation error: {message}",
|
||||||
|
error_code="VALIDATION_ERROR",
|
||||||
|
details={'field': field, 'message': message},
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
Convert response to JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Response dictionary
|
||||||
|
indent: JSON indentation (None for compact)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON string
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.dumps(response, indent=indent, ensure_ascii=False, default=str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error converting response to JSON: {e}")
|
||||||
|
return json.dumps(self.error("JSON serialization failed", "JSON_ERROR"))
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, int]:
|
||||||
|
"""Get formatter statistics"""
|
||||||
|
return {
|
||||||
|
'responses_formatted': self.responses_formatted
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset formatter statistics"""
|
||||||
|
self.responses_formatted = 0
|
||||||
|
logger.info("Response formatter statistics reset")
|
391
COBY/api/rest_api.py
Normal file
391
COBY/api/rest_api.py
Normal file
@ -0,0 +1,391 @@
|
|||||||
|
"""
|
||||||
|
REST API server for COBY system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request, Query, Path
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from typing import Optional, List
|
||||||
|
import asyncio
|
||||||
|
from ..config import config
|
||||||
|
from ..caching.redis_manager import redis_manager
|
||||||
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
|
from ..utils.validation import validate_symbol
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
from .response_formatter import ResponseFormatter
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure FastAPI application"""
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="COBY Market Data API",
|
||||||
|
description="Real-time cryptocurrency market data aggregation API",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=config.api.cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
rate_limiter = RateLimiter(
|
||||||
|
requests_per_minute=config.api.rate_limit,
|
||||||
|
burst_size=20
|
||||||
|
)
|
||||||
|
response_formatter = ResponseFormatter()
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def rate_limit_middleware(request: Request, call_next):
|
||||||
|
"""Rate limiting middleware"""
|
||||||
|
client_ip = request.client.host
|
||||||
|
|
||||||
|
if not rate_limiter.is_allowed(client_ip):
|
||||||
|
client_stats = rate_limiter.get_client_stats(client_ip)
|
||||||
|
error_response = response_formatter.rate_limit_error(client_stats)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content=error_response,
|
||||||
|
headers={
|
||||||
|
"X-RateLimit-Remaining": str(int(client_stats['remaining_tokens'])),
|
||||||
|
"X-RateLimit-Reset": str(int(client_stats['reset_time']))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Add rate limit headers
|
||||||
|
client_stats = rate_limiter.get_client_stats(client_ip)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(int(client_stats['remaining_tokens']))
|
||||||
|
response.headers["X-RateLimit-Reset"] = str(int(client_stats['reset_time']))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def correlation_middleware(request: Request, call_next):
|
||||||
|
"""Add correlation ID to requests"""
|
||||||
|
set_correlation_id()
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Initialize services on startup"""
|
||||||
|
try:
|
||||||
|
await redis_manager.initialize()
|
||||||
|
logger.info("API server startup completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API server startup failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Cleanup on shutdown"""
|
||||||
|
try:
|
||||||
|
await redis_manager.close()
|
||||||
|
logger.info("API server shutdown completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API server shutdown error: {e}")
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
try:
|
||||||
|
# Check Redis connection
|
||||||
|
redis_healthy = await redis_manager.ping()
|
||||||
|
|
||||||
|
health_data = {
|
||||||
|
'status': 'healthy' if redis_healthy else 'degraded',
|
||||||
|
'redis': 'connected' if redis_healthy else 'disconnected',
|
||||||
|
'version': '1.0.0'
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_formatter.status_response(health_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check failed: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=503,
|
||||||
|
content=response_formatter.error("Service unavailable", "HEALTH_CHECK_FAILED")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Heatmap endpoints
|
||||||
|
@app.get("/api/v1/heatmap/{symbol}")
|
||||||
|
async def get_heatmap(
|
||||||
|
symbol: str = Path(..., description="Trading symbol (e.g., BTCUSDT)"),
|
||||||
|
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
|
||||||
|
):
|
||||||
|
"""Get heatmap data for a symbol"""
|
||||||
|
try:
|
||||||
|
# Validate symbol
|
||||||
|
if not validate_symbol(symbol):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=response_formatter.validation_error("symbol", "Invalid symbol format")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get heatmap from cache
|
||||||
|
heatmap_data = await redis_manager.get_heatmap(symbol.upper(), exchange)
|
||||||
|
|
||||||
|
return response_formatter.heatmap_response(heatmap_data, symbol.upper(), exchange)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting heatmap for {symbol}: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "HEATMAP_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Order book endpoints
|
||||||
|
@app.get("/api/v1/orderbook/{symbol}/{exchange}")
|
||||||
|
async def get_orderbook(
|
||||||
|
symbol: str = Path(..., description="Trading symbol"),
|
||||||
|
exchange: str = Path(..., description="Exchange name")
|
||||||
|
):
|
||||||
|
"""Get order book data for a symbol on an exchange"""
|
||||||
|
try:
|
||||||
|
# Validate symbol
|
||||||
|
if not validate_symbol(symbol):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=response_formatter.validation_error("symbol", "Invalid symbol format")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get order book from cache
|
||||||
|
orderbook_data = await redis_manager.get_orderbook(symbol.upper(), exchange.lower())
|
||||||
|
|
||||||
|
return response_formatter.orderbook_response(orderbook_data, symbol.upper(), exchange.lower())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting order book for {symbol}@{exchange}: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "ORDERBOOK_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Metrics endpoints
|
||||||
|
@app.get("/api/v1/metrics/{symbol}/{exchange}")
|
||||||
|
async def get_metrics(
|
||||||
|
symbol: str = Path(..., description="Trading symbol"),
|
||||||
|
exchange: str = Path(..., description="Exchange name")
|
||||||
|
):
|
||||||
|
"""Get metrics data for a symbol on an exchange"""
|
||||||
|
try:
|
||||||
|
# Validate symbol
|
||||||
|
if not validate_symbol(symbol):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=response_formatter.validation_error("symbol", "Invalid symbol format")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get metrics from cache
|
||||||
|
metrics_data = await redis_manager.get_metrics(symbol.upper(), exchange.lower())
|
||||||
|
|
||||||
|
return response_formatter.metrics_response(metrics_data, symbol.upper(), exchange.lower())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting metrics for {symbol}@{exchange}: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "METRICS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exchange status endpoints
|
||||||
|
@app.get("/api/v1/status/{exchange}")
|
||||||
|
async def get_exchange_status(
|
||||||
|
exchange: str = Path(..., description="Exchange name")
|
||||||
|
):
|
||||||
|
"""Get status for an exchange"""
|
||||||
|
try:
|
||||||
|
# Get status from cache
|
||||||
|
status_data = await redis_manager.get_exchange_status(exchange.lower())
|
||||||
|
|
||||||
|
if not status_data:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content=response_formatter.error("Exchange status not found", "STATUS_NOT_FOUND")
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=status_data,
|
||||||
|
message=f"Status for {exchange}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting status for {exchange}: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "STATUS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# List endpoints
|
||||||
|
@app.get("/api/v1/symbols")
|
||||||
|
async def list_symbols():
|
||||||
|
"""List available trading symbols"""
|
||||||
|
try:
|
||||||
|
# Get symbols from cache (this would be populated by exchange connectors)
|
||||||
|
symbols_pattern = "symbols:*"
|
||||||
|
symbol_keys = await redis_manager.keys(symbols_pattern)
|
||||||
|
|
||||||
|
all_symbols = set()
|
||||||
|
for key in symbol_keys:
|
||||||
|
symbols_data = await redis_manager.get(key)
|
||||||
|
if symbols_data and isinstance(symbols_data, list):
|
||||||
|
all_symbols.update(symbols_data)
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=sorted(list(all_symbols)),
|
||||||
|
message="Available trading symbols",
|
||||||
|
metadata={'total_symbols': len(all_symbols)}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing symbols: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "SYMBOLS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/v1/exchanges")
|
||||||
|
async def list_exchanges():
|
||||||
|
"""List available exchanges"""
|
||||||
|
try:
|
||||||
|
# Get exchange status keys
|
||||||
|
status_pattern = "st:*"
|
||||||
|
status_keys = await redis_manager.keys(status_pattern)
|
||||||
|
|
||||||
|
exchanges = []
|
||||||
|
for key in status_keys:
|
||||||
|
# Extract exchange name from key (st:exchange_name)
|
||||||
|
exchange_name = key.split(':', 1)[1] if ':' in key else key
|
||||||
|
exchanges.append(exchange_name)
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=sorted(exchanges),
|
||||||
|
message="Available exchanges",
|
||||||
|
metadata={'total_exchanges': len(exchanges)}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing exchanges: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "EXCHANGES_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Statistics endpoints
|
||||||
|
@app.get("/api/v1/stats/cache")
|
||||||
|
async def get_cache_stats():
|
||||||
|
"""Get cache statistics"""
|
||||||
|
try:
|
||||||
|
cache_stats = redis_manager.get_stats()
|
||||||
|
redis_health = await redis_manager.health_check()
|
||||||
|
|
||||||
|
stats_data = {
|
||||||
|
'cache_performance': cache_stats,
|
||||||
|
'redis_health': redis_health
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=stats_data,
|
||||||
|
message="Cache statistics"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting cache stats: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "STATS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/v1/stats/api")
|
||||||
|
async def get_api_stats():
|
||||||
|
"""Get API statistics"""
|
||||||
|
try:
|
||||||
|
api_stats = {
|
||||||
|
'rate_limiter': rate_limiter.get_global_stats(),
|
||||||
|
'response_formatter': response_formatter.get_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=api_stats,
|
||||||
|
message="API statistics"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API stats: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "API_STATS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch endpoints for efficiency
|
||||||
|
@app.get("/api/v1/batch/heatmaps")
|
||||||
|
async def get_batch_heatmaps(
|
||||||
|
symbols: str = Query(..., description="Comma-separated list of symbols"),
|
||||||
|
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
|
||||||
|
):
|
||||||
|
"""Get heatmaps for multiple symbols"""
|
||||||
|
try:
|
||||||
|
symbol_list = [s.strip().upper() for s in symbols.split(',')]
|
||||||
|
|
||||||
|
# Validate all symbols
|
||||||
|
for symbol in symbol_list:
|
||||||
|
if not validate_symbol(symbol):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=response_formatter.validation_error("symbols", f"Invalid symbol: {symbol}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get heatmaps in batch
|
||||||
|
heatmaps = {}
|
||||||
|
for symbol in symbol_list:
|
||||||
|
heatmap_data = await redis_manager.get_heatmap(symbol, exchange)
|
||||||
|
if heatmap_data:
|
||||||
|
heatmaps[symbol] = {
|
||||||
|
'symbol': heatmap_data.symbol,
|
||||||
|
'timestamp': heatmap_data.timestamp.isoformat(),
|
||||||
|
'bucket_size': heatmap_data.bucket_size,
|
||||||
|
'points': [
|
||||||
|
{
|
||||||
|
'price': point.price,
|
||||||
|
'volume': point.volume,
|
||||||
|
'intensity': point.intensity,
|
||||||
|
'side': point.side
|
||||||
|
}
|
||||||
|
for point in heatmap_data.data
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_formatter.success(
|
||||||
|
data=heatmaps,
|
||||||
|
message=f"Batch heatmaps for {len(symbol_list)} symbols",
|
||||||
|
metadata={
|
||||||
|
'requested_symbols': len(symbol_list),
|
||||||
|
'found_heatmaps': len(heatmaps),
|
||||||
|
'exchange': exchange or 'consolidated'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting batch heatmaps: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=response_formatter.error("Internal server error", "BATCH_HEATMAPS_ERROR")
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# Create the FastAPI app instance
|
||||||
|
app = create_app()
|
400
COBY/api/websocket_server.py
Normal file
400
COBY/api/websocket_server.py
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
"""
|
||||||
|
WebSocket server for real-time data streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Dict, Set, Optional, Any
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from ..utils.logging import get_logger, set_correlation_id
|
||||||
|
from ..utils.validation import validate_symbol
|
||||||
|
from ..caching.redis_manager import redis_manager
|
||||||
|
from .response_formatter import ResponseFormatter
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketManager:
|
||||||
|
"""
|
||||||
|
Manages WebSocket connections and real-time data streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize WebSocket manager"""
|
||||||
|
# Active connections: connection_id -> WebSocket
|
||||||
|
self.connections: Dict[str, WebSocket] = {}
|
||||||
|
|
||||||
|
# Subscriptions: symbol -> set of connection_ids
|
||||||
|
self.subscriptions: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
|
# Connection metadata: connection_id -> metadata
|
||||||
|
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
self.response_formatter = ResponseFormatter()
|
||||||
|
self.connection_counter = 0
|
||||||
|
|
||||||
|
logger.info("WebSocket manager initialized")
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
|
||||||
|
"""
|
||||||
|
Accept new WebSocket connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket connection
|
||||||
|
client_ip: Client IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Connection ID
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Generate connection ID
|
||||||
|
self.connection_counter += 1
|
||||||
|
connection_id = f"ws_{self.connection_counter}_{client_ip}"
|
||||||
|
|
||||||
|
# Store connection
|
||||||
|
self.connections[connection_id] = websocket
|
||||||
|
self.connection_metadata[connection_id] = {
|
||||||
|
'client_ip': client_ip,
|
||||||
|
'connected_at': asyncio.get_event_loop().time(),
|
||||||
|
'subscriptions': set(),
|
||||||
|
'messages_sent': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"WebSocket connected: {connection_id}")
|
||||||
|
|
||||||
|
# Send welcome message
|
||||||
|
welcome_msg = self.response_formatter.success(
|
||||||
|
data={'connection_id': connection_id},
|
||||||
|
message="WebSocket connected successfully"
|
||||||
|
)
|
||||||
|
await self._send_to_connection(connection_id, welcome_msg)
|
||||||
|
|
||||||
|
return connection_id
|
||||||
|
|
||||||
|
async def disconnect(self, connection_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle WebSocket disconnection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_id: Connection ID to disconnect
|
||||||
|
"""
|
||||||
|
if connection_id in self.connections:
|
||||||
|
# Remove from all subscriptions
|
||||||
|
metadata = self.connection_metadata.get(connection_id, {})
|
||||||
|
for symbol in metadata.get('subscriptions', set()):
|
||||||
|
await self._unsubscribe_connection(connection_id, symbol)
|
||||||
|
|
||||||
|
# Remove connection
|
||||||
|
del self.connections[connection_id]
|
||||||
|
del self.connection_metadata[connection_id]
|
||||||
|
|
||||||
|
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||||
|
|
||||||
|
async def subscribe(self, connection_id: str, symbol: str,
|
||||||
|
data_type: str = "heatmap") -> bool:
|
||||||
|
"""
|
||||||
|
Subscribe connection to symbol updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_id: Connection ID
|
||||||
|
symbol: Trading symbol
|
||||||
|
data_type: Type of data to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if subscribed successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate symbol
|
||||||
|
if not validate_symbol(symbol):
|
||||||
|
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
|
||||||
|
await self._send_to_connection(connection_id, error_msg)
|
||||||
|
return False
|
||||||
|
|
||||||
|
symbol = symbol.upper()
|
||||||
|
subscription_key = f"{symbol}:{data_type}"
|
||||||
|
|
||||||
|
# Add to subscriptions
|
||||||
|
if subscription_key not in self.subscriptions:
|
||||||
|
self.subscriptions[subscription_key] = set()
|
||||||
|
|
||||||
|
self.subscriptions[subscription_key].add(connection_id)
|
||||||
|
|
||||||
|
# Update connection metadata
|
||||||
|
if connection_id in self.connection_metadata:
|
||||||
|
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
|
||||||
|
|
||||||
|
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
|
||||||
|
|
||||||
|
# Send confirmation
|
||||||
|
confirm_msg = self.response_formatter.success(
|
||||||
|
data={'symbol': symbol, 'data_type': data_type},
|
||||||
|
message=f"Subscribed to {symbol} {data_type} updates"
|
||||||
|
)
|
||||||
|
await self._send_to_connection(connection_id, confirm_msg)
|
||||||
|
|
||||||
|
# Send initial data if available
|
||||||
|
await self._send_initial_data(connection_id, symbol, data_type)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
|
||||||
|
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
|
||||||
|
await self._send_to_connection(connection_id, error_msg)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def unsubscribe(self, connection_id: str, symbol: str,
|
||||||
|
data_type: str = "heatmap") -> bool:
|
||||||
|
"""
|
||||||
|
Unsubscribe connection from symbol updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_id: Connection ID
|
||||||
|
symbol: Trading symbol
|
||||||
|
data_type: Type of data to unsubscribe from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if unsubscribed successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
symbol = symbol.upper()
|
||||||
|
subscription_key = f"{symbol}:{data_type}"
|
||||||
|
|
||||||
|
await self._unsubscribe_connection(connection_id, subscription_key)
|
||||||
|
|
||||||
|
# Send confirmation
|
||||||
|
confirm_msg = self.response_formatter.success(
|
||||||
|
data={'symbol': symbol, 'data_type': data_type},
|
||||||
|
message=f"Unsubscribed from {symbol} {data_type} updates"
|
||||||
|
)
|
||||||
|
await self._send_to_connection(connection_id, confirm_msg)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
|
||||||
|
"""
|
||||||
|
Broadcast data update to all subscribers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
data_type: Type of data
|
||||||
|
data: Data to broadcast
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of connections notified
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
set_correlation_id()
|
||||||
|
|
||||||
|
subscription_key = f"{symbol.upper()}:{data_type}"
|
||||||
|
subscribers = self.subscriptions.get(subscription_key, set())
|
||||||
|
|
||||||
|
if not subscribers:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Format message based on data type
|
||||||
|
if data_type == "heatmap":
|
||||||
|
message = self.response_formatter.heatmap_response(data, symbol)
|
||||||
|
elif data_type == "orderbook":
|
||||||
|
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
|
||||||
|
else:
|
||||||
|
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
|
||||||
|
|
||||||
|
# Add update type to message
|
||||||
|
message['update_type'] = data_type
|
||||||
|
message['symbol'] = symbol
|
||||||
|
|
||||||
|
# Send to all subscribers
|
||||||
|
sent_count = 0
|
||||||
|
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
|
||||||
|
if await self._send_to_connection(connection_id, message):
|
||||||
|
sent_count += 1
|
||||||
|
|
||||||
|
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
|
||||||
|
return sent_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error broadcasting update for {symbol}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Send message to specific connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_id: Connection ID
|
||||||
|
message: Message to send
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sent successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if connection_id not in self.connections:
|
||||||
|
return False
|
||||||
|
|
||||||
|
websocket = self.connections[connection_id]
|
||||||
|
message_json = json.dumps(message, default=str)
|
||||||
|
|
||||||
|
await websocket.send_text(message_json)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
if connection_id in self.connection_metadata:
|
||||||
|
self.connection_metadata[connection_id]['messages_sent'] += 1
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error sending message to {connection_id}: {e}")
|
||||||
|
# Remove broken connection
|
||||||
|
await self.disconnect(connection_id)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
|
||||||
|
"""Remove connection from subscription"""
|
||||||
|
if subscription_key in self.subscriptions:
|
||||||
|
self.subscriptions[subscription_key].discard(connection_id)
|
||||||
|
|
||||||
|
# Clean up empty subscriptions
|
||||||
|
if not self.subscriptions[subscription_key]:
|
||||||
|
del self.subscriptions[subscription_key]
|
||||||
|
|
||||||
|
# Update connection metadata
|
||||||
|
if connection_id in self.connection_metadata:
|
||||||
|
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
|
||||||
|
|
||||||
|
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
|
||||||
|
"""Send initial data to newly subscribed connection"""
|
||||||
|
try:
|
||||||
|
if data_type == "heatmap":
|
||||||
|
# Get latest heatmap from cache
|
||||||
|
heatmap_data = await redis_manager.get_heatmap(symbol)
|
||||||
|
if heatmap_data:
|
||||||
|
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
|
||||||
|
message['update_type'] = 'initial_data'
|
||||||
|
await self._send_to_connection(connection_id, message)
|
||||||
|
|
||||||
|
elif data_type == "orderbook":
|
||||||
|
# Could get latest order book from cache
|
||||||
|
# This would require knowing which exchange to get data from
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error sending initial data to {connection_id}: {e}")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get WebSocket manager statistics"""
|
||||||
|
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
|
||||||
|
|
||||||
|
return {
|
||||||
|
'active_connections': len(self.connections),
|
||||||
|
'total_subscriptions': total_subscriptions,
|
||||||
|
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
|
||||||
|
'connection_counter': self.connection_counter
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global WebSocket manager instance
|
||||||
|
websocket_manager = WebSocketManager()
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketServer:
|
||||||
|
"""
|
||||||
|
WebSocket server for real-time data streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize WebSocket server"""
|
||||||
|
self.manager = websocket_manager
|
||||||
|
logger.info("WebSocket server initialized")
|
||||||
|
|
||||||
|
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle WebSocket connection lifecycle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket connection
|
||||||
|
client_ip: Client IP address
|
||||||
|
"""
|
||||||
|
connection_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Accept connection
|
||||||
|
connection_id = await self.manager.connect(websocket, client_ip)
|
||||||
|
|
||||||
|
# Handle messages
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Receive message
|
||||||
|
message = await websocket.receive_text()
|
||||||
|
await self._handle_message(connection_id, message)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket client disconnected: {connection_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket connection error: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up connection
|
||||||
|
if connection_id:
|
||||||
|
await self.manager.disconnect(connection_id)
|
||||||
|
|
||||||
|
async def _handle_message(self, connection_id: str, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle incoming WebSocket message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_id: Connection ID
|
||||||
|
message: Received message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse message
|
||||||
|
data = json.loads(message)
|
||||||
|
action = data.get('action')
|
||||||
|
|
||||||
|
if action == 'subscribe':
|
||||||
|
symbol = data.get('symbol')
|
||||||
|
data_type = data.get('data_type', 'heatmap')
|
||||||
|
await self.manager.subscribe(connection_id, symbol, data_type)
|
||||||
|
|
||||||
|
elif action == 'unsubscribe':
|
||||||
|
symbol = data.get('symbol')
|
||||||
|
data_type = data.get('data_type', 'heatmap')
|
||||||
|
await self.manager.unsubscribe(connection_id, symbol, data_type)
|
||||||
|
|
||||||
|
elif action == 'ping':
|
||||||
|
# Send pong response
|
||||||
|
pong_msg = self.manager.response_formatter.success(
|
||||||
|
data={'action': 'pong'},
|
||||||
|
message="Pong"
|
||||||
|
)
|
||||||
|
await self.manager._send_to_connection(connection_id, pong_msg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unknown action
|
||||||
|
error_msg = self.manager.response_formatter.error(
|
||||||
|
f"Unknown action: {action}",
|
||||||
|
"UNKNOWN_ACTION"
|
||||||
|
)
|
||||||
|
await self.manager._send_to_connection(connection_id, error_msg)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_msg = self.manager.response_formatter.error(
|
||||||
|
"Invalid JSON message",
|
||||||
|
"INVALID_JSON"
|
||||||
|
)
|
||||||
|
await self.manager._send_to_connection(connection_id, error_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling WebSocket message: {e}")
|
||||||
|
error_msg = self.manager.response_formatter.error(
|
||||||
|
"Message processing failed",
|
||||||
|
"MESSAGE_ERROR"
|
||||||
|
)
|
||||||
|
await self.manager._send_to_connection(connection_id, error_msg)
|
Reference in New Issue
Block a user