From 600bee98f3d91e00c4019c8fe79b3c58f00264e0 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 8 Dec 2025 22:09:43 +0200 Subject: [PATCH] training fixes --- ANNOTATE/core/real_training_adapter.py | 25 +++++--- NN/models/advanced_transformer_trading.py | 77 +++++++++++++++-------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 783b119..eab4cfa 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -160,6 +160,12 @@ class RealTrainingAdapter: self.data_provider = data_provider self.training_sessions: Dict[str, TrainingSession] = {} + # CRITICAL: Training lock to prevent concurrent model access + # Multiple threads (batch training + per-candle training) can corrupt + # the computation graph if they access the model simultaneously + import threading + self._training_lock = threading.Lock() + # Real-time training tracking self.realtime_training_metrics = { 'total_steps': 0, @@ -2614,9 +2620,12 @@ class RealTrainingAdapter: symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT') self._store_training_prediction(batch, trainer, symbol) - # Call the trainer's train_step method with mini-batch - # Batch is already on GPU and contains multiple samples - result = trainer.train_step(batch, accumulate_gradients=False) + # CRITICAL: Acquire training lock to prevent concurrent model access + # This prevents "inplace operation" errors when per-candle training runs simultaneously + with self._training_lock: + # Call the trainer's train_step method with mini-batch + # Batch is already on GPU and contains multiple samples + result = trainer.train_step(batch, accumulate_gradients=False) if result is not None: # MEMORY FIX: Detach all tensor values to break computation graph @@ -3574,11 +3583,13 @@ class RealTrainingAdapter: logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}") return - # Train on this batch + # CRITICAL: Acquire training lock to prevent concurrent model access + # This prevents "inplace operation" errors when batch training runs simultaneously import torch - with torch.enable_grad(): - trainer.model.train() - result = trainer.train_step(batch, accumulate_gradients=False) + with self._training_lock: + with torch.enable_grad(): + trainer.model.train() + result = trainer.train_step(batch, accumulate_gradients=False) if result: loss = result.get('total_loss', 0) diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index d09f8c5..2df3182 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -148,8 +148,11 @@ class DeepMultiScaleAttention(nn.Module): x_input = x.clone() for scale_proj in self.scale_projections: + # CRITICAL: Clone input for each scale to ensure complete isolation + # This prevents any potential view issues when reusing x_input across scales + x_scale_input = x_input.clone() # Apply enhanced temporal convolution for this scale - x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2) + x_conv = scale_proj['conv'](x_scale_input.transpose(1, 2)).transpose(1, 2) # Enhanced attention computation with deeper projections # Use contiguous() before view() to ensure memory layout is correct @@ -315,6 +318,10 @@ class TradingTransformerLayer(nn.Module): regime_probs = None if hasattr(self, 'regime_detector'): x_new, regime_probs = self.regime_detector(x_new) + # CRITICAL: Clone after regime detector to ensure fresh tensor + # The regime detector may return a tensor that's a view or has been through operations + # that create views, which can cause inplace operation errors + x_new = x_new.clone() # Feed-forward with residual connection - clone to avoid version conflicts x_ff_residual = x_new.clone() @@ -609,14 +616,18 @@ class AdvancedTradingTransformer(nn.Module): # MEMORY EFFICIENT: Process timeframes with shared weights # Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model] # This avoids creating huge concatenated sequences while still processing efficiently - batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model) + # CRITICAL: Clone and make contiguous to ensure proper memory layout and avoid inplace operation errors + # The reshape creates a view, so we must clone to get a fresh tensor for TransformerEncoderLayer + batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model).contiguous().clone() # Apply single cross-timeframe attention layer - # Use new variable to avoid inplace modification issues + # TransformerEncoderLayer may use inplace operations internally, so we need a fresh tensor cross_tf_encoded = self.cross_timeframe_layer(batched_tfs) # Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model] - cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model) + # CRITICAL: Clone after reshape to ensure fresh tensor (reshape creates a view) + # This prevents inplace operation errors when the tensor is used in subsequent operations + cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model).contiguous().clone() # Average across timeframes to get unified representation # [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model] @@ -826,29 +837,43 @@ class AdvancedTradingTransformer(nn.Module): angle_threshold = math.pi / 4 # 45 degrees # Determine action from trend angle - trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD] + # CRITICAL FIX: Build tensor using non-inplace operations to avoid gradient computation errors + # Instead of creating zeros and modifying inplace, build the tensor element by element - # Calculate action probabilities based on trend - for i in range(batch_size): - # Handle both 0-dim and 1-dim tensors - if trend_angle.dim() == 0: - angle = trend_angle.item() - steep = trend_steepness_val.item() - else: - angle = trend_angle[i].item() - steep = trend_steepness_val[i].item() - - # Normalize steepness to [0, 1] range (assuming max steepness of 10 units) - normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0 - - if angle > angle_threshold: # Steep upward trend - trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY - trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD - elif angle < -angle_threshold: # Steep downward trend - trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL - trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD - else: # Shallow trend - trend_action_logits[i, 2] = 1.0 # HOLD + # Handle both 0-dim and 1-dim tensors for trend_angle and trend_steepness_val + if trend_angle.dim() == 0: + angles = trend_angle.unsqueeze(0).expand(batch_size) + steeps = trend_steepness_val.unsqueeze(0).expand(batch_size) + else: + angles = trend_angle if trend_angle.dim() > 0 else trend_angle.unsqueeze(0) + steeps = trend_steepness_val if trend_steepness_val.dim() > 0 else trend_steepness_val.unsqueeze(0) + # Ensure correct shape + if angles.shape[0] != batch_size: + angles = angles.expand(batch_size) if angles.numel() == 1 else angles[:batch_size] + if steeps.shape[0] != batch_size: + steeps = steeps.expand(batch_size) if steeps.numel() == 1 else steeps[:batch_size] + + # Normalize steepness to [0, 1] range (assuming max steepness of 10 units) + normalized_steepness = torch.clamp(steeps / 10.0, 0.0, 1.0) + + # Build logits using vectorized operations (no inplace modifications) + # Initialize with HOLD (index 2) + buy_logits = torch.zeros(batch_size, device=pooled.device) + sell_logits = torch.zeros(batch_size, device=pooled.device) + hold_logits = torch.ones(batch_size, device=pooled.device) + + # Steep upward trend -> BUY + upward_mask = angles > angle_threshold + buy_logits = torch.where(upward_mask, normalized_steepness * 2.0, buy_logits) + hold_logits = torch.where(upward_mask, (1.0 - normalized_steepness) * 0.5, hold_logits) + + # Steep downward trend -> SELL + downward_mask = angles < -angle_threshold + sell_logits = torch.where(downward_mask, normalized_steepness * 2.0, sell_logits) + hold_logits = torch.where(downward_mask, (1.0 - normalized_steepness) * 0.5, hold_logits) + + # Stack into final logits tensor [batch_size, 3] + trend_action_logits = torch.stack([buy_logits, sell_logits, hold_logits], dim=1) # Combine trend-based action with main action prediction trend_action_probs = F.softmax(trend_action_logits, dim=-1)