Files
gogo2/core/standardized_data_provider.py
2025-07-28 08:35:08 +03:00

671 lines
28 KiB
Python

"""
Standardized Data Provider Extension
This module extends the existing DataProvider with standardized BaseDataInput functionality
for all models in the multi-modal trading system.
"""
import logging
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from collections import deque
from threading import Lock
from .data_provider import DataProvider
from .data_models import BaseDataInput, OHLCVBar, COBData, ModelOutput, PivotPoint
from .multi_exchange_cob_provider import MultiExchangeCOBProvider
from .model_output_manager import ModelOutputManager
logger = logging.getLogger(__name__)
class StandardizedDataProvider(DataProvider):
"""
Extended DataProvider with standardized BaseDataInput support
Provides unified data format for all models:
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
"""
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
"""Initialize the standardized data provider"""
super().__init__(symbols, timeframes)
# Standardized data storage
self.base_data_cache: Dict[str, BaseDataInput] = {} # {symbol: BaseDataInput}
self.cob_data_cache: Dict[str, COBData] = {} # {symbol: COBData}
# Model output management with extensible storage
self.model_output_manager = ModelOutputManager(
cache_dir=str(self.cache_dir / "model_outputs"),
max_history=1000
)
# COB moving averages calculation
self.cob_imbalance_history: Dict[str, deque] = {} # {symbol: deque of (timestamp, imbalance_data)}
self.ma_calculation_lock = Lock()
# Initialize caches for each symbol
for symbol in self.symbols:
self.base_data_cache[symbol] = None
self.cob_data_cache[symbol] = None
self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data
# Ensure live price cache exists (in case parent didn't initialize it)
if not hasattr(self, 'live_price_cache'):
self.live_price_cache: Dict[str, Tuple[float, datetime]] = {}
if not hasattr(self, 'live_price_cache_ttl'):
from datetime import timedelta
self.live_price_cache_ttl = timedelta(milliseconds=500)
# Initialize WebSocket cache for dashboard compatibility
if not hasattr(self, 'ws_price_cache'):
self.ws_price_cache: Dict[str, float] = {}
# Initialize orchestrator reference (for dashboard compatibility)
self.orchestrator = None
# COB provider integration
self.cob_provider: Optional[MultiExchangeCOBProvider] = None
self._initialize_cob_provider()
logger.info("StandardizedDataProvider initialized with BaseDataInput support")
def _initialize_cob_provider(self):
"""Initialize COB provider for order book data"""
try:
from .multi_exchange_cob_provider import MultiExchangeCOBProvider, ExchangeConfig, ExchangeType
# Configure exchanges (focusing on Binance for now)
exchange_configs = {
'binance': ExchangeConfig(
exchange_type=ExchangeType.BINANCE,
weight=1.0,
enabled=True,
websocket_url="wss://stream.binance.com:9443/ws/",
symbols_mapping={symbol: symbol.replace('/', '').lower() for symbol in self.symbols}
)
}
self.cob_provider = MultiExchangeCOBProvider(self.symbols, exchange_configs)
logger.info("COB provider initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize COB provider: {e}")
self.cob_provider = None
def get_base_data_input(self, symbol: str, timestamp: Optional[datetime] = None) -> Optional[BaseDataInput]:
"""
Get standardized BaseDataInput for a symbol
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timestamp: Optional timestamp, defaults to current time
Returns:
BaseDataInput: Standardized input data for models, or None if insufficient data
"""
if timestamp is None:
timestamp = datetime.now()
try:
# Get OHLCV data for all timeframes
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
# Get BTC reference data
btc_symbol = 'BTC/USDT'
btc_ohlcv_1s = self._get_ohlcv_bars(btc_symbol, '1s', 300)
# Check if we have sufficient data
if not all([ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
logger.warning(f"Insufficient OHLCV data for {symbol}")
return None
if any(len(data) < 100 for data in [ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
logger.warning(f"Insufficient data frames for {symbol}")
return None
# Get COB data
cob_data = self._get_cob_data(symbol, timestamp)
# Get technical indicators
technical_indicators = self._get_technical_indicators(symbol)
# Get pivot points
pivot_points = self._get_pivot_points(symbol)
# Get last predictions from all models
last_predictions = self.model_output_manager.get_all_current_outputs(symbol)
# Create BaseDataInput
base_input = BaseDataInput(
symbol=symbol,
timestamp=timestamp,
ohlcv_1s=ohlcv_1s,
ohlcv_1m=ohlcv_1m,
ohlcv_1h=ohlcv_1h,
ohlcv_1d=ohlcv_1d,
btc_ohlcv_1s=btc_ohlcv_1s,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
last_predictions=last_predictions
)
# Validate the input
if not base_input.validate():
logger.warning(f"BaseDataInput validation failed for {symbol}")
return None
# Cache the result
self.base_data_cache[symbol] = base_input
return base_input
except Exception as e:
logger.error(f"Error creating BaseDataInput for {symbol}: {e}")
return None
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List[OHLCVBar]:
"""
Get OHLCV bars for a symbol and timeframe
Args:
symbol: Trading symbol
timeframe: Timeframe ('1s', '1m', '1h', '1d')
count: Number of bars to retrieve
Returns:
List[OHLCVBar]: List of OHLCV bars
"""
try:
# Get historical data from parent class
df = self.get_historical_data(symbol, timeframe, count)
if df is None or df.empty:
return []
# Convert DataFrame to OHLCVBar objects
bars = []
for _, row in df.tail(count).iterrows():
bar = OHLCVBar(
symbol=symbol,
timestamp=row.name if hasattr(row, 'name') else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe,
indicators={}
)
# Add technical indicators if available
for col in df.columns:
if col not in ['open', 'high', 'low', 'close', 'volume']:
bar.indicators[col] = float(row[col]) if not np.isnan(row[col]) else 0.0
bars.append(bar)
return bars
except Exception as e:
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
return []
def _get_cob_data(self, symbol: str, timestamp: datetime) -> Optional[COBData]:
"""
Get COB data for a symbol
Args:
symbol: Trading symbol
timestamp: Current timestamp
Returns:
COBData: COB data with price buckets and moving averages
"""
try:
if not self.cob_provider:
return None
# Get current price
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
if current_price <= 0:
return None
# Determine bucket size based on symbol
bucket_size = 1.0 if 'ETH' in symbol else 10.0 # $1 for ETH, $10 for BTC
# Calculate price range (±20 buckets)
price_range = 20 * bucket_size
min_price = current_price - price_range
max_price = current_price + price_range
# Create price buckets
price_buckets = {}
bid_ask_imbalance = {}
volume_weighted_prices = {}
# Generate mock COB data for now (will be replaced with real COB provider data)
for i in range(-20, 21):
price = current_price + (i * bucket_size)
if price > 0:
# Mock data - replace with real COB provider data
bid_volume = max(0, 1000 - abs(i) * 50) # More volume near current price
ask_volume = max(0, 1000 - abs(i) * 50)
total_volume = bid_volume + ask_volume
imbalance = (bid_volume - ask_volume) / max(total_volume, 1)
price_buckets[price] = {
'bid_volume': bid_volume,
'ask_volume': ask_volume,
'total_volume': total_volume,
'imbalance': imbalance
}
bid_ask_imbalance[price] = imbalance
volume_weighted_prices[price] = price # Simplified VWAP
# Calculate moving averages of imbalance for ±5 buckets
ma_data = self._calculate_cob_moving_averages(symbol, bid_ask_imbalance, timestamp)
cob_data = COBData(
symbol=symbol,
timestamp=timestamp,
current_price=current_price,
bucket_size=bucket_size,
price_buckets=price_buckets,
bid_ask_imbalance=bid_ask_imbalance,
volume_weighted_prices=volume_weighted_prices,
order_flow_metrics={},
ma_1s_imbalance=ma_data.get('1s', {}),
ma_5s_imbalance=ma_data.get('5s', {}),
ma_15s_imbalance=ma_data.get('15s', {}),
ma_60s_imbalance=ma_data.get('60s', {})
)
# Cache the COB data
self.cob_data_cache[symbol] = cob_data
return cob_data
except Exception as e:
logger.error(f"Error getting COB data for {symbol}: {e}")
return None
def _calculate_cob_moving_averages(self, symbol: str, bid_ask_imbalance: Dict[float, float],
timestamp: datetime) -> Dict[str, Dict[float, float]]:
"""
Calculate moving averages of COB imbalance for ±5 buckets
Args:
symbol: Trading symbol
bid_ask_imbalance: Current bid/ask imbalance data
timestamp: Current timestamp
Returns:
Dict containing MA data for different timeframes
"""
try:
with self.ma_calculation_lock:
# Add current imbalance data to history
self.cob_imbalance_history[symbol].append((timestamp, bid_ask_imbalance))
# Calculate MAs for different timeframes
ma_results = {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
# Get current price for ±5 bucket calculation
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
if current_price <= 0:
return ma_results
bucket_size = 1.0 if 'ETH' in symbol else 10.0
# Calculate MAs for ±5 buckets around current price
for i in range(-5, 6):
price = current_price + (i * bucket_size)
if price <= 0:
continue
# Get historical imbalance data for this price bucket
history = self.cob_imbalance_history[symbol]
# Calculate different MA periods
for period, period_name in [(1, '1s'), (5, '5s'), (15, '15s'), (60, '60s')]:
recent_data = []
cutoff_time = timestamp - timedelta(seconds=period)
for hist_timestamp, hist_imbalance in history:
if hist_timestamp >= cutoff_time and price in hist_imbalance:
recent_data.append(hist_imbalance[price])
# Calculate moving average
if recent_data:
ma_results[period_name][price] = sum(recent_data) / len(recent_data)
else:
ma_results[period_name][price] = 0.0
return ma_results
except Exception as e:
logger.error(f"Error calculating COB moving averages for {symbol}: {e}")
return {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators for a symbol"""
try:
# Get latest OHLCV data
df = self.get_historical_data(symbol, '1h', 100) # Use 1h for indicators
if df is None or df.empty:
return {}
indicators = {}
# Add basic indicators if available in the dataframe
latest_row = df.iloc[-1]
for col in df.columns:
if col not in ['open', 'high', 'low', 'close', 'volume']:
indicators[col] = float(latest_row[col]) if not np.isnan(latest_row[col]) else 0.0
return indicators
except Exception as e:
logger.error(f"Error getting technical indicators for {symbol}: {e}")
return {}
def _get_pivot_points(self, symbol: str) -> List[PivotPoint]:
"""Get pivot points for a symbol"""
try:
pivot_points = []
# Get pivot points from Williams Market Structure if available
if symbol in self.williams_structure:
williams = self.williams_structure[symbol]
# This would need to be implemented based on the actual Williams structure
# For now, return empty list
pass
return pivot_points
except Exception as e:
logger.error(f"Error getting pivot points for {symbol}: {e}")
return []
def store_model_output(self, model_output: ModelOutput):
"""
Store model output for cross-model feeding using ModelOutputManager
Args:
model_output: ModelOutput from any model
"""
try:
success = self.model_output_manager.store_output(model_output)
if success:
logger.debug(f"Stored model output from {model_output.model_name} for {model_output.symbol}")
else:
logger.warning(f"Failed to store model output from {model_output.model_name}")
except Exception as e:
logger.error(f"Error storing model output: {e}")
def get_model_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all model outputs for a symbol using ModelOutputManager
Args:
symbol: Trading symbol
Returns:
Dict[str, ModelOutput]: Dictionary of model outputs by model name
"""
return self.model_output_manager.get_all_current_outputs(symbol)
def get_model_output_manager(self) -> ModelOutputManager:
"""
Get the model output manager for advanced operations
Returns:
ModelOutputManager: The model output manager instance
"""
return self.model_output_manager
def start_real_time_processing(self):
"""Start real-time processing for standardized data"""
try:
# Start parent class real-time processing
if hasattr(super(), 'start_real_time_processing'):
super().start_real_time_processing()
# Start COB provider if available
if self.cob_provider:
import asyncio
asyncio.create_task(self.cob_provider.start_streaming())
logger.info("Started real-time processing for standardized data")
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
def stop_real_time_processing(self):
"""Stop real-time processing"""
try:
# Stop COB provider if available
if self.cob_provider:
import asyncio
asyncio.create_task(self.cob_provider.stop_streaming())
# Stop parent class processing
if hasattr(super(), 'stop_real_time_processing'):
super().stop_real_time_processing()
logger.info("Stopped real-time processing for standardized data")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
"""
Get recent prices for a symbol
Args:
symbol: Trading symbol
limit: Number of recent prices to return
Returns:
List[float]: List of recent prices
"""
try:
# Get recent OHLCV data using parent class method
df = self.get_historical_data(symbol, '1m', limit)
if df is None or df.empty:
return []
# Extract close prices from DataFrame
if 'close' in df.columns:
prices = df['close'].tolist()
return prices[-limit:] # Return most recent prices
else:
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
return []
except Exception as e:
logger.error(f"Error getting recent prices for {symbol}: {e}")
return []
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
"""ROBUST live price fetching with comprehensive fallbacks"""
try:
# 1. Check cache first to avoid excessive API calls
if symbol in self.live_price_cache:
price, timestamp = self.live_price_cache[symbol]
if datetime.now() - timestamp < self.live_price_cache_ttl:
logger.debug(f"Using cached price for {symbol}: ${price:.2f}")
return price
# 2. Try direct Binance API call
try:
import requests
binance_symbol = symbol.replace('/', '')
url = f"https://api.binance.com/api/v3/ticker/price?symbol={binance_symbol}"
response = requests.get(url, timeout=0.5) # Use a short timeout for low latency
response.raise_for_status()
data = response.json()
price = float(data['price'])
# Update cache and current prices
self.live_price_cache[symbol] = (price, datetime.now())
self.current_prices[symbol] = price
logger.info(f"LIVE PRICE for {symbol}: ${price:.2f}")
return price
except requests.exceptions.RequestException as e:
logger.warning(f"Failed to get live price for {symbol} from API: {e}")
except Exception as e:
logger.warning(f"Unexpected error in API call for {symbol}: {e}")
# 3. Fallback to current prices from parent
if hasattr(self, 'current_prices') and symbol in self.current_prices:
price = self.current_prices[symbol]
if price and price > 0:
logger.debug(f"Using current price for {symbol}: ${price:.2f}")
return price
# 4. Try parent's get_current_price method
if hasattr(self, 'get_current_price'):
try:
price = self.get_current_price(symbol)
if price and price > 0:
self.current_prices[symbol] = price
logger.debug(f"Got current price for {symbol} from parent: ${price:.2f}")
return price
except Exception as e:
logger.debug(f"Parent get_current_price failed for {symbol}: {e}")
# 5. Try historical data from multiple timeframes
for timeframe in ['1m', '5m', '1h']: # Start with 1m for better reliability
try:
df = self.get_historical_data(symbol, timeframe, limit=1, refresh=True)
if df is not None and not df.empty:
price = float(df['close'].iloc[-1])
if price > 0:
self.current_prices[symbol] = price
logger.debug(f"Got current price for {symbol} from {timeframe}: ${price:.2f}")
return price
except Exception as tf_error:
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
continue
# 6. Try WebSocket cache if available
ws_symbol = symbol.replace('/', '')
if hasattr(self, 'ws_price_cache') and ws_symbol in self.ws_price_cache:
price = self.ws_price_cache[ws_symbol]
if price and price > 0:
logger.debug(f"Using WebSocket cache for {symbol}: ${price:.2f}")
return price
# 7. Try to get from orchestrator if available (for dashboard compatibility)
if hasattr(self, 'orchestrator') and self.orchestrator:
try:
if hasattr(self.orchestrator, 'data_provider'):
price = self.orchestrator.data_provider.get_current_price(symbol)
if price and price > 0:
self.current_prices[symbol] = price
logger.debug(f"Got current price for {symbol} from orchestrator: ${price:.2f}")
return price
except Exception as orch_error:
logger.debug(f"Failed to get price from orchestrator: {orch_error}")
# 8. Last resort: try external API with longer timeout
try:
import requests
binance_symbol = symbol.replace('/', '')
url = f"https://api.binance.com/api/v3/ticker/price?symbol={binance_symbol}"
response = requests.get(url, timeout=2) # Longer timeout for last resort
if response.status_code == 200:
data = response.json()
price = float(data['price'])
if price > 0:
self.current_prices[symbol] = price
logger.warning(f"Got current price for {symbol} from external API (last resort): ${price:.2f}")
return price
except Exception as api_error:
logger.debug(f"External API failed: {api_error}")
logger.warning(f"Could not get current price for {symbol} from any source")
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
# Return a fallback price if we have any cached data
if hasattr(self, 'current_prices') and symbol in self.current_prices and self.current_prices[symbol] > 0:
return self.current_prices[symbol]
# Return None instead of hardcoded fallbacks - let the caller handle missing data
return None
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price with robust fallbacks - enhanced version"""
try:
# 1. Try live price API first (our enhanced method)
price = self.get_live_price_from_api(symbol)
if price and price > 0:
return price
# 2. Try parent's get_current_price method
if hasattr(super(), 'get_current_price'):
try:
price = super().get_current_price(symbol)
if price and price > 0:
return price
except Exception as e:
logger.debug(f"Parent get_current_price failed for {symbol}: {e}")
# 3. Try current prices cache
if hasattr(self, 'current_prices') and symbol in self.current_prices:
price = self.current_prices[symbol]
if price and price > 0:
return price
# 4. Try historical data from multiple timeframes
for timeframe in ['1m', '5m', '1h']:
try:
df = self.get_historical_data(symbol, timeframe, limit=1, refresh=True)
if df is not None and not df.empty:
price = float(df['close'].iloc[-1])
if price > 0:
self.current_prices[symbol] = price
return price
except Exception as tf_error:
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
continue
# 5. Try WebSocket cache if available
ws_symbol = symbol.replace('/', '')
if hasattr(self, 'ws_price_cache') and ws_symbol in self.ws_price_cache:
price = self.ws_price_cache[ws_symbol]
if price and price > 0:
return price
logger.warning(f"Could not get current price for {symbol} from any source")
return None
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return None
def update_ws_price_cache(self, symbol: str, price: float):
"""Update WebSocket price cache for dashboard compatibility"""
try:
ws_symbol = symbol.replace('/', '')
self.ws_price_cache[ws_symbol] = price
# Also update current prices for consistency
self.current_prices[symbol] = price
logger.debug(f"Updated WS cache for {symbol}: ${price:.2f}")
except Exception as e:
logger.error(f"Error updating WS cache for {symbol}: {e}")
def set_orchestrator(self, orchestrator):
"""Set orchestrator reference for dashboard compatibility"""
self.orchestrator = orchestrator