REALTIME candlesstick prediction training fixes

This commit is contained in:
Dobromir Popov
2025-12-08 19:57:47 +02:00
parent c8ce314872
commit cc555735e8
4 changed files with 275 additions and 20 deletions

View File

@@ -219,8 +219,8 @@ class MarketRegimeDetector(nn.Module):
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
# Weighted sum across regimes - clone to avoid inplace errors
adapted_output = torch.sum(regime_stack * regime_weights, dim=0).clone()
# Weighted sum across regimes
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
return adapted_output, regime_probs
@@ -294,24 +294,26 @@ class TradingTransformerLayer(nn.Module):
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Self-attention with residual connection
# Store residual before any operations to avoid version conflicts
if isinstance(self.attention, DeepMultiScaleAttention):
attn_output = self.attention(x, mask)
else:
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# Create new tensor for residual to avoid inplace modification tracking
x_new = self.norm1(x + self.dropout(attn_output))
# Market regime adaptation
regime_probs = None
if hasattr(self, 'regime_detector'):
x, regime_probs = self.regime_detector(x)
x_new, regime_probs = self.regime_detector(x_new)
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
ff_output = self.feed_forward(x_new)
x_out = self.norm2(x_new + self.dropout(ff_output))
return {
'output': x,
'output': x_out,
'regime_probs': regime_probs
}
@@ -669,8 +671,8 @@ class AdvancedTradingTransformer(nn.Module):
else:
layer_output = layer(x, mask)
# Clone to avoid inplace operation errors during backward pass
x = layer_output['output'].clone()
# Use output directly - no clone needed with proper variable naming
x = layer_output['output']
if layer_output['regime_probs'] is not None:
regime_probs_history.append(layer_output['regime_probs'])
@@ -1318,7 +1320,7 @@ class TradingTransformerTrainer:
# Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to True to find exact inplace operation causing errors
enable_anomaly_detection = True # TEMPORARILY ENABLED to find inplace operations
enable_anomaly_detection = False # DISABLED - inplace operations fixed
if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True)
@@ -1339,6 +1341,11 @@ class TradingTransformerTrainer:
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
if not is_accumulation_step or self.current_accumulation_step == 1:
self.optimizer.zero_grad(set_to_none=True)
# Also clear any cached gradients in the model
for param in self.model.parameters():
if param.grad is not None:
param.grad = None
# OPTIMIZATION: Only move batch to device if not already there
# Check if first tensor is already on correct device
@@ -1557,15 +1564,21 @@ class TradingTransformerTrainer:
else:
total_loss.backward()
except RuntimeError as e:
if "inplace operation" in str(e):
error_msg = str(e)
if "inplace operation" in error_msg or "modified by an inplace operation" in error_msg:
logger.error(f"Inplace operation error during backward pass: {e}")
# Clear gradients to reset state
self.optimizer.zero_grad(set_to_none=True)
for param in self.model.parameters():
if param.grad is not None:
param.grad = None
# Return zero loss to continue training
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'learning_rate': self.scheduler.get_last_lr()[0]
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
}
else:
raise