gogo2/crypto/gogo/model/transformer.py
Dobromir Popov 5606ed3cab init gogo
2025-02-12 01:15:44 +02:00

222 lines
9.7 KiB
Python

# model/transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
x = x + self.pe[:, :x.size(1)]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Masking
attn_probs = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
return output, attn_probs
def split_heads(self, x):
batch_size = x.size(0)
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
def combine_heads(self, x):
batch_size = x.size(0)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return x
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
output = self.W_o(self.combine_heads(attn_output))
return output, attn_probs
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Encoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, enc_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, dropout=0.1):
super(Transformer, self).__init__()
self.candle_embedding = nn.Linear(input_dim, d_model)
self.tick_embedding = nn.Linear(2, d_model) # Each tick has price and quantity
self.positional_encoding = PositionalEncoding(d_model)
self.encoder = Encoder(num_layers, d_model, num_heads, d_ff, dropout)
# Decoder for future candle
self.future_candle_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_candle_projection = nn.Linear(d_model, 5) # Output 5 values: O, H, L, C, V
# Decoder for future volume
self.future_volume_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_volume_projection = nn.Linear(d_model, 1)
# Decoder for future ticks
self.future_ticks_decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout)
self.future_ticks_projection = nn.Linear(d_model, 60) # 30 ticks * (price, quantity) = 60
def forward(self, candle_data, tick_data, future_candle_mask, future_ticks_mask):
# candle_data: [batch_size, seq_len, input_dim]
# tick_data: [batch_size, tick_seq_len, 2]
candle_embedded = self.candle_embedding(candle_data)
candle_embedded = self.positional_encoding(candle_embedded) # Add positional info
tick_embedded = self.tick_embedding(tick_data)
tick_embedded = self.positional_encoding(tick_embedded)
# Concatenate candle and tick embeddings
# We can concatenate along the sequence length dimension
combined_input = torch.cat((candle_embedded, tick_embedded), dim=1)
# The combined mask will also be needed
combined_mask = torch.cat((future_candle_mask, torch.ones(tick_embedded.shape[0], tick_embedded.shape[1], tick_embedded.shape[1]).to(candle_data.device)), dim = -1)
enc_output = self.encoder(combined_input, combined_mask)
# --- Future Candle Prediction ---
future_candle_input = torch.zeros_like(candle_embedded[:, -1:, :]) # Start with zeros
future_candle_output = self.future_candle_decoder(future_candle_input, enc_output, combined_mask, None) # No target mask for prediction
future_candle_pred = self.future_candle_projection(future_candle_output)
# --- Future Volume Prediction ---
future_volume_input = torch.zeros_like(candle_embedded[:, -1:, :]) #start with zeros
future_volume_output = self.future_volume_decoder(future_volume_input, enc_output, combined_mask, None)
future_volume_pred = self.future_volume_projection(future_volume_output)
# --- Future Ticks Prediction ---
future_ticks_input = torch.zeros_like(tick_embedded)
future_ticks_output = self.future_ticks_decoder(future_ticks_input, enc_output, combined_mask, future_ticks_mask)
future_ticks_pred = self.future_ticks_projection(future_ticks_output)
return future_candle_pred, future_volume_pred, future_ticks_pred
# Example instantiation (adjust parameters for ~1B parameters)
if __name__ == '__main__':
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
d_model = 512 # Hidden dimension
num_heads = 8
num_layers = 6 # Number of encoder/decoder layers
d_ff = 2048 # Feedforward dimension
dropout = 0.1
# Calculate approximate parameter count
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
num_params = count_parameters(model)
print(f"Number of parameters: {num_params:,}") # Formatted with commas
# --- Dummy Input Data for Testing ---
batch_size = 2
candle_seq_len = 119 # We have 119 past candles and predict the 120th
tick_seq_len = 30 * 2 # 30 ticks, each with price and quantity.
candle_data = torch.randn(batch_size, candle_seq_len, input_dim)
tick_data = torch.randn(batch_size, tick_seq_len // 2, 2) # tick sequence
future_candle_mask = create_mask(candle_seq_len)
future_ticks_mask = create_mask(tick_seq_len //2 , future_mask=True)
# --- Forward Pass ---
future_candle_pred, future_volume_pred, future_ticks_pred = model(
candle_data, tick_data, future_candle_mask, future_ticks_mask
)
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]