fix realtime training
This commit is contained in:
@@ -219,8 +219,8 @@ class MarketRegimeDetector(nn.Module):
|
||||
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
|
||||
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
|
||||
|
||||
# Weighted sum across regimes
|
||||
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
|
||||
# Weighted sum across regimes - clone to avoid inplace errors
|
||||
adapted_output = torch.sum(regime_stack * regime_weights, dim=0).clone()
|
||||
|
||||
return adapted_output, regime_probs
|
||||
|
||||
@@ -634,8 +634,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
else:
|
||||
market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
||||
|
||||
# Combine all embeddings
|
||||
x = price_emb + cob_emb + tech_emb + market_emb
|
||||
# Combine all embeddings - use clone() to avoid inplace operation errors
|
||||
x = price_emb.clone() + cob_emb + tech_emb + market_emb
|
||||
|
||||
# Add position state if provided - critical for loss minimization and profit taking
|
||||
if position_state is not None:
|
||||
@@ -647,8 +647,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
# This conditions the entire sequence on current position state
|
||||
position_emb = position_emb.unsqueeze(1).expand(batch_size, seq_len, -1) # [batch, seq_len, d_model]
|
||||
|
||||
# Add position embedding to the combined embeddings
|
||||
# This allows the model to modulate its predictions based on position state
|
||||
# Add position embedding to the combined embeddings - create new tensor to avoid inplace
|
||||
x = x + position_emb
|
||||
|
||||
# Add positional encoding
|
||||
@@ -670,7 +669,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
else:
|
||||
layer_output = layer(x, mask)
|
||||
|
||||
x = layer_output['output']
|
||||
# Clone to avoid inplace operation errors during backward pass
|
||||
x = layer_output['output'].clone()
|
||||
if layer_output['regime_probs'] is not None:
|
||||
regime_probs_history.append(layer_output['regime_probs'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user