411 lines
18 KiB
Python
411 lines
18 KiB
Python
"""
|
|
Universal Data Adapter for Trading Models
|
|
|
|
This adapter ensures all models receive data in our universal format:
|
|
- ETH/USDT: ticks (1s), 1m, 1h, 1d
|
|
- BTC/USDT: ticks (1s) as reference
|
|
|
|
This is the standard input format that all models must respect.
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from dataclasses import dataclass
|
|
|
|
from .config import get_config
|
|
from .data_provider import DataProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class UniversalDataStream:
|
|
"""Universal data stream containing the 5 required timeseries"""
|
|
eth_ticks: np.ndarray # ETH/USDT 1s/ticks data [timestamp, open, high, low, close, volume]
|
|
eth_1m: np.ndarray # ETH/USDT 1m data
|
|
eth_1h: np.ndarray # ETH/USDT 1h data
|
|
eth_1d: np.ndarray # ETH/USDT 1d data
|
|
btc_ticks: np.ndarray # BTC/USDT 1s/ticks reference data
|
|
timestamp: datetime # Current timestamp
|
|
metadata: Dict[str, Any] # Additional metadata
|
|
|
|
class UniversalDataAdapter:
|
|
"""
|
|
Adapter that converts any data source into our universal 5-timeseries format
|
|
"""
|
|
|
|
def __init__(self, data_provider: DataProvider = None):
|
|
"""Initialize the universal data adapter"""
|
|
self.config = get_config()
|
|
self.data_provider = data_provider or DataProvider()
|
|
|
|
# Universal format configuration
|
|
self.required_symbols = ['ETH/USDT', 'BTC/USDT']
|
|
self.required_timeframes = {
|
|
'ETH/USDT': ['1s', '1m', '1h', '1d'], # Primary trading pair
|
|
'BTC/USDT': ['1s'] # Reference pair
|
|
}
|
|
|
|
# Data window sizes for each timeframe
|
|
self.window_sizes = {
|
|
'1s': 60, # Last 60 seconds of tick data
|
|
'1m': 60, # Last 60 minutes
|
|
'1h': 24, # Last 24 hours
|
|
'1d': 30 # Last 30 days
|
|
}
|
|
|
|
# Feature columns (OHLCV)
|
|
self.feature_columns = ['open', 'high', 'low', 'close', 'volume']
|
|
|
|
logger.info("Universal Data Adapter initialized")
|
|
logger.info(f"Required symbols: {self.required_symbols}")
|
|
logger.info(f"Required timeframes: {self.required_timeframes}")
|
|
|
|
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
|
|
"""
|
|
Get data in universal format for all models
|
|
|
|
Returns:
|
|
UniversalDataStream with the 5 required timeseries
|
|
"""
|
|
try:
|
|
current_time = current_time or datetime.now()
|
|
|
|
# Get ETH/USDT data for all required timeframes
|
|
eth_data = {}
|
|
for timeframe in self.required_timeframes['ETH/USDT']:
|
|
data = self._get_timeframe_data('ETH/USDT', timeframe)
|
|
if data is not None:
|
|
eth_data[timeframe] = data
|
|
else:
|
|
logger.warning(f"Failed to get ETH/USDT {timeframe} data")
|
|
return None
|
|
|
|
# Get BTC/USDT reference data
|
|
btc_data = self._get_timeframe_data('BTC/USDT', '1s')
|
|
if btc_data is None:
|
|
logger.warning("Failed to get BTC/USDT reference data")
|
|
return None
|
|
|
|
# Create universal data stream
|
|
stream = UniversalDataStream(
|
|
eth_ticks=eth_data['1s'],
|
|
eth_1m=eth_data['1m'],
|
|
eth_1h=eth_data['1h'],
|
|
eth_1d=eth_data['1d'],
|
|
btc_ticks=btc_data,
|
|
timestamp=current_time,
|
|
metadata={
|
|
'data_quality': self._assess_data_quality(eth_data, btc_data),
|
|
'market_hours': self._is_market_hours(current_time),
|
|
'data_freshness': self._calculate_data_freshness(eth_data, btc_data, current_time)
|
|
}
|
|
)
|
|
|
|
logger.debug(f"Universal data stream created with {len(stream.eth_ticks)} ETH ticks, "
|
|
f"{len(stream.eth_1m)} ETH 1m candles, {len(stream.btc_ticks)} BTC ticks")
|
|
|
|
return stream
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating universal data stream: {e}")
|
|
return None
|
|
|
|
def _get_timeframe_data(self, symbol: str, timeframe: str) -> Optional[np.ndarray]:
|
|
"""Get data for a specific symbol and timeframe"""
|
|
try:
|
|
window_size = self.window_sizes.get(timeframe, 60)
|
|
|
|
# Get historical data from data provider
|
|
df = self.data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=window_size
|
|
)
|
|
|
|
if df is None or df.empty:
|
|
logger.warning(f"No data returned for {symbol} {timeframe}")
|
|
return None
|
|
|
|
# Ensure we have the required columns
|
|
missing_cols = [col for col in self.feature_columns if col not in df.columns]
|
|
if missing_cols:
|
|
logger.warning(f"Missing columns for {symbol} {timeframe}: {missing_cols}")
|
|
return None
|
|
|
|
# Convert to numpy array with timestamp
|
|
data_array = df[self.feature_columns].values.astype(np.float32)
|
|
|
|
# Add timestamp column if available
|
|
if 'timestamp' in df.columns:
|
|
timestamps = pd.to_datetime(df['timestamp']).astype(np.int64) // 10**9 # Unix timestamp
|
|
data_with_time = np.column_stack([timestamps, data_array])
|
|
else:
|
|
# Generate timestamps if not available
|
|
end_time = datetime.now()
|
|
if timeframe == '1s':
|
|
timestamps = [(end_time - timedelta(seconds=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
|
|
elif timeframe == '1m':
|
|
timestamps = [(end_time - timedelta(minutes=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
|
|
elif timeframe == '1h':
|
|
timestamps = [(end_time - timedelta(hours=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
|
|
elif timeframe == '1d':
|
|
timestamps = [(end_time - timedelta(days=i)).timestamp() for i in range(len(data_array)-1, -1, -1)]
|
|
else:
|
|
timestamps = [end_time.timestamp()] * len(data_array)
|
|
|
|
data_with_time = np.column_stack([timestamps, data_array])
|
|
|
|
return data_with_time
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting {symbol} {timeframe} data: {e}")
|
|
return None
|
|
|
|
def _assess_data_quality(self, eth_data: Dict[str, np.ndarray], btc_data: np.ndarray) -> Dict[str, Any]:
|
|
"""Assess the quality of the data streams"""
|
|
quality = {
|
|
'overall_score': 1.0,
|
|
'issues': []
|
|
}
|
|
|
|
try:
|
|
# Check ETH data completeness
|
|
for timeframe, data in eth_data.items():
|
|
expected_size = self.window_sizes.get(timeframe, 60)
|
|
actual_size = len(data) if data is not None else 0
|
|
|
|
if actual_size < expected_size * 0.8: # Less than 80% of expected data
|
|
quality['issues'].append(f"ETH {timeframe} data incomplete: {actual_size}/{expected_size}")
|
|
quality['overall_score'] *= 0.9
|
|
|
|
# Check BTC reference data
|
|
btc_expected = self.window_sizes.get('1s', 60)
|
|
btc_actual = len(btc_data) if btc_data is not None else 0
|
|
|
|
if btc_actual < btc_expected * 0.8:
|
|
quality['issues'].append(f"BTC reference data incomplete: {btc_actual}/{btc_expected}")
|
|
quality['overall_score'] *= 0.9
|
|
|
|
# Check for data gaps or anomalies
|
|
for timeframe, data in eth_data.items():
|
|
if data is not None and len(data) > 1:
|
|
# Check for price anomalies (sudden jumps > 10%)
|
|
prices = data[:, 4] # Close prices
|
|
price_changes = np.abs(np.diff(prices) / prices[:-1])
|
|
if np.any(price_changes > 0.1):
|
|
quality['issues'].append(f"ETH {timeframe} has price anomalies")
|
|
quality['overall_score'] *= 0.95
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error assessing data quality: {e}")
|
|
quality['issues'].append(f"Quality assessment error: {e}")
|
|
quality['overall_score'] *= 0.8
|
|
|
|
return quality
|
|
|
|
def _is_market_hours(self, timestamp: datetime) -> bool:
|
|
"""Check if it's market hours (crypto markets are 24/7)"""
|
|
return True # Crypto markets are always open
|
|
|
|
def _calculate_data_freshness(self, eth_data: Dict[str, np.ndarray], btc_data: np.ndarray,
|
|
current_time: datetime) -> Dict[str, float]:
|
|
"""Calculate how fresh the data is"""
|
|
freshness = {}
|
|
|
|
try:
|
|
current_timestamp = current_time.timestamp()
|
|
|
|
# Check ETH data freshness
|
|
for timeframe, data in eth_data.items():
|
|
if data is not None and len(data) > 0:
|
|
latest_timestamp = data[-1, 0] # First column is timestamp
|
|
age_seconds = current_timestamp - latest_timestamp
|
|
|
|
# Convert to appropriate units
|
|
if timeframe == '1s':
|
|
freshness[f'eth_{timeframe}'] = age_seconds # Seconds
|
|
elif timeframe == '1m':
|
|
freshness[f'eth_{timeframe}'] = age_seconds / 60 # Minutes
|
|
elif timeframe == '1h':
|
|
freshness[f'eth_{timeframe}'] = age_seconds / 3600 # Hours
|
|
elif timeframe == '1d':
|
|
freshness[f'eth_{timeframe}'] = age_seconds / 86400 # Days
|
|
else:
|
|
freshness[f'eth_{timeframe}'] = float('inf')
|
|
|
|
# Check BTC data freshness
|
|
if btc_data is not None and len(btc_data) > 0:
|
|
btc_latest = btc_data[-1, 0]
|
|
btc_age = current_timestamp - btc_latest
|
|
freshness['btc_1s'] = btc_age # Seconds
|
|
else:
|
|
freshness['btc_1s'] = float('inf')
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating data freshness: {e}")
|
|
freshness['error'] = str(e)
|
|
|
|
return freshness
|
|
|
|
def format_for_model(self, stream: UniversalDataStream, model_type: str = 'cnn') -> Dict[str, np.ndarray]:
|
|
"""
|
|
Format universal data stream for specific model types
|
|
|
|
Args:
|
|
stream: Universal data stream
|
|
model_type: Type of model ('cnn', 'rl', 'transformer', etc.)
|
|
|
|
Returns:
|
|
Dictionary with formatted data for the model
|
|
"""
|
|
try:
|
|
if model_type.lower() == 'cnn':
|
|
return self._format_for_cnn(stream)
|
|
elif model_type.lower() == 'rl':
|
|
return self._format_for_rl(stream)
|
|
elif model_type.lower() == 'transformer':
|
|
return self._format_for_transformer(stream)
|
|
else:
|
|
# Default format - return raw arrays
|
|
return {
|
|
'eth_ticks': stream.eth_ticks,
|
|
'eth_1m': stream.eth_1m,
|
|
'eth_1h': stream.eth_1h,
|
|
'eth_1d': stream.eth_1d,
|
|
'btc_ticks': stream.btc_ticks,
|
|
'metadata': stream.metadata
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error formatting data for {model_type}: {e}")
|
|
return {}
|
|
|
|
def _format_for_cnn(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
|
|
"""Format data for CNN models"""
|
|
# CNN expects [batch, sequence, features] format
|
|
formatted = {}
|
|
|
|
# Remove timestamp column and keep only OHLCV
|
|
formatted['eth_ticks'] = stream.eth_ticks[:, 1:] if stream.eth_ticks.shape[1] > 5 else stream.eth_ticks
|
|
formatted['eth_1m'] = stream.eth_1m[:, 1:] if stream.eth_1m.shape[1] > 5 else stream.eth_1m
|
|
formatted['eth_1h'] = stream.eth_1h[:, 1:] if stream.eth_1h.shape[1] > 5 else stream.eth_1h
|
|
formatted['eth_1d'] = stream.eth_1d[:, 1:] if stream.eth_1d.shape[1] > 5 else stream.eth_1d
|
|
formatted['btc_ticks'] = stream.btc_ticks[:, 1:] if stream.btc_ticks.shape[1] > 5 else stream.btc_ticks
|
|
|
|
return formatted
|
|
|
|
def _format_for_rl(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
|
|
"""Format data for RL models"""
|
|
# RL typically expects flattened state vector
|
|
state_components = []
|
|
|
|
# Add latest values from each timeframe
|
|
if len(stream.eth_ticks) > 0:
|
|
state_components.extend(stream.eth_ticks[-1, 1:]) # Latest ETH tick (OHLCV)
|
|
|
|
if len(stream.eth_1m) > 0:
|
|
state_components.extend(stream.eth_1m[-1, 1:]) # Latest ETH 1m (OHLCV)
|
|
|
|
if len(stream.eth_1h) > 0:
|
|
state_components.extend(stream.eth_1h[-1, 1:]) # Latest ETH 1h (OHLCV)
|
|
|
|
if len(stream.eth_1d) > 0:
|
|
state_components.extend(stream.eth_1d[-1, 1:]) # Latest ETH 1d (OHLCV)
|
|
|
|
if len(stream.btc_ticks) > 0:
|
|
state_components.extend(stream.btc_ticks[-1, 1:]) # Latest BTC tick (OHLCV)
|
|
|
|
# Add some derived features
|
|
if len(stream.eth_ticks) > 1:
|
|
# Price momentum
|
|
eth_momentum = (stream.eth_ticks[-1, 4] - stream.eth_ticks[-2, 4]) / stream.eth_ticks[-2, 4]
|
|
state_components.append(eth_momentum)
|
|
|
|
if len(stream.btc_ticks) > 1:
|
|
# BTC momentum for correlation
|
|
btc_momentum = (stream.btc_ticks[-1, 4] - stream.btc_ticks[-2, 4]) / stream.btc_ticks[-2, 4]
|
|
state_components.append(btc_momentum)
|
|
|
|
return {'state_vector': np.array(state_components, dtype=np.float32)}
|
|
|
|
def _format_for_transformer(self, stream: UniversalDataStream) -> Dict[str, np.ndarray]:
|
|
"""Format data for Transformer models"""
|
|
# Transformers expect sequence data with attention
|
|
formatted = {}
|
|
|
|
# Keep timestamp for positional encoding
|
|
formatted['eth_ticks'] = stream.eth_ticks
|
|
formatted['eth_1m'] = stream.eth_1m
|
|
formatted['eth_1h'] = stream.eth_1h
|
|
formatted['eth_1d'] = stream.eth_1d
|
|
formatted['btc_ticks'] = stream.btc_ticks
|
|
|
|
# Add sequence length information
|
|
formatted['sequence_lengths'] = {
|
|
'eth_ticks': len(stream.eth_ticks),
|
|
'eth_1m': len(stream.eth_1m),
|
|
'eth_1h': len(stream.eth_1h),
|
|
'eth_1d': len(stream.eth_1d),
|
|
'btc_ticks': len(stream.btc_ticks)
|
|
}
|
|
|
|
return formatted
|
|
|
|
def validate_universal_format(self, stream: UniversalDataStream) -> Tuple[bool, List[str]]:
|
|
"""
|
|
Validate that the data stream conforms to our universal format
|
|
|
|
Returns:
|
|
(is_valid, list_of_issues)
|
|
"""
|
|
issues = []
|
|
|
|
try:
|
|
# Check that all required arrays are present and not None
|
|
required_arrays = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
|
|
for array_name in required_arrays:
|
|
array = getattr(stream, array_name)
|
|
if array is None:
|
|
issues.append(f"{array_name} is None")
|
|
elif len(array) == 0:
|
|
issues.append(f"{array_name} is empty")
|
|
elif array.shape[1] < 5: # Should have at least OHLCV
|
|
issues.append(f"{array_name} has insufficient columns: {array.shape[1]} < 5")
|
|
|
|
# Check timestamp
|
|
if stream.timestamp is None:
|
|
issues.append("timestamp is None")
|
|
|
|
# Check data consistency (more tolerant for cached data)
|
|
if stream.eth_ticks is not None and len(stream.eth_ticks) > 0:
|
|
if stream.btc_ticks is not None and len(stream.btc_ticks) > 0:
|
|
# Check if timestamps are roughly aligned (more tolerant for cached data)
|
|
eth_latest = stream.eth_ticks[-1, 0] if stream.eth_ticks.shape[1] > 5 else 0
|
|
btc_latest = stream.btc_ticks[-1, 0] if stream.btc_ticks.shape[1] > 5 else 0
|
|
|
|
# Be more tolerant - allow up to 1 hour difference for cached data
|
|
max_time_diff = 3600 # 1 hour instead of 5 minutes
|
|
time_diff = abs(eth_latest - btc_latest)
|
|
|
|
if time_diff > max_time_diff:
|
|
# This is a warning, not a failure for cached data
|
|
issues.append(f"ETH and BTC timestamps far apart: {time_diff} seconds (using cached data)")
|
|
logger.warning(f"Timestamp difference detected: {time_diff} seconds - this is normal for cached data")
|
|
|
|
# Check data quality from metadata
|
|
if 'data_quality' in stream.metadata:
|
|
quality_score = stream.metadata['data_quality'].get('overall_score', 0)
|
|
if quality_score < 0.5: # Very low quality
|
|
issues.append(f"Data quality too low: {quality_score:.2f}")
|
|
|
|
except Exception as e:
|
|
issues.append(f"Validation error: {e}")
|
|
|
|
# For cached data, we're more lenient - only fail on critical issues
|
|
critical_issues = [issue for issue in issues if not ('timestamps far apart' in issue and 'cached data' in issue)]
|
|
is_valid = len(critical_issues) == 0
|
|
|
|
return is_valid, issues |