REALTIME candlesstick prediction training fixes
This commit is contained in:
@@ -7,21 +7,44 @@
|
||||
**Problem**:
|
||||
```
|
||||
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
Error detected in NativeLayerNormBackward0
|
||||
Error detected in MmBackward0
|
||||
```
|
||||
|
||||
**Root Cause**:
|
||||
- Tensor operations like `x = x + position_emb` were modifying tensors that are part of the computation graph
|
||||
- The regime detector's weighted sum was creating shared memory references
|
||||
- Layer outputs were being reused without cloning
|
||||
- Residual connections in transformer layers were reusing variable names (`x = x + something`)
|
||||
- PyTorch tracks tensor versions and detects when tensors in the computation graph are modified
|
||||
- Layer normalization was operating on tensors that had been modified in-place
|
||||
- Gradient accumulation wasn't properly clearing stale gradients
|
||||
|
||||
**Fix Applied**:
|
||||
- Added `.clone()` to create new tensors instead of modifying existing ones:
|
||||
- `x = price_emb.clone() + cob_emb + tech_emb + market_emb`
|
||||
- `x = layer_output['output'].clone()`
|
||||
- `adapted_output = torch.sum(regime_stack * regime_weights, dim=0).clone()`
|
||||
1. **Residual Connections**: Changed to use new variable names instead of reusing `x`:
|
||||
```python
|
||||
# Before: x = self.norm1(x + self.dropout(attn_output))
|
||||
# After: x_new = self.norm1(x + self.dropout(attn_output))
|
||||
```
|
||||
|
||||
2. **Gradient Clearing**: Added explicit gradient clearing before each training step:
|
||||
```python
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
for param in self.model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
```
|
||||
|
||||
3. **Error Recovery**: Enhanced error handling to catch and recover from inplace errors:
|
||||
```python
|
||||
except RuntimeError as e:
|
||||
if "inplace operation" in str(e):
|
||||
# Clear gradients and continue
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
return zero_loss_result
|
||||
```
|
||||
|
||||
4. **Disabled Anomaly Detection**: Turned off PyTorch's anomaly detection (was causing 2-3x slowdown)
|
||||
|
||||
**Files Modified**:
|
||||
- `NN/models/advanced_transformer_trading.py` (lines 638, 668, 223)
|
||||
- `NN/models/advanced_transformer_trading.py` (lines 296-315, 638-653, 1323-1330, 1560-1580)
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user