From cefd30d2bdda84f284f15cae73339192f22bec4e Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Fri, 31 Oct 2025 01:29:05 +0200 Subject: [PATCH] training fixes --- ANNOTATE/core/real_training_adapter.py | 61 +++++++++++++++++++---- NN/models/advanced_transformer_trading.py | 30 ++++++----- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 825d240..8003a98 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -441,12 +441,19 @@ class RealTrainingAdapter: logger.debug(" No holding period, skipping HOLD samples") return hold_samples - # Parse entry timestamp + # Parse entry timestamp - handle multiple formats try: if 'T' in entry_timestamp: entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00')) else: - entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S') + # Try with seconds first, then without + try: + entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S') + except ValueError: + # Try without seconds + entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M') + + # Make timezone-aware if pytz: entry_time = entry_time.replace(tzinfo=pytz.UTC) else: @@ -466,7 +473,21 @@ class RealTrainingAdapter: # Find all candles between entry and exit for idx, ts_str in enumerate(timestamps): - ts = datetime.fromisoformat(ts_str.replace(' ', 'T')) + # Parse timestamp and ensure it's timezone-aware + try: + if 'T' in ts_str: + ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) + else: + ts = datetime.fromisoformat(ts_str.replace(' ', 'T')) + # Make timezone-aware if it's naive + if ts.tzinfo is None: + if pytz: + ts = ts.replace(tzinfo=pytz.UTC) + else: + ts = ts.replace(tzinfo=timezone.utc) + except Exception as e: + logger.debug(f"Could not parse timestamp '{ts_str}': {e}") + continue # If this candle is between entry and exit (exclusive) if entry_time < ts < exit_time: @@ -534,7 +555,14 @@ class RealTrainingAdapter: if 'T' in signal_timestamp: signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00')) else: - signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S') + # Try with seconds first, then without + try: + signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S') + except ValueError: + # Try without seconds + signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M') + + # Make timezone-aware if pytz: signal_time = signal_time.replace(tzinfo=pytz.UTC) else: @@ -546,15 +574,22 @@ class RealTrainingAdapter: signal_index = None for idx, ts_str in enumerate(timestamps): try: - # Parse timestamp from market data + # Parse timestamp from market data - handle multiple formats if 'T' in ts_str: ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) else: - ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') - if pytz: - ts = ts.replace(tzinfo=pytz.UTC) - else: - ts = ts.replace(tzinfo=timezone.utc) + # Try with seconds first, then without + try: + ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') + except ValueError: + ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M') + + # Make timezone-aware if naive + if ts.tzinfo is None: + if pytz: + ts = ts.replace(tzinfo=pytz.UTC) + else: + ts = ts.replace(tzinfo=timezone.utc) # Match within 1 minute if abs((ts - signal_time).total_seconds()) < 60: @@ -1166,9 +1201,13 @@ class RealTrainingAdapter: batch = self._convert_annotation_to_transformer_batch(data) if batch is not None: # Repeat based on repetitions parameter + # IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors repetitions = data.get('repetitions', 1) for _ in range(repetitions): - converted_batches.append(batch) + # Clone all tensors in the batch to ensure independence + cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + converted_batches.append(cloned_batch) else: logger.warning(f" Failed to convert sample {i+1}") diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index c1d3b64..4bdb97d 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -950,8 +950,9 @@ class TradingTransformerTrainer: self.model.train() self.optimizer.zero_grad() - # Move batch to device - batch = {k: v.to(self.device) for k, v in batch.items()} + # Clone and detach batch tensors before moving to device to avoid in-place operation issues + # This ensures each batch is independent and prevents gradient computation errors + batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()} # Forward pass outputs = self.model( @@ -965,24 +966,29 @@ class TradingTransformerTrainer: action_loss = self.action_criterion(outputs['action_logits'], batch['actions']) price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices']) + # Start with base losses - avoid inplace operations on computation graph total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task # Add confidence loss if available if 'confidence' in outputs and 'trade_success' in batch: - # Ensure both tensors have compatible shapes - # confidence: [batch_size, 1] -> squeeze last dim to [batch_size] - # trade_success: [batch_size] - ensure same shape - confidence_pred = outputs['confidence'].squeeze(-1) # Only remove last dimension + # Ensure both tensors have compatible shapes for BCELoss + # BCELoss requires both inputs to have the same shape + confidence_pred = outputs['confidence'] # Keep as [batch_size, 1] trade_target = batch['trade_success'].float() - # Ensure shapes match (handle edge case where batch_size=1) - if confidence_pred.dim() == 0: # scalar case - confidence_pred = confidence_pred.unsqueeze(0) - if trade_target.dim() == 0: # scalar case - trade_target = trade_target.unsqueeze(0) + # Reshape target to match prediction shape [batch_size, 1] + if trade_target.dim() == 1: + trade_target = trade_target.unsqueeze(-1) + + # Ensure both have same shape + if confidence_pred.shape != trade_target.shape: + # If shapes still don't match, squeeze both to 1D + confidence_pred = confidence_pred.view(-1) + trade_target = trade_target.view(-1) confidence_loss = self.confidence_criterion(confidence_pred, trade_target) - total_loss += 0.1 * confidence_loss + # Use addition instead of += to avoid inplace operation + total_loss = total_loss + 0.1 * confidence_loss # Backward pass total_loss.backward()