load market data for training/inference
This commit is contained in:
@@ -971,20 +971,20 @@ class TradingTransformerTrainer:
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
# Ensure both tensors have compatible shapes for BCELoss
|
||||
# BCELoss requires both inputs to have the same shape
|
||||
confidence_pred = outputs['confidence'] # Keep as [batch_size, 1]
|
||||
# Both tensors should have shape [batch_size, 1]
|
||||
# confidence: [batch_size, 1] from confidence_head
|
||||
# trade_success: [batch_size, 1] from batch preparation
|
||||
confidence_pred = outputs['confidence']
|
||||
trade_target = batch['trade_success'].float()
|
||||
|
||||
# Reshape target to match prediction shape [batch_size, 1]
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
|
||||
# Ensure both have same shape
|
||||
# Verify shapes match (should both be [batch_size, 1])
|
||||
if confidence_pred.shape != trade_target.shape:
|
||||
# If shapes still don't match, squeeze both to 1D
|
||||
confidence_pred = confidence_pred.view(-1)
|
||||
trade_target = trade_target.view(-1)
|
||||
logger.warning(f"Shape mismatch: confidence {confidence_pred.shape} vs target {trade_target.shape}")
|
||||
# Reshape to match if needed
|
||||
if trade_target.dim() == 1:
|
||||
trade_target = trade_target.unsqueeze(-1)
|
||||
if confidence_pred.dim() == 1:
|
||||
confidence_pred = confidence_pred.unsqueeze(-1)
|
||||
|
||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||
# Use addition instead of += to avoid inplace operation
|
||||
|
||||
Reference in New Issue
Block a user