LR training wip

This commit is contained in:
Dobromir Popov
2025-12-08 21:52:26 +02:00
parent 1ab1c02889
commit 08ee2b6a3a
3 changed files with 61 additions and 31 deletions

View File

@@ -2530,13 +2530,20 @@ class RealTrainingAdapter:
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing. OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
Each mini-batch contains 5 samples for better GPU utilization. Each mini-batch contains 5 samples for better GPU utilization.
IMPORTANT: Creates a shallow copy of batch dict to prevent in-place modifications CRITICAL FIX: Clone tensors for each epoch to avoid autograd version conflicts.
from affecting subsequent epochs. Tensors themselves are shared (not copied). When the same tensor is used across multiple forward passes, operations like
.contiguous() and .view() modify the tensor's version number, breaking backprop.
""" """
for batch in grouped_batches: for batch in grouped_batches:
# Create shallow copy of batch dict to prevent modifications # CRITICAL: Clone all tensors to avoid version conflicts across epochs
# Tensors are shared (not cloned) for memory efficiency # This prevents "modified by an inplace operation" errors during backward pass
batch_copy = {k: v for k, v in batch.items()} batch_copy = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Clone tensor to create independent copy with fresh version number
batch_copy[k] = v.clone()
else:
batch_copy[k] = v
yield batch_copy yield batch_copy
total_batches = len(grouped_batches) total_batches = len(grouped_batches)

View File

@@ -144,19 +144,23 @@ class DeepMultiScaleAttention(nn.Module):
batch_size, seq_len, _ = x.size() batch_size, seq_len, _ = x.size()
scale_outputs = [] scale_outputs = []
# Clone input to avoid inplace modification issues
x_input = x.clone()
for scale_proj in self.scale_projections: for scale_proj in self.scale_projections:
# Apply enhanced temporal convolution for this scale # Apply enhanced temporal convolution for this scale
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2) x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2)
# Enhanced attention computation with deeper projections # Enhanced attention computation with deeper projections
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) # Use contiguous() before view() to ensure memory layout is correct
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) Q = scale_proj['query'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) K = scale_proj['key'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
V = scale_proj['value'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
# Transpose for attention computation # Transpose for attention computation
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim) Q = Q.transpose(1, 2).contiguous() # (batch, n_heads, seq_len, head_dim)
K = K.transpose(1, 2) K = K.transpose(1, 2).contiguous()
V = V.transpose(1, 2) V = V.transpose(1, 2).contiguous()
# Scaled dot-product attention # Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
@@ -293,24 +297,29 @@ class TradingTransformerLayer(nn.Module):
self.regime_detector = MarketRegimeDetector(config.d_model) self.regime_detector = MarketRegimeDetector(config.d_model)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# CRITICAL: Clone input to avoid version conflicts during backpropagation
# This prevents "modified by an inplace operation" errors when the same
# batch is used across multiple epochs
x_residual = x.clone()
# Self-attention with residual connection # Self-attention with residual connection
# Store residual before any operations to avoid version conflicts
if isinstance(self.attention, DeepMultiScaleAttention): if isinstance(self.attention, DeepMultiScaleAttention):
attn_output = self.attention(x, mask) attn_output = self.attention(x_residual, mask)
else: else:
attn_output, _ = self.attention(x, x, x, attn_mask=mask) attn_output, _ = self.attention(x_residual, x_residual, x_residual, attn_mask=mask)
# Create new tensor for residual to avoid inplace modification tracking # Create new tensor for residual to avoid inplace modification tracking
x_new = self.norm1(x + self.dropout(attn_output)) x_new = self.norm1(x_residual + self.dropout(attn_output))
# Market regime adaptation # Market regime adaptation
regime_probs = None regime_probs = None
if hasattr(self, 'regime_detector'): if hasattr(self, 'regime_detector'):
x_new, regime_probs = self.regime_detector(x_new) x_new, regime_probs = self.regime_detector(x_new)
# Feed-forward with residual connection # Feed-forward with residual connection - clone to avoid version conflicts
ff_output = self.feed_forward(x_new) x_ff_residual = x_new.clone()
x_out = self.norm2(x_new + self.dropout(ff_output)) ff_output = self.feed_forward(x_ff_residual)
x_out = self.norm2(x_ff_residual + self.dropout(ff_output))
return { return {
'output': x_out, 'output': x_out,
@@ -557,19 +566,23 @@ class AdvancedTradingTransformer(nn.Module):
timeframe_indices = [] timeframe_indices = []
for idx, (tf_name, tf_data) in enumerate(available_tfs): for idx, (tf_name, tf_data) in enumerate(available_tfs):
# CRITICAL: Clone input data to avoid modifying original tensors
# This prevents version conflicts when batches are reused across epochs
tf_data_processed = tf_data.clone()
# Ensure correct sequence length # Ensure correct sequence length
if tf_data.shape[1] != seq_len: if tf_data_processed.shape[1] != seq_len:
if tf_data.shape[1] < seq_len: if tf_data_processed.shape[1] < seq_len:
# Pad with last candle # Pad with last candle
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5) padding = tf_data_processed[:, -1:, :].expand(batch_size, seq_len - tf_data_processed.shape[1], 5)
tf_data = torch.cat([tf_data, padding], dim=1) tf_data_processed = torch.cat([tf_data_processed, padding], dim=1)
else: else:
# Truncate to seq_len # Truncate to seq_len
tf_data = tf_data[:, :seq_len, :] tf_data_processed = tf_data_processed[:, :seq_len, :]
# Apply SHARED pattern encoder (learns patterns once for all timeframes) # Apply SHARED pattern encoder (learns patterns once for all timeframes)
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model] # Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
tf_encoded = self.shared_pattern_encoder(tf_data) tf_encoded = self.shared_pattern_encoder(tf_data_processed)
# Add timeframe-specific embedding (helps model know which timeframe) # Add timeframe-specific embedding (helps model know which timeframe)
# Get timeframe index # Get timeframe index
@@ -1321,7 +1334,7 @@ class TradingTransformerTrainer:
# Enable anomaly detection temporarily to debug inplace operation issues # Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging # NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to True to find exact inplace operation causing errors # Set to True to find exact inplace operation causing errors
enable_anomaly_detection = False # DISABLED - inplace operations fixed enable_anomaly_detection = False # DISABLED - inplace operation issues fixed
if enable_anomaly_detection: if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
@@ -1374,9 +1387,15 @@ class TradingTransformerTrainer:
else: else:
batch_on_device[k] = v batch_on_device[k] = v
else: else:
# Batch is already on GPU, but still create a copy of the dict # CRITICAL FIX: Batch is already on GPU, but we must clone tensors
# to avoid modifying the original batch dict # to avoid version conflicts when the same batch is reused across epochs.
# Without cloning, operations like .contiguous() and .view() modify
# the tensor's version number, breaking backpropagation.
for k, v in batch.items(): for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Clone tensor to create independent copy with fresh version number
batch_on_device[k] = v.clone()
else:
batch_on_device[k] = v batch_on_device[k] = v
# Ensure all batch tensors are on the same device as the model # Ensure all batch tensors are on the same device as the model

View File

@@ -613,7 +613,11 @@ class TradingOrchestrator:
# CRITICAL: Initialize checkpoint manager for saving training progress # CRITICAL: Initialize checkpoint manager for saving training progress
self.checkpoint_manager = None self.checkpoint_manager = None
self.training_iterations = 0 # Track training iterations for periodic saves self.training_iterations = 0 # Track training iterations for periodic saves
try:
self._initialize_checkpoint_manager() self._initialize_checkpoint_manager()
except Exception as e:
logger.error(f"Failed to initialize checkpoint manager in __init__: {e}")
self.checkpoint_manager = None
# Initialize models, COB integration, and training system # Initialize models, COB integration, and training system
self._initialize_ml_models() self._initialize_ml_models()
@@ -828,7 +832,7 @@ class TradingOrchestrator:
# Try to load best checkpoint # Try to load best checkpoint
checkpoint_loaded = False checkpoint_loaded = False
try: try:
if self.checkpoint_manager: if hasattr(self, 'checkpoint_manager') and self.checkpoint_manager:
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer") checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
if checkpoint_path and checkpoint_metadata: if checkpoint_path and checkpoint_metadata:
# Load the checkpoint # Load the checkpoint