T model was only using 1t data so far. fixing WIP
This commit is contained in:
@@ -349,8 +349,17 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Input projections
|
||||
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
|
||||
# Multi-timeframe input projections
|
||||
# Each timeframe gets its own projection to learn timeframe-specific patterns
|
||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||
self.price_projections = nn.ModuleDict({
|
||||
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
|
||||
})
|
||||
|
||||
# Reference symbol projection (BTC 1m)
|
||||
self.btc_projection = nn.Linear(5, config.d_model)
|
||||
|
||||
# Other input projections
|
||||
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
||||
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
||||
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
||||
@@ -367,6 +376,9 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
||||
)
|
||||
|
||||
# Timeframe importance weights (learnable)
|
||||
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
|
||||
|
||||
# Positional encoding
|
||||
if config.use_relative_position:
|
||||
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
||||
@@ -435,7 +447,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
|
||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||
# Note: self.timeframes already defined above in input projections
|
||||
self.next_candle_heads = nn.ModuleDict({
|
||||
tf: nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
@@ -448,6 +460,17 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
) for tf in self.timeframes
|
||||
})
|
||||
|
||||
# BTC next candle prediction head
|
||||
self.btc_next_candle_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
|
||||
)
|
||||
|
||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
|
||||
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity
|
||||
|
||||
Reference in New Issue
Block a user