gogo2/crypto/gogo/data/data_utils.py
2025-02-12 01:27:38 +02:00

207 lines
7.9 KiB
Python

# data/data_utils.py
import numpy as np
import torch
from collections import deque
def calculate_ema(data, period):
"""Calculates EMA for a given data series and period."""
if len(data) < period:
return [np.nan] * len(data) # Return NaN for insufficient data
close_prices = np.array([candle['close'] for candle in data])
ema = [close_prices[0]] # Initialize EMA with the first close price
multiplier = 2 / (period + 1)
for i in range(1, len(close_prices)):
ema_value = (close_prices[i] - ema[-1]) * multiplier + ema[-1]
ema.append(ema_value)
return ema
def preprocess_data(candles, ticks, ema_periods=[5, 10, 20, 60, 120, 200]):
"""Preprocesses candles and ticks for the transformer.
Args:
candles: List of candle dictionaries.
ticks: List of tick dictionaries.
ema_periods: List of periods for EMA calculation.
Returns:
Tuple: (candle_features, tick_features, future_candle, future_volume, future_ticks)
"""
if not candles or len(candles) < 2: # Need at least 2 candles for current and future
return None, None, None, None, None
# --- Calculate EMAs ---
emas = {}
for period in ema_periods:
emas[period] = calculate_ema(candles, period)
# --- Prepare Candle Features ---
candle_features = []
for i, candle in enumerate(candles[:-1]): # Exclude the last candle (used for future)
features = [
candle['open'],
candle['high'],
candle['low'],
candle['close'],
candle['volume'],
]
for period in ema_periods:
features.append(emas[period][i])
candle_features.append(features)
# --- Prepare Tick Features (Last 30 seconds before next candle) ---
last_candle_timestamp = candles[-2]['timestamp']
thirty_sec_ago = last_candle_timestamp - 30 * 1000
relevant_ticks = [tick for tick in ticks if tick['timestamp'] > thirty_sec_ago and tick['timestamp']<= last_candle_timestamp]
tick_features = []
# Pad or truncate tick data to a fixed length (e.g., 30 ticks, 1 tick/second)
for i in range(30):
if i < len(relevant_ticks):
tick_features.extend([relevant_ticks[i]['price'], relevant_ticks[i]['quantity']])
else:
tick_features.extend([0.0, 0.0]) # Padding with 0s
# --- Prepare Future Data (Targets) ---
future_candle = [
candles[-1]['open'],
candles[-1]['high'],
candles[-1]['low'],
candles[-1]['close'],
candles[-1]['volume'],
]
# --- Future Volume (5-min) ---
future_volume = 0.0 # we don't know it yet.
#future_volume = calculate_volume_for_next_n_minutes(candles, n=5)
# --- Future Ticks (Next 30 seconds, for masking) ---
next_candle_timestamp = candles[-1]['timestamp']
future_ticks_end_time = next_candle_timestamp + 30 * 1000
future_ticks_data = [tick for tick in ticks if tick['timestamp'] > next_candle_timestamp and tick['timestamp'] <= future_ticks_end_time ]
future_ticks = []
for i in range(30):
if i < len(future_ticks_data):
future_ticks.extend([future_ticks_data[i]['price'], future_ticks_data[i]['quantity']])
else:
future_ticks.extend([0.0, 0.0])
return (np.array(candle_features, dtype=np.float32),
np.array(tick_features, dtype=np.float32),
np.array(future_candle, dtype=np.float32),
np.array(future_volume, dtype=np.float32),
np.array(future_ticks, dtype=np.float32)
)
def create_mask(seq_len, future_mask=True):
"""Creates a mask for the input sequence.
Args:
seq_len: The length of the sequence.
future_mask: Whether to mask the future tokens.
Returns:
A mask tensor of shape (seq_len, seq_len).
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
if future_mask:
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_padding_mask(seq, pad_token=0):
"""
Creates a padding mask.
Args:
seq: sequence tensor
pad_token: padding token, default 0.
Returns: padding mask, (seq_len, seq_len)
"""
return (seq == pad_token).all(dim=-1).unsqueeze(0)
def get_aligned_candle_with_index(candles_list, base_ts):
"""
Find the candle from candles_list that is closest to (and <=) base_ts.
Returns: (index, candle)
"""
aligned_index = None
aligned_candle = None
for i in range(len(candles_list)):
if candles_list[i]["timestamp"] <= base_ts:
aligned_index = i
aligned_candle = candles_list[i]
else:
break
return aligned_index, aligned_candle
def get_features_for_tf(candles_list, aligned_index, period=10):
"""
Extract features from the candle at aligned_index.
If aligned_index is None, return a zeroed feature vector.
"""
if aligned_index is None:
return [0.0] * 7 # return zeroed feature vector
candle = candles_list[aligned_index]
# Simple features: open, high, low, close, volume, and two EMAs.
close_prices = [c["close"] for c in candles_list[:aligned_index+1]]
ema_short = calculate_ema(candles_list[:aligned_index+1], period=period)[-1]
ema_long = calculate_ema(candles_list[:aligned_index+1], period=period*2)[-1]
features = [
candle["open"],
candle["high"],
candle["low"],
candle["close"],
candle["volume"],
ema_short,
ema_long
]
return features
# Example usage (within a larger training loop):
if __name__ == '__main__':
# Dummy data for demonstration
candles_data = [
{'timestamp': 1678886400000, 'open': 25000.0, 'high': 25050.0, 'low': 24950.0, 'close': 25025.0, 'volume': 100.0},
{'timestamp': 1678886460000, 'open': 25025.0, 'high': 25100.0, 'low': 25000.0, 'close': 25075.0, 'volume': 120.0},
{'timestamp': 1678886520000, 'open': 25075.0, 'high': 25150.0, 'low': 25050.0, 'close': 25125.0, 'volume': 150.0},
{'timestamp': 1678886580000, 'open': 25125.0, 'high': 25200.0, 'low': 25100.0, 'close': 25175.0, 'volume': 180.0},
{'timestamp': 1678886640000, 'open': 25175.0, 'high': 25250.0, 'low': 25150.0, 'close': 25225.0, 'volume': 200.0},
]
ticks_data = [
{'timestamp': 1678886455000, 'symbol': 'BTC/USDT', 'price': 25020.0, 'quantity': 0.1},
{'timestamp': 1678886458000, 'symbol': 'BTC/USDT', 'price': 25022.0, 'quantity': 0.2},
{'timestamp': 1678886515000, 'symbol': 'BTC/USDT', 'price': 25070.0, 'quantity': 0.3},
{'timestamp': 1678886518000, 'symbol': 'BTC/USDT', 'price': 25078.0, 'quantity': 0.1},
{'timestamp': 1678886575000, 'symbol': 'BTC/USDT', 'price': 25120.0, 'quantity': 0.2},
{'timestamp': 1678886578000, 'symbol': 'BTC/USDT', 'price': 25122.0, 'quantity': 0.1},
{'timestamp': 1678886635000, 'symbol': 'BTC/USDT', 'price': 25170.0, 'quantity': 0.4},
{'timestamp': 1678886638000, 'symbol': 'BTC/USDT', 'price': 25172.0, 'quantity': 0.2},
]
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles_data, ticks_data)
print("Candle Features:\n", candle_features)
print("\nTick Features:\n", tick_features)
print("\nFuture Candle:\n", future_candle)
print("\nFuture Volume:\n", future_volume)
print("\nFuture Ticks\n", future_ticks)
# Example mask creation
seq_len = len(candle_features) # Example sequence length
mask = create_mask(seq_len)
print("\nMask:\n", mask)
padding_mask = create_padding_mask(torch.tensor(candle_features))
print(f"\nPadding mask: {padding_mask}")
# Example usage of the new functions
index, candle = get_aligned_candle_with_index(candles_data, 1678886570000)
if candle:
print(f"\nAligned candle: {candle}")
else:
print("\nNo aligned candle found.")
features = get_features_for_tf(candles_data, index)
print(f"\nFeatures for timeframe: {features}")