T model was only using 1t data so far. fixing WIP

This commit is contained in:
Dobromir Popov
2025-11-06 00:03:19 +02:00
parent 907a7d6224
commit 07d97100c0
2 changed files with 400 additions and 3 deletions

View File

@@ -349,8 +349,17 @@ class AdvancedTradingTransformer(nn.Module):
super().__init__()
self.config = config
# Input projections
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
# Multi-timeframe input projections
# Each timeframe gets its own projection to learn timeframe-specific patterns
self.timeframes = ['1s', '1m', '1h', '1d']
self.price_projections = nn.ModuleDict({
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
})
# Reference symbol projection (BTC 1m)
self.btc_projection = nn.Linear(5, config.d_model)
# Other input projections
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
self.market_projection = nn.Linear(config.market_features, config.d_model)
@@ -367,6 +376,9 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
)
# Timeframe importance weights (learnable)
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
# Positional encoding
if config.use_relative_position:
self.pos_encoding = RelativePositionalEncoding(config.d_model)
@@ -435,7 +447,7 @@ class AdvancedTradingTransformer(nn.Module):
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
self.timeframes = ['1s', '1m', '1h', '1d']
# Note: self.timeframes already defined above in input projections
self.next_candle_heads = nn.ModuleDict({
tf: nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
@@ -448,6 +460,17 @@ class AdvancedTradingTransformer(nn.Module):
) for tf in self.timeframes
})
# BTC next candle prediction head
self.btc_next_candle_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
)
# NEW: Next pivot point prediction heads for L1-L5 levels
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity