init gogo
This commit is contained in:
158
crypto/gogo/data/data_utils.py
Normal file
158
crypto/gogo/data/data_utils.py
Normal file
@ -0,0 +1,158 @@
|
||||
# data/data_utils.py
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import deque
|
||||
|
||||
def calculate_ema(data, period):
|
||||
"""Calculates EMA for a given data series and period."""
|
||||
if len(data) < period:
|
||||
return [np.nan] * len(data) # Return NaN for insufficient data
|
||||
|
||||
close_prices = np.array([candle['close'] for candle in data])
|
||||
ema = [close_prices[0]] # Initialize EMA with the first close price
|
||||
|
||||
multiplier = 2 / (period + 1)
|
||||
for i in range(1, len(close_prices)):
|
||||
ema_value = (close_prices[i] - ema[-1]) * multiplier + ema[-1]
|
||||
ema.append(ema_value)
|
||||
return ema
|
||||
|
||||
def preprocess_data(candles, ticks, ema_periods=[5, 10, 20, 60, 120, 200]):
|
||||
"""Preprocesses candles and ticks for the transformer.
|
||||
|
||||
Args:
|
||||
candles: List of candle dictionaries.
|
||||
ticks: List of tick dictionaries.
|
||||
ema_periods: List of periods for EMA calculation.
|
||||
|
||||
Returns:
|
||||
Tuple: (candle_features, tick_features, future_candle, future_volume, future_ticks)
|
||||
"""
|
||||
if not candles or len(candles) < 2: # Need at least 2 candles for current and future
|
||||
return None, None, None, None, None
|
||||
|
||||
# --- Calculate EMAs ---
|
||||
emas = {}
|
||||
for period in ema_periods:
|
||||
emas[period] = calculate_ema(candles, period)
|
||||
|
||||
|
||||
# --- Prepare Candle Features ---
|
||||
candle_features = []
|
||||
for i, candle in enumerate(candles[:-1]): # Exclude the last candle (used for future)
|
||||
features = [
|
||||
candle['open'],
|
||||
candle['high'],
|
||||
candle['low'],
|
||||
candle['close'],
|
||||
candle['volume'],
|
||||
]
|
||||
for period in ema_periods:
|
||||
features.append(emas[period][i])
|
||||
candle_features.append(features)
|
||||
|
||||
# --- Prepare Tick Features (Last 30 seconds before next candle) ---
|
||||
last_candle_timestamp = candles[-2]['timestamp']
|
||||
thirty_sec_ago = last_candle_timestamp - 30 * 1000
|
||||
relevant_ticks = [tick for tick in ticks if tick['timestamp'] > thirty_sec_ago and tick['timestamp']<= last_candle_timestamp]
|
||||
|
||||
tick_features = []
|
||||
# Pad or truncate tick data to a fixed length (e.g., 30 ticks, 1 tick/second)
|
||||
for i in range(30):
|
||||
if i < len(relevant_ticks):
|
||||
tick_features.extend([relevant_ticks[i]['price'], relevant_ticks[i]['quantity']])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0]) # Padding with 0s
|
||||
|
||||
|
||||
# --- Prepare Future Data (Targets) ---
|
||||
future_candle = [
|
||||
candles[-1]['open'],
|
||||
candles[-1]['high'],
|
||||
candles[-1]['low'],
|
||||
candles[-1]['close'],
|
||||
candles[-1]['volume'],
|
||||
]
|
||||
|
||||
# --- Future Volume (5-min) ---
|
||||
future_volume = 0.0 # we don't know it yet.
|
||||
#future_volume = calculate_volume_for_next_n_minutes(candles, n=5)
|
||||
|
||||
# --- Future Ticks (Next 30 seconds, for masking) ---
|
||||
next_candle_timestamp = candles[-1]['timestamp']
|
||||
future_ticks_end_time = next_candle_timestamp + 30 * 1000
|
||||
future_ticks_data = [tick for tick in ticks if tick['timestamp'] > next_candle_timestamp and tick['timestamp'] <= future_ticks_end_time ]
|
||||
future_ticks = []
|
||||
for i in range(30):
|
||||
if i < len(future_ticks_data):
|
||||
future_ticks.extend([future_ticks_data[i]['price'], future_ticks_data[i]['quantity']])
|
||||
else:
|
||||
future_ticks.extend([0.0, 0.0])
|
||||
|
||||
return (np.array(candle_features, dtype=np.float32),
|
||||
np.array(tick_features, dtype=np.float32),
|
||||
np.array(future_candle, dtype=np.float32),
|
||||
np.array(future_volume, dtype=np.float32),
|
||||
np.array(future_ticks, dtype=np.float32)
|
||||
)
|
||||
|
||||
def create_mask(seq_len, future_mask=True):
|
||||
"""Creates a mask for the input sequence.
|
||||
|
||||
Args:
|
||||
seq_len: The length of the sequence.
|
||||
future_mask: Whether to mask the future tokens.
|
||||
|
||||
Returns:
|
||||
A mask tensor of shape (seq_len, seq_len).
|
||||
"""
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||
if future_mask:
|
||||
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
return mask
|
||||
|
||||
def create_padding_mask(seq, pad_token=0):
|
||||
"""
|
||||
Creates a padding mask.
|
||||
Args:
|
||||
seq: sequence tensor
|
||||
pad_token: padding token, default 0.
|
||||
Returns: padding mask, (seq_len, seq_len)
|
||||
"""
|
||||
return (seq == pad_token).all(dim=-1).unsqueeze(0)
|
||||
|
||||
# Example usage (within a larger training loop):
|
||||
if __name__ == '__main__':
|
||||
# Dummy data for demonstration
|
||||
candles_data = [
|
||||
{'timestamp': 1678886400000, 'open': 25000.0, 'high': 25050.0, 'low': 24950.0, 'close': 25025.0, 'volume': 100.0},
|
||||
{'timestamp': 1678886460000, 'open': 25025.0, 'high': 25100.0, 'low': 25000.0, 'close': 25075.0, 'volume': 120.0},
|
||||
{'timestamp': 1678886520000, 'open': 25075.0, 'high': 25150.0, 'low': 25050.0, 'close': 25125.0, 'volume': 150.0},
|
||||
{'timestamp': 1678886580000, 'open': 25125.0, 'high': 25200.0, 'low': 25100.0, 'close': 25175.0, 'volume': 180.0},
|
||||
{'timestamp': 1678886640000, 'open': 25175.0, 'high': 25250.0, 'low': 25150.0, 'close': 25225.0, 'volume': 200.0},
|
||||
]
|
||||
ticks_data = [
|
||||
{'timestamp': 1678886455000, 'symbol': 'BTC/USDT', 'price': 25020.0, 'quantity': 0.1},
|
||||
{'timestamp': 1678886458000, 'symbol': 'BTC/USDT', 'price': 25022.0, 'quantity': 0.2},
|
||||
{'timestamp': 1678886515000, 'symbol': 'BTC/USDT', 'price': 25070.0, 'quantity': 0.3},
|
||||
{'timestamp': 1678886518000, 'symbol': 'BTC/USDT', 'price': 25078.0, 'quantity': 0.1},
|
||||
{'timestamp': 1678886575000, 'symbol': 'BTC/USDT', 'price': 25120.0, 'quantity': 0.2},
|
||||
{'timestamp': 1678886578000, 'symbol': 'BTC/USDT', 'price': 25122.0, 'quantity': 0.1},
|
||||
{'timestamp': 1678886635000, 'symbol': 'BTC/USDT', 'price': 25170.0, 'quantity': 0.4},
|
||||
{'timestamp': 1678886638000, 'symbol': 'BTC/USDT', 'price': 25172.0, 'quantity': 0.2},
|
||||
]
|
||||
|
||||
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles_data, ticks_data)
|
||||
|
||||
print("Candle Features:\n", candle_features)
|
||||
print("\nTick Features:\n", tick_features)
|
||||
print("\nFuture Candle:\n", future_candle)
|
||||
print("\nFuture Volume:\n", future_volume)
|
||||
print("\nFuture Ticks\n", future_ticks)
|
||||
|
||||
# Example mask creation
|
||||
seq_len = len(candle_features) # Example sequence length
|
||||
mask = create_mask(seq_len)
|
||||
print("\nMask:\n", mask)
|
||||
padding_mask = create_padding_mask(torch.tensor(candle_features))
|
||||
print(f"\nPadding mask: {padding_mask}")
|
158
crypto/gogo/data/live_data.py
Normal file
158
crypto/gogo/data/live_data.py
Normal file
@ -0,0 +1,158 @@
|
||||
# data/live_data.py
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
class LiveDataManager:
|
||||
def __init__(self, symbol, exchange_name='mexc', window_size=120):
|
||||
load_dotenv() # Load environment variables
|
||||
self.symbol = symbol
|
||||
self.exchange_name = exchange_name
|
||||
self.window_size = window_size
|
||||
self.candles = deque(maxlen=window_size)
|
||||
self.ticks = deque(maxlen=window_size * 60) # Assuming max 60 ticks per minute
|
||||
self.last_candle_time = None
|
||||
self.exchange = self._initialize_exchange()
|
||||
self.lock = asyncio.Lock() # Lock to prevent race conditions
|
||||
|
||||
def _initialize_exchange(self):
|
||||
exchange_class = getattr(ccxt, self.exchange_name)
|
||||
mexc_api_key = os.environ.get('MEXC_API_KEY')
|
||||
mexc_api_secret = os.environ.get('MEXC_API_SECRET')
|
||||
|
||||
if not mexc_api_key or not mexc_api_secret:
|
||||
raise ValueError("API keys not found in environment variables. Please check your .env file.")
|
||||
|
||||
return exchange_class({
|
||||
'apiKey': mexc_api_key,
|
||||
'secret': mexc_api_secret,
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
|
||||
async def _fetch_initial_candles(self):
|
||||
print(f"Fetching initial candles for {self.symbol}...")
|
||||
now = int(time.time() * 1000)
|
||||
since = now - self.window_size * 60 * 1000
|
||||
try:
|
||||
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
|
||||
for candle in candles:
|
||||
self.candles.append(self._format_candle(candle))
|
||||
if candles:
|
||||
self.last_candle_time = candles[-1][0]
|
||||
print(f"Fetched {len(candles)} initial candles.")
|
||||
except Exception as e:
|
||||
print(f"Error fetching initial candles: {e}")
|
||||
|
||||
def _format_candle(self, candle_data):
|
||||
return {
|
||||
'timestamp': candle_data[0],
|
||||
'open': float(candle_data[1]),
|
||||
'high': float(candle_data[2]),
|
||||
'low': float(candle_data[3]),
|
||||
'close': float(candle_data[4]),
|
||||
'volume': float(candle_data[5])
|
||||
}
|
||||
|
||||
def _format_tick(self, tick_data):
|
||||
# Check if 's' (symbol) is present, otherwise return None
|
||||
if 's' not in tick_data:
|
||||
return None
|
||||
|
||||
return {
|
||||
'timestamp': tick_data['E'],
|
||||
'symbol': tick_data['s'],
|
||||
'price': float(tick_data['p']),
|
||||
'quantity': float(tick_data['q'])
|
||||
}
|
||||
|
||||
async def _update_candle(self, tick):
|
||||
async with self.lock:
|
||||
if self.last_candle_time is None: # first time
|
||||
self.last_candle_time = tick['timestamp'] - (tick['timestamp'] % (60 * 1000))
|
||||
new_candle = {
|
||||
'timestamp': self.last_candle_time,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['quantity']
|
||||
}
|
||||
self.candles.append(new_candle)
|
||||
|
||||
if tick['timestamp'] >= self.last_candle_time + 60 * 1000:
|
||||
# Start a new candle
|
||||
self.last_candle_time += 60 * 1000
|
||||
new_candle = {
|
||||
'timestamp': self.last_candle_time,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['quantity']
|
||||
}
|
||||
self.candles.append(new_candle)
|
||||
|
||||
else:
|
||||
# Update the current candle
|
||||
current_candle = self.candles[-1]
|
||||
current_candle['high'] = max(current_candle['high'], tick['price'])
|
||||
current_candle['low'] = min(current_candle['low'], tick['price'])
|
||||
current_candle['close'] = tick['price']
|
||||
current_candle['volume'] += tick['quantity']
|
||||
self.candles[-1] = current_candle # Reassign to trigger deque update
|
||||
|
||||
async def fetch_and_process_ticks(self):
|
||||
async with self.lock:
|
||||
since = None if not self.ticks else self.ticks[-1]['timestamp']
|
||||
try:
|
||||
# Use fetch_trades (or appropriate method for your exchange) for live ticks.
|
||||
ticks = await self.exchange.fetch_trades(self.symbol, since=since)
|
||||
for tick in ticks:
|
||||
formatted_tick = self._format_tick(tick)
|
||||
if formatted_tick: # Add the check here
|
||||
self.ticks.append(formatted_tick)
|
||||
await self._update_candle(formatted_tick)
|
||||
except Exception as e:
|
||||
print(f"Error fetching ticks: {e}")
|
||||
|
||||
async def get_data(self):
|
||||
async with self.lock:
|
||||
candles_copy = list(self.candles).copy()
|
||||
ticks_copy = list(self.ticks).copy()
|
||||
return candles_copy, ticks_copy
|
||||
|
||||
async def close(self):
|
||||
await self.exchange.close()
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
symbol = 'BTC/USDT'
|
||||
manager = LiveDataManager(symbol)
|
||||
await manager._fetch_initial_candles()
|
||||
|
||||
async def print_data():
|
||||
while True:
|
||||
await manager.fetch_and_process_ticks() # Fetch new ticks continuously
|
||||
candles, ticks = await manager.get_data()
|
||||
if candles:
|
||||
print("Last Candle:", candles[-1])
|
||||
if ticks:
|
||||
print("Last Tick:", ticks[-1])
|
||||
await asyncio.sleep(1) # Print every second
|
||||
|
||||
try:
|
||||
await print_data() # Run the printing task
|
||||
except KeyboardInterrupt:
|
||||
print("Closing connection...")
|
||||
finally:
|
||||
await manager.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user