training, local log

This commit is contained in:
Dobromir Popov
2025-10-25 16:21:22 +03:00
parent bd213c44e0
commit 5aa4925cff
4 changed files with 128 additions and 36 deletions

View File

@@ -243,10 +243,11 @@ class MarketRegimeDetector(nn.Module):
# Weighted combination based on regime probabilities
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
regime_weights = regime_probs.unsqueeze(1).unsqueeze(3) # (batch, 1, 1, n_regimes)
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
# Weighted sum across regimes
adapted_output = torch.sum(regime_stack * regime_weights.transpose(0, 3), dim=0)
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
return adapted_output, regime_probs