158 lines
6.3 KiB
Python
158 lines
6.3 KiB
Python
# data/data_utils.py
|
|
import numpy as np
|
|
import torch
|
|
from collections import deque
|
|
|
|
def calculate_ema(data, period):
|
|
"""Calculates EMA for a given data series and period."""
|
|
if len(data) < period:
|
|
return [np.nan] * len(data) # Return NaN for insufficient data
|
|
|
|
close_prices = np.array([candle['close'] for candle in data])
|
|
ema = [close_prices[0]] # Initialize EMA with the first close price
|
|
|
|
multiplier = 2 / (period + 1)
|
|
for i in range(1, len(close_prices)):
|
|
ema_value = (close_prices[i] - ema[-1]) * multiplier + ema[-1]
|
|
ema.append(ema_value)
|
|
return ema
|
|
|
|
def preprocess_data(candles, ticks, ema_periods=[5, 10, 20, 60, 120, 200]):
|
|
"""Preprocesses candles and ticks for the transformer.
|
|
|
|
Args:
|
|
candles: List of candle dictionaries.
|
|
ticks: List of tick dictionaries.
|
|
ema_periods: List of periods for EMA calculation.
|
|
|
|
Returns:
|
|
Tuple: (candle_features, tick_features, future_candle, future_volume, future_ticks)
|
|
"""
|
|
if not candles or len(candles) < 2: # Need at least 2 candles for current and future
|
|
return None, None, None, None, None
|
|
|
|
# --- Calculate EMAs ---
|
|
emas = {}
|
|
for period in ema_periods:
|
|
emas[period] = calculate_ema(candles, period)
|
|
|
|
|
|
# --- Prepare Candle Features ---
|
|
candle_features = []
|
|
for i, candle in enumerate(candles[:-1]): # Exclude the last candle (used for future)
|
|
features = [
|
|
candle['open'],
|
|
candle['high'],
|
|
candle['low'],
|
|
candle['close'],
|
|
candle['volume'],
|
|
]
|
|
for period in ema_periods:
|
|
features.append(emas[period][i])
|
|
candle_features.append(features)
|
|
|
|
# --- Prepare Tick Features (Last 30 seconds before next candle) ---
|
|
last_candle_timestamp = candles[-2]['timestamp']
|
|
thirty_sec_ago = last_candle_timestamp - 30 * 1000
|
|
relevant_ticks = [tick for tick in ticks if tick['timestamp'] > thirty_sec_ago and tick['timestamp']<= last_candle_timestamp]
|
|
|
|
tick_features = []
|
|
# Pad or truncate tick data to a fixed length (e.g., 30 ticks, 1 tick/second)
|
|
for i in range(30):
|
|
if i < len(relevant_ticks):
|
|
tick_features.extend([relevant_ticks[i]['price'], relevant_ticks[i]['quantity']])
|
|
else:
|
|
tick_features.extend([0.0, 0.0]) # Padding with 0s
|
|
|
|
|
|
# --- Prepare Future Data (Targets) ---
|
|
future_candle = [
|
|
candles[-1]['open'],
|
|
candles[-1]['high'],
|
|
candles[-1]['low'],
|
|
candles[-1]['close'],
|
|
candles[-1]['volume'],
|
|
]
|
|
|
|
# --- Future Volume (5-min) ---
|
|
future_volume = 0.0 # we don't know it yet.
|
|
#future_volume = calculate_volume_for_next_n_minutes(candles, n=5)
|
|
|
|
# --- Future Ticks (Next 30 seconds, for masking) ---
|
|
next_candle_timestamp = candles[-1]['timestamp']
|
|
future_ticks_end_time = next_candle_timestamp + 30 * 1000
|
|
future_ticks_data = [tick for tick in ticks if tick['timestamp'] > next_candle_timestamp and tick['timestamp'] <= future_ticks_end_time ]
|
|
future_ticks = []
|
|
for i in range(30):
|
|
if i < len(future_ticks_data):
|
|
future_ticks.extend([future_ticks_data[i]['price'], future_ticks_data[i]['quantity']])
|
|
else:
|
|
future_ticks.extend([0.0, 0.0])
|
|
|
|
return (np.array(candle_features, dtype=np.float32),
|
|
np.array(tick_features, dtype=np.float32),
|
|
np.array(future_candle, dtype=np.float32),
|
|
np.array(future_volume, dtype=np.float32),
|
|
np.array(future_ticks, dtype=np.float32)
|
|
)
|
|
|
|
def create_mask(seq_len, future_mask=True):
|
|
"""Creates a mask for the input sequence.
|
|
|
|
Args:
|
|
seq_len: The length of the sequence.
|
|
future_mask: Whether to mask the future tokens.
|
|
|
|
Returns:
|
|
A mask tensor of shape (seq_len, seq_len).
|
|
"""
|
|
mask = torch.tril(torch.ones(seq_len, seq_len))
|
|
if future_mask:
|
|
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
return mask
|
|
|
|
def create_padding_mask(seq, pad_token=0):
|
|
"""
|
|
Creates a padding mask.
|
|
Args:
|
|
seq: sequence tensor
|
|
pad_token: padding token, default 0.
|
|
Returns: padding mask, (seq_len, seq_len)
|
|
"""
|
|
return (seq == pad_token).all(dim=-1).unsqueeze(0)
|
|
|
|
# Example usage (within a larger training loop):
|
|
if __name__ == '__main__':
|
|
# Dummy data for demonstration
|
|
candles_data = [
|
|
{'timestamp': 1678886400000, 'open': 25000.0, 'high': 25050.0, 'low': 24950.0, 'close': 25025.0, 'volume': 100.0},
|
|
{'timestamp': 1678886460000, 'open': 25025.0, 'high': 25100.0, 'low': 25000.0, 'close': 25075.0, 'volume': 120.0},
|
|
{'timestamp': 1678886520000, 'open': 25075.0, 'high': 25150.0, 'low': 25050.0, 'close': 25125.0, 'volume': 150.0},
|
|
{'timestamp': 1678886580000, 'open': 25125.0, 'high': 25200.0, 'low': 25100.0, 'close': 25175.0, 'volume': 180.0},
|
|
{'timestamp': 1678886640000, 'open': 25175.0, 'high': 25250.0, 'low': 25150.0, 'close': 25225.0, 'volume': 200.0},
|
|
]
|
|
ticks_data = [
|
|
{'timestamp': 1678886455000, 'symbol': 'BTC/USDT', 'price': 25020.0, 'quantity': 0.1},
|
|
{'timestamp': 1678886458000, 'symbol': 'BTC/USDT', 'price': 25022.0, 'quantity': 0.2},
|
|
{'timestamp': 1678886515000, 'symbol': 'BTC/USDT', 'price': 25070.0, 'quantity': 0.3},
|
|
{'timestamp': 1678886518000, 'symbol': 'BTC/USDT', 'price': 25078.0, 'quantity': 0.1},
|
|
{'timestamp': 1678886575000, 'symbol': 'BTC/USDT', 'price': 25120.0, 'quantity': 0.2},
|
|
{'timestamp': 1678886578000, 'symbol': 'BTC/USDT', 'price': 25122.0, 'quantity': 0.1},
|
|
{'timestamp': 1678886635000, 'symbol': 'BTC/USDT', 'price': 25170.0, 'quantity': 0.4},
|
|
{'timestamp': 1678886638000, 'symbol': 'BTC/USDT', 'price': 25172.0, 'quantity': 0.2},
|
|
]
|
|
|
|
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles_data, ticks_data)
|
|
|
|
print("Candle Features:\n", candle_features)
|
|
print("\nTick Features:\n", tick_features)
|
|
print("\nFuture Candle:\n", future_candle)
|
|
print("\nFuture Volume:\n", future_volume)
|
|
print("\nFuture Ticks\n", future_ticks)
|
|
|
|
# Example mask creation
|
|
seq_len = len(candle_features) # Example sequence length
|
|
mask = create_mask(seq_len)
|
|
print("\nMask:\n", mask)
|
|
padding_mask = create_padding_mask(torch.tensor(candle_features))
|
|
print(f"\nPadding mask: {padding_mask}") |