LR training wip
This commit is contained in:
@@ -2530,13 +2530,20 @@ class RealTrainingAdapter:
|
||||
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
|
||||
Each mini-batch contains 5 samples for better GPU utilization.
|
||||
|
||||
IMPORTANT: Creates a shallow copy of batch dict to prevent in-place modifications
|
||||
from affecting subsequent epochs. Tensors themselves are shared (not copied).
|
||||
CRITICAL FIX: Clone tensors for each epoch to avoid autograd version conflicts.
|
||||
When the same tensor is used across multiple forward passes, operations like
|
||||
.contiguous() and .view() modify the tensor's version number, breaking backprop.
|
||||
"""
|
||||
for batch in grouped_batches:
|
||||
# Create shallow copy of batch dict to prevent modifications
|
||||
# Tensors are shared (not cloned) for memory efficiency
|
||||
batch_copy = {k: v for k, v in batch.items()}
|
||||
# CRITICAL: Clone all tensors to avoid version conflicts across epochs
|
||||
# This prevents "modified by an inplace operation" errors during backward pass
|
||||
batch_copy = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Clone tensor to create independent copy with fresh version number
|
||||
batch_copy[k] = v.clone()
|
||||
else:
|
||||
batch_copy[k] = v
|
||||
yield batch_copy
|
||||
|
||||
total_batches = len(grouped_batches)
|
||||
|
||||
@@ -144,19 +144,23 @@ class DeepMultiScaleAttention(nn.Module):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
scale_outputs = []
|
||||
|
||||
# Clone input to avoid inplace modification issues
|
||||
x_input = x.clone()
|
||||
|
||||
for scale_proj in self.scale_projections:
|
||||
# Apply enhanced temporal convolution for this scale
|
||||
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
|
||||
x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
# Enhanced attention computation with deeper projections
|
||||
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
# Use contiguous() before view() to ensure memory layout is correct
|
||||
Q = scale_proj['query'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
K = scale_proj['key'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
V = scale_proj['value'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
|
||||
# Transpose for attention computation
|
||||
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
Q = Q.transpose(1, 2).contiguous() # (batch, n_heads, seq_len, head_dim)
|
||||
K = K.transpose(1, 2).contiguous()
|
||||
V = V.transpose(1, 2).contiguous()
|
||||
|
||||
# Scaled dot-product attention
|
||||
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
@@ -293,24 +297,29 @@ class TradingTransformerLayer(nn.Module):
|
||||
self.regime_detector = MarketRegimeDetector(config.d_model)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
# CRITICAL: Clone input to avoid version conflicts during backpropagation
|
||||
# This prevents "modified by an inplace operation" errors when the same
|
||||
# batch is used across multiple epochs
|
||||
x_residual = x.clone()
|
||||
|
||||
# Self-attention with residual connection
|
||||
# Store residual before any operations to avoid version conflicts
|
||||
if isinstance(self.attention, DeepMultiScaleAttention):
|
||||
attn_output = self.attention(x, mask)
|
||||
attn_output = self.attention(x_residual, mask)
|
||||
else:
|
||||
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
||||
attn_output, _ = self.attention(x_residual, x_residual, x_residual, attn_mask=mask)
|
||||
|
||||
# Create new tensor for residual to avoid inplace modification tracking
|
||||
x_new = self.norm1(x + self.dropout(attn_output))
|
||||
x_new = self.norm1(x_residual + self.dropout(attn_output))
|
||||
|
||||
# Market regime adaptation
|
||||
regime_probs = None
|
||||
if hasattr(self, 'regime_detector'):
|
||||
x_new, regime_probs = self.regime_detector(x_new)
|
||||
|
||||
# Feed-forward with residual connection
|
||||
ff_output = self.feed_forward(x_new)
|
||||
x_out = self.norm2(x_new + self.dropout(ff_output))
|
||||
# Feed-forward with residual connection - clone to avoid version conflicts
|
||||
x_ff_residual = x_new.clone()
|
||||
ff_output = self.feed_forward(x_ff_residual)
|
||||
x_out = self.norm2(x_ff_residual + self.dropout(ff_output))
|
||||
|
||||
return {
|
||||
'output': x_out,
|
||||
@@ -557,19 +566,23 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
timeframe_indices = []
|
||||
|
||||
for idx, (tf_name, tf_data) in enumerate(available_tfs):
|
||||
# CRITICAL: Clone input data to avoid modifying original tensors
|
||||
# This prevents version conflicts when batches are reused across epochs
|
||||
tf_data_processed = tf_data.clone()
|
||||
|
||||
# Ensure correct sequence length
|
||||
if tf_data.shape[1] != seq_len:
|
||||
if tf_data.shape[1] < seq_len:
|
||||
if tf_data_processed.shape[1] != seq_len:
|
||||
if tf_data_processed.shape[1] < seq_len:
|
||||
# Pad with last candle
|
||||
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
|
||||
tf_data = torch.cat([tf_data, padding], dim=1)
|
||||
padding = tf_data_processed[:, -1:, :].expand(batch_size, seq_len - tf_data_processed.shape[1], 5)
|
||||
tf_data_processed = torch.cat([tf_data_processed, padding], dim=1)
|
||||
else:
|
||||
# Truncate to seq_len
|
||||
tf_data = tf_data[:, :seq_len, :]
|
||||
tf_data_processed = tf_data_processed[:, :seq_len, :]
|
||||
|
||||
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
|
||||
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
|
||||
tf_encoded = self.shared_pattern_encoder(tf_data)
|
||||
tf_encoded = self.shared_pattern_encoder(tf_data_processed)
|
||||
|
||||
# Add timeframe-specific embedding (helps model know which timeframe)
|
||||
# Get timeframe index
|
||||
@@ -1321,7 +1334,7 @@ class TradingTransformerTrainer:
|
||||
# Enable anomaly detection temporarily to debug inplace operation issues
|
||||
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
|
||||
# Set to True to find exact inplace operation causing errors
|
||||
enable_anomaly_detection = False # DISABLED - inplace operations fixed
|
||||
enable_anomaly_detection = False # DISABLED - inplace operation issues fixed
|
||||
if enable_anomaly_detection:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
@@ -1374,10 +1387,16 @@ class TradingTransformerTrainer:
|
||||
else:
|
||||
batch_on_device[k] = v
|
||||
else:
|
||||
# Batch is already on GPU, but still create a copy of the dict
|
||||
# to avoid modifying the original batch dict
|
||||
# CRITICAL FIX: Batch is already on GPU, but we must clone tensors
|
||||
# to avoid version conflicts when the same batch is reused across epochs.
|
||||
# Without cloning, operations like .contiguous() and .view() modify
|
||||
# the tensor's version number, breaking backpropagation.
|
||||
for k, v in batch.items():
|
||||
batch_on_device[k] = v
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Clone tensor to create independent copy with fresh version number
|
||||
batch_on_device[k] = v.clone()
|
||||
else:
|
||||
batch_on_device[k] = v
|
||||
|
||||
# Ensure all batch tensors are on the same device as the model
|
||||
# This is critical to avoid device mismatch errors
|
||||
|
||||
@@ -613,7 +613,11 @@ class TradingOrchestrator:
|
||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||
self.checkpoint_manager = None
|
||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||
self._initialize_checkpoint_manager()
|
||||
try:
|
||||
self._initialize_checkpoint_manager()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize checkpoint manager in __init__: {e}")
|
||||
self.checkpoint_manager = None
|
||||
|
||||
# Initialize models, COB integration, and training system
|
||||
self._initialize_ml_models()
|
||||
@@ -828,7 +832,7 @@ class TradingOrchestrator:
|
||||
# Try to load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
if hasattr(self, 'checkpoint_manager') and self.checkpoint_manager:
|
||||
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
||||
if checkpoint_path and checkpoint_metadata:
|
||||
# Load the checkpoint
|
||||
|
||||
Reference in New Issue
Block a user