LR training wip
This commit is contained in:
@@ -2530,13 +2530,20 @@ class RealTrainingAdapter:
|
|||||||
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
|
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
|
||||||
Each mini-batch contains 5 samples for better GPU utilization.
|
Each mini-batch contains 5 samples for better GPU utilization.
|
||||||
|
|
||||||
IMPORTANT: Creates a shallow copy of batch dict to prevent in-place modifications
|
CRITICAL FIX: Clone tensors for each epoch to avoid autograd version conflicts.
|
||||||
from affecting subsequent epochs. Tensors themselves are shared (not copied).
|
When the same tensor is used across multiple forward passes, operations like
|
||||||
|
.contiguous() and .view() modify the tensor's version number, breaking backprop.
|
||||||
"""
|
"""
|
||||||
for batch in grouped_batches:
|
for batch in grouped_batches:
|
||||||
# Create shallow copy of batch dict to prevent modifications
|
# CRITICAL: Clone all tensors to avoid version conflicts across epochs
|
||||||
# Tensors are shared (not cloned) for memory efficiency
|
# This prevents "modified by an inplace operation" errors during backward pass
|
||||||
batch_copy = {k: v for k, v in batch.items()}
|
batch_copy = {}
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
# Clone tensor to create independent copy with fresh version number
|
||||||
|
batch_copy[k] = v.clone()
|
||||||
|
else:
|
||||||
|
batch_copy[k] = v
|
||||||
yield batch_copy
|
yield batch_copy
|
||||||
|
|
||||||
total_batches = len(grouped_batches)
|
total_batches = len(grouped_batches)
|
||||||
|
|||||||
@@ -144,19 +144,23 @@ class DeepMultiScaleAttention(nn.Module):
|
|||||||
batch_size, seq_len, _ = x.size()
|
batch_size, seq_len, _ = x.size()
|
||||||
scale_outputs = []
|
scale_outputs = []
|
||||||
|
|
||||||
|
# Clone input to avoid inplace modification issues
|
||||||
|
x_input = x.clone()
|
||||||
|
|
||||||
for scale_proj in self.scale_projections:
|
for scale_proj in self.scale_projections:
|
||||||
# Apply enhanced temporal convolution for this scale
|
# Apply enhanced temporal convolution for this scale
|
||||||
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
|
x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
# Enhanced attention computation with deeper projections
|
# Enhanced attention computation with deeper projections
|
||||||
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
# Use contiguous() before view() to ensure memory layout is correct
|
||||||
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
Q = scale_proj['query'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||||
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
K = scale_proj['key'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||||
|
V = scale_proj['value'](x_conv).contiguous().view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
# Transpose for attention computation
|
# Transpose for attention computation
|
||||||
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
|
Q = Q.transpose(1, 2).contiguous() # (batch, n_heads, seq_len, head_dim)
|
||||||
K = K.transpose(1, 2)
|
K = K.transpose(1, 2).contiguous()
|
||||||
V = V.transpose(1, 2)
|
V = V.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
# Scaled dot-product attention
|
# Scaled dot-product attention
|
||||||
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
@@ -293,24 +297,29 @@ class TradingTransformerLayer(nn.Module):
|
|||||||
self.regime_detector = MarketRegimeDetector(config.d_model)
|
self.regime_detector = MarketRegimeDetector(config.d_model)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
# CRITICAL: Clone input to avoid version conflicts during backpropagation
|
||||||
|
# This prevents "modified by an inplace operation" errors when the same
|
||||||
|
# batch is used across multiple epochs
|
||||||
|
x_residual = x.clone()
|
||||||
|
|
||||||
# Self-attention with residual connection
|
# Self-attention with residual connection
|
||||||
# Store residual before any operations to avoid version conflicts
|
|
||||||
if isinstance(self.attention, DeepMultiScaleAttention):
|
if isinstance(self.attention, DeepMultiScaleAttention):
|
||||||
attn_output = self.attention(x, mask)
|
attn_output = self.attention(x_residual, mask)
|
||||||
else:
|
else:
|
||||||
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
attn_output, _ = self.attention(x_residual, x_residual, x_residual, attn_mask=mask)
|
||||||
|
|
||||||
# Create new tensor for residual to avoid inplace modification tracking
|
# Create new tensor for residual to avoid inplace modification tracking
|
||||||
x_new = self.norm1(x + self.dropout(attn_output))
|
x_new = self.norm1(x_residual + self.dropout(attn_output))
|
||||||
|
|
||||||
# Market regime adaptation
|
# Market regime adaptation
|
||||||
regime_probs = None
|
regime_probs = None
|
||||||
if hasattr(self, 'regime_detector'):
|
if hasattr(self, 'regime_detector'):
|
||||||
x_new, regime_probs = self.regime_detector(x_new)
|
x_new, regime_probs = self.regime_detector(x_new)
|
||||||
|
|
||||||
# Feed-forward with residual connection
|
# Feed-forward with residual connection - clone to avoid version conflicts
|
||||||
ff_output = self.feed_forward(x_new)
|
x_ff_residual = x_new.clone()
|
||||||
x_out = self.norm2(x_new + self.dropout(ff_output))
|
ff_output = self.feed_forward(x_ff_residual)
|
||||||
|
x_out = self.norm2(x_ff_residual + self.dropout(ff_output))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'output': x_out,
|
'output': x_out,
|
||||||
@@ -557,19 +566,23 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
timeframe_indices = []
|
timeframe_indices = []
|
||||||
|
|
||||||
for idx, (tf_name, tf_data) in enumerate(available_tfs):
|
for idx, (tf_name, tf_data) in enumerate(available_tfs):
|
||||||
|
# CRITICAL: Clone input data to avoid modifying original tensors
|
||||||
|
# This prevents version conflicts when batches are reused across epochs
|
||||||
|
tf_data_processed = tf_data.clone()
|
||||||
|
|
||||||
# Ensure correct sequence length
|
# Ensure correct sequence length
|
||||||
if tf_data.shape[1] != seq_len:
|
if tf_data_processed.shape[1] != seq_len:
|
||||||
if tf_data.shape[1] < seq_len:
|
if tf_data_processed.shape[1] < seq_len:
|
||||||
# Pad with last candle
|
# Pad with last candle
|
||||||
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
|
padding = tf_data_processed[:, -1:, :].expand(batch_size, seq_len - tf_data_processed.shape[1], 5)
|
||||||
tf_data = torch.cat([tf_data, padding], dim=1)
|
tf_data_processed = torch.cat([tf_data_processed, padding], dim=1)
|
||||||
else:
|
else:
|
||||||
# Truncate to seq_len
|
# Truncate to seq_len
|
||||||
tf_data = tf_data[:, :seq_len, :]
|
tf_data_processed = tf_data_processed[:, :seq_len, :]
|
||||||
|
|
||||||
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
|
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
|
||||||
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
|
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
|
||||||
tf_encoded = self.shared_pattern_encoder(tf_data)
|
tf_encoded = self.shared_pattern_encoder(tf_data_processed)
|
||||||
|
|
||||||
# Add timeframe-specific embedding (helps model know which timeframe)
|
# Add timeframe-specific embedding (helps model know which timeframe)
|
||||||
# Get timeframe index
|
# Get timeframe index
|
||||||
@@ -1321,7 +1334,7 @@ class TradingTransformerTrainer:
|
|||||||
# Enable anomaly detection temporarily to debug inplace operation issues
|
# Enable anomaly detection temporarily to debug inplace operation issues
|
||||||
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
|
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
|
||||||
# Set to True to find exact inplace operation causing errors
|
# Set to True to find exact inplace operation causing errors
|
||||||
enable_anomaly_detection = False # DISABLED - inplace operations fixed
|
enable_anomaly_detection = False # DISABLED - inplace operation issues fixed
|
||||||
if enable_anomaly_detection:
|
if enable_anomaly_detection:
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
@@ -1374,9 +1387,15 @@ class TradingTransformerTrainer:
|
|||||||
else:
|
else:
|
||||||
batch_on_device[k] = v
|
batch_on_device[k] = v
|
||||||
else:
|
else:
|
||||||
# Batch is already on GPU, but still create a copy of the dict
|
# CRITICAL FIX: Batch is already on GPU, but we must clone tensors
|
||||||
# to avoid modifying the original batch dict
|
# to avoid version conflicts when the same batch is reused across epochs.
|
||||||
|
# Without cloning, operations like .contiguous() and .view() modify
|
||||||
|
# the tensor's version number, breaking backpropagation.
|
||||||
for k, v in batch.items():
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
# Clone tensor to create independent copy with fresh version number
|
||||||
|
batch_on_device[k] = v.clone()
|
||||||
|
else:
|
||||||
batch_on_device[k] = v
|
batch_on_device[k] = v
|
||||||
|
|
||||||
# Ensure all batch tensors are on the same device as the model
|
# Ensure all batch tensors are on the same device as the model
|
||||||
|
|||||||
@@ -613,7 +613,11 @@ class TradingOrchestrator:
|
|||||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||||
self.checkpoint_manager = None
|
self.checkpoint_manager = None
|
||||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||||
|
try:
|
||||||
self._initialize_checkpoint_manager()
|
self._initialize_checkpoint_manager()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize checkpoint manager in __init__: {e}")
|
||||||
|
self.checkpoint_manager = None
|
||||||
|
|
||||||
# Initialize models, COB integration, and training system
|
# Initialize models, COB integration, and training system
|
||||||
self._initialize_ml_models()
|
self._initialize_ml_models()
|
||||||
@@ -828,7 +832,7 @@ class TradingOrchestrator:
|
|||||||
# Try to load best checkpoint
|
# Try to load best checkpoint
|
||||||
checkpoint_loaded = False
|
checkpoint_loaded = False
|
||||||
try:
|
try:
|
||||||
if self.checkpoint_manager:
|
if hasattr(self, 'checkpoint_manager') and self.checkpoint_manager:
|
||||||
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
||||||
if checkpoint_path and checkpoint_metadata:
|
if checkpoint_path and checkpoint_metadata:
|
||||||
# Load the checkpoint
|
# Load the checkpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user