orchestrator adaptor
This commit is contained in:
276
COBY/examples/orchestrator_integration_example.py
Normal file
276
COBY/examples/orchestrator_integration_example.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Example showing how to integrate COBY system with existing orchestrator.
|
||||
Demonstrates drop-in replacement and mode switching capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import the COBY data provider replacement
|
||||
from ..integration.data_provider_replacement import COBYDataProvider
|
||||
from ..integration.orchestrator_adapter import MarketTick
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def demonstrate_basic_usage():
|
||||
"""Demonstrate basic COBY data provider usage."""
|
||||
logger.info("=== Basic COBY Data Provider Usage ===")
|
||||
|
||||
# Initialize COBY data provider (drop-in replacement)
|
||||
data_provider = COBYDataProvider()
|
||||
|
||||
try:
|
||||
# Test basic data access methods
|
||||
logger.info("Testing basic data access...")
|
||||
|
||||
# Get current price
|
||||
current_price = data_provider.get_current_price('BTCUSDT')
|
||||
logger.info(f"Current BTC price: ${current_price}")
|
||||
|
||||
# Get historical data
|
||||
historical_data = data_provider.get_historical_data('BTCUSDT', '1m', limit=10)
|
||||
if historical_data is not None:
|
||||
logger.info(f"Historical data shape: {historical_data.shape}")
|
||||
logger.info(f"Latest close price: ${historical_data['close'].iloc[-1]}")
|
||||
|
||||
# Get COB data
|
||||
cob_data = data_provider.get_latest_cob_data('BTCUSDT')
|
||||
if cob_data:
|
||||
logger.info(f"Latest COB data: {cob_data}")
|
||||
|
||||
# Get data quality indicators
|
||||
quality = data_provider.adapter.get_data_quality_indicators('BTCUSDT')
|
||||
logger.info(f"Data quality score: {quality.get('quality_score', 0)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in basic usage: {e}")
|
||||
|
||||
finally:
|
||||
await data_provider.close()
|
||||
|
||||
|
||||
async def demonstrate_subscription_system():
|
||||
"""Demonstrate the subscription system."""
|
||||
logger.info("=== COBY Subscription System ===")
|
||||
|
||||
data_provider = COBYDataProvider()
|
||||
|
||||
try:
|
||||
# Set up tick subscription
|
||||
tick_count = 0
|
||||
|
||||
def tick_callback(tick: MarketTick):
|
||||
nonlocal tick_count
|
||||
tick_count += 1
|
||||
logger.info(f"Received tick #{tick_count}: {tick.symbol} @ ${tick.price}")
|
||||
|
||||
# Subscribe to ticks
|
||||
subscriber_id = data_provider.subscribe_to_ticks(
|
||||
tick_callback,
|
||||
symbols=['BTCUSDT', 'ETHUSDT'],
|
||||
subscriber_name='example_subscriber'
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to ticks with ID: {subscriber_id}")
|
||||
|
||||
# Set up COB data subscription
|
||||
cob_count = 0
|
||||
|
||||
def cob_callback(symbol: str, data: dict):
|
||||
nonlocal cob_count
|
||||
cob_count += 1
|
||||
logger.info(f"Received COB data #{cob_count} for {symbol}")
|
||||
|
||||
cob_subscriber_id = data_provider.subscribe_to_cob_raw_ticks(cob_callback)
|
||||
logger.info(f"Subscribed to COB data with ID: {cob_subscriber_id}")
|
||||
|
||||
# Wait for some data
|
||||
logger.info("Waiting for data updates...")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Unsubscribe
|
||||
data_provider.unsubscribe(subscriber_id)
|
||||
data_provider.unsubscribe(cob_subscriber_id)
|
||||
logger.info("Unsubscribed from all feeds")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscription demo: {e}")
|
||||
|
||||
finally:
|
||||
await data_provider.close()
|
||||
|
||||
|
||||
async def demonstrate_mode_switching():
|
||||
"""Demonstrate switching between live and replay modes."""
|
||||
logger.info("=== COBY Mode Switching ===")
|
||||
|
||||
data_provider = COBYDataProvider()
|
||||
|
||||
try:
|
||||
# Start in live mode
|
||||
current_mode = data_provider.get_current_mode()
|
||||
logger.info(f"Current mode: {current_mode}")
|
||||
|
||||
# Get some live data
|
||||
live_price = data_provider.get_current_price('BTCUSDT')
|
||||
logger.info(f"Live price: ${live_price}")
|
||||
|
||||
# Switch to replay mode
|
||||
logger.info("Switching to replay mode...")
|
||||
start_time = datetime.utcnow() - timedelta(hours=1)
|
||||
end_time = datetime.utcnow() - timedelta(minutes=30)
|
||||
|
||||
success = await data_provider.switch_to_replay_mode(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
speed=10.0, # 10x speed
|
||||
symbols=['BTCUSDT']
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Successfully switched to replay mode")
|
||||
|
||||
# Get replay status
|
||||
replay_status = data_provider.get_replay_status()
|
||||
if replay_status:
|
||||
logger.info(f"Replay progress: {replay_status['progress']:.2%}")
|
||||
logger.info(f"Replay speed: {replay_status['speed']}x")
|
||||
|
||||
# Wait for some replay data
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Get data during replay
|
||||
replay_price = data_provider.get_current_price('BTCUSDT')
|
||||
logger.info(f"Replay price: ${replay_price}")
|
||||
|
||||
# Switch back to live mode
|
||||
logger.info("Switching back to live mode...")
|
||||
success = await data_provider.switch_to_live_mode()
|
||||
|
||||
if success:
|
||||
logger.info("Successfully switched back to live mode")
|
||||
current_mode = data_provider.get_current_mode()
|
||||
logger.info(f"Current mode: {current_mode}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in mode switching demo: {e}")
|
||||
|
||||
finally:
|
||||
await data_provider.close()
|
||||
|
||||
|
||||
async def demonstrate_orchestrator_compatibility():
|
||||
"""Demonstrate compatibility with orchestrator interface."""
|
||||
logger.info("=== Orchestrator Compatibility ===")
|
||||
|
||||
data_provider = COBYDataProvider()
|
||||
|
||||
try:
|
||||
# Test methods that orchestrator uses
|
||||
logger.info("Testing orchestrator-compatible methods...")
|
||||
|
||||
# Build base data input (used by ML models)
|
||||
base_data = data_provider.build_base_data_input('BTCUSDT')
|
||||
if base_data:
|
||||
features = base_data.get_feature_vector()
|
||||
logger.info(f"Feature vector shape: {features.shape}")
|
||||
|
||||
# Get feature matrix (used by ML models)
|
||||
feature_matrix = data_provider.get_feature_matrix(
|
||||
'BTCUSDT',
|
||||
timeframes=['1m', '5m'],
|
||||
window_size=20
|
||||
)
|
||||
if feature_matrix is not None:
|
||||
logger.info(f"Feature matrix shape: {feature_matrix.shape}")
|
||||
|
||||
# Get pivot bounds (used for normalization)
|
||||
pivot_bounds = data_provider.get_pivot_bounds('BTCUSDT')
|
||||
if pivot_bounds:
|
||||
logger.info(f"Price range: ${pivot_bounds.price_min:.2f} - ${pivot_bounds.price_max:.2f}")
|
||||
|
||||
# Get COB imbalance (used for market microstructure analysis)
|
||||
imbalance = data_provider.get_current_cob_imbalance('BTCUSDT')
|
||||
logger.info(f"Order book imbalance: {imbalance['imbalance']:.3f}")
|
||||
|
||||
# Get system status
|
||||
status = data_provider.get_cached_data_summary()
|
||||
logger.info(f"System status: {status}")
|
||||
|
||||
# Test compatibility methods
|
||||
data_provider.start_centralized_data_collection()
|
||||
data_provider.invalidate_ohlcv_cache('BTCUSDT')
|
||||
|
||||
logger.info("All orchestrator compatibility tests passed!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in compatibility demo: {e}")
|
||||
|
||||
finally:
|
||||
await data_provider.close()
|
||||
|
||||
|
||||
async def demonstrate_performance_monitoring():
|
||||
"""Demonstrate performance monitoring capabilities."""
|
||||
logger.info("=== Performance Monitoring ===")
|
||||
|
||||
data_provider = COBYDataProvider()
|
||||
|
||||
try:
|
||||
# Get initial statistics
|
||||
initial_stats = data_provider.get_subscriber_stats()
|
||||
logger.info(f"Initial stats: {initial_stats}")
|
||||
|
||||
# Get data quality information
|
||||
quality_info = data_provider.get_cob_data_quality()
|
||||
logger.info(f"Data quality info: {quality_info}")
|
||||
|
||||
# Get WebSocket status
|
||||
ws_status = data_provider.get_cob_websocket_status()
|
||||
logger.info(f"WebSocket status: {ws_status}")
|
||||
|
||||
# Monitor system metadata
|
||||
system_metadata = data_provider.adapter.get_system_metadata()
|
||||
logger.info(f"System components health: {system_metadata['components']}")
|
||||
logger.info(f"Active subscribers: {system_metadata['active_subscribers']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in performance monitoring: {e}")
|
||||
|
||||
finally:
|
||||
await data_provider.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all demonstration examples."""
|
||||
logger.info("Starting COBY Integration Examples...")
|
||||
|
||||
try:
|
||||
# Run all demonstrations
|
||||
await demonstrate_basic_usage()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await demonstrate_subscription_system()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await demonstrate_mode_switching()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await demonstrate_orchestrator_compatibility()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await demonstrate_performance_monitoring()
|
||||
|
||||
logger.info("All COBY integration examples completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running examples: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the examples
|
||||
asyncio.run(main())
|
8
COBY/integration/__init__.py
Normal file
8
COBY/integration/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""
|
||||
Integration layer for the COBY multi-exchange data aggregation system.
|
||||
Provides compatibility interfaces for seamless integration with existing systems.
|
||||
"""
|
||||
|
||||
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
|
||||
|
||||
__all__ = ['COBYOrchestratorAdapter', 'MarketTick', 'PivotBounds']
|
390
COBY/integration/data_provider_replacement.py
Normal file
390
COBY/integration/data_provider_replacement.py
Normal file
@ -0,0 +1,390 @@
|
||||
"""
|
||||
Drop-in replacement for the existing DataProvider class using COBY system.
|
||||
Provides full compatibility with the orchestrator interface.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable, Union
|
||||
from pathlib import Path
|
||||
|
||||
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
|
||||
from ..config import Config
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class COBYDataProvider:
|
||||
"""
|
||||
Drop-in replacement for DataProvider using COBY system.
|
||||
|
||||
Provides full compatibility with existing orchestrator interface while
|
||||
leveraging COBY's multi-exchange data aggregation capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize COBY data provider.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to configuration file
|
||||
"""
|
||||
# Initialize COBY configuration
|
||||
self.config = Config()
|
||||
|
||||
# Initialize COBY adapter
|
||||
self.adapter = COBYOrchestratorAdapter(self.config)
|
||||
|
||||
# Initialize adapter components
|
||||
asyncio.run(self.adapter._initialize_components())
|
||||
|
||||
# Compatibility attributes
|
||||
self.symbols = self.config.exchanges.symbols
|
||||
self.exchanges = self.config.exchanges.exchanges
|
||||
|
||||
logger.info("COBY data provider initialized")
|
||||
|
||||
# === CORE DATA METHODS ===
|
||||
|
||||
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
|
||||
refresh: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""Get historical OHLCV data."""
|
||||
return self.adapter.get_historical_data(symbol, timeframe, limit, refresh)
|
||||
|
||||
def get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for a symbol."""
|
||||
return self.adapter.get_current_price(symbol)
|
||||
|
||||
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
|
||||
"""Get live price from API (low-latency method)."""
|
||||
return self.adapter.get_live_price_from_api(symbol)
|
||||
|
||||
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
||||
"""Build base data input for ML models."""
|
||||
return self.adapter.build_base_data_input(symbol)
|
||||
|
||||
# === COB DATA METHODS ===
|
||||
|
||||
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
|
||||
"""Get raw COB ticks for a symbol."""
|
||||
return self.adapter.get_cob_raw_ticks(symbol, count)
|
||||
|
||||
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
|
||||
"""Get 1s aggregated COB data with $1 price buckets."""
|
||||
return self.adapter.get_cob_1s_aggregated(symbol, count)
|
||||
|
||||
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest COB raw tick for a symbol."""
|
||||
return self.adapter.get_latest_cob_data(symbol)
|
||||
|
||||
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest 1s aggregated COB data for a symbol."""
|
||||
return self.adapter.get_latest_cob_aggregated(symbol)
|
||||
|
||||
def get_current_cob_imbalance(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get current COB imbalance metrics for a symbol."""
|
||||
try:
|
||||
latest_data = self.get_latest_cob_data(symbol)
|
||||
if not latest_data:
|
||||
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
||||
|
||||
bid_volume = latest_data.get('bid_volume', 0.0)
|
||||
ask_volume = latest_data.get('ask_volume', 0.0)
|
||||
total_volume = bid_volume + ask_volume
|
||||
|
||||
imbalance = 0.0
|
||||
if total_volume > 0:
|
||||
imbalance = (bid_volume - ask_volume) / total_volume
|
||||
|
||||
return {
|
||||
'bid_volume': bid_volume,
|
||||
'ask_volume': ask_volume,
|
||||
'imbalance': imbalance
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB imbalance for {symbol}: {e}")
|
||||
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
||||
|
||||
def get_cob_price_buckets(self, symbol: str, timeframe_seconds: int = 60) -> Dict:
|
||||
"""Get price bucket analysis for a timeframe."""
|
||||
try:
|
||||
# Get aggregated data for the timeframe
|
||||
count = timeframe_seconds # 1 second per data point
|
||||
aggregated_data = self.get_cob_1s_aggregated(symbol, count)
|
||||
|
||||
if not aggregated_data:
|
||||
return {}
|
||||
|
||||
# Combine all buckets
|
||||
combined_bid_buckets = {}
|
||||
combined_ask_buckets = {}
|
||||
|
||||
for data_point in aggregated_data:
|
||||
for price, volume in data_point.get('bid_buckets', {}).items():
|
||||
combined_bid_buckets[price] = combined_bid_buckets.get(price, 0) + volume
|
||||
|
||||
for price, volume in data_point.get('ask_buckets', {}).items():
|
||||
combined_ask_buckets[price] = combined_ask_buckets.get(price, 0) + volume
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timeframe_seconds': timeframe_seconds,
|
||||
'bid_buckets': combined_bid_buckets,
|
||||
'ask_buckets': combined_ask_buckets,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price buckets for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_cob_websocket_status(self) -> Dict[str, Any]:
|
||||
"""Get COB WebSocket status."""
|
||||
try:
|
||||
system_metadata = self.adapter.get_system_metadata()
|
||||
connectors = system_metadata.get('components', {}).get('connectors', {})
|
||||
|
||||
return {
|
||||
'connected': any(connectors.values()),
|
||||
'exchanges': connectors,
|
||||
'last_update': datetime.utcnow().isoformat(),
|
||||
'mode': system_metadata.get('mode', 'unknown')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting WebSocket status: {e}")
|
||||
return {'connected': False, 'error': str(e)}
|
||||
|
||||
# === SUBSCRIPTION METHODS ===
|
||||
|
||||
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
||||
symbols: List[str] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""Subscribe to tick data updates."""
|
||||
return self.adapter.subscribe_to_ticks(callback, symbols, subscriber_name)
|
||||
|
||||
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to raw COB tick updates."""
|
||||
return self.adapter.subscribe_to_cob_raw_ticks(callback)
|
||||
|
||||
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to 1s aggregated COB updates."""
|
||||
return self.adapter.subscribe_to_cob_aggregated(callback)
|
||||
|
||||
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to training data updates."""
|
||||
return self.adapter.subscribe_to_training_data(callback)
|
||||
|
||||
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to model prediction updates."""
|
||||
return self.adapter.subscribe_to_model_predictions(callback)
|
||||
|
||||
def unsubscribe(self, subscriber_id: str) -> bool:
|
||||
"""Unsubscribe from data feeds."""
|
||||
return self.adapter.unsubscribe(subscriber_id)
|
||||
|
||||
# === MODE SWITCHING ===
|
||||
|
||||
async def switch_to_live_mode(self) -> bool:
|
||||
"""Switch to live data mode."""
|
||||
return await self.adapter.switch_to_live_mode()
|
||||
|
||||
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
|
||||
speed: float = 1.0, symbols: List[str] = None) -> bool:
|
||||
"""Switch to replay data mode."""
|
||||
return await self.adapter.switch_to_replay_mode(start_time, end_time, speed, symbols)
|
||||
|
||||
def get_current_mode(self) -> str:
|
||||
"""Get current data mode."""
|
||||
return self.adapter.get_current_mode()
|
||||
|
||||
def get_replay_status(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get replay session status."""
|
||||
return self.adapter.get_replay_status()
|
||||
|
||||
# === COMPATIBILITY METHODS ===
|
||||
|
||||
def start_centralized_data_collection(self) -> None:
|
||||
"""Start centralized data collection."""
|
||||
self.adapter.start_centralized_data_collection()
|
||||
|
||||
def start_training_data_collection(self) -> None:
|
||||
"""Start training data collection."""
|
||||
self.adapter.start_training_data_collection()
|
||||
|
||||
def invalidate_ohlcv_cache(self, symbol: str) -> None:
|
||||
"""Invalidate OHLCV cache for a symbol."""
|
||||
self.adapter.invalidate_ohlcv_cache(symbol)
|
||||
|
||||
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
|
||||
"""Get the latest candles from cached data."""
|
||||
return self.get_historical_data(symbol, timeframe, limit) or pd.DataFrame()
|
||||
|
||||
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
|
||||
"""Get price at specific index for backtesting."""
|
||||
try:
|
||||
df = self.get_historical_data(symbol, timeframe, limit=index + 10)
|
||||
if df is not None and len(df) > index:
|
||||
return float(df.iloc[-(index + 1)]['close'])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price at index {index} for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# === PIVOT AND MARKET STRUCTURE (MOCK IMPLEMENTATIONS) ===
|
||||
|
||||
def get_pivot_bounds(self, symbol: str) -> Optional[PivotBounds]:
|
||||
"""Get pivot bounds for a symbol (mock implementation)."""
|
||||
try:
|
||||
# Get recent price data
|
||||
df = self.get_historical_data(symbol, '1m', limit=1000)
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Calculate basic pivot levels
|
||||
high_prices = df['high'].values
|
||||
low_prices = df['low'].values
|
||||
volumes = df['volume'].values
|
||||
|
||||
price_max = float(np.max(high_prices))
|
||||
price_min = float(np.min(low_prices))
|
||||
volume_max = float(np.max(volumes))
|
||||
volume_min = float(np.min(volumes))
|
||||
|
||||
# Simple support/resistance calculation
|
||||
price_range = price_max - price_min
|
||||
support_levels = [price_min + i * price_range / 10 for i in range(1, 5)]
|
||||
resistance_levels = [price_max - i * price_range / 10 for i in range(1, 5)]
|
||||
|
||||
return PivotBounds(
|
||||
symbol=symbol,
|
||||
price_max=price_max,
|
||||
price_min=price_min,
|
||||
volume_max=volume_max,
|
||||
volume_min=volume_min,
|
||||
pivot_support_levels=support_levels,
|
||||
pivot_resistance_levels=resistance_levels,
|
||||
pivot_context={'method': 'simple'},
|
||||
created_timestamp=datetime.utcnow(),
|
||||
data_period_start=df.index[0].to_pydatetime(),
|
||||
data_period_end=df.index[-1].to_pydatetime(),
|
||||
total_candles_analyzed=len(df)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot bounds for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_pivot_normalized_features(self, symbol: str, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
"""Get dataframe with pivot-normalized features."""
|
||||
try:
|
||||
pivot_bounds = self.get_pivot_bounds(symbol)
|
||||
if not pivot_bounds:
|
||||
return df
|
||||
|
||||
# Add normalized features
|
||||
df_copy = df.copy()
|
||||
price_range = pivot_bounds.get_price_range()
|
||||
|
||||
if price_range > 0:
|
||||
df_copy['normalized_close'] = (df_copy['close'] - pivot_bounds.price_min) / price_range
|
||||
df_copy['normalized_high'] = (df_copy['high'] - pivot_bounds.price_min) / price_range
|
||||
df_copy['normalized_low'] = (df_copy['low'] - pivot_bounds.price_min) / price_range
|
||||
|
||||
return df_copy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot normalized features for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
# === FEATURE EXTRACTION METHODS ===
|
||||
|
||||
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
|
||||
window_size: int = 20) -> Optional[np.ndarray]:
|
||||
"""Get feature matrix for ML models."""
|
||||
try:
|
||||
if not timeframes:
|
||||
timeframes = ['1m', '5m', '15m']
|
||||
|
||||
features = []
|
||||
|
||||
for timeframe in timeframes:
|
||||
df = self.get_historical_data(symbol, timeframe, limit=window_size + 10)
|
||||
if df is not None and len(df) >= window_size:
|
||||
# Extract basic features
|
||||
closes = df['close'].values[-window_size:]
|
||||
volumes = df['volume'].values[-window_size:]
|
||||
|
||||
# Normalize features
|
||||
close_mean = np.mean(closes)
|
||||
close_std = np.std(closes) + 1e-8
|
||||
normalized_closes = (closes - close_mean) / close_std
|
||||
|
||||
volume_mean = np.mean(volumes)
|
||||
volume_std = np.std(volumes) + 1e-8
|
||||
normalized_volumes = (volumes - volume_mean) / volume_std
|
||||
|
||||
features.extend(normalized_closes)
|
||||
features.extend(normalized_volumes)
|
||||
|
||||
if features:
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting feature matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# === SYSTEM STATUS AND STATISTICS ===
|
||||
|
||||
def get_cached_data_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of cached data."""
|
||||
try:
|
||||
system_metadata = self.adapter.get_system_metadata()
|
||||
return {
|
||||
'system': 'COBY',
|
||||
'mode': system_metadata.get('mode'),
|
||||
'statistics': system_metadata.get('statistics', {}),
|
||||
'components_healthy': system_metadata.get('components', {}),
|
||||
'active_subscribers': system_metadata.get('active_subscribers', 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cached data summary: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_cob_data_quality(self) -> Dict[str, Any]:
|
||||
"""Get COB data quality information."""
|
||||
try:
|
||||
quality_info = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
quality_info[symbol] = self.adapter.get_data_quality_indicators(symbol)
|
||||
|
||||
return quality_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB data quality: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_subscriber_stats(self) -> Dict[str, Any]:
|
||||
"""Get subscriber statistics."""
|
||||
return self.adapter.get_stats()
|
||||
|
||||
# === CLEANUP ===
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all connections and cleanup."""
|
||||
await self.adapter.close()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
try:
|
||||
asyncio.run(self.close())
|
||||
except:
|
||||
pass
|
888
COBY/integration/orchestrator_adapter.py
Normal file
888
COBY/integration/orchestrator_adapter.py
Normal file
@ -0,0 +1,888 @@
|
||||
"""
|
||||
Orchestrator integration adapter for COBY system.
|
||||
Provides compatibility layer for seamless integration with existing orchestrator.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable, Union
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
from ..storage.storage_manager import StorageManager
|
||||
from ..replay.replay_manager import HistoricalReplayManager
|
||||
from ..caching.redis_manager import RedisManager
|
||||
from ..aggregation.aggregation_engine import StandardAggregationEngine
|
||||
from ..processing.data_processor import StandardDataProcessor
|
||||
from ..connectors.binance_connector import BinanceConnector
|
||||
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, ReplayStatus
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import IntegrationError, ValidationError
|
||||
from ..config import Config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketTick:
|
||||
"""Market tick data structure compatible with orchestrator"""
|
||||
symbol: str
|
||||
price: float
|
||||
volume: float
|
||||
timestamp: datetime
|
||||
side: str = "unknown"
|
||||
exchange: str = "binance"
|
||||
subscriber_name: str = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PivotBounds:
|
||||
"""Pivot bounds structure compatible with orchestrator"""
|
||||
symbol: str
|
||||
price_max: float
|
||||
price_min: float
|
||||
volume_max: float
|
||||
volume_min: float
|
||||
pivot_support_levels: List[float]
|
||||
pivot_resistance_levels: List[float]
|
||||
pivot_context: Dict[str, Any]
|
||||
created_timestamp: datetime
|
||||
data_period_start: datetime
|
||||
data_period_end: datetime
|
||||
total_candles_analyzed: int
|
||||
|
||||
def get_price_range(self) -> float:
|
||||
return self.price_max - self.price_min
|
||||
|
||||
def normalize_price(self, price: float) -> float:
|
||||
return (price - self.price_min) / self.get_price_range()
|
||||
|
||||
|
||||
class COBYOrchestratorAdapter:
|
||||
"""
|
||||
Adapter that makes COBY system compatible with existing orchestrator interface.
|
||||
|
||||
Provides:
|
||||
- Data provider interface compatibility
|
||||
- Live/replay mode switching
|
||||
- Data quality indicators
|
||||
- Subscription management
|
||||
- Caching and performance optimization
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""
|
||||
Initialize orchestrator adapter.
|
||||
|
||||
Args:
|
||||
config: COBY system configuration
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Core components
|
||||
self.storage_manager = StorageManager(config)
|
||||
self.replay_manager = HistoricalReplayManager(self.storage_manager, config)
|
||||
self.redis_manager = RedisManager()
|
||||
self.aggregation_engine = StandardAggregationEngine()
|
||||
self.data_processor = StandardDataProcessor()
|
||||
|
||||
# Exchange connectors
|
||||
self.connectors = {
|
||||
'binance': BinanceConnector()
|
||||
}
|
||||
|
||||
# Mode management
|
||||
self.mode = 'live' # 'live' or 'replay'
|
||||
self.current_replay_session = None
|
||||
|
||||
# Subscription management
|
||||
self.subscribers = {
|
||||
'ticks': {},
|
||||
'cob_raw': {},
|
||||
'cob_aggregated': {},
|
||||
'training_data': {},
|
||||
'model_predictions': {}
|
||||
}
|
||||
self.subscriber_lock = threading.Lock()
|
||||
|
||||
# Data caching
|
||||
self.tick_cache = {}
|
||||
self.orderbook_cache = {}
|
||||
self.price_cache = {}
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'ticks_processed': 0,
|
||||
'orderbooks_processed': 0,
|
||||
'subscribers_active': 0,
|
||||
'cache_hits': 0,
|
||||
'cache_misses': 0
|
||||
}
|
||||
|
||||
# Initialize components
|
||||
self._initialize_components()
|
||||
|
||||
logger.info("COBY orchestrator adapter initialized")
|
||||
|
||||
async def _initialize_components(self):
|
||||
"""Initialize all COBY components."""
|
||||
try:
|
||||
# Initialize storage
|
||||
await self.storage_manager.initialize()
|
||||
|
||||
# Initialize Redis cache
|
||||
await self.redis_manager.initialize()
|
||||
|
||||
# Initialize connectors
|
||||
for name, connector in self.connectors.items():
|
||||
await connector.connect()
|
||||
connector.add_data_callback(self._handle_connector_data)
|
||||
|
||||
logger.info("COBY components initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize COBY components: {e}")
|
||||
raise IntegrationError(f"Component initialization failed: {e}")
|
||||
|
||||
# === ORCHESTRATOR COMPATIBILITY METHODS ===
|
||||
|
||||
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
|
||||
refresh: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""Get historical OHLCV data compatible with orchestrator interface."""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
# Convert timeframe to minutes
|
||||
timeframe_minutes = self._parse_timeframe(timeframe)
|
||||
if not timeframe_minutes:
|
||||
logger.warning(f"Unsupported timeframe: {timeframe}")
|
||||
return None
|
||||
|
||||
# Calculate time range
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(minutes=timeframe_minutes * limit)
|
||||
|
||||
# Get data from storage
|
||||
if self.mode == 'replay' and self.current_replay_session:
|
||||
# Use replay data
|
||||
data = asyncio.run(self.storage_manager.get_historical_data(
|
||||
symbol, start_time, end_time, 'ohlcv'
|
||||
))
|
||||
else:
|
||||
# Use live data from cache or storage
|
||||
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
|
||||
cached_data = asyncio.run(self.redis_manager.get(cache_key))
|
||||
|
||||
if cached_data and not refresh:
|
||||
self.stats['cache_hits'] += 1
|
||||
return pd.DataFrame(cached_data)
|
||||
|
||||
self.stats['cache_misses'] += 1
|
||||
data = asyncio.run(self.storage_manager.get_historical_data(
|
||||
symbol, start_time, end_time, 'ohlcv'
|
||||
))
|
||||
|
||||
# Cache the result
|
||||
if data:
|
||||
asyncio.run(self.redis_manager.set(cache_key, data, ttl=60))
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Convert to DataFrame compatible with orchestrator
|
||||
df = pd.DataFrame(data)
|
||||
if not df.empty:
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df = df.sort_index()
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting historical data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for a symbol."""
|
||||
try:
|
||||
# Check cache first
|
||||
if symbol in self.price_cache:
|
||||
cached_price, timestamp = self.price_cache[symbol]
|
||||
if (datetime.utcnow() - timestamp).seconds < 5: # 5 second cache
|
||||
return cached_price
|
||||
|
||||
# Get latest orderbook
|
||||
latest_orderbook = asyncio.run(
|
||||
self.storage_manager.get_latest_orderbook(symbol)
|
||||
)
|
||||
|
||||
if latest_orderbook and latest_orderbook.get('mid_price'):
|
||||
price = float(latest_orderbook['mid_price'])
|
||||
self.price_cache[symbol] = (price, datetime.utcnow())
|
||||
return price
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
|
||||
"""Get live price from API (low-latency method)."""
|
||||
return self.get_current_price(symbol)
|
||||
|
||||
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
||||
"""Build base data input compatible with orchestrator models."""
|
||||
try:
|
||||
# This would need to be implemented based on the specific
|
||||
# BaseDataInput class used by the orchestrator
|
||||
# For now, return a mock object that provides the interface
|
||||
|
||||
class MockBaseDataInput:
|
||||
def __init__(self, symbol: str, adapter):
|
||||
self.symbol = symbol
|
||||
self.adapter = adapter
|
||||
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
# Return feature vector from COBY data
|
||||
return self.adapter._get_feature_vector(self.symbol)
|
||||
|
||||
return MockBaseDataInput(symbol, self)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building base data input for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_feature_vector(self, symbol: str) -> np.ndarray:
|
||||
"""Get feature vector for ML models."""
|
||||
try:
|
||||
# Get latest market data
|
||||
latest_orderbook = asyncio.run(
|
||||
self.storage_manager.get_latest_orderbook(symbol)
|
||||
)
|
||||
|
||||
if not latest_orderbook:
|
||||
return np.zeros(100, dtype=np.float32) # Default size
|
||||
|
||||
# Extract features from orderbook
|
||||
features = []
|
||||
|
||||
# Price features
|
||||
if latest_orderbook.get('mid_price'):
|
||||
features.append(float(latest_orderbook['mid_price']))
|
||||
if latest_orderbook.get('spread'):
|
||||
features.append(float(latest_orderbook['spread']))
|
||||
|
||||
# Volume features
|
||||
if latest_orderbook.get('bid_volume'):
|
||||
features.append(float(latest_orderbook['bid_volume']))
|
||||
if latest_orderbook.get('ask_volume'):
|
||||
features.append(float(latest_orderbook['ask_volume']))
|
||||
|
||||
# Pad or truncate to expected size
|
||||
target_size = 100
|
||||
if len(features) < target_size:
|
||||
features.extend([0.0] * (target_size - len(features)))
|
||||
elif len(features) > target_size:
|
||||
features = features[:target_size]
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting feature vector for {symbol}: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
# === COB DATA METHODS ===
|
||||
|
||||
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
|
||||
"""Get raw COB ticks for a symbol."""
|
||||
try:
|
||||
# Get recent orderbook snapshots
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(minutes=15) # 15 minutes of data
|
||||
|
||||
data = asyncio.run(self.storage_manager.get_historical_data(
|
||||
symbol, start_time, end_time, 'orderbook'
|
||||
))
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# Convert to COB tick format
|
||||
ticks = []
|
||||
for item in data[-count:]:
|
||||
tick = {
|
||||
'symbol': item['symbol'],
|
||||
'timestamp': item['timestamp'].isoformat(),
|
||||
'mid_price': item.get('mid_price'),
|
||||
'spread': item.get('spread'),
|
||||
'bid_volume': item.get('bid_volume'),
|
||||
'ask_volume': item.get('ask_volume'),
|
||||
'exchange': item['exchange']
|
||||
}
|
||||
ticks.append(tick)
|
||||
|
||||
return ticks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB raw ticks for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
|
||||
"""Get 1s aggregated COB data with $1 price buckets."""
|
||||
try:
|
||||
# Get heatmap data
|
||||
bucket_size = self.config.aggregation.bucket_size
|
||||
start_time = datetime.utcnow() - timedelta(seconds=count)
|
||||
|
||||
heatmap_data = asyncio.run(
|
||||
self.storage_manager.get_heatmap_data(symbol, bucket_size, start_time)
|
||||
)
|
||||
|
||||
if not heatmap_data:
|
||||
return []
|
||||
|
||||
# Group by timestamp and aggregate
|
||||
aggregated = {}
|
||||
for item in heatmap_data:
|
||||
timestamp = item['timestamp']
|
||||
if timestamp not in aggregated:
|
||||
aggregated[timestamp] = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'bid_buckets': {},
|
||||
'ask_buckets': {},
|
||||
'total_bid_volume': 0,
|
||||
'total_ask_volume': 0
|
||||
}
|
||||
|
||||
price_bucket = float(item['price_bucket'])
|
||||
volume = float(item['volume'])
|
||||
side = item['side']
|
||||
|
||||
if side == 'bid':
|
||||
aggregated[timestamp]['bid_buckets'][price_bucket] = volume
|
||||
aggregated[timestamp]['total_bid_volume'] += volume
|
||||
else:
|
||||
aggregated[timestamp]['ask_buckets'][price_bucket] = volume
|
||||
aggregated[timestamp]['total_ask_volume'] += volume
|
||||
|
||||
# Return sorted by timestamp
|
||||
result = list(aggregated.values())
|
||||
result.sort(key=lambda x: x['timestamp'])
|
||||
return result[-count:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB 1s aggregated data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest COB raw tick for a symbol."""
|
||||
try:
|
||||
latest_orderbook = asyncio.run(
|
||||
self.storage_manager.get_latest_orderbook(symbol)
|
||||
)
|
||||
|
||||
if not latest_orderbook:
|
||||
return None
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timestamp': latest_orderbook['timestamp'].isoformat(),
|
||||
'mid_price': latest_orderbook.get('mid_price'),
|
||||
'spread': latest_orderbook.get('spread'),
|
||||
'bid_volume': latest_orderbook.get('bid_volume'),
|
||||
'ask_volume': latest_orderbook.get('ask_volume'),
|
||||
'exchange': latest_orderbook['exchange']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting latest COB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest 1s aggregated COB data for a symbol."""
|
||||
try:
|
||||
aggregated_data = self.get_cob_1s_aggregated(symbol, count=1)
|
||||
return aggregated_data[0] if aggregated_data else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting latest COB aggregated data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# === SUBSCRIPTION METHODS ===
|
||||
|
||||
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
||||
symbols: List[str] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""Subscribe to tick data updates."""
|
||||
try:
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers['ticks'][subscriber_id] = {
|
||||
'callback': callback,
|
||||
'symbols': symbols or [],
|
||||
'subscriber_name': subscriber_name or 'unknown',
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
self.stats['subscribers_active'] += 1
|
||||
|
||||
logger.info(f"Added tick subscriber {subscriber_id} for {subscriber_name}")
|
||||
return subscriber_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding tick subscriber: {e}")
|
||||
return ""
|
||||
|
||||
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to raw COB tick updates."""
|
||||
try:
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers['cob_raw'][subscriber_id] = {
|
||||
'callback': callback,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
self.stats['subscribers_active'] += 1
|
||||
|
||||
logger.info(f"Added COB raw tick subscriber {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding COB raw tick subscriber: {e}")
|
||||
return ""
|
||||
|
||||
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to 1s aggregated COB updates."""
|
||||
try:
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers['cob_aggregated'][subscriber_id] = {
|
||||
'callback': callback,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
self.stats['subscribers_active'] += 1
|
||||
|
||||
logger.info(f"Added COB aggregated subscriber {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding COB aggregated subscriber: {e}")
|
||||
return ""
|
||||
|
||||
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to training data updates."""
|
||||
try:
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers['training_data'][subscriber_id] = {
|
||||
'callback': callback,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
self.stats['subscribers_active'] += 1
|
||||
|
||||
logger.info(f"Added training data subscriber {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data subscriber: {e}")
|
||||
return ""
|
||||
|
||||
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to model prediction updates."""
|
||||
try:
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers['model_predictions'][subscriber_id] = {
|
||||
'callback': callback,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
self.stats['subscribers_active'] += 1
|
||||
|
||||
logger.info(f"Added model prediction subscriber {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model prediction subscriber: {e}")
|
||||
return ""
|
||||
|
||||
def unsubscribe(self, subscriber_id: str) -> bool:
|
||||
"""Unsubscribe from all data feeds."""
|
||||
try:
|
||||
with self.subscriber_lock:
|
||||
removed = False
|
||||
for category in self.subscribers:
|
||||
if subscriber_id in self.subscribers[category]:
|
||||
del self.subscribers[category][subscriber_id]
|
||||
self.stats['subscribers_active'] -= 1
|
||||
removed = True
|
||||
break
|
||||
|
||||
if removed:
|
||||
logger.info(f"Removed subscriber {subscriber_id}")
|
||||
|
||||
return removed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing subscriber {subscriber_id}: {e}")
|
||||
return False
|
||||
|
||||
# === MODE SWITCHING ===
|
||||
|
||||
async def switch_to_live_mode(self) -> bool:
|
||||
"""Switch to live data mode."""
|
||||
try:
|
||||
if self.mode == 'live':
|
||||
logger.info("Already in live mode")
|
||||
return True
|
||||
|
||||
# Stop replay session if active
|
||||
if self.current_replay_session:
|
||||
await self.replay_manager.stop_replay(self.current_replay_session)
|
||||
self.current_replay_session = None
|
||||
|
||||
# Start live connectors
|
||||
for name, connector in self.connectors.items():
|
||||
if not connector.is_connected:
|
||||
await connector.connect()
|
||||
|
||||
self.mode = 'live'
|
||||
logger.info("Switched to live data mode")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error switching to live mode: {e}")
|
||||
return False
|
||||
|
||||
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
|
||||
speed: float = 1.0, symbols: List[str] = None) -> bool:
|
||||
"""Switch to replay data mode."""
|
||||
try:
|
||||
if self.mode == 'replay' and self.current_replay_session:
|
||||
await self.replay_manager.stop_replay(self.current_replay_session)
|
||||
|
||||
# Create replay session
|
||||
session_id = self.replay_manager.create_replay_session(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
speed=speed,
|
||||
symbols=symbols or self.config.exchanges.symbols
|
||||
)
|
||||
|
||||
# Add data callback for replay
|
||||
self.replay_manager.add_data_callback(session_id, self._handle_replay_data)
|
||||
|
||||
# Start replay
|
||||
await self.replay_manager.start_replay(session_id)
|
||||
|
||||
self.current_replay_session = session_id
|
||||
self.mode = 'replay'
|
||||
|
||||
logger.info(f"Switched to replay mode: {start_time} to {end_time}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error switching to replay mode: {e}")
|
||||
return False
|
||||
|
||||
def get_current_mode(self) -> str:
|
||||
"""Get current data mode (live or replay)."""
|
||||
return self.mode
|
||||
|
||||
def get_replay_status(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get current replay session status."""
|
||||
if not self.current_replay_session:
|
||||
return None
|
||||
|
||||
session = self.replay_manager.get_replay_status(self.current_replay_session)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
return {
|
||||
'session_id': session.session_id,
|
||||
'status': session.status.value,
|
||||
'progress': session.progress,
|
||||
'current_time': session.current_time.isoformat(),
|
||||
'speed': session.speed,
|
||||
'events_replayed': session.events_replayed,
|
||||
'total_events': session.total_events
|
||||
}
|
||||
|
||||
# === DATA QUALITY AND METADATA ===
|
||||
|
||||
def get_data_quality_indicators(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get data quality indicators for a symbol."""
|
||||
try:
|
||||
# Get recent data statistics
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(minutes=5)
|
||||
|
||||
orderbook_data = asyncio.run(self.storage_manager.get_historical_data(
|
||||
symbol, start_time, end_time, 'orderbook'
|
||||
))
|
||||
|
||||
trade_data = asyncio.run(self.storage_manager.get_historical_data(
|
||||
symbol, start_time, end_time, 'trades'
|
||||
))
|
||||
|
||||
# Calculate quality metrics
|
||||
quality = {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'orderbook_updates': len(orderbook_data) if orderbook_data else 0,
|
||||
'trade_events': len(trade_data) if trade_data else 0,
|
||||
'data_freshness_seconds': 0,
|
||||
'exchange_coverage': [],
|
||||
'quality_score': 0.0
|
||||
}
|
||||
|
||||
# Calculate data freshness
|
||||
if orderbook_data:
|
||||
latest_timestamp = max(item['timestamp'] for item in orderbook_data)
|
||||
quality['data_freshness_seconds'] = (
|
||||
datetime.utcnow() - latest_timestamp
|
||||
).total_seconds()
|
||||
|
||||
# Get exchange coverage
|
||||
if orderbook_data:
|
||||
exchanges = set(item['exchange'] for item in orderbook_data)
|
||||
quality['exchange_coverage'] = list(exchanges)
|
||||
|
||||
# Calculate quality score (0-1)
|
||||
score = 0.0
|
||||
if quality['orderbook_updates'] > 0:
|
||||
score += 0.4
|
||||
if quality['trade_events'] > 0:
|
||||
score += 0.3
|
||||
if quality['data_freshness_seconds'] < 10:
|
||||
score += 0.3
|
||||
|
||||
quality['quality_score'] = score
|
||||
|
||||
return quality
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting data quality for {symbol}: {e}")
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'quality_score': 0.0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_system_metadata(self) -> Dict[str, Any]:
|
||||
"""Get system metadata and status."""
|
||||
try:
|
||||
return {
|
||||
'system': 'COBY',
|
||||
'version': '1.0.0',
|
||||
'mode': self.mode,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'components': {
|
||||
'storage': self.storage_manager.is_healthy(),
|
||||
'redis': True, # Simplified check
|
||||
'connectors': {
|
||||
name: connector.is_connected
|
||||
for name, connector in self.connectors.items()
|
||||
}
|
||||
},
|
||||
'statistics': self.stats,
|
||||
'replay_session': self.get_replay_status(),
|
||||
'active_subscribers': sum(
|
||||
len(subs) for subs in self.subscribers.values()
|
||||
)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system metadata: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
# === DATA HANDLERS ===
|
||||
|
||||
async def _handle_connector_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> None:
|
||||
"""Handle data from exchange connectors."""
|
||||
try:
|
||||
# Store data
|
||||
if isinstance(data, OrderBookSnapshot):
|
||||
await self.storage_manager.store_orderbook(data)
|
||||
self.stats['orderbooks_processed'] += 1
|
||||
|
||||
# Create market tick for subscribers
|
||||
if data.bids and data.asks:
|
||||
best_bid = max(data.bids, key=lambda x: x.price)
|
||||
best_ask = min(data.asks, key=lambda x: x.price)
|
||||
mid_price = (best_bid.price + best_ask.price) / 2
|
||||
|
||||
tick = MarketTick(
|
||||
symbol=data.symbol,
|
||||
price=mid_price,
|
||||
volume=best_bid.size + best_ask.size,
|
||||
timestamp=data.timestamp,
|
||||
exchange=data.exchange
|
||||
)
|
||||
|
||||
await self._notify_tick_subscribers(tick)
|
||||
|
||||
# Create COB data for subscribers
|
||||
cob_data = {
|
||||
'symbol': data.symbol,
|
||||
'timestamp': data.timestamp.isoformat(),
|
||||
'bids': [{'price': b.price, 'size': b.size} for b in data.bids[:10]],
|
||||
'asks': [{'price': a.price, 'size': a.size} for a in data.asks[:10]],
|
||||
'exchange': data.exchange
|
||||
}
|
||||
|
||||
await self._notify_cob_raw_subscribers(data.symbol, cob_data)
|
||||
|
||||
elif isinstance(data, TradeEvent):
|
||||
await self.storage_manager.store_trade(data)
|
||||
self.stats['ticks_processed'] += 1
|
||||
|
||||
# Create market tick
|
||||
tick = MarketTick(
|
||||
symbol=data.symbol,
|
||||
price=data.price,
|
||||
volume=data.size,
|
||||
timestamp=data.timestamp,
|
||||
side=data.side,
|
||||
exchange=data.exchange
|
||||
)
|
||||
|
||||
await self._notify_tick_subscribers(tick)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling connector data: {e}")
|
||||
|
||||
async def _handle_replay_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> None:
|
||||
"""Handle data from replay system."""
|
||||
try:
|
||||
# Process replay data same as live data
|
||||
await self._handle_connector_data(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling replay data: {e}")
|
||||
|
||||
async def _notify_tick_subscribers(self, tick: MarketTick) -> None:
|
||||
"""Notify tick subscribers."""
|
||||
try:
|
||||
with self.subscriber_lock:
|
||||
subscribers = self.subscribers['ticks'].copy()
|
||||
|
||||
for subscriber_id, sub_info in subscribers.items():
|
||||
try:
|
||||
callback = sub_info['callback']
|
||||
symbols = sub_info['symbols']
|
||||
|
||||
# Check if subscriber wants this symbol
|
||||
if not symbols or tick.symbol in symbols:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(tick)
|
||||
else:
|
||||
callback(tick)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying tick subscriber {subscriber_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying tick subscribers: {e}")
|
||||
|
||||
async def _notify_cob_raw_subscribers(self, symbol: str, data: Dict) -> None:
|
||||
"""Notify COB raw tick subscribers."""
|
||||
try:
|
||||
with self.subscriber_lock:
|
||||
subscribers = self.subscribers['cob_raw'].copy()
|
||||
|
||||
for subscriber_id, sub_info in subscribers.items():
|
||||
try:
|
||||
callback = sub_info['callback']
|
||||
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(symbol, data)
|
||||
else:
|
||||
callback(symbol, data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying COB raw subscriber {subscriber_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying COB raw subscribers: {e}")
|
||||
|
||||
# === UTILITY METHODS ===
|
||||
|
||||
def _parse_timeframe(self, timeframe: str) -> Optional[int]:
|
||||
"""Parse timeframe string to minutes."""
|
||||
try:
|
||||
if timeframe.endswith('m'):
|
||||
return int(timeframe[:-1])
|
||||
elif timeframe.endswith('h'):
|
||||
return int(timeframe[:-1]) * 60
|
||||
elif timeframe.endswith('d'):
|
||||
return int(timeframe[:-1]) * 24 * 60
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def start_centralized_data_collection(self) -> None:
|
||||
"""Start centralized data collection (compatibility method)."""
|
||||
logger.info("Centralized data collection started (COBY mode)")
|
||||
|
||||
def start_training_data_collection(self) -> None:
|
||||
"""Start training data collection (compatibility method)."""
|
||||
logger.info("Training data collection started (COBY mode)")
|
||||
|
||||
def invalidate_ohlcv_cache(self, symbol: str) -> None:
|
||||
"""Invalidate OHLCV cache for a symbol."""
|
||||
try:
|
||||
# Clear Redis cache for this symbol
|
||||
cache_pattern = f"ohlcv:{symbol}:*"
|
||||
asyncio.run(self.redis_manager.delete_pattern(cache_pattern))
|
||||
|
||||
# Clear local price cache
|
||||
if symbol in self.price_cache:
|
||||
del self.price_cache[symbol]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating cache for {symbol}: {e}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all connections and cleanup."""
|
||||
try:
|
||||
# Stop replay session
|
||||
if self.current_replay_session:
|
||||
await self.replay_manager.stop_replay(self.current_replay_session)
|
||||
|
||||
# Close connectors
|
||||
for connector in self.connectors.values():
|
||||
await connector.disconnect()
|
||||
|
||||
# Close storage
|
||||
await self.storage_manager.close()
|
||||
|
||||
# Close Redis
|
||||
await self.redis_manager.close()
|
||||
|
||||
logger.info("COBY orchestrator adapter closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing adapter: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get adapter statistics."""
|
||||
return {
|
||||
**self.stats,
|
||||
'mode': self.mode,
|
||||
'active_subscribers': sum(len(subs) for subs in self.subscribers.values()),
|
||||
'cache_size': len(self.price_cache),
|
||||
'replay_session': self.current_replay_session
|
||||
}
|
385
COBY/tests/test_orchestrator_integration.py
Normal file
385
COBY/tests/test_orchestrator_integration.py
Normal file
@ -0,0 +1,385 @@
|
||||
"""
|
||||
Integration tests for COBY orchestrator compatibility.
|
||||
Tests the adapter's compatibility with the existing orchestrator interface.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from ..integration.orchestrator_adapter import COBYOrchestratorAdapter, MarketTick
|
||||
from ..integration.data_provider_replacement import COBYDataProvider
|
||||
from ..config import Config
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestOrchestratorIntegration:
|
||||
"""Test suite for orchestrator integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def adapter(self):
|
||||
"""Create adapter instance for testing."""
|
||||
config = Config()
|
||||
adapter = COBYOrchestratorAdapter(config)
|
||||
|
||||
# Mock the storage manager for testing
|
||||
adapter.storage_manager = Mock()
|
||||
adapter.storage_manager.initialize = AsyncMock()
|
||||
adapter.storage_manager.is_healthy = Mock(return_value=True)
|
||||
adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
|
||||
'symbol': 'BTCUSDT',
|
||||
'timestamp': datetime.utcnow(),
|
||||
'mid_price': 50000.0,
|
||||
'spread': 0.01,
|
||||
'bid_volume': 10.5,
|
||||
'ask_volume': 8.3,
|
||||
'exchange': 'binance'
|
||||
})
|
||||
adapter.storage_manager.get_historical_data = AsyncMock(return_value=[
|
||||
{
|
||||
'timestamp': datetime.utcnow() - timedelta(minutes=i),
|
||||
'open': 50000 + i,
|
||||
'high': 50010 + i,
|
||||
'low': 49990 + i,
|
||||
'close': 50005 + i,
|
||||
'volume': 100 + i,
|
||||
'symbol': 'BTCUSDT',
|
||||
'exchange': 'binance'
|
||||
}
|
||||
for i in range(100)
|
||||
])
|
||||
|
||||
# Mock Redis manager
|
||||
adapter.redis_manager = Mock()
|
||||
adapter.redis_manager.initialize = AsyncMock()
|
||||
adapter.redis_manager.get = AsyncMock(return_value=None)
|
||||
adapter.redis_manager.set = AsyncMock()
|
||||
|
||||
# Mock connectors
|
||||
adapter.connectors = {'binance': Mock()}
|
||||
adapter.connectors['binance'].connect = AsyncMock()
|
||||
adapter.connectors['binance'].is_connected = True
|
||||
|
||||
await adapter._initialize_components()
|
||||
return adapter
|
||||
|
||||
@pytest.fixture
|
||||
async def data_provider(self):
|
||||
"""Create data provider replacement for testing."""
|
||||
# Mock the adapter initialization
|
||||
provider = COBYDataProvider()
|
||||
|
||||
# Use the same mocks as adapter
|
||||
provider.adapter.storage_manager = Mock()
|
||||
provider.adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
|
||||
'symbol': 'BTCUSDT',
|
||||
'timestamp': datetime.utcnow(),
|
||||
'mid_price': 50000.0,
|
||||
'spread': 0.01,
|
||||
'bid_volume': 10.5,
|
||||
'ask_volume': 8.3,
|
||||
'exchange': 'binance'
|
||||
})
|
||||
|
||||
return provider
|
||||
|
||||
async def test_adapter_initialization(self, adapter):
|
||||
"""Test adapter initializes correctly."""
|
||||
assert adapter is not None
|
||||
assert adapter.mode == 'live'
|
||||
assert adapter.config is not None
|
||||
assert 'binance' in adapter.connectors
|
||||
|
||||
async def test_get_current_price(self, adapter):
|
||||
"""Test getting current price."""
|
||||
price = adapter.get_current_price('BTCUSDT')
|
||||
assert price == 50000.0
|
||||
|
||||
async def test_get_historical_data(self, adapter):
|
||||
"""Test getting historical data."""
|
||||
df = adapter.get_historical_data('BTCUSDT', '1m', limit=50)
|
||||
|
||||
assert df is not None
|
||||
assert len(df) == 100 # Mock returns 100 records
|
||||
assert 'open' in df.columns
|
||||
assert 'high' in df.columns
|
||||
assert 'low' in df.columns
|
||||
assert 'close' in df.columns
|
||||
assert 'volume' in df.columns
|
||||
|
||||
async def test_build_base_data_input(self, adapter):
|
||||
"""Test building base data input."""
|
||||
base_data = adapter.build_base_data_input('BTCUSDT')
|
||||
|
||||
assert base_data is not None
|
||||
assert hasattr(base_data, 'get_feature_vector')
|
||||
|
||||
features = base_data.get_feature_vector()
|
||||
assert isinstance(features, type(features)) # numpy array
|
||||
assert len(features) == 100 # Expected feature size
|
||||
|
||||
async def test_cob_data_methods(self, adapter):
|
||||
"""Test COB data access methods."""
|
||||
# Mock COB data
|
||||
adapter.storage_manager.get_historical_data = AsyncMock(return_value=[
|
||||
{
|
||||
'symbol': 'BTCUSDT',
|
||||
'timestamp': datetime.utcnow(),
|
||||
'mid_price': 50000.0,
|
||||
'spread': 0.01,
|
||||
'bid_volume': 10.5,
|
||||
'ask_volume': 8.3,
|
||||
'exchange': 'binance'
|
||||
}
|
||||
])
|
||||
|
||||
# Test raw ticks
|
||||
raw_ticks = adapter.get_cob_raw_ticks('BTCUSDT', count=10)
|
||||
assert isinstance(raw_ticks, list)
|
||||
|
||||
# Test latest COB data
|
||||
latest_cob = adapter.get_latest_cob_data('BTCUSDT')
|
||||
assert latest_cob is not None
|
||||
assert latest_cob['symbol'] == 'BTCUSDT'
|
||||
assert 'mid_price' in latest_cob
|
||||
|
||||
async def test_subscription_management(self, adapter):
|
||||
"""Test subscription methods."""
|
||||
callback_called = False
|
||||
received_tick = None
|
||||
|
||||
def tick_callback(tick):
|
||||
nonlocal callback_called, received_tick
|
||||
callback_called = True
|
||||
received_tick = tick
|
||||
|
||||
# Subscribe to ticks
|
||||
subscriber_id = adapter.subscribe_to_ticks(
|
||||
tick_callback,
|
||||
symbols=['BTCUSDT'],
|
||||
subscriber_name='test_subscriber'
|
||||
)
|
||||
|
||||
assert subscriber_id != ""
|
||||
assert len(adapter.subscribers['ticks']) == 1
|
||||
|
||||
# Simulate tick notification
|
||||
test_tick = MarketTick(
|
||||
symbol='BTCUSDT',
|
||||
price=50000.0,
|
||||
volume=1.5,
|
||||
timestamp=datetime.utcnow(),
|
||||
exchange='binance'
|
||||
)
|
||||
|
||||
await adapter._notify_tick_subscribers(test_tick)
|
||||
|
||||
assert callback_called
|
||||
assert received_tick is not None
|
||||
assert received_tick.symbol == 'BTCUSDT'
|
||||
|
||||
# Unsubscribe
|
||||
success = adapter.unsubscribe(subscriber_id)
|
||||
assert success
|
||||
assert len(adapter.subscribers['ticks']) == 0
|
||||
|
||||
async def test_mode_switching(self, adapter):
|
||||
"""Test switching between live and replay modes."""
|
||||
# Initially in live mode
|
||||
assert adapter.get_current_mode() == 'live'
|
||||
|
||||
# Mock replay manager
|
||||
adapter.replay_manager = Mock()
|
||||
adapter.replay_manager.create_replay_session = Mock(return_value='test_session_123')
|
||||
adapter.replay_manager.add_data_callback = Mock()
|
||||
adapter.replay_manager.start_replay = AsyncMock()
|
||||
adapter.replay_manager.stop_replay = AsyncMock()
|
||||
|
||||
# Switch to replay mode
|
||||
start_time = datetime.utcnow() - timedelta(hours=1)
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
success = await adapter.switch_to_replay_mode(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
speed=2.0,
|
||||
symbols=['BTCUSDT']
|
||||
)
|
||||
|
||||
assert success
|
||||
assert adapter.get_current_mode() == 'replay'
|
||||
assert adapter.current_replay_session == 'test_session_123'
|
||||
|
||||
# Switch back to live mode
|
||||
success = await adapter.switch_to_live_mode()
|
||||
assert success
|
||||
assert adapter.get_current_mode() == 'live'
|
||||
assert adapter.current_replay_session is None
|
||||
|
||||
async def test_data_quality_indicators(self, adapter):
|
||||
"""Test data quality indicators."""
|
||||
quality = adapter.get_data_quality_indicators('BTCUSDT')
|
||||
|
||||
assert quality is not None
|
||||
assert quality['symbol'] == 'BTCUSDT'
|
||||
assert 'quality_score' in quality
|
||||
assert 'timestamp' in quality
|
||||
assert isinstance(quality['quality_score'], float)
|
||||
assert 0.0 <= quality['quality_score'] <= 1.0
|
||||
|
||||
async def test_system_metadata(self, adapter):
|
||||
"""Test system metadata retrieval."""
|
||||
metadata = adapter.get_system_metadata()
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata['system'] == 'COBY'
|
||||
assert metadata['version'] == '1.0.0'
|
||||
assert 'mode' in metadata
|
||||
assert 'components' in metadata
|
||||
assert 'statistics' in metadata
|
||||
|
||||
async def test_data_provider_compatibility(self, data_provider):
|
||||
"""Test data provider replacement compatibility."""
|
||||
# Test core methods exist and work
|
||||
assert hasattr(data_provider, 'get_historical_data')
|
||||
assert hasattr(data_provider, 'get_current_price')
|
||||
assert hasattr(data_provider, 'build_base_data_input')
|
||||
assert hasattr(data_provider, 'subscribe_to_ticks')
|
||||
assert hasattr(data_provider, 'get_cob_raw_ticks')
|
||||
|
||||
# Test current price
|
||||
price = data_provider.get_current_price('BTCUSDT')
|
||||
assert price == 50000.0
|
||||
|
||||
# Test COB imbalance
|
||||
imbalance = data_provider.get_current_cob_imbalance('BTCUSDT')
|
||||
assert 'bid_volume' in imbalance
|
||||
assert 'ask_volume' in imbalance
|
||||
assert 'imbalance' in imbalance
|
||||
|
||||
# Test WebSocket status
|
||||
status = data_provider.get_cob_websocket_status()
|
||||
assert 'connected' in status
|
||||
assert 'exchanges' in status
|
||||
|
||||
async def test_error_handling(self, adapter):
|
||||
"""Test error handling in various scenarios."""
|
||||
# Test with invalid symbol
|
||||
price = adapter.get_current_price('INVALID_SYMBOL')
|
||||
# Should not raise exception, may return None
|
||||
|
||||
# Test with storage error
|
||||
adapter.storage_manager.get_latest_orderbook = AsyncMock(side_effect=Exception("Storage error"))
|
||||
|
||||
price = adapter.get_current_price('BTCUSDT')
|
||||
# Should handle error gracefully
|
||||
|
||||
# Test subscription with invalid callback
|
||||
subscriber_id = adapter.subscribe_to_ticks(None, ['BTCUSDT'])
|
||||
# Should handle gracefully
|
||||
|
||||
async def test_performance_metrics(self, adapter):
|
||||
"""Test performance metrics and statistics."""
|
||||
# Get initial stats
|
||||
initial_stats = adapter.get_stats()
|
||||
assert 'ticks_processed' in initial_stats
|
||||
assert 'orderbooks_processed' in initial_stats
|
||||
|
||||
# Simulate some data processing
|
||||
from ..models.core import OrderBookSnapshot, PriceLevel
|
||||
|
||||
test_orderbook = OrderBookSnapshot(
|
||||
symbol='BTCUSDT',
|
||||
exchange='binance',
|
||||
timestamp=datetime.utcnow(),
|
||||
bids=[PriceLevel(price=49999.0, size=1.5)],
|
||||
asks=[PriceLevel(price=50001.0, size=1.2)]
|
||||
)
|
||||
|
||||
await adapter._handle_connector_data(test_orderbook)
|
||||
|
||||
# Check stats updated
|
||||
updated_stats = adapter.get_stats()
|
||||
assert updated_stats['orderbooks_processed'] >= initial_stats['orderbooks_processed']
|
||||
|
||||
|
||||
async def test_integration_suite():
|
||||
"""Run the complete integration test suite."""
|
||||
logger.info("Starting COBY orchestrator integration tests...")
|
||||
|
||||
try:
|
||||
# Create test instances
|
||||
config = Config()
|
||||
adapter = COBYOrchestratorAdapter(config)
|
||||
|
||||
# Mock components for testing
|
||||
adapter.storage_manager = Mock()
|
||||
adapter.storage_manager.initialize = AsyncMock()
|
||||
adapter.storage_manager.is_healthy = Mock(return_value=True)
|
||||
adapter.redis_manager = Mock()
|
||||
adapter.redis_manager.initialize = AsyncMock()
|
||||
adapter.connectors = {'binance': Mock()}
|
||||
adapter.connectors['binance'].connect = AsyncMock()
|
||||
|
||||
await adapter._initialize_components()
|
||||
|
||||
# Run basic functionality tests
|
||||
logger.info("Testing basic functionality...")
|
||||
|
||||
# Test price retrieval
|
||||
adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
|
||||
'symbol': 'BTCUSDT',
|
||||
'timestamp': datetime.utcnow(),
|
||||
'mid_price': 50000.0,
|
||||
'spread': 0.01,
|
||||
'bid_volume': 10.5,
|
||||
'ask_volume': 8.3,
|
||||
'exchange': 'binance'
|
||||
})
|
||||
|
||||
price = adapter.get_current_price('BTCUSDT')
|
||||
assert price == 50000.0
|
||||
logger.info(f"✓ Current price retrieval: {price}")
|
||||
|
||||
# Test subscription
|
||||
callback_called = False
|
||||
|
||||
def test_callback(tick):
|
||||
nonlocal callback_called
|
||||
callback_called = True
|
||||
|
||||
subscriber_id = adapter.subscribe_to_ticks(test_callback, ['BTCUSDT'])
|
||||
assert subscriber_id != ""
|
||||
logger.info(f"✓ Subscription created: {subscriber_id}")
|
||||
|
||||
# Test data quality
|
||||
quality = adapter.get_data_quality_indicators('BTCUSDT')
|
||||
assert quality['symbol'] == 'BTCUSDT'
|
||||
logger.info(f"✓ Data quality check: {quality['quality_score']}")
|
||||
|
||||
# Test system metadata
|
||||
metadata = adapter.get_system_metadata()
|
||||
assert metadata['system'] == 'COBY'
|
||||
logger.info(f"✓ System metadata: {metadata['mode']}")
|
||||
|
||||
logger.info("All integration tests passed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Integration test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the integration tests
|
||||
success = asyncio.run(test_integration_suite())
|
||||
if success:
|
||||
print("✓ COBY orchestrator integration tests completed successfully")
|
||||
else:
|
||||
print("✗ COBY orchestrator integration tests failed")
|
||||
exit(1)
|
Reference in New Issue
Block a user