LR training wip

This commit is contained in:
Dobromir Popov
2025-12-08 21:52:26 +02:00
parent 1ab1c02889
commit 08ee2b6a3a
3 changed files with 61 additions and 31 deletions

View File

@@ -144,19 +144,23 @@ class DeepMultiScaleAttention(nn.Module):
batch_size, seq_len, _ = x.size()
scale_outputs = []
# Clone input to avoid inplace modification issues
x_input = x.clone()
for scale_proj in self.scale_projections:
# Apply enhanced temporal convolution for this scale
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2)
# Enhanced attention computation with deeper projections
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
# Use contiguous() before view() to ensure memory layout is correct
Q = scale_proj['query'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
K = scale_proj['key'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
V = scale_proj['value'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
# Transpose for attention computation
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
Q = Q.transpose(1, 2).contiguous() # (batch, n_heads, seq_len, head_dim)
K = K.transpose(1, 2).contiguous()
V = V.transpose(1, 2).contiguous()
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
@@ -293,24 +297,29 @@ class TradingTransformerLayer(nn.Module):
self.regime_detector = MarketRegimeDetector(config.d_model)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# CRITICAL: Clone input to avoid version conflicts during backpropagation
# This prevents "modified by an inplace operation" errors when the same
# batch is used across multiple epochs
x_residual = x.clone()
# Self-attention with residual connection
# Store residual before any operations to avoid version conflicts
if isinstance(self.attention, DeepMultiScaleAttention):
attn_output = self.attention(x, mask)
attn_output = self.attention(x_residual, mask)
else:
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
attn_output, _ = self.attention(x_residual, x_residual, x_residual, attn_mask=mask)
# Create new tensor for residual to avoid inplace modification tracking
x_new = self.norm1(x + self.dropout(attn_output))
x_new = self.norm1(x_residual + self.dropout(attn_output))
# Market regime adaptation
regime_probs = None
if hasattr(self, 'regime_detector'):
x_new, regime_probs = self.regime_detector(x_new)
# Feed-forward with residual connection
ff_output = self.feed_forward(x_new)
x_out = self.norm2(x_new + self.dropout(ff_output))
# Feed-forward with residual connection - clone to avoid version conflicts
x_ff_residual = x_new.clone()
ff_output = self.feed_forward(x_ff_residual)
x_out = self.norm2(x_ff_residual + self.dropout(ff_output))
return {
'output': x_out,
@@ -557,19 +566,23 @@ class AdvancedTradingTransformer(nn.Module):
timeframe_indices = []
for idx, (tf_name, tf_data) in enumerate(available_tfs):
# CRITICAL: Clone input data to avoid modifying original tensors
# This prevents version conflicts when batches are reused across epochs
tf_data_processed = tf_data.clone()
# Ensure correct sequence length
if tf_data.shape[1] != seq_len:
if tf_data.shape[1] < seq_len:
if tf_data_processed.shape[1] != seq_len:
if tf_data_processed.shape[1] < seq_len:
# Pad with last candle
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
tf_data = torch.cat([tf_data, padding], dim=1)
padding = tf_data_processed[:, -1:, :].expand(batch_size, seq_len - tf_data_processed.shape[1], 5)
tf_data_processed = torch.cat([tf_data_processed, padding], dim=1)
else:
# Truncate to seq_len
tf_data = tf_data[:, :seq_len, :]
tf_data_processed = tf_data_processed[:, :seq_len, :]
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
tf_encoded = self.shared_pattern_encoder(tf_data)
tf_encoded = self.shared_pattern_encoder(tf_data_processed)
# Add timeframe-specific embedding (helps model know which timeframe)
# Get timeframe index
@@ -1321,7 +1334,7 @@ class TradingTransformerTrainer:
# Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to True to find exact inplace operation causing errors
enable_anomaly_detection = False # DISABLED - inplace operations fixed
enable_anomaly_detection = False # DISABLED - inplace operation issues fixed
if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True)
@@ -1374,10 +1387,16 @@ class TradingTransformerTrainer:
else:
batch_on_device[k] = v
else:
# Batch is already on GPU, but still create a copy of the dict
# to avoid modifying the original batch dict
# CRITICAL FIX: Batch is already on GPU, but we must clone tensors
# to avoid version conflicts when the same batch is reused across epochs.
# Without cloning, operations like .contiguous() and .view() modify
# the tensor's version number, breaking backpropagation.
for k, v in batch.items():
batch_on_device[k] = v
if isinstance(v, torch.Tensor):
# Clone tensor to create independent copy with fresh version number
batch_on_device[k] = v.clone()
else:
batch_on_device[k] = v
# Ensure all batch tensors are on the same device as the model
# This is critical to avoid device mismatch errors