REALTIME candlesstick prediction training fixes
This commit is contained in:
@@ -219,8 +219,8 @@ class MarketRegimeDetector(nn.Module):
|
||||
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
|
||||
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
|
||||
|
||||
# Weighted sum across regimes - clone to avoid inplace errors
|
||||
adapted_output = torch.sum(regime_stack * regime_weights, dim=0).clone()
|
||||
# Weighted sum across regimes
|
||||
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
|
||||
|
||||
return adapted_output, regime_probs
|
||||
|
||||
@@ -294,24 +294,26 @@ class TradingTransformerLayer(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
# Self-attention with residual connection
|
||||
# Store residual before any operations to avoid version conflicts
|
||||
if isinstance(self.attention, DeepMultiScaleAttention):
|
||||
attn_output = self.attention(x, mask)
|
||||
else:
|
||||
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
||||
|
||||
x = self.norm1(x + self.dropout(attn_output))
|
||||
# Create new tensor for residual to avoid inplace modification tracking
|
||||
x_new = self.norm1(x + self.dropout(attn_output))
|
||||
|
||||
# Market regime adaptation
|
||||
regime_probs = None
|
||||
if hasattr(self, 'regime_detector'):
|
||||
x, regime_probs = self.regime_detector(x)
|
||||
x_new, regime_probs = self.regime_detector(x_new)
|
||||
|
||||
# Feed-forward with residual connection
|
||||
ff_output = self.feed_forward(x)
|
||||
x = self.norm2(x + self.dropout(ff_output))
|
||||
ff_output = self.feed_forward(x_new)
|
||||
x_out = self.norm2(x_new + self.dropout(ff_output))
|
||||
|
||||
return {
|
||||
'output': x,
|
||||
'output': x_out,
|
||||
'regime_probs': regime_probs
|
||||
}
|
||||
|
||||
@@ -669,8 +671,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
else:
|
||||
layer_output = layer(x, mask)
|
||||
|
||||
# Clone to avoid inplace operation errors during backward pass
|
||||
x = layer_output['output'].clone()
|
||||
# Use output directly - no clone needed with proper variable naming
|
||||
x = layer_output['output']
|
||||
if layer_output['regime_probs'] is not None:
|
||||
regime_probs_history.append(layer_output['regime_probs'])
|
||||
|
||||
@@ -1318,7 +1320,7 @@ class TradingTransformerTrainer:
|
||||
# Enable anomaly detection temporarily to debug inplace operation issues
|
||||
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
|
||||
# Set to True to find exact inplace operation causing errors
|
||||
enable_anomaly_detection = True # TEMPORARILY ENABLED to find inplace operations
|
||||
enable_anomaly_detection = False # DISABLED - inplace operations fixed
|
||||
if enable_anomaly_detection:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
@@ -1339,6 +1341,11 @@ class TradingTransformerTrainer:
|
||||
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
|
||||
if not is_accumulation_step or self.current_accumulation_step == 1:
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Also clear any cached gradients in the model
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
|
||||
# OPTIMIZATION: Only move batch to device if not already there
|
||||
# Check if first tensor is already on correct device
|
||||
@@ -1557,15 +1564,21 @@ class TradingTransformerTrainer:
|
||||
else:
|
||||
total_loss.backward()
|
||||
except RuntimeError as e:
|
||||
if "inplace operation" in str(e):
|
||||
error_msg = str(e)
|
||||
if "inplace operation" in error_msg or "modified by an inplace operation" in error_msg:
|
||||
logger.error(f"Inplace operation error during backward pass: {e}")
|
||||
# Clear gradients to reset state
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
# Return zero loss to continue training
|
||||
return {
|
||||
'total_loss': 0.0,
|
||||
'action_loss': 0.0,
|
||||
'price_loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
|
||||
}
|
||||
else:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user