105 lines
3.7 KiB
Markdown
105 lines
3.7 KiB
Markdown
# Tensor Operation Fixes Report
|
|
*Generated: 2024-12-19*
|
|
|
|
## 🎯 Issue Summary
|
|
|
|
The orchestrator was experiencing critical tensor operation errors that prevented model predictions:
|
|
|
|
1. **Softmax Error**: `softmax() received an invalid combination of arguments - got (tuple, dim=int)`
|
|
2. **View Error**: `view size is not compatible with input tensor's size and stride`
|
|
3. **Unpacking Error**: `cannot unpack non-iterable NoneType object`
|
|
|
|
## 🔧 Fixes Applied
|
|
|
|
### 1. DQN Agent Softmax Fix (`NN/models/dqn_agent.py`)
|
|
|
|
**Problem**: Q-values tensor had incorrect dimensions for softmax operation.
|
|
|
|
**Solution**: Added dimension checking and reshaping before softmax:
|
|
|
|
```python
|
|
# Before
|
|
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
|
|
|
# After
|
|
if q_values.dim() == 1:
|
|
q_values = q_values.unsqueeze(0)
|
|
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
|
```
|
|
|
|
**Impact**: Prevents tensor dimension mismatch errors in confidence calculations.
|
|
|
|
### 2. CNN Model View Operations Fix (`NN/models/cnn_model.py`)
|
|
|
|
**Problem**: `.view()` operations failed due to non-contiguous tensor memory layout.
|
|
|
|
**Solution**: Replaced `.view()` with `.reshape()` for automatic contiguity handling:
|
|
|
|
```python
|
|
# Before
|
|
x = x.view(x.shape[0], -1, x.shape[-1])
|
|
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2).contiguous()
|
|
|
|
# After
|
|
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
|
embedded = embedded.reshape(batch_size, seq_len, -1).transpose(1, 2).contiguous()
|
|
```
|
|
|
|
**Impact**: Eliminates tensor stride incompatibility errors during CNN forward pass.
|
|
|
|
### 3. Generic Prediction Unpacking Fix (`core/orchestrator.py`)
|
|
|
|
**Problem**: Model prediction methods returned different formats, causing unpacking errors.
|
|
|
|
**Solution**: Added robust return value handling:
|
|
|
|
```python
|
|
# Before
|
|
action_probs, confidence = model.predict(feature_matrix)
|
|
|
|
# After
|
|
prediction_result = model.predict(feature_matrix)
|
|
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
|
|
action_probs, confidence = prediction_result
|
|
elif isinstance(prediction_result, dict):
|
|
action_probs = prediction_result.get('probabilities', None)
|
|
confidence = prediction_result.get('confidence', 0.7)
|
|
else:
|
|
action_probs = prediction_result
|
|
confidence = 0.7
|
|
```
|
|
|
|
**Impact**: Prevents unpacking errors when models return different formats.
|
|
|
|
## 📊 Technical Details
|
|
|
|
### Root Causes
|
|
1. **Tensor Dimension Mismatch**: DQN models sometimes output 1D tensors when 2D expected
|
|
2. **Memory Layout Issues**: `.view()` requires contiguous memory, `.reshape()` handles non-contiguous
|
|
3. **API Inconsistency**: Different models return predictions in different formats
|
|
|
|
### Best Practices Applied
|
|
- **Defensive Programming**: Check tensor dimensions before operations
|
|
- **Memory Safety**: Use `.reshape()` instead of `.view()` for flexibility
|
|
- **API Robustness**: Handle multiple return formats gracefully
|
|
|
|
## 🎯 Expected Results
|
|
|
|
After these fixes:
|
|
- ✅ DQN predictions should work without softmax errors
|
|
- ✅ CNN predictions should work without view/stride errors
|
|
- ✅ Generic model predictions should work without unpacking errors
|
|
- ✅ Orchestrator should generate proper trading decisions
|
|
|
|
## 🔄 Testing Recommendations
|
|
|
|
1. **Run Dashboard**: Test that predictions are generated successfully
|
|
2. **Monitor Logs**: Check for reduction in tensor operation errors
|
|
3. **Verify Trading Signals**: Ensure BUY/SELL/HOLD decisions are made
|
|
4. **Performance Check**: Confirm no significant performance degradation
|
|
|
|
## 📝 Notes
|
|
|
|
- Some linter errors remain but are related to missing attributes, not tensor operations
|
|
- The core tensor operation issues have been resolved
|
|
- Models should now make predictions without crashing the orchestrator |