training, local log
This commit is contained in:
@@ -243,10 +243,11 @@ class MarketRegimeDetector(nn.Module):
|
||||
|
||||
# Weighted combination based on regime probabilities
|
||||
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
|
||||
regime_weights = regime_probs.unsqueeze(1).unsqueeze(3) # (batch, 1, 1, n_regimes)
|
||||
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
|
||||
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
|
||||
|
||||
# Weighted sum across regimes
|
||||
adapted_output = torch.sum(regime_stack * regime_weights.transpose(0, 3), dim=0)
|
||||
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
|
||||
|
||||
return adapted_output, regime_probs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user