gogo2/core/universal_data_adapter.py
2025-05-26 16:02:40 +03:00

411 lines
18 KiB
Python

"""
Universal Data Adapter for Trading Models
This adapter ensures all models receive data in our universal format:
- ETH/USDT: ticks (1s), 1m, 1h, 1d
- BTC/USDT: ticks (1s) as reference
This is the standard input format that all models must respect.
"""
import logging
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from .config import get_config
from .data_provider import DataProvider
logger = logging.getLogger(__name__)
@dataclass
class UniversalDataStream:
"""Universal data stream containing the 5 required timeseries"""
eth_ticks: np.ndarray # ETH/USDT 1s/ticks data [timestamp, open, high, low, close, volume]
eth_1m: np.ndarray # ETH/USDT 1m data
eth_1h: np.ndarray # ETH/USDT 1h data
eth_1d: np.ndarray # ETH/USDT 1d data
btc_ticks: np.ndarray # BTC/USDT 1s/ticks reference data
timestamp: datetime # Current timestamp
metadata: Dict[str, Any] # Additional metadata
class UniversalDataAdapter:
"""
Adapter that converts any data source into our universal 5-timeseries format
"""
def __init__(self, data_provider: DataProvider = None):
"""Initialize the universal data adapter"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
# Universal format configuration
self.required_symbols = ['ETH/USDT', 'BTC/USDT']
self.required_timeframes = {
'ETH/USDT': ['1s', '1m', '1h', '1d'], # Primary trading pair
'BTC/USDT': ['1s'] # Reference pair
}
# Data window sizes for each timeframe
self.window_sizes = {
'1s': 60, # Last 60 seconds of tick data
'1m': 60, # Last 60 minutes
'1h': 24, # Last 24 hours
'1d': 30 # Last 30 days
}
# Feature columns (OHLCV)
self.feature_columns = ['open', 'high', 'low', 'close', 'volume']
logger.info("Universal Data Adapter initialized")
logger.info(f"Required symbols: {self.required_symbols}")
logger.info(f"Required timeframes: {self.required_timeframes}")
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
"""
Get data in universal format for all models
Returns:
UniversalDataStream with the 5 required timeseries
"""
try:
current_time = current_time or datetime.now()
# Get ETH/USDT data for all required timeframes
eth_data = {}
for timeframe in self.required_timeframes['ETH/USDT']:
data = self._get_timeframe_data('ETH/USDT', timeframe)
if data is not None:
eth_data[timeframe] = data
else:
logger.warning(f"Failed to get ETH/USDT {timeframe} data")
return None
# Get BTC/USDT reference data
btc_data = self._get_timeframe_data('BTC/USDT', '1s')
if btc_data is None:
logger.warning("Failed to get BTC/USDT reference data")
return None
# Create universal data stream
stream = UniversalDataStream(
eth_ticks=eth_data['1s'],
eth_1m=eth_data['1m'],
eth_1h=eth_data['1h'],
eth_1d=eth_data['1d'],
btc_ticks=btc_data,
timestamp=current_time,
metadata={
'data_quality': self._assess_data_quality(eth_data, btc_data),
'market_hours': self._is_market_hours(current_time),
'data_freshness': self._calculate_data_freshness(eth_data, btc_data, current_time)
}
)
logger.debug(f"Universal data stream created with {len(stream.eth_ticks)} ETH ticks, "
f"{len(stream.eth_1m)} ETH 1m candles, {len(stream.btc_ticks)} BTC ticks")
return stream
except Exception as e:
logger.error(f"Error creating universal data stream: {e}")
return None
def _get_timeframe_data(self, symbol: str, timeframe: str) -> Optional[np.ndarray]:
"""Get data for a specific symbol and timeframe"""
try:
window_size = self.window_sizes.get(timeframe, 60)
# Get historical data from data provider
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=window_size
)
if df is None or df.empty:
logger.warning(f"No data returned for {symbol} {timeframe}")
return None
# Ensure we have the required columns
missing_cols = [col for col in self.feature_columns if col not in df.columns]
if missing_cols:
logger.warning(f"Missing columns for {symbol} {timeframe}: {missing_cols}")
return None
# Convert to numpy array with timestamp
data_array = df[self.feature_columns].values.astype(np.float32)
# Add timestamp column if available
if 'timestamp' in df.columns:
timestamps = pd.to_datetime(df['timestamp']).astype(np.int64) // 10**9 # Unix timestamp
data_with_time = np.column_stack([timestamps, data_array])
else:
# Generate timestamps if not available
end_time = datetime.now()
if timeframe == '1s':
timestamps = [(end_time - timedelta(seconds=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
elif timeframe == '1m':
timestamps = [(end_time - timedelta(minutes=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
elif timeframe == '1h':
timestamps = [(end_time - timedelta(hours=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
elif timeframe == '1d':
timestamps = [(end_time - timedelta(days=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
else:
timestamps = [end_time.timestamp()] * len(data_array)
data_with_time = np.column_stack([timestamps, data_array])
return data_with_time
except Exception as e:
logger.error(f"Error getting {symbol} {timeframe} data: {e}")
return None
def _assess_data_quality(self, eth_data: Dict[str, np.ndarray], btc_data: np.ndarray) -> Dict[str, Any]:
"""Assess the quality of the data streams"""
quality = {
'overall_score': 1.0,
'issues': []
}
try:
# Check ETH data completeness
for timeframe, data in eth_data.items():
expected_size = self.window_sizes.get(timeframe, 60)
actual_size = len(data) if data is not None else 0
if actual_size < expected_size * 0.8: # Less than 80% of expected data
quality['issues'].append(f"ETH {timeframe} data incomplete: {actual_size}/{expected_size}")
quality['overall_score'] *= 0.9
# Check BTC reference data
btc_expected = self.window_sizes.get('1s', 60)
btc_actual = len(btc_data) if btc_data is not None else 0
if btc_actual < btc_expected * 0.8:
quality['issues'].append(f"BTC reference data incomplete: {btc_actual}/{btc_expected}")
quality['overall_score'] *= 0.9
# Check for data gaps or anomalies
for timeframe, data in eth_data.items():
if data is not None and len(data) > 1:
# Check for price anomalies (sudden jumps > 10%)
prices = data[:, 4] # Close prices
price_changes = np.abs(np.diff(prices) / prices[:-1])
if np.any(price_changes > 0.1):
quality['issues'].append(f"ETH {timeframe} has price anomalies")
quality['overall_score'] *= 0.95
except Exception as e:
logger.error(f"Error assessing data quality: {e}")
quality['issues'].append(f"Quality assessment error: {e}")
quality['overall_score'] *= 0.8
return quality
def _is_market_hours(self, timestamp: datetime) -> bool:
"""Check if it's market hours (crypto markets are 24/7)"""
return True # Crypto markets are always open
def _calculate_data_freshness(self, eth_data: Dict[str, np.ndarray], btc_data: np.ndarray,
current_time: datetime) -> Dict[str, float]:
"""Calculate how fresh the data is"""
freshness = {}
try:
current_timestamp = current_time.timestamp()
# Check ETH data freshness
for timeframe, data in eth_data.items():
if data is not None and len(data) > 0:
latest_timestamp = data[-1, 0] # First column is timestamp
age_seconds = current_timestamp - latest_timestamp
# Convert to appropriate units
if timeframe == '1s':
freshness[f'eth_{timeframe}'] = age_seconds # Seconds
elif timeframe == '1m':
freshness[f'eth_{timeframe}'] = age_seconds / 60 # Minutes
elif timeframe == '1h':
freshness[f'eth_{timeframe}'] = age_seconds / 3600 # Hours
elif timeframe == '1d':
freshness[f'eth_{timeframe}'] = age_seconds / 86400 # Days
else:
freshness[f'eth_{timeframe}'] = float('inf')
# Check BTC data freshness
if btc_data is not None and len(btc_data) > 0:
btc_latest = btc_data[-1, 0]
btc_age = current_timestamp - btc_latest
freshness['btc_1s'] = btc_age # Seconds
else:
freshness['btc_1s'] = float('inf')
except Exception as e:
logger.error(f"Error calculating data freshness: {e}")
freshness['error'] = str(e)
return freshness
def format_for_model(self, stream: UniversalDataStream, model_type: str = 'cnn') -> Dict[str, np.ndarray]:
"""
Format universal data stream for specific model types
Args:
stream: Universal data stream
model_type: Type of model ('cnn', 'rl', 'transformer', etc.)
Returns:
Dictionary with formatted data for the model
"""
try:
if model_type.lower() == 'cnn':
return self._format_for_cnn(stream)
elif model_type.lower() == 'rl':
return self._format_for_rl(stream)
elif model_type.lower() == 'transformer':
return self._format_for_transformer(stream)
else:
# Default format - return raw arrays
return {
'eth_ticks': stream.eth_ticks,
'eth_1m': stream.eth_1m,
'eth_1h': stream.eth_1h,
'eth_1d': stream.eth_1d,
'btc_ticks': stream.btc_ticks,
'metadata': stream.metadata
}
except Exception as e:
logger.error(f"Error formatting data for {model_type}: {e}")
return {}
def _format_for_cnn(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
"""Format data for CNN models"""
# CNN expects [batch, sequence, features] format
formatted = {}
# Remove timestamp column and keep only OHLCV
formatted['eth_ticks'] = stream.eth_ticks[:, 1:] if stream.eth_ticks.shape[1] > 5 else stream.eth_ticks
formatted['eth_1m'] = stream.eth_1m[:, 1:] if stream.eth_1m.shape[1] > 5 else stream.eth_1m
formatted['eth_1h'] = stream.eth_1h[:, 1:] if stream.eth_1h.shape[1] > 5 else stream.eth_1h
formatted['eth_1d'] = stream.eth_1d[:, 1:] if stream.eth_1d.shape[1] > 5 else stream.eth_1d
formatted['btc_ticks'] = stream.btc_ticks[:, 1:] if stream.btc_ticks.shape[1] > 5 else stream.btc_ticks
return formatted
def _format_for_rl(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
"""Format data for RL models"""
# RL typically expects flattened state vector
state_components = []
# Add latest values from each timeframe
if len(stream.eth_ticks) > 0:
state_components.extend(stream.eth_ticks[-1, 1:]) # Latest ETH tick (OHLCV)
if len(stream.eth_1m) > 0:
state_components.extend(stream.eth_1m[-1, 1:]) # Latest ETH 1m (OHLCV)
if len(stream.eth_1h) > 0:
state_components.extend(stream.eth_1h[-1, 1:]) # Latest ETH 1h (OHLCV)
if len(stream.eth_1d) > 0:
state_components.extend(stream.eth_1d[-1, 1:]) # Latest ETH 1d (OHLCV)
if len(stream.btc_ticks) > 0:
state_components.extend(stream.btc_ticks[-1, 1:]) # Latest BTC tick (OHLCV)
# Add some derived features
if len(stream.eth_ticks) > 1:
# Price momentum
eth_momentum = (stream.eth_ticks[-1, 4] - stream.eth_ticks[-2, 4]) / stream.eth_ticks[-2, 4]
state_components.append(eth_momentum)
if len(stream.btc_ticks) > 1:
# BTC momentum for correlation
btc_momentum = (stream.btc_ticks[-1, 4] - stream.btc_ticks[-2, 4]) / stream.btc_ticks[-2, 4]
state_components.append(btc_momentum)
return {'state_vector': np.array(state_components, dtype=np.float32)}
def _format_for_transformer(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
"""Format data for Transformer models"""
# Transformers expect sequence data with attention
formatted = {}
# Keep timestamp for positional encoding
formatted['eth_ticks'] = stream.eth_ticks
formatted['eth_1m'] = stream.eth_1m
formatted['eth_1h'] = stream.eth_1h
formatted['eth_1d'] = stream.eth_1d
formatted['btc_ticks'] = stream.btc_ticks
# Add sequence length information
formatted['sequence_lengths'] = {
'eth_ticks': len(stream.eth_ticks),
'eth_1m': len(stream.eth_1m),
'eth_1h': len(stream.eth_1h),
'eth_1d': len(stream.eth_1d),
'btc_ticks': len(stream.btc_ticks)
}
return formatted
def validate_universal_format(self, stream: UniversalDataStream) -> Tuple[bool, List[str]]:
"""
Validate that the data stream conforms to our universal format
Returns:
(is_valid, list_of_issues)
"""
issues = []
try:
# Check that all required arrays are present and not None
required_arrays = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
for array_name in required_arrays:
array = getattr(stream, array_name)
if array is None:
issues.append(f"{array_name} is None")
elif len(array) == 0:
issues.append(f"{array_name} is empty")
elif array.shape[1] < 5: # Should have at least OHLCV
issues.append(f"{array_name} has insufficient columns: {array.shape[1]} < 5")
# Check timestamp
if stream.timestamp is None:
issues.append("timestamp is None")
# Check data consistency (more tolerant for cached data)
if stream.eth_ticks is not None and len(stream.eth_ticks) > 0:
if stream.btc_ticks is not None and len(stream.btc_ticks) > 0:
# Check if timestamps are roughly aligned (more tolerant for cached data)
eth_latest = stream.eth_ticks[-1, 0] if stream.eth_ticks.shape[1] > 5 else 0
btc_latest = stream.btc_ticks[-1, 0] if stream.btc_ticks.shape[1] > 5 else 0
# Be more tolerant - allow up to 1 hour difference for cached data
max_time_diff = 3600 # 1 hour instead of 5 minutes
time_diff = abs(eth_latest - btc_latest)
if time_diff > max_time_diff:
# This is a warning, not a failure for cached data
issues.append(f"ETH and BTC timestamps far apart: {time_diff} seconds (using cached data)")
logger.warning(f"Timestamp difference detected: {time_diff} seconds - this is normal for cached data")
# Check data quality from metadata
if 'data_quality' in stream.metadata:
quality_score = stream.metadata['data_quality'].get('overall_score', 0)
if quality_score < 0.5: # Very low quality
issues.append(f"Data quality too low: {quality_score:.2f}")
except Exception as e:
issues.append(f"Validation error: {e}")
# For cached data, we're more lenient - only fail on critical issues
critical_issues = [issue for issue in issues if not ('timestamps far apart' in issue and 'cached data' in issue)]
is_valid = len(critical_issues) == 0
return is_valid, issues