fetching data from the DB to train

This commit is contained in:
Dobromir Popov
2025-10-31 03:14:35 +02:00
parent 07150fd019
commit 6ac324289c
6 changed files with 1113 additions and 46 deletions

View File

@@ -977,14 +977,17 @@ class TradingTransformerTrainer:
confidence_pred = outputs['confidence']
trade_target = batch['trade_success'].float()
# Verify shapes match (should both be [batch_size, 1])
# Ensure both have shape [batch_size, 1] for BCELoss
# BCELoss requires exact shape match
if trade_target.dim() == 1:
trade_target = trade_target.unsqueeze(-1)
if confidence_pred.dim() == 1:
confidence_pred = confidence_pred.unsqueeze(-1)
# Final shape verification
if confidence_pred.shape != trade_target.shape:
logger.warning(f"Shape mismatch: confidence {confidence_pred.shape} vs target {trade_target.shape}")
# Reshape to match if needed
if trade_target.dim() == 1:
trade_target = trade_target.unsqueeze(-1)
if confidence_pred.dim() == 1:
confidence_pred = confidence_pred.unsqueeze(-1)
# Force reshape to match
trade_target = trade_target.view(confidence_pred.shape)
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
# Use addition instead of += to avoid inplace operation