training fixes
This commit is contained in:
@@ -160,6 +160,12 @@ class RealTrainingAdapter:
|
|||||||
self.data_provider = data_provider
|
self.data_provider = data_provider
|
||||||
self.training_sessions: Dict[str, TrainingSession] = {}
|
self.training_sessions: Dict[str, TrainingSession] = {}
|
||||||
|
|
||||||
|
# CRITICAL: Training lock to prevent concurrent model access
|
||||||
|
# Multiple threads (batch training + per-candle training) can corrupt
|
||||||
|
# the computation graph if they access the model simultaneously
|
||||||
|
import threading
|
||||||
|
self._training_lock = threading.Lock()
|
||||||
|
|
||||||
# Real-time training tracking
|
# Real-time training tracking
|
||||||
self.realtime_training_metrics = {
|
self.realtime_training_metrics = {
|
||||||
'total_steps': 0,
|
'total_steps': 0,
|
||||||
@@ -2614,9 +2620,12 @@ class RealTrainingAdapter:
|
|||||||
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
|
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
|
||||||
self._store_training_prediction(batch, trainer, symbol)
|
self._store_training_prediction(batch, trainer, symbol)
|
||||||
|
|
||||||
# Call the trainer's train_step method with mini-batch
|
# CRITICAL: Acquire training lock to prevent concurrent model access
|
||||||
# Batch is already on GPU and contains multiple samples
|
# This prevents "inplace operation" errors when per-candle training runs simultaneously
|
||||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
with self._training_lock:
|
||||||
|
# Call the trainer's train_step method with mini-batch
|
||||||
|
# Batch is already on GPU and contains multiple samples
|
||||||
|
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
# MEMORY FIX: Detach all tensor values to break computation graph
|
# MEMORY FIX: Detach all tensor values to break computation graph
|
||||||
@@ -3574,11 +3583,13 @@ class RealTrainingAdapter:
|
|||||||
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
|
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Train on this batch
|
# CRITICAL: Acquire training lock to prevent concurrent model access
|
||||||
|
# This prevents "inplace operation" errors when batch training runs simultaneously
|
||||||
import torch
|
import torch
|
||||||
with torch.enable_grad():
|
with self._training_lock:
|
||||||
trainer.model.train()
|
with torch.enable_grad():
|
||||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
trainer.model.train()
|
||||||
|
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
loss = result.get('total_loss', 0)
|
loss = result.get('total_loss', 0)
|
||||||
|
|||||||
@@ -148,8 +148,11 @@ class DeepMultiScaleAttention(nn.Module):
|
|||||||
x_input = x.clone()
|
x_input = x.clone()
|
||||||
|
|
||||||
for scale_proj in self.scale_projections:
|
for scale_proj in self.scale_projections:
|
||||||
|
# CRITICAL: Clone input for each scale to ensure complete isolation
|
||||||
|
# This prevents any potential view issues when reusing x_input across scales
|
||||||
|
x_scale_input = x_input.clone()
|
||||||
# Apply enhanced temporal convolution for this scale
|
# Apply enhanced temporal convolution for this scale
|
||||||
x_conv = scale_proj['conv'](x_input.transpose(1, 2)).transpose(1, 2)
|
x_conv = scale_proj['conv'](x_scale_input.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
# Enhanced attention computation with deeper projections
|
# Enhanced attention computation with deeper projections
|
||||||
# Use contiguous() before view() to ensure memory layout is correct
|
# Use contiguous() before view() to ensure memory layout is correct
|
||||||
@@ -315,6 +318,10 @@ class TradingTransformerLayer(nn.Module):
|
|||||||
regime_probs = None
|
regime_probs = None
|
||||||
if hasattr(self, 'regime_detector'):
|
if hasattr(self, 'regime_detector'):
|
||||||
x_new, regime_probs = self.regime_detector(x_new)
|
x_new, regime_probs = self.regime_detector(x_new)
|
||||||
|
# CRITICAL: Clone after regime detector to ensure fresh tensor
|
||||||
|
# The regime detector may return a tensor that's a view or has been through operations
|
||||||
|
# that create views, which can cause inplace operation errors
|
||||||
|
x_new = x_new.clone()
|
||||||
|
|
||||||
# Feed-forward with residual connection - clone to avoid version conflicts
|
# Feed-forward with residual connection - clone to avoid version conflicts
|
||||||
x_ff_residual = x_new.clone()
|
x_ff_residual = x_new.clone()
|
||||||
@@ -609,14 +616,18 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
# MEMORY EFFICIENT: Process timeframes with shared weights
|
# MEMORY EFFICIENT: Process timeframes with shared weights
|
||||||
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
|
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
|
||||||
# This avoids creating huge concatenated sequences while still processing efficiently
|
# This avoids creating huge concatenated sequences while still processing efficiently
|
||||||
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
# CRITICAL: Clone and make contiguous to ensure proper memory layout and avoid inplace operation errors
|
||||||
|
# The reshape creates a view, so we must clone to get a fresh tensor for TransformerEncoderLayer
|
||||||
|
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model).contiguous().clone()
|
||||||
|
|
||||||
# Apply single cross-timeframe attention layer
|
# Apply single cross-timeframe attention layer
|
||||||
# Use new variable to avoid inplace modification issues
|
# TransformerEncoderLayer may use inplace operations internally, so we need a fresh tensor
|
||||||
cross_tf_encoded = self.cross_timeframe_layer(batched_tfs)
|
cross_tf_encoded = self.cross_timeframe_layer(batched_tfs)
|
||||||
|
|
||||||
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||||
cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
# CRITICAL: Clone after reshape to ensure fresh tensor (reshape creates a view)
|
||||||
|
# This prevents inplace operation errors when the tensor is used in subsequent operations
|
||||||
|
cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model).contiguous().clone()
|
||||||
|
|
||||||
# Average across timeframes to get unified representation
|
# Average across timeframes to get unified representation
|
||||||
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
||||||
@@ -826,29 +837,43 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
angle_threshold = math.pi / 4 # 45 degrees
|
angle_threshold = math.pi / 4 # 45 degrees
|
||||||
|
|
||||||
# Determine action from trend angle
|
# Determine action from trend angle
|
||||||
trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD]
|
# CRITICAL FIX: Build tensor using non-inplace operations to avoid gradient computation errors
|
||||||
|
# Instead of creating zeros and modifying inplace, build the tensor element by element
|
||||||
|
|
||||||
# Calculate action probabilities based on trend
|
# Handle both 0-dim and 1-dim tensors for trend_angle and trend_steepness_val
|
||||||
for i in range(batch_size):
|
if trend_angle.dim() == 0:
|
||||||
# Handle both 0-dim and 1-dim tensors
|
angles = trend_angle.unsqueeze(0).expand(batch_size)
|
||||||
if trend_angle.dim() == 0:
|
steeps = trend_steepness_val.unsqueeze(0).expand(batch_size)
|
||||||
angle = trend_angle.item()
|
else:
|
||||||
steep = trend_steepness_val.item()
|
angles = trend_angle if trend_angle.dim() > 0 else trend_angle.unsqueeze(0)
|
||||||
else:
|
steeps = trend_steepness_val if trend_steepness_val.dim() > 0 else trend_steepness_val.unsqueeze(0)
|
||||||
angle = trend_angle[i].item()
|
# Ensure correct shape
|
||||||
steep = trend_steepness_val[i].item()
|
if angles.shape[0] != batch_size:
|
||||||
|
angles = angles.expand(batch_size) if angles.numel() == 1 else angles[:batch_size]
|
||||||
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
|
if steeps.shape[0] != batch_size:
|
||||||
normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0
|
steeps = steeps.expand(batch_size) if steeps.numel() == 1 else steeps[:batch_size]
|
||||||
|
|
||||||
if angle > angle_threshold: # Steep upward trend
|
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
|
||||||
trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY
|
normalized_steepness = torch.clamp(steeps / 10.0, 0.0, 1.0)
|
||||||
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
|
||||||
elif angle < -angle_threshold: # Steep downward trend
|
# Build logits using vectorized operations (no inplace modifications)
|
||||||
trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL
|
# Initialize with HOLD (index 2)
|
||||||
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
buy_logits = torch.zeros(batch_size, device=pooled.device)
|
||||||
else: # Shallow trend
|
sell_logits = torch.zeros(batch_size, device=pooled.device)
|
||||||
trend_action_logits[i, 2] = 1.0 # HOLD
|
hold_logits = torch.ones(batch_size, device=pooled.device)
|
||||||
|
|
||||||
|
# Steep upward trend -> BUY
|
||||||
|
upward_mask = angles > angle_threshold
|
||||||
|
buy_logits = torch.where(upward_mask, normalized_steepness * 2.0, buy_logits)
|
||||||
|
hold_logits = torch.where(upward_mask, (1.0 - normalized_steepness) * 0.5, hold_logits)
|
||||||
|
|
||||||
|
# Steep downward trend -> SELL
|
||||||
|
downward_mask = angles < -angle_threshold
|
||||||
|
sell_logits = torch.where(downward_mask, normalized_steepness * 2.0, sell_logits)
|
||||||
|
hold_logits = torch.where(downward_mask, (1.0 - normalized_steepness) * 0.5, hold_logits)
|
||||||
|
|
||||||
|
# Stack into final logits tensor [batch_size, 3]
|
||||||
|
trend_action_logits = torch.stack([buy_logits, sell_logits, hold_logits], dim=1)
|
||||||
|
|
||||||
# Combine trend-based action with main action prediction
|
# Combine trend-based action with main action prediction
|
||||||
trend_action_probs = F.softmax(trend_action_logits, dim=-1)
|
trend_action_probs = F.softmax(trend_action_logits, dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user