listen to all IPs

This commit is contained in:
Dobromir Popov
2025-12-08 21:36:07 +02:00
parent 81a7f27d2d
commit 1ab1c02889
3 changed files with 82 additions and 8 deletions

View File

@@ -599,10 +599,11 @@ class AdvancedTradingTransformer(nn.Module):
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
# Apply single cross-timeframe attention layer
batched_tfs = self.cross_timeframe_layer(batched_tfs)
# Use new variable to avoid inplace modification issues
cross_tf_encoded = self.cross_timeframe_layer(batched_tfs)
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
# Average across timeframes to get unified representation
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
@@ -1346,6 +1347,10 @@ class TradingTransformerTrainer:
for param in self.model.parameters():
if param.grad is not None:
param.grad = None
# Clear CUDA cache to prevent tensor version conflicts
if torch.cuda.is_available():
torch.cuda.empty_cache()
# OPTIMIZATION: Only move batch to device if not already there
# Check if first tensor is already on correct device