From fd6ec4eb40eddb53b14580216538c6210d64ce26 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 4 Aug 2025 18:38:51 +0300 Subject: [PATCH] api --- .../multi-exchange-data-aggregation/tasks.md | 3 + COBY/api/__init__.py | 15 + COBY/api/rate_limiter.py | 183 ++++++++ COBY/api/response_formatter.py | 326 ++++++++++++++ COBY/api/rest_api.py | 391 +++++++++++++++++ COBY/api/websocket_server.py | 400 ++++++++++++++++++ 6 files changed, 1318 insertions(+) create mode 100644 COBY/api/__init__.py create mode 100644 COBY/api/rate_limiter.py create mode 100644 COBY/api/response_formatter.py create mode 100644 COBY/api/rest_api.py create mode 100644 COBY/api/websocket_server.py diff --git a/.kiro/specs/multi-exchange-data-aggregation/tasks.md b/.kiro/specs/multi-exchange-data-aggregation/tasks.md index 2d511e5..fc9e5a7 100644 --- a/.kiro/specs/multi-exchange-data-aggregation/tasks.md +++ b/.kiro/specs/multi-exchange-data-aggregation/tasks.md @@ -53,6 +53,9 @@ - Implement data processor for normalizing raw exchange data - Create validation logic for order book and trade data - Implement data quality checks and filtering + + + - Add metrics calculation for order book statistics - Write comprehensive unit tests for data processing logic - _Requirements: 1.4, 6.3, 8.1_ diff --git a/COBY/api/__init__.py b/COBY/api/__init__.py new file mode 100644 index 0000000..8e0d1c7 --- /dev/null +++ b/COBY/api/__init__.py @@ -0,0 +1,15 @@ +""" +API layer for the COBY system. +""" + +from .rest_api import create_app +from .websocket_server import WebSocketServer +from .rate_limiter import RateLimiter +from .response_formatter import ResponseFormatter + +__all__ = [ + 'create_app', + 'WebSocketServer', + 'RateLimiter', + 'ResponseFormatter' +] \ No newline at end of file diff --git a/COBY/api/rate_limiter.py b/COBY/api/rate_limiter.py new file mode 100644 index 0000000..52a629b --- /dev/null +++ b/COBY/api/rate_limiter.py @@ -0,0 +1,183 @@ +""" +Rate limiting for API endpoints. +""" + +import time +from typing import Dict, Optional +from collections import defaultdict, deque +from ..utils.logging import get_logger + +logger = get_logger(__name__) + + +class RateLimiter: + """ + Token bucket rate limiter for API endpoints. + + Provides per-client rate limiting with configurable limits. + """ + + def __init__(self, requests_per_minute: int = 100, burst_size: int = 20): + """ + Initialize rate limiter. + + Args: + requests_per_minute: Maximum requests per minute + burst_size: Maximum burst requests + """ + self.requests_per_minute = requests_per_minute + self.burst_size = burst_size + self.refill_rate = requests_per_minute / 60.0 # tokens per second + + # Client buckets: client_id -> {'tokens': float, 'last_refill': float} + self.buckets: Dict[str, Dict] = defaultdict(lambda: { + 'tokens': float(burst_size), + 'last_refill': time.time() + }) + + # Request history for monitoring + self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000)) + + logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}") + + def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool: + """ + Check if request is allowed for client. + + Args: + client_id: Client identifier (IP, user ID, etc.) + tokens_requested: Number of tokens requested + + Returns: + bool: True if request is allowed, False otherwise + """ + current_time = time.time() + bucket = self.buckets[client_id] + + # Refill tokens based on time elapsed + time_elapsed = current_time - bucket['last_refill'] + tokens_to_add = time_elapsed * self.refill_rate + + # Update bucket + bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add) + bucket['last_refill'] = current_time + + # Check if enough tokens available + if bucket['tokens'] >= tokens_requested: + bucket['tokens'] -= tokens_requested + + # Record successful request + self.request_history[client_id].append(current_time) + + return True + else: + logger.debug(f"Rate limit exceeded for client {client_id}") + return False + + def get_remaining_tokens(self, client_id: str) -> float: + """ + Get remaining tokens for client. + + Args: + client_id: Client identifier + + Returns: + float: Number of remaining tokens + """ + current_time = time.time() + bucket = self.buckets[client_id] + + # Calculate current tokens (with refill) + time_elapsed = current_time - bucket['last_refill'] + tokens_to_add = time_elapsed * self.refill_rate + current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add) + + return current_tokens + + def get_reset_time(self, client_id: str) -> float: + """ + Get time until bucket is fully refilled. + + Args: + client_id: Client identifier + + Returns: + float: Seconds until full refill + """ + remaining_tokens = self.get_remaining_tokens(client_id) + tokens_needed = self.burst_size - remaining_tokens + + if tokens_needed <= 0: + return 0.0 + + return tokens_needed / self.refill_rate + + def get_client_stats(self, client_id: str) -> Dict[str, float]: + """ + Get statistics for a client. + + Args: + client_id: Client identifier + + Returns: + Dict: Client statistics + """ + current_time = time.time() + history = self.request_history[client_id] + + # Count requests in last minute + minute_ago = current_time - 60 + recent_requests = sum(1 for req_time in history if req_time > minute_ago) + + return { + 'remaining_tokens': self.get_remaining_tokens(client_id), + 'reset_time': self.get_reset_time(client_id), + 'requests_last_minute': recent_requests, + 'total_requests': len(history) + } + + def cleanup_old_data(self, max_age_hours: int = 24) -> None: + """ + Clean up old client data. + + Args: + max_age_hours: Maximum age of data to keep + """ + current_time = time.time() + cutoff_time = current_time - (max_age_hours * 3600) + + # Clean up buckets for inactive clients + inactive_clients = [] + for client_id, bucket in self.buckets.items(): + if bucket['last_refill'] < cutoff_time: + inactive_clients.append(client_id) + + for client_id in inactive_clients: + del self.buckets[client_id] + if client_id in self.request_history: + del self.request_history[client_id] + + logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients") + + def get_global_stats(self) -> Dict[str, int]: + """Get global rate limiter statistics""" + current_time = time.time() + minute_ago = current_time - 60 + + total_clients = len(self.buckets) + active_clients = 0 + total_requests_last_minute = 0 + + for client_id, history in self.request_history.items(): + recent_requests = sum(1 for req_time in history if req_time > minute_ago) + if recent_requests > 0: + active_clients += 1 + total_requests_last_minute += recent_requests + + return { + 'total_clients': total_clients, + 'active_clients': active_clients, + 'requests_per_minute_limit': self.requests_per_minute, + 'burst_size': self.burst_size, + 'total_requests_last_minute': total_requests_last_minute + } \ No newline at end of file diff --git a/COBY/api/response_formatter.py b/COBY/api/response_formatter.py new file mode 100644 index 0000000..73374af --- /dev/null +++ b/COBY/api/response_formatter.py @@ -0,0 +1,326 @@ +""" +Response formatting for API endpoints. +""" + +import json +from typing import Any, Dict, Optional, List +from datetime import datetime +from ..utils.logging import get_logger +from ..utils.timing import get_current_timestamp + +logger = get_logger(__name__) + + +class ResponseFormatter: + """ + Formats API responses with consistent structure and metadata. + """ + + def __init__(self): + """Initialize response formatter""" + self.responses_formatted = 0 + logger.info("Response formatter initialized") + + def success(self, data: Any, message: str = "Success", + metadata: Optional[Dict] = None) -> Dict[str, Any]: + """ + Format successful response. + + Args: + data: Response data + message: Success message + metadata: Additional metadata + + Returns: + Dict: Formatted response + """ + response = { + 'success': True, + 'message': message, + 'data': data, + 'timestamp': get_current_timestamp().isoformat(), + 'metadata': metadata or {} + } + + self.responses_formatted += 1 + return response + + def error(self, message: str, error_code: str = "UNKNOWN_ERROR", + details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]: + """ + Format error response. + + Args: + message: Error message + error_code: Error code + details: Error details + status_code: HTTP status code + + Returns: + Dict: Formatted error response + """ + response = { + 'success': False, + 'error': { + 'message': message, + 'code': error_code, + 'details': details or {}, + 'status_code': status_code + }, + 'timestamp': get_current_timestamp().isoformat() + } + + self.responses_formatted += 1 + return response + + def paginated(self, data: List[Any], page: int, page_size: int, + total_count: int, message: str = "Success") -> Dict[str, Any]: + """ + Format paginated response. + + Args: + data: Page data + page: Current page number + page_size: Items per page + total_count: Total number of items + message: Success message + + Returns: + Dict: Formatted paginated response + """ + total_pages = (total_count + page_size - 1) // page_size + + pagination = { + 'page': page, + 'page_size': page_size, + 'total_count': total_count, + 'total_pages': total_pages, + 'has_next': page < total_pages, + 'has_previous': page > 1 + } + + return self.success( + data=data, + message=message, + metadata={'pagination': pagination} + ) + + def heatmap_response(self, heatmap_data, symbol: str, + exchange: Optional[str] = None) -> Dict[str, Any]: + """ + Format heatmap data response. + + Args: + heatmap_data: Heatmap data + symbol: Trading symbol + exchange: Exchange name (None for consolidated) + + Returns: + Dict: Formatted heatmap response + """ + if not heatmap_data: + return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404) + + # Convert heatmap to API format + formatted_data = { + 'symbol': heatmap_data.symbol, + 'timestamp': heatmap_data.timestamp.isoformat(), + 'bucket_size': heatmap_data.bucket_size, + 'exchange': exchange, + 'points': [ + { + 'price': point.price, + 'volume': point.volume, + 'intensity': point.intensity, + 'side': point.side + } + for point in heatmap_data.data + ] + } + + metadata = { + 'total_points': len(heatmap_data.data), + 'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']), + 'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']), + 'data_type': 'consolidated' if not exchange else 'exchange_specific' + } + + return self.success( + data=formatted_data, + message=f"Heatmap data for {symbol}", + metadata=metadata + ) + + def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]: + """ + Format order book response. + + Args: + orderbook_data: Order book data + symbol: Trading symbol + exchange: Exchange name + + Returns: + Dict: Formatted order book response + """ + if not orderbook_data: + return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404) + + # Convert order book to API format + formatted_data = { + 'symbol': orderbook_data.symbol, + 'exchange': orderbook_data.exchange, + 'timestamp': orderbook_data.timestamp.isoformat(), + 'sequence_id': orderbook_data.sequence_id, + 'bids': [ + { + 'price': bid.price, + 'size': bid.size, + 'count': bid.count + } + for bid in orderbook_data.bids + ], + 'asks': [ + { + 'price': ask.price, + 'size': ask.size, + 'count': ask.count + } + for ask in orderbook_data.asks + ], + 'mid_price': orderbook_data.mid_price, + 'spread': orderbook_data.spread, + 'bid_volume': orderbook_data.bid_volume, + 'ask_volume': orderbook_data.ask_volume + } + + metadata = { + 'bid_levels': len(orderbook_data.bids), + 'ask_levels': len(orderbook_data.asks), + 'total_bid_volume': orderbook_data.bid_volume, + 'total_ask_volume': orderbook_data.ask_volume + } + + return self.success( + data=formatted_data, + message=f"Order book for {symbol}@{exchange}", + metadata=metadata + ) + + def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]: + """ + Format metrics response. + + Args: + metrics_data: Metrics data + symbol: Trading symbol + exchange: Exchange name + + Returns: + Dict: Formatted metrics response + """ + if not metrics_data: + return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404) + + # Convert metrics to API format + formatted_data = { + 'symbol': metrics_data.symbol, + 'exchange': metrics_data.exchange, + 'timestamp': metrics_data.timestamp.isoformat(), + 'mid_price': metrics_data.mid_price, + 'spread': metrics_data.spread, + 'spread_percentage': metrics_data.spread_percentage, + 'bid_volume': metrics_data.bid_volume, + 'ask_volume': metrics_data.ask_volume, + 'volume_imbalance': metrics_data.volume_imbalance, + 'depth_10': metrics_data.depth_10, + 'depth_50': metrics_data.depth_50 + } + + return self.success( + data=formatted_data, + message=f"Metrics for {symbol}@{exchange}" + ) + + def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Format system status response. + + Args: + status_data: System status data + + Returns: + Dict: Formatted status response + """ + return self.success( + data=status_data, + message="System status", + metadata={'response_count': self.responses_formatted} + ) + + def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]: + """ + Format rate limit error response. + + Args: + client_stats: Client rate limit statistics + + Returns: + Dict: Formatted rate limit error + """ + return self.error( + message="Rate limit exceeded", + error_code="RATE_LIMIT_EXCEEDED", + details={ + 'remaining_tokens': client_stats['remaining_tokens'], + 'reset_time': client_stats['reset_time'], + 'requests_last_minute': client_stats['requests_last_minute'] + }, + status_code=429 + ) + + def validation_error(self, field: str, message: str) -> Dict[str, Any]: + """ + Format validation error response. + + Args: + field: Field that failed validation + message: Validation error message + + Returns: + Dict: Formatted validation error + """ + return self.error( + message=f"Validation error: {message}", + error_code="VALIDATION_ERROR", + details={'field': field, 'message': message}, + status_code=400 + ) + + def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str: + """ + Convert response to JSON string. + + Args: + response: Response dictionary + indent: JSON indentation (None for compact) + + Returns: + str: JSON string + """ + try: + return json.dumps(response, indent=indent, ensure_ascii=False, default=str) + except Exception as e: + logger.error(f"Error converting response to JSON: {e}") + return json.dumps(self.error("JSON serialization failed", "JSON_ERROR")) + + def get_stats(self) -> Dict[str, int]: + """Get formatter statistics""" + return { + 'responses_formatted': self.responses_formatted + } + + def reset_stats(self) -> None: + """Reset formatter statistics""" + self.responses_formatted = 0 + logger.info("Response formatter statistics reset") \ No newline at end of file diff --git a/COBY/api/rest_api.py b/COBY/api/rest_api.py new file mode 100644 index 0000000..11a9e6d --- /dev/null +++ b/COBY/api/rest_api.py @@ -0,0 +1,391 @@ +""" +REST API server for COBY system. +""" + +from fastapi import FastAPI, HTTPException, Request, Query, Path +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from typing import Optional, List +import asyncio +from ..config import config +from ..caching.redis_manager import redis_manager +from ..utils.logging import get_logger, set_correlation_id +from ..utils.validation import validate_symbol +from .rate_limiter import RateLimiter +from .response_formatter import ResponseFormatter + +logger = get_logger(__name__) + + +def create_app() -> FastAPI: + """Create and configure FastAPI application""" + + app = FastAPI( + title="COBY Market Data API", + description="Real-time cryptocurrency market data aggregation API", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc" + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=config.api.cors_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], + ) + + # Initialize components + rate_limiter = RateLimiter( + requests_per_minute=config.api.rate_limit, + burst_size=20 + ) + response_formatter = ResponseFormatter() + + @app.middleware("http") + async def rate_limit_middleware(request: Request, call_next): + """Rate limiting middleware""" + client_ip = request.client.host + + if not rate_limiter.is_allowed(client_ip): + client_stats = rate_limiter.get_client_stats(client_ip) + error_response = response_formatter.rate_limit_error(client_stats) + return JSONResponse( + status_code=429, + content=error_response, + headers={ + "X-RateLimit-Remaining": str(int(client_stats['remaining_tokens'])), + "X-RateLimit-Reset": str(int(client_stats['reset_time'])) + } + ) + + response = await call_next(request) + + # Add rate limit headers + client_stats = rate_limiter.get_client_stats(client_ip) + response.headers["X-RateLimit-Remaining"] = str(int(client_stats['remaining_tokens'])) + response.headers["X-RateLimit-Reset"] = str(int(client_stats['reset_time'])) + + return response + + @app.middleware("http") + async def correlation_middleware(request: Request, call_next): + """Add correlation ID to requests""" + set_correlation_id() + response = await call_next(request) + return response + + @app.on_event("startup") + async def startup_event(): + """Initialize services on startup""" + try: + await redis_manager.initialize() + logger.info("API server startup completed") + except Exception as e: + logger.error(f"API server startup failed: {e}") + raise + + @app.on_event("shutdown") + async def shutdown_event(): + """Cleanup on shutdown""" + try: + await redis_manager.close() + logger.info("API server shutdown completed") + except Exception as e: + logger.error(f"API server shutdown error: {e}") + + # Health check endpoint + @app.get("/health") + async def health_check(): + """Health check endpoint""" + try: + # Check Redis connection + redis_healthy = await redis_manager.ping() + + health_data = { + 'status': 'healthy' if redis_healthy else 'degraded', + 'redis': 'connected' if redis_healthy else 'disconnected', + 'version': '1.0.0' + } + + return response_formatter.status_response(health_data) + + except Exception as e: + logger.error(f"Health check failed: {e}") + return JSONResponse( + status_code=503, + content=response_formatter.error("Service unavailable", "HEALTH_CHECK_FAILED") + ) + + # Heatmap endpoints + @app.get("/api/v1/heatmap/{symbol}") + async def get_heatmap( + symbol: str = Path(..., description="Trading symbol (e.g., BTCUSDT)"), + exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)") + ): + """Get heatmap data for a symbol""" + try: + # Validate symbol + if not validate_symbol(symbol): + return JSONResponse( + status_code=400, + content=response_formatter.validation_error("symbol", "Invalid symbol format") + ) + + # Get heatmap from cache + heatmap_data = await redis_manager.get_heatmap(symbol.upper(), exchange) + + return response_formatter.heatmap_response(heatmap_data, symbol.upper(), exchange) + + except Exception as e: + logger.error(f"Error getting heatmap for {symbol}: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "HEATMAP_ERROR") + ) + + # Order book endpoints + @app.get("/api/v1/orderbook/{symbol}/{exchange}") + async def get_orderbook( + symbol: str = Path(..., description="Trading symbol"), + exchange: str = Path(..., description="Exchange name") + ): + """Get order book data for a symbol on an exchange""" + try: + # Validate symbol + if not validate_symbol(symbol): + return JSONResponse( + status_code=400, + content=response_formatter.validation_error("symbol", "Invalid symbol format") + ) + + # Get order book from cache + orderbook_data = await redis_manager.get_orderbook(symbol.upper(), exchange.lower()) + + return response_formatter.orderbook_response(orderbook_data, symbol.upper(), exchange.lower()) + + except Exception as e: + logger.error(f"Error getting order book for {symbol}@{exchange}: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "ORDERBOOK_ERROR") + ) + + # Metrics endpoints + @app.get("/api/v1/metrics/{symbol}/{exchange}") + async def get_metrics( + symbol: str = Path(..., description="Trading symbol"), + exchange: str = Path(..., description="Exchange name") + ): + """Get metrics data for a symbol on an exchange""" + try: + # Validate symbol + if not validate_symbol(symbol): + return JSONResponse( + status_code=400, + content=response_formatter.validation_error("symbol", "Invalid symbol format") + ) + + # Get metrics from cache + metrics_data = await redis_manager.get_metrics(symbol.upper(), exchange.lower()) + + return response_formatter.metrics_response(metrics_data, symbol.upper(), exchange.lower()) + + except Exception as e: + logger.error(f"Error getting metrics for {symbol}@{exchange}: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "METRICS_ERROR") + ) + + # Exchange status endpoints + @app.get("/api/v1/status/{exchange}") + async def get_exchange_status( + exchange: str = Path(..., description="Exchange name") + ): + """Get status for an exchange""" + try: + # Get status from cache + status_data = await redis_manager.get_exchange_status(exchange.lower()) + + if not status_data: + return JSONResponse( + status_code=404, + content=response_formatter.error("Exchange status not found", "STATUS_NOT_FOUND") + ) + + return response_formatter.success( + data=status_data, + message=f"Status for {exchange}" + ) + + except Exception as e: + logger.error(f"Error getting status for {exchange}: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "STATUS_ERROR") + ) + + # List endpoints + @app.get("/api/v1/symbols") + async def list_symbols(): + """List available trading symbols""" + try: + # Get symbols from cache (this would be populated by exchange connectors) + symbols_pattern = "symbols:*" + symbol_keys = await redis_manager.keys(symbols_pattern) + + all_symbols = set() + for key in symbol_keys: + symbols_data = await redis_manager.get(key) + if symbols_data and isinstance(symbols_data, list): + all_symbols.update(symbols_data) + + return response_formatter.success( + data=sorted(list(all_symbols)), + message="Available trading symbols", + metadata={'total_symbols': len(all_symbols)} + ) + + except Exception as e: + logger.error(f"Error listing symbols: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "SYMBOLS_ERROR") + ) + + @app.get("/api/v1/exchanges") + async def list_exchanges(): + """List available exchanges""" + try: + # Get exchange status keys + status_pattern = "st:*" + status_keys = await redis_manager.keys(status_pattern) + + exchanges = [] + for key in status_keys: + # Extract exchange name from key (st:exchange_name) + exchange_name = key.split(':', 1)[1] if ':' in key else key + exchanges.append(exchange_name) + + return response_formatter.success( + data=sorted(exchanges), + message="Available exchanges", + metadata={'total_exchanges': len(exchanges)} + ) + + except Exception as e: + logger.error(f"Error listing exchanges: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "EXCHANGES_ERROR") + ) + + # Statistics endpoints + @app.get("/api/v1/stats/cache") + async def get_cache_stats(): + """Get cache statistics""" + try: + cache_stats = redis_manager.get_stats() + redis_health = await redis_manager.health_check() + + stats_data = { + 'cache_performance': cache_stats, + 'redis_health': redis_health + } + + return response_formatter.success( + data=stats_data, + message="Cache statistics" + ) + + except Exception as e: + logger.error(f"Error getting cache stats: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "STATS_ERROR") + ) + + @app.get("/api/v1/stats/api") + async def get_api_stats(): + """Get API statistics""" + try: + api_stats = { + 'rate_limiter': rate_limiter.get_global_stats(), + 'response_formatter': response_formatter.get_stats() + } + + return response_formatter.success( + data=api_stats, + message="API statistics" + ) + + except Exception as e: + logger.error(f"Error getting API stats: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "API_STATS_ERROR") + ) + + # Batch endpoints for efficiency + @app.get("/api/v1/batch/heatmaps") + async def get_batch_heatmaps( + symbols: str = Query(..., description="Comma-separated list of symbols"), + exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)") + ): + """Get heatmaps for multiple symbols""" + try: + symbol_list = [s.strip().upper() for s in symbols.split(',')] + + # Validate all symbols + for symbol in symbol_list: + if not validate_symbol(symbol): + return JSONResponse( + status_code=400, + content=response_formatter.validation_error("symbols", f"Invalid symbol: {symbol}") + ) + + # Get heatmaps in batch + heatmaps = {} + for symbol in symbol_list: + heatmap_data = await redis_manager.get_heatmap(symbol, exchange) + if heatmap_data: + heatmaps[symbol] = { + 'symbol': heatmap_data.symbol, + 'timestamp': heatmap_data.timestamp.isoformat(), + 'bucket_size': heatmap_data.bucket_size, + 'points': [ + { + 'price': point.price, + 'volume': point.volume, + 'intensity': point.intensity, + 'side': point.side + } + for point in heatmap_data.data + ] + } + + return response_formatter.success( + data=heatmaps, + message=f"Batch heatmaps for {len(symbol_list)} symbols", + metadata={ + 'requested_symbols': len(symbol_list), + 'found_heatmaps': len(heatmaps), + 'exchange': exchange or 'consolidated' + } + ) + + except Exception as e: + logger.error(f"Error getting batch heatmaps: {e}") + return JSONResponse( + status_code=500, + content=response_formatter.error("Internal server error", "BATCH_HEATMAPS_ERROR") + ) + + return app + + +# Create the FastAPI app instance +app = create_app() \ No newline at end of file diff --git a/COBY/api/websocket_server.py b/COBY/api/websocket_server.py new file mode 100644 index 0000000..93f09b8 --- /dev/null +++ b/COBY/api/websocket_server.py @@ -0,0 +1,400 @@ +""" +WebSocket server for real-time data streaming. +""" + +import asyncio +import json +from typing import Dict, Set, Optional, Any +from fastapi import WebSocket, WebSocketDisconnect +from ..utils.logging import get_logger, set_correlation_id +from ..utils.validation import validate_symbol +from ..caching.redis_manager import redis_manager +from .response_formatter import ResponseFormatter + +logger = get_logger(__name__) + + +class WebSocketManager: + """ + Manages WebSocket connections and real-time data streaming. + """ + + def __init__(self): + """Initialize WebSocket manager""" + # Active connections: connection_id -> WebSocket + self.connections: Dict[str, WebSocket] = {} + + # Subscriptions: symbol -> set of connection_ids + self.subscriptions: Dict[str, Set[str]] = {} + + # Connection metadata: connection_id -> metadata + self.connection_metadata: Dict[str, Dict[str, Any]] = {} + + self.response_formatter = ResponseFormatter() + self.connection_counter = 0 + + logger.info("WebSocket manager initialized") + + async def connect(self, websocket: WebSocket, client_ip: str) -> str: + """ + Accept new WebSocket connection. + + Args: + websocket: WebSocket connection + client_ip: Client IP address + + Returns: + str: Connection ID + """ + await websocket.accept() + + # Generate connection ID + self.connection_counter += 1 + connection_id = f"ws_{self.connection_counter}_{client_ip}" + + # Store connection + self.connections[connection_id] = websocket + self.connection_metadata[connection_id] = { + 'client_ip': client_ip, + 'connected_at': asyncio.get_event_loop().time(), + 'subscriptions': set(), + 'messages_sent': 0 + } + + logger.info(f"WebSocket connected: {connection_id}") + + # Send welcome message + welcome_msg = self.response_formatter.success( + data={'connection_id': connection_id}, + message="WebSocket connected successfully" + ) + await self._send_to_connection(connection_id, welcome_msg) + + return connection_id + + async def disconnect(self, connection_id: str) -> None: + """ + Handle WebSocket disconnection. + + Args: + connection_id: Connection ID to disconnect + """ + if connection_id in self.connections: + # Remove from all subscriptions + metadata = self.connection_metadata.get(connection_id, {}) + for symbol in metadata.get('subscriptions', set()): + await self._unsubscribe_connection(connection_id, symbol) + + # Remove connection + del self.connections[connection_id] + del self.connection_metadata[connection_id] + + logger.info(f"WebSocket disconnected: {connection_id}") + + async def subscribe(self, connection_id: str, symbol: str, + data_type: str = "heatmap") -> bool: + """ + Subscribe connection to symbol updates. + + Args: + connection_id: Connection ID + symbol: Trading symbol + data_type: Type of data to subscribe to + + Returns: + bool: True if subscribed successfully + """ + try: + # Validate symbol + if not validate_symbol(symbol): + error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format") + await self._send_to_connection(connection_id, error_msg) + return False + + symbol = symbol.upper() + subscription_key = f"{symbol}:{data_type}" + + # Add to subscriptions + if subscription_key not in self.subscriptions: + self.subscriptions[subscription_key] = set() + + self.subscriptions[subscription_key].add(connection_id) + + # Update connection metadata + if connection_id in self.connection_metadata: + self.connection_metadata[connection_id]['subscriptions'].add(subscription_key) + + logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}") + + # Send confirmation + confirm_msg = self.response_formatter.success( + data={'symbol': symbol, 'data_type': data_type}, + message=f"Subscribed to {symbol} {data_type} updates" + ) + await self._send_to_connection(connection_id, confirm_msg) + + # Send initial data if available + await self._send_initial_data(connection_id, symbol, data_type) + + return True + + except Exception as e: + logger.error(f"Error subscribing {connection_id} to {symbol}: {e}") + error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR") + await self._send_to_connection(connection_id, error_msg) + return False + + async def unsubscribe(self, connection_id: str, symbol: str, + data_type: str = "heatmap") -> bool: + """ + Unsubscribe connection from symbol updates. + + Args: + connection_id: Connection ID + symbol: Trading symbol + data_type: Type of data to unsubscribe from + + Returns: + bool: True if unsubscribed successfully + """ + try: + symbol = symbol.upper() + subscription_key = f"{symbol}:{data_type}" + + await self._unsubscribe_connection(connection_id, subscription_key) + + # Send confirmation + confirm_msg = self.response_formatter.success( + data={'symbol': symbol, 'data_type': data_type}, + message=f"Unsubscribed from {symbol} {data_type} updates" + ) + await self._send_to_connection(connection_id, confirm_msg) + + return True + + except Exception as e: + logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}") + return False + + async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int: + """ + Broadcast data update to all subscribers. + + Args: + symbol: Trading symbol + data_type: Type of data + data: Data to broadcast + + Returns: + int: Number of connections notified + """ + try: + set_correlation_id() + + subscription_key = f"{symbol.upper()}:{data_type}" + subscribers = self.subscriptions.get(subscription_key, set()) + + if not subscribers: + return 0 + + # Format message based on data type + if data_type == "heatmap": + message = self.response_formatter.heatmap_response(data, symbol) + elif data_type == "orderbook": + message = self.response_formatter.orderbook_response(data, symbol, "consolidated") + else: + message = self.response_formatter.success(data, f"{data_type} update for {symbol}") + + # Add update type to message + message['update_type'] = data_type + message['symbol'] = symbol + + # Send to all subscribers + sent_count = 0 + for connection_id in subscribers.copy(): # Copy to avoid modification during iteration + if await self._send_to_connection(connection_id, message): + sent_count += 1 + + logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections") + return sent_count + + except Exception as e: + logger.error(f"Error broadcasting update for {symbol}: {e}") + return 0 + + async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool: + """ + Send message to specific connection. + + Args: + connection_id: Connection ID + message: Message to send + + Returns: + bool: True if sent successfully + """ + try: + if connection_id not in self.connections: + return False + + websocket = self.connections[connection_id] + message_json = json.dumps(message, default=str) + + await websocket.send_text(message_json) + + # Update statistics + if connection_id in self.connection_metadata: + self.connection_metadata[connection_id]['messages_sent'] += 1 + + return True + + except Exception as e: + logger.warning(f"Error sending message to {connection_id}: {e}") + # Remove broken connection + await self.disconnect(connection_id) + return False + + async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None: + """Remove connection from subscription""" + if subscription_key in self.subscriptions: + self.subscriptions[subscription_key].discard(connection_id) + + # Clean up empty subscriptions + if not self.subscriptions[subscription_key]: + del self.subscriptions[subscription_key] + + # Update connection metadata + if connection_id in self.connection_metadata: + self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key) + + async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None: + """Send initial data to newly subscribed connection""" + try: + if data_type == "heatmap": + # Get latest heatmap from cache + heatmap_data = await redis_manager.get_heatmap(symbol) + if heatmap_data: + message = self.response_formatter.heatmap_response(heatmap_data, symbol) + message['update_type'] = 'initial_data' + await self._send_to_connection(connection_id, message) + + elif data_type == "orderbook": + # Could get latest order book from cache + # This would require knowing which exchange to get data from + pass + + except Exception as e: + logger.warning(f"Error sending initial data to {connection_id}: {e}") + + def get_stats(self) -> Dict[str, Any]: + """Get WebSocket manager statistics""" + total_subscriptions = sum(len(subs) for subs in self.subscriptions.values()) + + return { + 'active_connections': len(self.connections), + 'total_subscriptions': total_subscriptions, + 'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())), + 'connection_counter': self.connection_counter + } + + +# Global WebSocket manager instance +websocket_manager = WebSocketManager() + + +class WebSocketServer: + """ + WebSocket server for real-time data streaming. + """ + + def __init__(self): + """Initialize WebSocket server""" + self.manager = websocket_manager + logger.info("WebSocket server initialized") + + async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None: + """ + Handle WebSocket connection lifecycle. + + Args: + websocket: WebSocket connection + client_ip: Client IP address + """ + connection_id = None + + try: + # Accept connection + connection_id = await self.manager.connect(websocket, client_ip) + + # Handle messages + while True: + try: + # Receive message + message = await websocket.receive_text() + await self._handle_message(connection_id, message) + + except WebSocketDisconnect: + logger.info(f"WebSocket client disconnected: {connection_id}") + break + + except Exception as e: + logger.error(f"WebSocket connection error: {e}") + + finally: + # Clean up connection + if connection_id: + await self.manager.disconnect(connection_id) + + async def _handle_message(self, connection_id: str, message: str) -> None: + """ + Handle incoming WebSocket message. + + Args: + connection_id: Connection ID + message: Received message + """ + try: + # Parse message + data = json.loads(message) + action = data.get('action') + + if action == 'subscribe': + symbol = data.get('symbol') + data_type = data.get('data_type', 'heatmap') + await self.manager.subscribe(connection_id, symbol, data_type) + + elif action == 'unsubscribe': + symbol = data.get('symbol') + data_type = data.get('data_type', 'heatmap') + await self.manager.unsubscribe(connection_id, symbol, data_type) + + elif action == 'ping': + # Send pong response + pong_msg = self.manager.response_formatter.success( + data={'action': 'pong'}, + message="Pong" + ) + await self.manager._send_to_connection(connection_id, pong_msg) + + else: + # Unknown action + error_msg = self.manager.response_formatter.error( + f"Unknown action: {action}", + "UNKNOWN_ACTION" + ) + await self.manager._send_to_connection(connection_id, error_msg) + + except json.JSONDecodeError: + error_msg = self.manager.response_formatter.error( + "Invalid JSON message", + "INVALID_JSON" + ) + await self.manager._send_to_connection(connection_id, error_msg) + + except Exception as e: + logger.error(f"Error handling WebSocket message: {e}") + error_msg = self.manager.response_formatter.error( + "Message processing failed", + "MESSAGE_ERROR" + ) + await self.manager._send_to_connection(connection_id, error_msg) \ No newline at end of file