init gogo

This commit is contained in:
Dobromir Popov
2025-02-12 01:15:44 +02:00
parent 6dfeee18bf
commit 5606ed3cab
11 changed files with 822 additions and 8 deletions

View File

@ -0,0 +1,158 @@
# data/data_utils.py
import numpy as np
import torch
from collections import deque
def calculate_ema(data, period):
"""Calculates EMA for a given data series and period."""
if len(data) < period:
return [np.nan] * len(data) # Return NaN for insufficient data
close_prices = np.array([candle['close'] for candle in data])
ema = [close_prices[0]] # Initialize EMA with the first close price
multiplier = 2 / (period + 1)
for i in range(1, len(close_prices)):
ema_value = (close_prices[i] - ema[-1]) * multiplier + ema[-1]
ema.append(ema_value)
return ema
def preprocess_data(candles, ticks, ema_periods=[5, 10, 20, 60, 120, 200]):
"""Preprocesses candles and ticks for the transformer.
Args:
candles: List of candle dictionaries.
ticks: List of tick dictionaries.
ema_periods: List of periods for EMA calculation.
Returns:
Tuple: (candle_features, tick_features, future_candle, future_volume, future_ticks)
"""
if not candles or len(candles) < 2: # Need at least 2 candles for current and future
return None, None, None, None, None
# --- Calculate EMAs ---
emas = {}
for period in ema_periods:
emas[period] = calculate_ema(candles, period)
# --- Prepare Candle Features ---
candle_features = []
for i, candle in enumerate(candles[:-1]): # Exclude the last candle (used for future)
features = [
candle['open'],
candle['high'],
candle['low'],
candle['close'],
candle['volume'],
]
for period in ema_periods:
features.append(emas[period][i])
candle_features.append(features)
# --- Prepare Tick Features (Last 30 seconds before next candle) ---
last_candle_timestamp = candles[-2]['timestamp']
thirty_sec_ago = last_candle_timestamp - 30 * 1000
relevant_ticks = [tick for tick in ticks if tick['timestamp'] > thirty_sec_ago and tick['timestamp']<= last_candle_timestamp]
tick_features = []
# Pad or truncate tick data to a fixed length (e.g., 30 ticks, 1 tick/second)
for i in range(30):
if i < len(relevant_ticks):
tick_features.extend([relevant_ticks[i]['price'], relevant_ticks[i]['quantity']])
else:
tick_features.extend([0.0, 0.0]) # Padding with 0s
# --- Prepare Future Data (Targets) ---
future_candle = [
candles[-1]['open'],
candles[-1]['high'],
candles[-1]['low'],
candles[-1]['close'],
candles[-1]['volume'],
]
# --- Future Volume (5-min) ---
future_volume = 0.0 # we don't know it yet.
#future_volume = calculate_volume_for_next_n_minutes(candles, n=5)
# --- Future Ticks (Next 30 seconds, for masking) ---
next_candle_timestamp = candles[-1]['timestamp']
future_ticks_end_time = next_candle_timestamp + 30 * 1000
future_ticks_data = [tick for tick in ticks if tick['timestamp'] > next_candle_timestamp and tick['timestamp'] <= future_ticks_end_time ]
future_ticks = []
for i in range(30):
if i < len(future_ticks_data):
future_ticks.extend([future_ticks_data[i]['price'], future_ticks_data[i]['quantity']])
else:
future_ticks.extend([0.0, 0.0])
return (np.array(candle_features, dtype=np.float32),
np.array(tick_features, dtype=np.float32),
np.array(future_candle, dtype=np.float32),
np.array(future_volume, dtype=np.float32),
np.array(future_ticks, dtype=np.float32)
)
def create_mask(seq_len, future_mask=True):
"""Creates a mask for the input sequence.
Args:
seq_len: The length of the sequence.
future_mask: Whether to mask the future tokens.
Returns:
A mask tensor of shape (seq_len, seq_len).
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
if future_mask:
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_padding_mask(seq, pad_token=0):
"""
Creates a padding mask.
Args:
seq: sequence tensor
pad_token: padding token, default 0.
Returns: padding mask, (seq_len, seq_len)
"""
return (seq == pad_token).all(dim=-1).unsqueeze(0)
# Example usage (within a larger training loop):
if __name__ == '__main__':
# Dummy data for demonstration
candles_data = [
{'timestamp': 1678886400000, 'open': 25000.0, 'high': 25050.0, 'low': 24950.0, 'close': 25025.0, 'volume': 100.0},
{'timestamp': 1678886460000, 'open': 25025.0, 'high': 25100.0, 'low': 25000.0, 'close': 25075.0, 'volume': 120.0},
{'timestamp': 1678886520000, 'open': 25075.0, 'high': 25150.0, 'low': 25050.0, 'close': 25125.0, 'volume': 150.0},
{'timestamp': 1678886580000, 'open': 25125.0, 'high': 25200.0, 'low': 25100.0, 'close': 25175.0, 'volume': 180.0},
{'timestamp': 1678886640000, 'open': 25175.0, 'high': 25250.0, 'low': 25150.0, 'close': 25225.0, 'volume': 200.0},
]
ticks_data = [
{'timestamp': 1678886455000, 'symbol': 'BTC/USDT', 'price': 25020.0, 'quantity': 0.1},
{'timestamp': 1678886458000, 'symbol': 'BTC/USDT', 'price': 25022.0, 'quantity': 0.2},
{'timestamp': 1678886515000, 'symbol': 'BTC/USDT', 'price': 25070.0, 'quantity': 0.3},
{'timestamp': 1678886518000, 'symbol': 'BTC/USDT', 'price': 25078.0, 'quantity': 0.1},
{'timestamp': 1678886575000, 'symbol': 'BTC/USDT', 'price': 25120.0, 'quantity': 0.2},
{'timestamp': 1678886578000, 'symbol': 'BTC/USDT', 'price': 25122.0, 'quantity': 0.1},
{'timestamp': 1678886635000, 'symbol': 'BTC/USDT', 'price': 25170.0, 'quantity': 0.4},
{'timestamp': 1678886638000, 'symbol': 'BTC/USDT', 'price': 25172.0, 'quantity': 0.2},
]
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles_data, ticks_data)
print("Candle Features:\n", candle_features)
print("\nTick Features:\n", tick_features)
print("\nFuture Candle:\n", future_candle)
print("\nFuture Volume:\n", future_volume)
print("\nFuture Ticks\n", future_ticks)
# Example mask creation
seq_len = len(candle_features) # Example sequence length
mask = create_mask(seq_len)
print("\nMask:\n", mask)
padding_mask = create_padding_mask(torch.tensor(candle_features))
print(f"\nPadding mask: {padding_mask}")

View File

@ -0,0 +1,158 @@
# data/live_data.py
import asyncio
import json
import os
import time
from collections import deque
import ccxt.async_support as ccxt
from dotenv import load_dotenv
class LiveDataManager:
def __init__(self, symbol, exchange_name='mexc', window_size=120):
load_dotenv() # Load environment variables
self.symbol = symbol
self.exchange_name = exchange_name
self.window_size = window_size
self.candles = deque(maxlen=window_size)
self.ticks = deque(maxlen=window_size * 60) # Assuming max 60 ticks per minute
self.last_candle_time = None
self.exchange = self._initialize_exchange()
self.lock = asyncio.Lock() # Lock to prevent race conditions
def _initialize_exchange(self):
exchange_class = getattr(ccxt, self.exchange_name)
mexc_api_key = os.environ.get('MEXC_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET')
if not mexc_api_key or not mexc_api_secret:
raise ValueError("API keys not found in environment variables. Please check your .env file.")
return exchange_class({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
async def _fetch_initial_candles(self):
print(f"Fetching initial candles for {self.symbol}...")
now = int(time.time() * 1000)
since = now - self.window_size * 60 * 1000
try:
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
for candle in candles:
self.candles.append(self._format_candle(candle))
if candles:
self.last_candle_time = candles[-1][0]
print(f"Fetched {len(candles)} initial candles.")
except Exception as e:
print(f"Error fetching initial candles: {e}")
def _format_candle(self, candle_data):
return {
'timestamp': candle_data[0],
'open': float(candle_data[1]),
'high': float(candle_data[2]),
'low': float(candle_data[3]),
'close': float(candle_data[4]),
'volume': float(candle_data[5])
}
def _format_tick(self, tick_data):
# Check if 's' (symbol) is present, otherwise return None
if 's' not in tick_data:
return None
return {
'timestamp': tick_data['E'],
'symbol': tick_data['s'],
'price': float(tick_data['p']),
'quantity': float(tick_data['q'])
}
async def _update_candle(self, tick):
async with self.lock:
if self.last_candle_time is None: # first time
self.last_candle_time = tick['timestamp'] - (tick['timestamp'] % (60 * 1000))
new_candle = {
'timestamp': self.last_candle_time,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['quantity']
}
self.candles.append(new_candle)
if tick['timestamp'] >= self.last_candle_time + 60 * 1000:
# Start a new candle
self.last_candle_time += 60 * 1000
new_candle = {
'timestamp': self.last_candle_time,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['quantity']
}
self.candles.append(new_candle)
else:
# Update the current candle
current_candle = self.candles[-1]
current_candle['high'] = max(current_candle['high'], tick['price'])
current_candle['low'] = min(current_candle['low'], tick['price'])
current_candle['close'] = tick['price']
current_candle['volume'] += tick['quantity']
self.candles[-1] = current_candle # Reassign to trigger deque update
async def fetch_and_process_ticks(self):
async with self.lock:
since = None if not self.ticks else self.ticks[-1]['timestamp']
try:
# Use fetch_trades (or appropriate method for your exchange) for live ticks.
ticks = await self.exchange.fetch_trades(self.symbol, since=since)
for tick in ticks:
formatted_tick = self._format_tick(tick)
if formatted_tick: # Add the check here
self.ticks.append(formatted_tick)
await self._update_candle(formatted_tick)
except Exception as e:
print(f"Error fetching ticks: {e}")
async def get_data(self):
async with self.lock:
candles_copy = list(self.candles).copy()
ticks_copy = list(self.ticks).copy()
return candles_copy, ticks_copy
async def close(self):
await self.exchange.close()
async def main():
symbol = 'BTC/USDT'
manager = LiveDataManager(symbol)
await manager._fetch_initial_candles()
async def print_data():
while True:
await manager.fetch_and_process_ticks() # Fetch new ticks continuously
candles, ticks = await manager.get_data()
if candles:
print("Last Candle:", candles[-1])
if ticks:
print("Last Tick:", ticks[-1])
await asyncio.sleep(1) # Print every second
try:
await print_data() # Run the printing task
except KeyboardInterrupt:
print("Closing connection...")
finally:
await manager.close()
if __name__ == '__main__':
asyncio.run(main())