orchestrator adaptor
This commit is contained in:
390
COBY/integration/data_provider_replacement.py
Normal file
390
COBY/integration/data_provider_replacement.py
Normal file
@ -0,0 +1,390 @@
|
||||
"""
|
||||
Drop-in replacement for the existing DataProvider class using COBY system.
|
||||
Provides full compatibility with the orchestrator interface.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable, Union
|
||||
from pathlib import Path
|
||||
|
||||
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
|
||||
from ..config import Config
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class COBYDataProvider:
|
||||
"""
|
||||
Drop-in replacement for DataProvider using COBY system.
|
||||
|
||||
Provides full compatibility with existing orchestrator interface while
|
||||
leveraging COBY's multi-exchange data aggregation capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize COBY data provider.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to configuration file
|
||||
"""
|
||||
# Initialize COBY configuration
|
||||
self.config = Config()
|
||||
|
||||
# Initialize COBY adapter
|
||||
self.adapter = COBYOrchestratorAdapter(self.config)
|
||||
|
||||
# Initialize adapter components
|
||||
asyncio.run(self.adapter._initialize_components())
|
||||
|
||||
# Compatibility attributes
|
||||
self.symbols = self.config.exchanges.symbols
|
||||
self.exchanges = self.config.exchanges.exchanges
|
||||
|
||||
logger.info("COBY data provider initialized")
|
||||
|
||||
# === CORE DATA METHODS ===
|
||||
|
||||
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
|
||||
refresh: bool = False) -> Optional[pd.DataFrame]:
|
||||
"""Get historical OHLCV data."""
|
||||
return self.adapter.get_historical_data(symbol, timeframe, limit, refresh)
|
||||
|
||||
def get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for a symbol."""
|
||||
return self.adapter.get_current_price(symbol)
|
||||
|
||||
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
|
||||
"""Get live price from API (low-latency method)."""
|
||||
return self.adapter.get_live_price_from_api(symbol)
|
||||
|
||||
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
||||
"""Build base data input for ML models."""
|
||||
return self.adapter.build_base_data_input(symbol)
|
||||
|
||||
# === COB DATA METHODS ===
|
||||
|
||||
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
|
||||
"""Get raw COB ticks for a symbol."""
|
||||
return self.adapter.get_cob_raw_ticks(symbol, count)
|
||||
|
||||
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
|
||||
"""Get 1s aggregated COB data with $1 price buckets."""
|
||||
return self.adapter.get_cob_1s_aggregated(symbol, count)
|
||||
|
||||
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest COB raw tick for a symbol."""
|
||||
return self.adapter.get_latest_cob_data(symbol)
|
||||
|
||||
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get latest 1s aggregated COB data for a symbol."""
|
||||
return self.adapter.get_latest_cob_aggregated(symbol)
|
||||
|
||||
def get_current_cob_imbalance(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get current COB imbalance metrics for a symbol."""
|
||||
try:
|
||||
latest_data = self.get_latest_cob_data(symbol)
|
||||
if not latest_data:
|
||||
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
||||
|
||||
bid_volume = latest_data.get('bid_volume', 0.0)
|
||||
ask_volume = latest_data.get('ask_volume', 0.0)
|
||||
total_volume = bid_volume + ask_volume
|
||||
|
||||
imbalance = 0.0
|
||||
if total_volume > 0:
|
||||
imbalance = (bid_volume - ask_volume) / total_volume
|
||||
|
||||
return {
|
||||
'bid_volume': bid_volume,
|
||||
'ask_volume': ask_volume,
|
||||
'imbalance': imbalance
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB imbalance for {symbol}: {e}")
|
||||
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
||||
|
||||
def get_cob_price_buckets(self, symbol: str, timeframe_seconds: int = 60) -> Dict:
|
||||
"""Get price bucket analysis for a timeframe."""
|
||||
try:
|
||||
# Get aggregated data for the timeframe
|
||||
count = timeframe_seconds # 1 second per data point
|
||||
aggregated_data = self.get_cob_1s_aggregated(symbol, count)
|
||||
|
||||
if not aggregated_data:
|
||||
return {}
|
||||
|
||||
# Combine all buckets
|
||||
combined_bid_buckets = {}
|
||||
combined_ask_buckets = {}
|
||||
|
||||
for data_point in aggregated_data:
|
||||
for price, volume in data_point.get('bid_buckets', {}).items():
|
||||
combined_bid_buckets[price] = combined_bid_buckets.get(price, 0) + volume
|
||||
|
||||
for price, volume in data_point.get('ask_buckets', {}).items():
|
||||
combined_ask_buckets[price] = combined_ask_buckets.get(price, 0) + volume
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timeframe_seconds': timeframe_seconds,
|
||||
'bid_buckets': combined_bid_buckets,
|
||||
'ask_buckets': combined_ask_buckets,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price buckets for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_cob_websocket_status(self) -> Dict[str, Any]:
|
||||
"""Get COB WebSocket status."""
|
||||
try:
|
||||
system_metadata = self.adapter.get_system_metadata()
|
||||
connectors = system_metadata.get('components', {}).get('connectors', {})
|
||||
|
||||
return {
|
||||
'connected': any(connectors.values()),
|
||||
'exchanges': connectors,
|
||||
'last_update': datetime.utcnow().isoformat(),
|
||||
'mode': system_metadata.get('mode', 'unknown')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting WebSocket status: {e}")
|
||||
return {'connected': False, 'error': str(e)}
|
||||
|
||||
# === SUBSCRIPTION METHODS ===
|
||||
|
||||
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
||||
symbols: List[str] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""Subscribe to tick data updates."""
|
||||
return self.adapter.subscribe_to_ticks(callback, symbols, subscriber_name)
|
||||
|
||||
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to raw COB tick updates."""
|
||||
return self.adapter.subscribe_to_cob_raw_ticks(callback)
|
||||
|
||||
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
|
||||
"""Subscribe to 1s aggregated COB updates."""
|
||||
return self.adapter.subscribe_to_cob_aggregated(callback)
|
||||
|
||||
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to training data updates."""
|
||||
return self.adapter.subscribe_to_training_data(callback)
|
||||
|
||||
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
|
||||
"""Subscribe to model prediction updates."""
|
||||
return self.adapter.subscribe_to_model_predictions(callback)
|
||||
|
||||
def unsubscribe(self, subscriber_id: str) -> bool:
|
||||
"""Unsubscribe from data feeds."""
|
||||
return self.adapter.unsubscribe(subscriber_id)
|
||||
|
||||
# === MODE SWITCHING ===
|
||||
|
||||
async def switch_to_live_mode(self) -> bool:
|
||||
"""Switch to live data mode."""
|
||||
return await self.adapter.switch_to_live_mode()
|
||||
|
||||
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
|
||||
speed: float = 1.0, symbols: List[str] = None) -> bool:
|
||||
"""Switch to replay data mode."""
|
||||
return await self.adapter.switch_to_replay_mode(start_time, end_time, speed, symbols)
|
||||
|
||||
def get_current_mode(self) -> str:
|
||||
"""Get current data mode."""
|
||||
return self.adapter.get_current_mode()
|
||||
|
||||
def get_replay_status(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get replay session status."""
|
||||
return self.adapter.get_replay_status()
|
||||
|
||||
# === COMPATIBILITY METHODS ===
|
||||
|
||||
def start_centralized_data_collection(self) -> None:
|
||||
"""Start centralized data collection."""
|
||||
self.adapter.start_centralized_data_collection()
|
||||
|
||||
def start_training_data_collection(self) -> None:
|
||||
"""Start training data collection."""
|
||||
self.adapter.start_training_data_collection()
|
||||
|
||||
def invalidate_ohlcv_cache(self, symbol: str) -> None:
|
||||
"""Invalidate OHLCV cache for a symbol."""
|
||||
self.adapter.invalidate_ohlcv_cache(symbol)
|
||||
|
||||
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
|
||||
"""Get the latest candles from cached data."""
|
||||
return self.get_historical_data(symbol, timeframe, limit) or pd.DataFrame()
|
||||
|
||||
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
|
||||
"""Get price at specific index for backtesting."""
|
||||
try:
|
||||
df = self.get_historical_data(symbol, timeframe, limit=index + 10)
|
||||
if df is not None and len(df) > index:
|
||||
return float(df.iloc[-(index + 1)]['close'])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price at index {index} for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# === PIVOT AND MARKET STRUCTURE (MOCK IMPLEMENTATIONS) ===
|
||||
|
||||
def get_pivot_bounds(self, symbol: str) -> Optional[PivotBounds]:
|
||||
"""Get pivot bounds for a symbol (mock implementation)."""
|
||||
try:
|
||||
# Get recent price data
|
||||
df = self.get_historical_data(symbol, '1m', limit=1000)
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Calculate basic pivot levels
|
||||
high_prices = df['high'].values
|
||||
low_prices = df['low'].values
|
||||
volumes = df['volume'].values
|
||||
|
||||
price_max = float(np.max(high_prices))
|
||||
price_min = float(np.min(low_prices))
|
||||
volume_max = float(np.max(volumes))
|
||||
volume_min = float(np.min(volumes))
|
||||
|
||||
# Simple support/resistance calculation
|
||||
price_range = price_max - price_min
|
||||
support_levels = [price_min + i * price_range / 10 for i in range(1, 5)]
|
||||
resistance_levels = [price_max - i * price_range / 10 for i in range(1, 5)]
|
||||
|
||||
return PivotBounds(
|
||||
symbol=symbol,
|
||||
price_max=price_max,
|
||||
price_min=price_min,
|
||||
volume_max=volume_max,
|
||||
volume_min=volume_min,
|
||||
pivot_support_levels=support_levels,
|
||||
pivot_resistance_levels=resistance_levels,
|
||||
pivot_context={'method': 'simple'},
|
||||
created_timestamp=datetime.utcnow(),
|
||||
data_period_start=df.index[0].to_pydatetime(),
|
||||
data_period_end=df.index[-1].to_pydatetime(),
|
||||
total_candles_analyzed=len(df)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot bounds for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_pivot_normalized_features(self, symbol: str, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
"""Get dataframe with pivot-normalized features."""
|
||||
try:
|
||||
pivot_bounds = self.get_pivot_bounds(symbol)
|
||||
if not pivot_bounds:
|
||||
return df
|
||||
|
||||
# Add normalized features
|
||||
df_copy = df.copy()
|
||||
price_range = pivot_bounds.get_price_range()
|
||||
|
||||
if price_range > 0:
|
||||
df_copy['normalized_close'] = (df_copy['close'] - pivot_bounds.price_min) / price_range
|
||||
df_copy['normalized_high'] = (df_copy['high'] - pivot_bounds.price_min) / price_range
|
||||
df_copy['normalized_low'] = (df_copy['low'] - pivot_bounds.price_min) / price_range
|
||||
|
||||
return df_copy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot normalized features for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
# === FEATURE EXTRACTION METHODS ===
|
||||
|
||||
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
|
||||
window_size: int = 20) -> Optional[np.ndarray]:
|
||||
"""Get feature matrix for ML models."""
|
||||
try:
|
||||
if not timeframes:
|
||||
timeframes = ['1m', '5m', '15m']
|
||||
|
||||
features = []
|
||||
|
||||
for timeframe in timeframes:
|
||||
df = self.get_historical_data(symbol, timeframe, limit=window_size + 10)
|
||||
if df is not None and len(df) >= window_size:
|
||||
# Extract basic features
|
||||
closes = df['close'].values[-window_size:]
|
||||
volumes = df['volume'].values[-window_size:]
|
||||
|
||||
# Normalize features
|
||||
close_mean = np.mean(closes)
|
||||
close_std = np.std(closes) + 1e-8
|
||||
normalized_closes = (closes - close_mean) / close_std
|
||||
|
||||
volume_mean = np.mean(volumes)
|
||||
volume_std = np.std(volumes) + 1e-8
|
||||
normalized_volumes = (volumes - volume_mean) / volume_std
|
||||
|
||||
features.extend(normalized_closes)
|
||||
features.extend(normalized_volumes)
|
||||
|
||||
if features:
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting feature matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# === SYSTEM STATUS AND STATISTICS ===
|
||||
|
||||
def get_cached_data_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of cached data."""
|
||||
try:
|
||||
system_metadata = self.adapter.get_system_metadata()
|
||||
return {
|
||||
'system': 'COBY',
|
||||
'mode': system_metadata.get('mode'),
|
||||
'statistics': system_metadata.get('statistics', {}),
|
||||
'components_healthy': system_metadata.get('components', {}),
|
||||
'active_subscribers': system_metadata.get('active_subscribers', 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cached data summary: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_cob_data_quality(self) -> Dict[str, Any]:
|
||||
"""Get COB data quality information."""
|
||||
try:
|
||||
quality_info = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
quality_info[symbol] = self.adapter.get_data_quality_indicators(symbol)
|
||||
|
||||
return quality_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB data quality: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_subscriber_stats(self) -> Dict[str, Any]:
|
||||
"""Get subscriber statistics."""
|
||||
return self.adapter.get_stats()
|
||||
|
||||
# === CLEANUP ===
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all connections and cleanup."""
|
||||
await self.adapter.close()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
try:
|
||||
asyncio.run(self.close())
|
||||
except:
|
||||
pass
|
Reference in New Issue
Block a user