listen to all IPs
This commit is contained in:
@@ -599,10 +599,11 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Apply single cross-timeframe attention layer
|
||||
batched_tfs = self.cross_timeframe_layer(batched_tfs)
|
||||
# Use new variable to avoid inplace modification issues
|
||||
cross_tf_encoded = self.cross_timeframe_layer(batched_tfs)
|
||||
|
||||
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Average across timeframes to get unified representation
|
||||
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
||||
@@ -1346,6 +1347,10 @@ class TradingTransformerTrainer:
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
|
||||
# Clear CUDA cache to prevent tensor version conflicts
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# OPTIMIZATION: Only move batch to device if not already there
|
||||
# Check if first tensor is already on correct device
|
||||
|
||||
Reference in New Issue
Block a user