scale up transformer

This commit is contained in:
Dobromir Popov
2025-07-02 01:41:20 +03:00
parent 8645f6e8dd
commit 5eda20acc8
4 changed files with 96 additions and 122 deletions

View File

@ -452,6 +452,14 @@ class AdvancedTradingTransformer(nn.Module):
"""
batch_size, seq_len = price_data.shape[:2]
# Handle different input dimensions - expand to sequence if needed
if cob_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
if tech_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
if market_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
# Project inputs to model dimension
price_emb = self.price_projection(price_data)
cob_emb = self.cob_projection(cob_data)