fetching data from the DB to train
This commit is contained in:
@@ -977,14 +977,17 @@ class TradingTransformerTrainer:
|
||||
confidence_pred = outputs['confidence']
|
||||
trade_target = batch['trade_success'].float()
|
||||
|
||||
# Verify shapes match (should both be [batch_size, 1])
|
||||
# Ensure both have shape [batch_size, 1] for BCELoss
|
||||
# BCELoss requires exact shape match
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
if confidence_pred.dim() == 1:
|
||||
confidence_pred = confidence_pred.unsqueeze(-1)
|
||||
|
||||
# Final shape verification
|
||||
if confidence_pred.shape != trade_target.shape:
|
||||
logger.warning(f"Shape mismatch: confidence {confidence_pred.shape} vs target {trade_target.shape}")
|
||||
# Reshape to match if needed
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
if confidence_pred.dim() == 1:
|
||||
confidence_pred = confidence_pred.unsqueeze(-1)
|
||||
# Force reshape to match
|
||||
trade_target = trade_target.view(confidence_pred.shape)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
|
||||
Reference in New Issue
Block a user