init gogo

This commit is contained in:
Dobromir Popov 2025-02-12 01:15:44 +02:00
parent 6dfeee18bf
commit 5606ed3cab
11 changed files with 822 additions and 8 deletions

File diff suppressed because one or more lines are too long

0
crypto/brian/pt.py Normal file
View File

View File

@ -1,7 +1,7 @@
ccxt==4.1.97
numpy==1.26.3
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
matplotlib==3.8.2
python-dotenv==1.0.0
ccxt
numpy
torch
torchaudio
torchvision
matplotlib
python-dotenv

View File

@ -0,0 +1,158 @@
# data/data_utils.py
import numpy as np
import torch
from collections import deque
def calculate_ema(data, period):
"""Calculates EMA for a given data series and period."""
if len(data) < period:
return [np.nan] * len(data) # Return NaN for insufficient data
close_prices = np.array([candle['close'] for candle in data])
ema = [close_prices[0]] # Initialize EMA with the first close price
multiplier = 2 / (period + 1)
for i in range(1, len(close_prices)):
ema_value = (close_prices[i] - ema[-1]) * multiplier + ema[-1]
ema.append(ema_value)
return ema
def preprocess_data(candles, ticks, ema_periods=[5, 10, 20, 60, 120, 200]):
"""Preprocesses candles and ticks for the transformer.
Args:
candles: List of candle dictionaries.
ticks: List of tick dictionaries.
ema_periods: List of periods for EMA calculation.
Returns:
Tuple: (candle_features, tick_features, future_candle, future_volume, future_ticks)
"""
if not candles or len(candles) < 2: # Need at least 2 candles for current and future
return None, None, None, None, None
# --- Calculate EMAs ---
emas = {}
for period in ema_periods:
emas[period] = calculate_ema(candles, period)
# --- Prepare Candle Features ---
candle_features = []
for i, candle in enumerate(candles[:-1]): # Exclude the last candle (used for future)
features = [
candle['open'],
candle['high'],
candle['low'],
candle['close'],
candle['volume'],
]
for period in ema_periods:
features.append(emas[period][i])
candle_features.append(features)
# --- Prepare Tick Features (Last 30 seconds before next candle) ---
last_candle_timestamp = candles[-2]['timestamp']
thirty_sec_ago = last_candle_timestamp - 30 * 1000
relevant_ticks = [tick for tick in ticks if tick['timestamp'] > thirty_sec_ago and tick['timestamp']<= last_candle_timestamp]
tick_features = []
# Pad or truncate tick data to a fixed length (e.g., 30 ticks, 1 tick/second)
for i in range(30):
if i < len(relevant_ticks):
tick_features.extend([relevant_ticks[i]['price'], relevant_ticks[i]['quantity']])
else:
tick_features.extend([0.0, 0.0]) # Padding with 0s
# --- Prepare Future Data (Targets) ---
future_candle = [
candles[-1]['open'],
candles[-1]['high'],
candles[-1]['low'],
candles[-1]['close'],
candles[-1]['volume'],
]
# --- Future Volume (5-min) ---
future_volume = 0.0 # we don't know it yet.
#future_volume = calculate_volume_for_next_n_minutes(candles, n=5)
# --- Future Ticks (Next 30 seconds, for masking) ---
next_candle_timestamp = candles[-1]['timestamp']
future_ticks_end_time = next_candle_timestamp + 30 * 1000
future_ticks_data = [tick for tick in ticks if tick['timestamp'] > next_candle_timestamp and tick['timestamp'] <= future_ticks_end_time ]
future_ticks = []
for i in range(30):
if i < len(future_ticks_data):
future_ticks.extend([future_ticks_data[i]['price'], future_ticks_data[i]['quantity']])
else:
future_ticks.extend([0.0, 0.0])
return (np.array(candle_features, dtype=np.float32),
np.array(tick_features, dtype=np.float32),
np.array(future_candle, dtype=np.float32),
np.array(future_volume, dtype=np.float32),
np.array(future_ticks, dtype=np.float32)
)
def create_mask(seq_len, future_mask=True):
"""Creates a mask for the input sequence.
Args:
seq_len: The length of the sequence.
future_mask: Whether to mask the future tokens.
Returns:
A mask tensor of shape (seq_len, seq_len).
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
if future_mask:
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_padding_mask(seq, pad_token=0):
"""
Creates a padding mask.
Args:
seq: sequence tensor
pad_token: padding token, default 0.
Returns: padding mask, (seq_len, seq_len)
"""
return (seq == pad_token).all(dim=-1).unsqueeze(0)
# Example usage (within a larger training loop):
if __name__ == '__main__':
# Dummy data for demonstration
candles_data = [
{'timestamp': 1678886400000, 'open': 25000.0, 'high': 25050.0, 'low': 24950.0, 'close': 25025.0, 'volume': 100.0},
{'timestamp': 1678886460000, 'open': 25025.0, 'high': 25100.0, 'low': 25000.0, 'close': 25075.0, 'volume': 120.0},
{'timestamp': 1678886520000, 'open': 25075.0, 'high': 25150.0, 'low': 25050.0, 'close': 25125.0, 'volume': 150.0},
{'timestamp': 1678886580000, 'open': 25125.0, 'high': 25200.0, 'low': 25100.0, 'close': 25175.0, 'volume': 180.0},
{'timestamp': 1678886640000, 'open': 25175.0, 'high': 25250.0, 'low': 25150.0, 'close': 25225.0, 'volume': 200.0},
]
ticks_data = [
{'timestamp': 1678886455000, 'symbol': 'BTC/USDT', 'price': 25020.0, 'quantity': 0.1},
{'timestamp': 1678886458000, 'symbol': 'BTC/USDT', 'price': 25022.0, 'quantity': 0.2},
{'timestamp': 1678886515000, 'symbol': 'BTC/USDT', 'price': 25070.0, 'quantity': 0.3},
{'timestamp': 1678886518000, 'symbol': 'BTC/USDT', 'price': 25078.0, 'quantity': 0.1},
{'timestamp': 1678886575000, 'symbol': 'BTC/USDT', 'price': 25120.0, 'quantity': 0.2},
{'timestamp': 1678886578000, 'symbol': 'BTC/USDT', 'price': 25122.0, 'quantity': 0.1},
{'timestamp': 1678886635000, 'symbol': 'BTC/USDT', 'price': 25170.0, 'quantity': 0.4},
{'timestamp': 1678886638000, 'symbol': 'BTC/USDT', 'price': 25172.0, 'quantity': 0.2},
]
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles_data, ticks_data)
print("Candle Features:\n", candle_features)
print("\nTick Features:\n", tick_features)
print("\nFuture Candle:\n", future_candle)
print("\nFuture Volume:\n", future_volume)
print("\nFuture Ticks\n", future_ticks)
# Example mask creation
seq_len = len(candle_features) # Example sequence length
mask = create_mask(seq_len)
print("\nMask:\n", mask)
padding_mask = create_padding_mask(torch.tensor(candle_features))
print(f"\nPadding mask: {padding_mask}")

View File

@ -0,0 +1,158 @@
# data/live_data.py
import asyncio
import json
import os
import time
from collections import deque
import ccxt.async_support as ccxt
from dotenv import load_dotenv
class LiveDataManager:
def __init__(self, symbol, exchange_name='mexc', window_size=120):
load_dotenv() # Load environment variables
self.symbol = symbol
self.exchange_name = exchange_name
self.window_size = window_size
self.candles = deque(maxlen=window_size)
self.ticks = deque(maxlen=window_size * 60) # Assuming max 60 ticks per minute
self.last_candle_time = None
self.exchange = self._initialize_exchange()
self.lock = asyncio.Lock() # Lock to prevent race conditions
def _initialize_exchange(self):
exchange_class = getattr(ccxt, self.exchange_name)
mexc_api_key = os.environ.get('MEXC_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET')
if not mexc_api_key or not mexc_api_secret:
raise ValueError("API keys not found in environment variables. Please check your .env file.")
return exchange_class({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
async def _fetch_initial_candles(self):
print(f"Fetching initial candles for {self.symbol}...")
now = int(time.time() * 1000)
since = now - self.window_size * 60 * 1000
try:
candles = await self.exchange.fetch_ohlcv(self.symbol, '1m', since=since, limit=self.window_size)
for candle in candles:
self.candles.append(self._format_candle(candle))
if candles:
self.last_candle_time = candles[-1][0]
print(f"Fetched {len(candles)} initial candles.")
except Exception as e:
print(f"Error fetching initial candles: {e}")
def _format_candle(self, candle_data):
return {
'timestamp': candle_data[0],
'open': float(candle_data[1]),
'high': float(candle_data[2]),
'low': float(candle_data[3]),
'close': float(candle_data[4]),
'volume': float(candle_data[5])
}
def _format_tick(self, tick_data):
# Check if 's' (symbol) is present, otherwise return None
if 's' not in tick_data:
return None
return {
'timestamp': tick_data['E'],
'symbol': tick_data['s'],
'price': float(tick_data['p']),
'quantity': float(tick_data['q'])
}
async def _update_candle(self, tick):
async with self.lock:
if self.last_candle_time is None: # first time
self.last_candle_time = tick['timestamp'] - (tick['timestamp'] % (60 * 1000))
new_candle = {
'timestamp': self.last_candle_time,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['quantity']
}
self.candles.append(new_candle)
if tick['timestamp'] >= self.last_candle_time + 60 * 1000:
# Start a new candle
self.last_candle_time += 60 * 1000
new_candle = {
'timestamp': self.last_candle_time,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['quantity']
}
self.candles.append(new_candle)
else:
# Update the current candle
current_candle = self.candles[-1]
current_candle['high'] = max(current_candle['high'], tick['price'])
current_candle['low'] = min(current_candle['low'], tick['price'])
current_candle['close'] = tick['price']
current_candle['volume'] += tick['quantity']
self.candles[-1] = current_candle # Reassign to trigger deque update
async def fetch_and_process_ticks(self):
async with self.lock:
since = None if not self.ticks else self.ticks[-1]['timestamp']
try:
# Use fetch_trades (or appropriate method for your exchange) for live ticks.
ticks = await self.exchange.fetch_trades(self.symbol, since=since)
for tick in ticks:
formatted_tick = self._format_tick(tick)
if formatted_tick: # Add the check here
self.ticks.append(formatted_tick)
await self._update_candle(formatted_tick)
except Exception as e:
print(f"Error fetching ticks: {e}")
async def get_data(self):
async with self.lock:
candles_copy = list(self.candles).copy()
ticks_copy = list(self.ticks).copy()
return candles_copy, ticks_copy
async def close(self):
await self.exchange.close()
async def main():
symbol = 'BTC/USDT'
manager = LiveDataManager(symbol)
await manager._fetch_initial_candles()
async def print_data():
while True:
await manager.fetch_and_process_ticks() # Fetch new ticks continuously
candles, ticks = await manager.get_data()
if candles:
print("Last Candle:", candles[-1])
if ticks:
print("Last Tick:", ticks[-1])
await asyncio.sleep(1) # Print every second
try:
await print_data() # Run the printing task
except KeyboardInterrupt:
print("Closing connection...")
finally:
await manager.close()
if __name__ == '__main__':
asyncio.run(main())

47
crypto/gogo/main.py Normal file
View File

@ -0,0 +1,47 @@
# main.py
import asyncio
import torch
import torch.nn as nn
import torch.optim as optim
from data.live_data import LiveDataManager
from model.transformer import Transformer
from training.train import train
from data.data_utils import preprocess_data # Import preprocess_data
async def main():
symbol = 'BTC/USDT'
data_manager = LiveDataManager(symbol)
# Model parameters (adjust for ~1B parameters)
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define loss functions
criterion_candle = nn.MSELoss()
criterion_volume = nn.MSELoss() # Consider a different loss for volume if needed
criterion_ticks = nn.MSELoss()
# Check for CUDA availability and set device
if torch.cuda.is_available():
device = torch.device('cuda')
print("Using CUDA")
else:
device = torch.device('cpu')
print("Using CPU")
try:
await train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device=device)
except KeyboardInterrupt:
print("Training stopped manually.")
finally:
await data_manager.close()
if __name__ == '__main__':
asyncio.run(main())

View File

@ -0,0 +1,222 @@
# model/transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
x = x + self.pe[:, :x.size(1)]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Masking
attn_probs = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
return output, attn_probs
def split_heads(self, x):
batch_size = x.size(0)
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
def combine_heads(self, x):
batch_size = x.size(0)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return x
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
output = self.W_o(self.combine_heads(attn_output))
return output, attn_probs
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Encoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, enc_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, dropout=0.1):
super(Transformer, self).__init__()
self.candle_embedding = nn.Linear(input_dim, d_model)
self.tick_embedding = nn.Linear(2, d_model) # Each tick has price and quantity
self.positional_encoding = PositionalEncoding(d_model)
self.encoder = Encoder(num_layers, d_model, num_heads, d_ff, dropout)
# Decoder for future candle
self.future_candle_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_candle_projection = nn.Linear(d_model, 5) # Output 5 values: O, H, L, C, V
# Decoder for future volume
self.future_volume_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_volume_projection = nn.Linear(d_model, 1)
# Decoder for future ticks
self.future_ticks_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_ticks_projection = nn.Linear(d_model, 60) # 30 ticks * (price, quantity) = 60
def forward(self, candle_data, tick_data, future_candle_mask, future_ticks_mask):
# candle_data: [batch_size, seq_len, input_dim]
# tick_data: [batch_size, tick_seq_len, 2]
candle_embedded = self.candle_embedding(candle_data)
candle_embedded = self.positional_encoding(candle_embedded) # Add positional info
tick_embedded = self.tick_embedding(tick_data)
tick_embedded = self.positional_encoding(tick_embedded)
# Concatenate candle and tick embeddings
# We can concatenate along the sequence length dimension
combined_input = torch.cat((candle_embedded, tick_embedded), dim=1)
# The combined mask will also be needed
combined_mask = torch.cat((future_candle_mask, torch.ones(tick_embedded.shape[0], tick_embedded.shape[1], tick_embedded.shape[1]).to(candle_data.device)), dim = -1)
enc_output = self.encoder(combined_input, combined_mask)
# --- Future Candle Prediction ---
future_candle_input = torch.zeros_like(candle_embedded[:, -1:, :]) # Start with zeros
future_candle_output = self.future_candle_decoder(future_candle_input, enc_output, combined_mask, None) # No target mask for prediction
future_candle_pred = self.future_candle_projection(future_candle_output)
# --- Future Volume Prediction ---
future_volume_input = torch.zeros_like(candle_embedded[:, -1:, :]) #start with zeros
future_volume_output = self.future_volume_decoder(future_volume_input, enc_output, combined_mask, None)
future_volume_pred = self.future_volume_projection(future_volume_output)
# --- Future Ticks Prediction ---
future_ticks_input = torch.zeros_like(tick_embedded)
future_ticks_output = self.future_ticks_decoder(future_ticks_input, enc_output, combined_mask, future_ticks_mask)
future_ticks_pred = self.future_ticks_projection(future_ticks_output)
return future_candle_pred, future_volume_pred, future_ticks_pred
# Example instantiation (adjust parameters for ~1B parameters)
if __name__ == '__main__':
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
d_model = 512 # Hidden dimension
num_heads = 8
num_layers = 6 # Number of encoder/decoder layers
d_ff = 2048 # Feedforward dimension
dropout = 0.1
# Calculate approximate parameter count
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
num_params = count_parameters(model)
print(f"Number of parameters: {num_params:,}") # Formatted with commas
# --- Dummy Input Data for Testing ---
batch_size = 2
candle_seq_len = 119 # We have 119 past candles and predict the 120th
tick_seq_len = 30 * 2 # 30 ticks, each with price and quantity.
candle_data = torch.randn(batch_size, candle_seq_len, input_dim)
tick_data = torch.randn(batch_size, tick_seq_len // 2, 2) # tick sequence
future_candle_mask = create_mask(candle_seq_len)
future_ticks_mask = create_mask(tick_seq_len //2 , future_mask=True)
# --- Forward Pass ---
future_candle_pred, future_volume_pred, future_ticks_pred = model(
candle_data, tick_data, future_candle_mask, future_ticks_mask
)
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]

31
crypto/gogo/readme.md Normal file
View File

@ -0,0 +1,31 @@
To run this code:
Install Dependencies: pip install -r requirements.txt
Set up .env: Create a .env file in your project root and add your MEXC API keys:
MEXC_API_KEY=your_api_key
MEXC_API_SECRET=your_api_secret
Use code with caution.
Run: python main.py
Important Considerations and Next Steps:
Hyperparameter Tuning: The provided hyperparameters are a starting point. You'll need to experiment with d_model, num_heads, num_layers, d_ff, learning rate, weight decay, and dropout to optimize performance. Consider using a hyperparameter optimization library like Optuna.
Loss Function Choices: MSE is used as a placeholder. For predicting price movements, you might consider using a loss function that focuses on the direction of the change (up or down) rather than just the magnitude. For volume, you might need a different loss function altogether.
Trading Strategy: The included "trading logic" is purely for demonstration. You'll need to develop a robust trading strategy with proper risk management, entry/exit criteria, and position sizing.
Data Normalization/Scaling: Normalize or scale your input features (candles and ticks) to improve training stability and performance. Common techniques include min-max scaling or standardization. This should be added to data_utils.py.
Evaluation Metrics: Track relevant metrics beyond just loss, such as Sharpe ratio, maximum drawdown, and win rate. This should be added to train.py and possibly a separate evaluation.py module.
Backtesting: Before deploying live, thoroughly backtest your model and strategy on historical data. This helps you assess its performance and identify potential weaknesses. This code trains and backtests, and you'd ideally separate those.
Overfitting: Monitor for overfitting (the model performing well on training data but poorly on new data). Techniques like dropout, weight decay, and early stopping can help mitigate overfitting.
Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples.
Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others.

View File

@ -0,0 +1,6 @@
asyncio
ccxt
python-dotenv
torch
numpy
matplotlib

View File

@ -0,0 +1,157 @@
# training/train.py
import torch
import torch.nn as nn
import torch.optim as optim
from data.data_utils import preprocess_data, create_mask
from model.transformer import Transformer
from data.live_data import LiveDataManager
from visualization.plotting import plot_live_data
import asyncio
import time
import os
from datetime import datetime
from collections import deque
# --- Directories for saving models ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
# -------------------------------------
# Checkpoint Functions (same as before)
# -------------------------------------
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, total_loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"total_loss": total_loss,
"model_state_dict": model.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_loss, min_file = min(best_models, key=lambda x: x[0])
if total_loss < min_loss:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{total_loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"total_loss": total_loss,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with loss {total_loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_loss, best_file = min(best_models, key=lambda x: x[0]) #changed to min to represent the loss
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
async def train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device='cpu'):
model.to(device)
model.train()
trade_history = deque(maxlen=100)
# Load best checkpoint if available.
load_best_checkpoint(model, BEST_DIR)
await data_manager._fetch_initial_candles()
for epoch in range(1, num_epochs + 1):
start_time = time.time()
total_loss = 0
while True: # Continuously train on live data
await data_manager.fetch_and_process_ticks()
candles, ticks = await data_manager.get_data()
if len(candles) < data_manager.window_size:
# print("Waiting for enough data...") # avoid to print too many lines
await asyncio.sleep(1) #wait and try again
continue
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles, ticks)
# Skip if preprocessing fails (e.g., not enough data)
if candle_features is None:
await asyncio.sleep(1)
continue
# Convert to PyTorch tensors and move to the correct device
candle_features = torch.tensor(candle_features).unsqueeze(0).to(device) # Add batch dimension
tick_features = torch.tensor(tick_features).unsqueeze(0).to(device)
future_candle = torch.tensor(future_candle).unsqueeze(0).to(device)
future_volume = torch.tensor(future_volume).unsqueeze(0).to(device)
future_ticks = torch.tensor(future_ticks).unsqueeze(0).to(device)
future_candle_mask = create_mask(candle_features.size(1)).to(device)
future_ticks_mask = create_mask(tick_features.size(1)).to(device)
optimizer.zero_grad()
future_candle_pred, future_volume_pred, future_ticks_pred = model(candle_features, tick_features, future_candle_mask, future_ticks_mask)
# Calculate Loss
loss_candle = criterion_candle(future_candle_pred.squeeze(1), future_candle)
loss_volume = criterion_volume(future_volume_pred.squeeze(1), future_volume) # Add .squeeze() here
loss_ticks = criterion_ticks(future_ticks_pred.squeeze(1), future_ticks)
# Combine losses (you can add weights to each loss component)
total_loss = loss_candle + loss_volume + loss_ticks
total_loss.backward()
optimizer.step()
print(f"Epoch: {epoch}, Candle Loss: {loss_candle.item():.4f}, Volume Loss: {loss_volume.item():.4f}, Tick Loss: {loss_ticks.item():.4f}, Total: {total_loss.item():.4f}")
# Save checkpoint
if epoch % 1 == 0: # every epoch
save_checkpoint(model, epoch, total_loss.item(), LAST_DIR, BEST_DIR)
# --- Basic Trading Logic (Illustrative) ---
# This is a very simplified example. In a real system, you would have
# much more sophisticated entry/exit logic, risk management, etc.
predicted_close = future_candle_pred[0, 0, 3].item() # Predicted close
current_close = candles[-1]['close']
if predicted_close > current_close * 1.005: # Example: Buy if predicted close is 0.5% higher
trade_history.append({"type": "buy", "price": current_close, "time": time.time()})
print(f"BUY signal at {current_close}")
elif predicted_close < current_close * 0.995: # Example: Sell if predicted close is 0.5% lower
trade_history.append({"type": "sell", "price": current_close, "time": time.time()})
print(f"SELL signal at {current_close}")
# Plot data
if len(trade_history)>0: # only after the first trade
plot_live_data(candles, list(trade_history))
await asyncio.sleep(1) # Adjust sleep time as needed

View File

@ -0,0 +1,35 @@
# visualization/plotting.py
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
from IPython.display import clear_output #to clear output each time we plot
def plot_live_data(candles, trade_history):
clear_output(wait=True) # Clear previous plot
# Extract data for plotting
times = [candle['timestamp'] for candle in candles]
close_prices = [candle['close'] for candle in candles]
# Create the figure and axes
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(times, close_prices, label="Close Price", color='blue')
# Plot trade signals
buy_times = [trade['time'] * 1000 for trade in trade_history if trade['type'] == 'buy']
buy_prices = [trade['price'] for trade in trade_history if trade['type'] == 'buy']
sell_times = [trade['time'] * 1000 for trade in trade_history if trade['type'] == 'sell']
sell_prices = [trade['price'] for trade in trade_history if trade['type'] == 'sell']
ax.scatter(buy_times, buy_prices, color='green', marker='^', label='Buy')
ax.scatter(sell_times, sell_prices, color='red', marker='v', label='Sell')
# Format the plot
ax.set_xlabel('Time')
ax.set_ylabel('Price')
ax.set_title('Live Trading Data with Signals')
ax.legend()
ax.grid(True)
plt.tight_layout()
plt.show()