training fixes
This commit is contained in:
@@ -950,8 +950,9 @@ class TradingTransformerTrainer:
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Move batch to device
|
||||
batch = {k: v.to(self.device) for k, v in batch.items()}
|
||||
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
|
||||
# This ensures each batch is independent and prevents gradient computation errors
|
||||
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
@@ -965,24 +966,29 @@ class TradingTransformerTrainer:
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
# Start with base losses - avoid inplace operations on computation graph
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
# Ensure both tensors have compatible shapes
|
||||
# confidence: [batch_size, 1] -> squeeze last dim to [batch_size]
|
||||
# trade_success: [batch_size] - ensure same shape
|
||||
confidence_pred = outputs['confidence'].squeeze(-1) # Only remove last dimension
|
||||
# Ensure both tensors have compatible shapes for BCELoss
|
||||
# BCELoss requires both inputs to have the same shape
|
||||
confidence_pred = outputs['confidence'] # Keep as [batch_size, 1]
|
||||
trade_target = batch['trade_success'].float()
|
||||
|
||||
# Ensure shapes match (handle edge case where batch_size=1)
|
||||
if confidence_pred.dim() == 0: # scalar case
|
||||
confidence_pred = confidence_pred.unsqueeze(0)
|
||||
if trade_target.dim() == 0: # scalar case
|
||||
trade_target = trade_target.unsqueeze(0)
|
||||
# Reshape target to match prediction shape [batch_size, 1]
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
|
||||
# Ensure both have same shape
|
||||
if confidence_pred.shape != trade_target.shape:
|
||||
# If shapes still don't match, squeeze both to 1D
|
||||
confidence_pred = confidence_pred.view(-1)
|
||||
trade_target = trade_target.view(-1)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
total_loss += 0.1 * confidence_loss
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
total_loss = total_loss + 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user