implementations
This commit is contained in:
@ -121,6 +121,44 @@ def create_padding_mask(seq, pad_token=0):
|
||||
"""
|
||||
return (seq == pad_token).all(dim=-1).unsqueeze(0)
|
||||
|
||||
def get_aligned_candle_with_index(candles_list, base_ts):
|
||||
"""
|
||||
Find the candle from candles_list that is closest to (and <=) base_ts.
|
||||
Returns: (index, candle)
|
||||
"""
|
||||
aligned_index = None
|
||||
aligned_candle = None
|
||||
for i in range(len(candles_list)):
|
||||
if candles_list[i]["timestamp"] <= base_ts:
|
||||
aligned_index = i
|
||||
aligned_candle = candles_list[i]
|
||||
else:
|
||||
break
|
||||
return aligned_index, aligned_candle
|
||||
|
||||
def get_features_for_tf(candles_list, aligned_index, period=10):
|
||||
"""
|
||||
Extract features from the candle at aligned_index.
|
||||
If aligned_index is None, return a zeroed feature vector.
|
||||
"""
|
||||
if aligned_index is None:
|
||||
return [0.0] * 7 # return zeroed feature vector
|
||||
candle = candles_list[aligned_index]
|
||||
# Simple features: open, high, low, close, volume, and two EMAs.
|
||||
close_prices = [c["close"] for c in candles_list[:aligned_index+1]]
|
||||
ema_short = calculate_ema(candles_list[:aligned_index+1], period=period)[-1]
|
||||
ema_long = calculate_ema(candles_list[:aligned_index+1], period=period*2)[-1]
|
||||
features = [
|
||||
candle["open"],
|
||||
candle["high"],
|
||||
candle["low"],
|
||||
candle["close"],
|
||||
candle["volume"],
|
||||
ema_short,
|
||||
ema_long
|
||||
]
|
||||
return features
|
||||
|
||||
# Example usage (within a larger training loop):
|
||||
if __name__ == '__main__':
|
||||
# Dummy data for demonstration
|
||||
@ -155,4 +193,14 @@ if __name__ == '__main__':
|
||||
mask = create_mask(seq_len)
|
||||
print("\nMask:\n", mask)
|
||||
padding_mask = create_padding_mask(torch.tensor(candle_features))
|
||||
print(f"\nPadding mask: {padding_mask}")
|
||||
print(f"\nPadding mask: {padding_mask}")
|
||||
|
||||
# Example usage of the new functions
|
||||
index, candle = get_aligned_candle_with_index(candles_data, 1678886570000)
|
||||
if candle:
|
||||
print(f"\nAligned candle: {candle}")
|
||||
else:
|
||||
print("\nNo aligned candle found.")
|
||||
|
||||
features = get_features_for_tf(candles_data, index)
|
||||
print(f"\nFeatures for timeframe: {features}")
|
||||
|
@ -7,7 +7,7 @@ from collections import deque
|
||||
|
||||
import ccxt.async_support as ccxt
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import platform
|
||||
|
||||
class LiveDataManager:
|
||||
def __init__(self, symbol, exchange_name='mexc', window_size=120):
|
||||
@ -20,6 +20,7 @@ class LiveDataManager:
|
||||
self.last_candle_time = None
|
||||
self.exchange = self._initialize_exchange()
|
||||
self.lock = asyncio.Lock() # Lock to prevent race conditions
|
||||
self.is_windows = platform.system() == 'Windows'
|
||||
|
||||
def _initialize_exchange(self):
|
||||
exchange_class = getattr(ccxt, self.exchange_name)
|
||||
@ -41,15 +42,23 @@ class LiveDataManager:
|
||||
print(f"Fetching initial candles for {self.symbol}...")
|
||||
now = int(time.time() * 1000)
|
||||
since = now - self.window_size * 60 * 1000
|
||||
try:
|
||||
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
|
||||
for candle in candles:
|
||||
self.candles.append(self._format_candle(candle))
|
||||
if candles:
|
||||
self.last_candle_time = candles[-1][0]
|
||||
print(f"Fetched {len(candles)} initial candles.")
|
||||
except Exception as e:
|
||||
print(f"Error fetching initial candles: {e}")
|
||||
retries = 3
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
|
||||
for candle in candles:
|
||||
self.candles.append(self._format_candle(candle))
|
||||
if candles:
|
||||
self.last_candle_time = candles[-1][0]
|
||||
print(f"Fetched {len(candles)} initial candles.")
|
||||
return # Exit the function if successful
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1} failed: {e}")
|
||||
if self.is_windows and "aiodns needs a SelectorEventLoop" in str(e):
|
||||
print("aiodns issue detected on Windows. This is a known problem with aiodns and ccxt on Windows.")
|
||||
if attempt < retries - 1:
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
print("Failed to fetch initial candles after multiple retries.")
|
||||
|
||||
def _format_candle(self, candle_data):
|
||||
return {
|
||||
@ -112,16 +121,23 @@ class LiveDataManager:
|
||||
async def fetch_and_process_ticks(self):
|
||||
async with self.lock:
|
||||
since = None if not self.ticks else self.ticks[-1]['timestamp']
|
||||
try:
|
||||
# Use fetch_trades (or appropriate method for your exchange) for live ticks.
|
||||
ticks = await self.exchange.fetch_trades(self.symbol, since=since)
|
||||
for tick in ticks:
|
||||
formatted_tick = self._format_tick(tick)
|
||||
if formatted_tick: # Add the check here
|
||||
self.ticks.append(formatted_tick)
|
||||
await self._update_candle(formatted_tick)
|
||||
except Exception as e:
|
||||
print(f"Error fetching ticks: {e}")
|
||||
retries = 3
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
# Use fetch_trades (or appropriate method for your exchange) for live ticks.
|
||||
ticks = await self.exchange.fetch_trades(self.symbol, since=since)
|
||||
for tick in ticks:
|
||||
formatted_tick = self._format_tick(tick)
|
||||
if formatted_tick: # Add the check here
|
||||
self.ticks.append(formatted_tick)
|
||||
await self._update_candle(formatted_tick)
|
||||
break # Exit the retry loop if successful
|
||||
except Exception as e:
|
||||
print(f"Error fetching ticks (attempt {attempt + 1}): {e}")
|
||||
if self.is_windows and "aiodns needs a SelectorEventLoop" in str(e):
|
||||
print("aiodns issue detected on Windows. This is a known problem with aiodns and ccxt on Windows.")
|
||||
if attempt < retries - 1:
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
|
||||
async def get_data(self):
|
||||
async with self.lock:
|
||||
|
Reference in New Issue
Block a user