390 lines
16 KiB
Python
390 lines
16 KiB
Python
"""
|
|
Drop-in replacement for the existing DataProvider class using COBY system.
|
|
Provides full compatibility with the orchestrator interface.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Callable, Union
|
|
from pathlib import Path
|
|
|
|
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
|
|
from ..config import Config
|
|
from ..utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class COBYDataProvider:
|
|
"""
|
|
Drop-in replacement for DataProvider using COBY system.
|
|
|
|
Provides full compatibility with existing orchestrator interface while
|
|
leveraging COBY's multi-exchange data aggregation capabilities.
|
|
"""
|
|
|
|
def __init__(self, config_path: Optional[str] = None):
|
|
"""
|
|
Initialize COBY data provider.
|
|
|
|
Args:
|
|
config_path: Optional path to configuration file
|
|
"""
|
|
# Initialize COBY configuration
|
|
self.config = Config()
|
|
|
|
# Initialize COBY adapter
|
|
self.adapter = COBYOrchestratorAdapter(self.config)
|
|
|
|
# Initialize adapter components
|
|
asyncio.run(self.adapter._initialize_components())
|
|
|
|
# Compatibility attributes
|
|
self.symbols = self.config.exchanges.symbols
|
|
self.exchanges = self.config.exchanges.exchanges
|
|
|
|
logger.info("COBY data provider initialized")
|
|
|
|
# === CORE DATA METHODS ===
|
|
|
|
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
|
|
refresh: bool = False) -> Optional[pd.DataFrame]:
|
|
"""Get historical OHLCV data."""
|
|
return self.adapter.get_historical_data(symbol, timeframe, limit, refresh)
|
|
|
|
def get_current_price(self, symbol: str) -> Optional[float]:
|
|
"""Get current price for a symbol."""
|
|
return self.adapter.get_current_price(symbol)
|
|
|
|
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
|
|
"""Get live price from API (low-latency method)."""
|
|
return self.adapter.get_live_price_from_api(symbol)
|
|
|
|
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
|
"""Build base data input for ML models."""
|
|
return self.adapter.build_base_data_input(symbol)
|
|
|
|
# === COB DATA METHODS ===
|
|
|
|
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
|
|
"""Get raw COB ticks for a symbol."""
|
|
return self.adapter.get_cob_raw_ticks(symbol, count)
|
|
|
|
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
|
|
"""Get 1s aggregated COB data with $1 price buckets."""
|
|
return self.adapter.get_cob_1s_aggregated(symbol, count)
|
|
|
|
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
|
|
"""Get latest COB raw tick for a symbol."""
|
|
return self.adapter.get_latest_cob_data(symbol)
|
|
|
|
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
|
|
"""Get latest 1s aggregated COB data for a symbol."""
|
|
return self.adapter.get_latest_cob_aggregated(symbol)
|
|
|
|
def get_current_cob_imbalance(self, symbol: str) -> Dict[str, float]:
|
|
"""Get current COB imbalance metrics for a symbol."""
|
|
try:
|
|
latest_data = self.get_latest_cob_data(symbol)
|
|
if not latest_data:
|
|
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
|
|
|
bid_volume = latest_data.get('bid_volume', 0.0)
|
|
ask_volume = latest_data.get('ask_volume', 0.0)
|
|
total_volume = bid_volume + ask_volume
|
|
|
|
imbalance = 0.0
|
|
if total_volume > 0:
|
|
imbalance = (bid_volume - ask_volume) / total_volume
|
|
|
|
return {
|
|
'bid_volume': bid_volume,
|
|
'ask_volume': ask_volume,
|
|
'imbalance': imbalance
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting COB imbalance for {symbol}: {e}")
|
|
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
|
|
|
|
def get_cob_price_buckets(self, symbol: str, timeframe_seconds: int = 60) -> Dict:
|
|
"""Get price bucket analysis for a timeframe."""
|
|
try:
|
|
# Get aggregated data for the timeframe
|
|
count = timeframe_seconds # 1 second per data point
|
|
aggregated_data = self.get_cob_1s_aggregated(symbol, count)
|
|
|
|
if not aggregated_data:
|
|
return {}
|
|
|
|
# Combine all buckets
|
|
combined_bid_buckets = {}
|
|
combined_ask_buckets = {}
|
|
|
|
for data_point in aggregated_data:
|
|
for price, volume in data_point.get('bid_buckets', {}).items():
|
|
combined_bid_buckets[price] = combined_bid_buckets.get(price, 0) + volume
|
|
|
|
for price, volume in data_point.get('ask_buckets', {}).items():
|
|
combined_ask_buckets[price] = combined_ask_buckets.get(price, 0) + volume
|
|
|
|
return {
|
|
'symbol': symbol,
|
|
'timeframe_seconds': timeframe_seconds,
|
|
'bid_buckets': combined_bid_buckets,
|
|
'ask_buckets': combined_ask_buckets,
|
|
'timestamp': datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting price buckets for {symbol}: {e}")
|
|
return {}
|
|
|
|
def get_cob_websocket_status(self) -> Dict[str, Any]:
|
|
"""Get COB WebSocket status."""
|
|
try:
|
|
system_metadata = self.adapter.get_system_metadata()
|
|
connectors = system_metadata.get('components', {}).get('connectors', {})
|
|
|
|
return {
|
|
'connected': any(connectors.values()),
|
|
'exchanges': connectors,
|
|
'last_update': datetime.utcnow().isoformat(),
|
|
'mode': system_metadata.get('mode', 'unknown')
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting WebSocket status: {e}")
|
|
return {'connected': False, 'error': str(e)}
|
|
|
|
# === SUBSCRIPTION METHODS ===
|
|
|
|
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
|
symbols: List[str] = None,
|
|
subscriber_name: str = None) -> str:
|
|
"""Subscribe to tick data updates."""
|
|
return self.adapter.subscribe_to_ticks(callback, symbols, subscriber_name)
|
|
|
|
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
|
|
"""Subscribe to raw COB tick updates."""
|
|
return self.adapter.subscribe_to_cob_raw_ticks(callback)
|
|
|
|
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
|
|
"""Subscribe to 1s aggregated COB updates."""
|
|
return self.adapter.subscribe_to_cob_aggregated(callback)
|
|
|
|
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
|
|
"""Subscribe to training data updates."""
|
|
return self.adapter.subscribe_to_training_data(callback)
|
|
|
|
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
|
|
"""Subscribe to model prediction updates."""
|
|
return self.adapter.subscribe_to_model_predictions(callback)
|
|
|
|
def unsubscribe(self, subscriber_id: str) -> bool:
|
|
"""Unsubscribe from data feeds."""
|
|
return self.adapter.unsubscribe(subscriber_id)
|
|
|
|
# === MODE SWITCHING ===
|
|
|
|
async def switch_to_live_mode(self) -> bool:
|
|
"""Switch to live data mode."""
|
|
return await self.adapter.switch_to_live_mode()
|
|
|
|
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
|
|
speed: float = 1.0, symbols: List[str] = None) -> bool:
|
|
"""Switch to replay data mode."""
|
|
return await self.adapter.switch_to_replay_mode(start_time, end_time, speed, symbols)
|
|
|
|
def get_current_mode(self) -> str:
|
|
"""Get current data mode."""
|
|
return self.adapter.get_current_mode()
|
|
|
|
def get_replay_status(self) -> Optional[Dict[str, Any]]:
|
|
"""Get replay session status."""
|
|
return self.adapter.get_replay_status()
|
|
|
|
# === COMPATIBILITY METHODS ===
|
|
|
|
def start_centralized_data_collection(self) -> None:
|
|
"""Start centralized data collection."""
|
|
self.adapter.start_centralized_data_collection()
|
|
|
|
def start_training_data_collection(self) -> None:
|
|
"""Start training data collection."""
|
|
self.adapter.start_training_data_collection()
|
|
|
|
def invalidate_ohlcv_cache(self, symbol: str) -> None:
|
|
"""Invalidate OHLCV cache for a symbol."""
|
|
self.adapter.invalidate_ohlcv_cache(symbol)
|
|
|
|
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
|
|
"""Get the latest candles from cached data."""
|
|
return self.get_historical_data(symbol, timeframe, limit) or pd.DataFrame()
|
|
|
|
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
|
|
"""Get price at specific index for backtesting."""
|
|
try:
|
|
df = self.get_historical_data(symbol, timeframe, limit=index + 10)
|
|
if df is not None and len(df) > index:
|
|
return float(df.iloc[-(index + 1)]['close'])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting price at index {index} for {symbol}: {e}")
|
|
return None
|
|
|
|
# === PIVOT AND MARKET STRUCTURE (MOCK IMPLEMENTATIONS) ===
|
|
|
|
def get_pivot_bounds(self, symbol: str) -> Optional[PivotBounds]:
|
|
"""Get pivot bounds for a symbol (mock implementation)."""
|
|
try:
|
|
# Get recent price data
|
|
df = self.get_historical_data(symbol, '1m', limit=1000)
|
|
if df is None or df.empty:
|
|
return None
|
|
|
|
# Calculate basic pivot levels
|
|
high_prices = df['high'].values
|
|
low_prices = df['low'].values
|
|
volumes = df['volume'].values
|
|
|
|
price_max = float(np.max(high_prices))
|
|
price_min = float(np.min(low_prices))
|
|
volume_max = float(np.max(volumes))
|
|
volume_min = float(np.min(volumes))
|
|
|
|
# Simple support/resistance calculation
|
|
price_range = price_max - price_min
|
|
support_levels = [price_min + i * price_range / 10 for i in range(1, 5)]
|
|
resistance_levels = [price_max - i * price_range / 10 for i in range(1, 5)]
|
|
|
|
return PivotBounds(
|
|
symbol=symbol,
|
|
price_max=price_max,
|
|
price_min=price_min,
|
|
volume_max=volume_max,
|
|
volume_min=volume_min,
|
|
pivot_support_levels=support_levels,
|
|
pivot_resistance_levels=resistance_levels,
|
|
pivot_context={'method': 'simple'},
|
|
created_timestamp=datetime.utcnow(),
|
|
data_period_start=df.index[0].to_pydatetime(),
|
|
data_period_end=df.index[-1].to_pydatetime(),
|
|
total_candles_analyzed=len(df)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot bounds for {symbol}: {e}")
|
|
return None
|
|
|
|
def get_pivot_normalized_features(self, symbol: str, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
|
"""Get dataframe with pivot-normalized features."""
|
|
try:
|
|
pivot_bounds = self.get_pivot_bounds(symbol)
|
|
if not pivot_bounds:
|
|
return df
|
|
|
|
# Add normalized features
|
|
df_copy = df.copy()
|
|
price_range = pivot_bounds.get_price_range()
|
|
|
|
if price_range > 0:
|
|
df_copy['normalized_close'] = (df_copy['close'] - pivot_bounds.price_min) / price_range
|
|
df_copy['normalized_high'] = (df_copy['high'] - pivot_bounds.price_min) / price_range
|
|
df_copy['normalized_low'] = (df_copy['low'] - pivot_bounds.price_min) / price_range
|
|
|
|
return df_copy
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot normalized features for {symbol}: {e}")
|
|
return df
|
|
|
|
# === FEATURE EXTRACTION METHODS ===
|
|
|
|
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
|
|
window_size: int = 20) -> Optional[np.ndarray]:
|
|
"""Get feature matrix for ML models."""
|
|
try:
|
|
if not timeframes:
|
|
timeframes = ['1m', '5m', '15m']
|
|
|
|
features = []
|
|
|
|
for timeframe in timeframes:
|
|
df = self.get_historical_data(symbol, timeframe, limit=window_size + 10)
|
|
if df is not None and len(df) >= window_size:
|
|
# Extract basic features
|
|
closes = df['close'].values[-window_size:]
|
|
volumes = df['volume'].values[-window_size:]
|
|
|
|
# Normalize features
|
|
close_mean = np.mean(closes)
|
|
close_std = np.std(closes) + 1e-8
|
|
normalized_closes = (closes - close_mean) / close_std
|
|
|
|
volume_mean = np.mean(volumes)
|
|
volume_std = np.std(volumes) + 1e-8
|
|
normalized_volumes = (volumes - volume_mean) / volume_std
|
|
|
|
features.extend(normalized_closes)
|
|
features.extend(normalized_volumes)
|
|
|
|
if features:
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting feature matrix for {symbol}: {e}")
|
|
return None
|
|
|
|
# === SYSTEM STATUS AND STATISTICS ===
|
|
|
|
def get_cached_data_summary(self) -> Dict[str, Any]:
|
|
"""Get summary of cached data."""
|
|
try:
|
|
system_metadata = self.adapter.get_system_metadata()
|
|
return {
|
|
'system': 'COBY',
|
|
'mode': system_metadata.get('mode'),
|
|
'statistics': system_metadata.get('statistics', {}),
|
|
'components_healthy': system_metadata.get('components', {}),
|
|
'active_subscribers': system_metadata.get('active_subscribers', 0)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting cached data summary: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_cob_data_quality(self) -> Dict[str, Any]:
|
|
"""Get COB data quality information."""
|
|
try:
|
|
quality_info = {}
|
|
|
|
for symbol in self.symbols:
|
|
quality_info[symbol] = self.adapter.get_data_quality_indicators(symbol)
|
|
|
|
return quality_info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting COB data quality: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_subscriber_stats(self) -> Dict[str, Any]:
|
|
"""Get subscriber statistics."""
|
|
return self.adapter.get_stats()
|
|
|
|
# === CLEANUP ===
|
|
|
|
async def close(self) -> None:
|
|
"""Close all connections and cleanup."""
|
|
await self.adapter.close()
|
|
|
|
def __del__(self):
|
|
"""Cleanup on deletion."""
|
|
try:
|
|
asyncio.run(self.close())
|
|
except:
|
|
pass |