# Multi-Timeframe Transformer - Implementation Plan ## Current Problem The transformer currently: - ❌ Only uses **1m timeframe** (ignores 1s, 1h, 1d) - ❌ Doesn't use **BTC reference data** - ❌ Has `price_projection` (single) but we changed it to `price_projections` (dict) - ❌ Predicts next candles for all timeframes but only trains on 1m ## Goal Make the transformer: - ✅ Accept **ALL timeframes** (1s, 1m, 1h, 1d) as separate inputs - ✅ Accept **BTC 1m** as reference symbol - ✅ Learn **timeframe-specific patterns** with separate projections - ✅ Predict **next candle for EACH timeframe** - ✅ Handle **missing timeframes** gracefully (1s may not be available) - ✅ Maintain **backward compatibility** with existing code --- ## Architecture Changes ### 1. Input Structure (NEW) ```python forward( price_data_1s=None, # [batch, 600, 5] - 1-second candles (optional) price_data_1m=None, # [batch, 600, 5] - 1-minute candles (optional) price_data_1h=None, # [batch, 600, 5] - 1-hour candles (optional) price_data_1d=None, # [batch, 600, 5] - 1-day candles (optional) btc_data_1m=None, # [batch, 600, 5] - BTC 1-minute (optional) cob_data=None, # [batch, 600, 100] - COB features tech_data=None, # [batch, 40] - Technical indicators market_data=None, # [batch, 30] - Market features position_state=None, # [batch, 5] - Position state # Legacy support price_data=None # [batch, 600, 5] - Falls back to 1m if provided ) ``` ### 2. Input Projections (UPDATED) ```python # Separate projection for each timeframe self.price_projections = nn.ModuleDict({ '1s': nn.Linear(5, 1024), # 1-second patterns '1m': nn.Linear(5, 1024), # 1-minute patterns '1h': nn.Linear(5, 1024), # 1-hour patterns '1d': nn.Linear(5, 1024) # 1-day patterns }) # BTC reference projection self.btc_projection = nn.Linear(5, 1024) # Learnable timeframe importance weights self.timeframe_weights = nn.Parameter(torch.ones(5)) # 4 timeframes + BTC ``` ### 3. Embedding Combination Strategy **Option A: Weighted Average** (Current Plan) ```python # Project each timeframe emb_1s = self.price_projections['1s'](price_data_1s) * weight_1s emb_1m = self.price_projections['1m'](price_data_1m) * weight_1m emb_1h = self.price_projections['1h'](price_data_1h) * weight_1h emb_1d = self.price_projections['1d'](price_data_1d) * weight_1d emb_btc = self.btc_projection(btc_data_1m) * weight_btc # Combine price_emb = (emb_1s + emb_1m + emb_1h + emb_1d + emb_btc) / num_available ``` **Option B: Concatenation + Projection** (Alternative) ```python # Concatenate all timeframes combined = torch.cat([emb_1s, emb_1m, emb_1h, emb_1d, emb_btc], dim=-1) # [batch, seq_len, 5*1024] = [batch, seq_len, 5120] # Project back to d_model price_emb = nn.Linear(5120, 1024)(combined) ``` **Decision**: Use **Option A** (Weighted Average) - ✅ Simpler - ✅ Handles missing timeframes naturally - ✅ Fewer parameters - ✅ Learned weights show timeframe importance ### 4. Output Structure (ALREADY EXISTS) ```python outputs = { 'action_logits': [batch, 3], 'action_probs': [batch, 3], 'confidence': [batch, 1], 'price_prediction': [batch, 1], # Next candle predictions for EACH timeframe 'next_candles': { '1s': [batch, 5], # [open, high, low, close, volume] '1m': [batch, 5], '1h': [batch, 5], '1d': [batch, 5] }, # BTC next candle 'btc_next_candle': [batch, 5], # Pivot predictions 'next_pivots': {...}, # Trend analysis 'trend_analysis': {...} } ``` --- ## Training Data Preparation ### Current (WRONG) ```python # Only uses 1m timeframe primary_data = timeframes['1m'] price_data = stack_ohlcv(primary_data) # [batch, 150, 5] batch = { 'price_data': price_data, # Only 1m! 'cob_data': cob_data, 'tech_data': tech_data, 'market_data': market_data } ``` ### New (CORRECT) ```python # Use ALL available timeframes batch = { 'price_data_1s': stack_ohlcv(timeframes.get('1s')), # [batch, 600, 5] or None 'price_data_1m': stack_ohlcv(timeframes['1m']), # [batch, 600, 5] 'price_data_1h': stack_ohlcv(timeframes.get('1h')), # [batch, 600, 5] or None 'price_data_1d': stack_ohlcv(timeframes.get('1d')), # [batch, 600, 5] or None 'btc_data_1m': stack_ohlcv(secondary_timeframes.get('BTC/USDT', {}).get('1m')), 'cob_data': cob_data, 'tech_data': tech_data, 'market_data': market_data, 'position_state': position_state } ``` --- ## Sequence Length Strategy ### Problem Different timeframes have different natural sequence lengths: - 1s: 600 candles = 10 minutes - 1m: 600 candles = 10 hours - 1h: 600 candles = 25 days - 1d: 600 candles = ~2 years ### Solution: Fixed Sequence Length with Padding **Use 600 candles for ALL timeframes**: ```python seq_len = 600 # Fixed for all timeframes # Each timeframe provides 600 candles # - 1s: Last 600 seconds (10 min) # - 1m: Last 600 minutes (10 hours) # - 1h: Last 600 hours (25 days) # - 1d: Last 600 days (~2 years) # If less than 600 available, pad with last candle if len(candles) < 600: padding = repeat_last_candle(600 - len(candles)) candles = concat([candles, padding]) ``` **Why 600?** - ✅ Captures sufficient history for all timeframes - ✅ Not too large (memory efficient) - ✅ Divisible by many numbers (batch processing) - ✅ Transformer can handle it (we have 46M params) --- ## Normalization Strategy ### Per-Timeframe Normalization **Problem**: Different timeframes have different price ranges - 1s: Tight range (e.g., $2000-$2001) - 1d: Wide range (e.g., $1800-$2200) **Solution**: Normalize each timeframe independently ```python def normalize_timeframe(ohlcv, timeframe): """Normalize OHLCV data for a specific timeframe""" # Use min/max from this timeframe's data price_min = min(ohlcv[:, [0,1,2,3]].min()) # Min of OHLC price_max = max(ohlcv[:, [0,1,2,3]].max()) # Max of OHLC volume_min = ohlcv[:, 4].min() volume_max = ohlcv[:, 4].max() # Normalize prices to [0, 1] ohlcv_norm = ohlcv.copy() ohlcv_norm[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min) ohlcv_norm[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min) # Store bounds for denormalization bounds = { 'price_min': price_min, 'price_max': price_max, 'volume_min': volume_min, 'volume_max': volume_max } return ohlcv_norm, bounds ``` ### Denormalization for Predictions ```python def denormalize_prediction(pred_norm, bounds): """Convert normalized prediction back to actual price""" pred = pred_norm * (bounds['price_max'] - bounds['price_min']) + bounds['price_min'] return pred ``` --- ## Loss Function Updates ### Current (Single Timeframe) ```python loss = action_loss + 0.1 * price_loss ``` ### New (Multi-Timeframe) ```python # Action loss (same) action_loss = CrossEntropyLoss(action_logits, actions) # Price loss for each timeframe price_losses = [] for tf in ['1s', '1m', '1h', '1d']: if f'target_{tf}' in batch: pred = outputs['next_candles'][tf] target = batch[f'target_{tf}'] loss_tf = MSELoss(pred, target) price_losses.append(loss_tf) # BTC loss if 'target_btc' in batch: btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc']) price_losses.append(btc_loss) # Combined loss total_price_loss = sum(price_losses) / len(price_losses) if price_losses else 0 total_loss = action_loss + 0.1 * total_price_loss ``` --- ## Backward Compatibility ### Legacy Code Support ```python # Old code (still works) outputs = model( price_data=price_1m, # Single timeframe cob_data=cob, tech_data=tech, market_data=market ) # New code (multi-timeframe) outputs = model( price_data_1m=price_1m, price_data_1h=price_1h, price_data_1d=price_1d, btc_data_1m=btc_1m, cob_data=cob, tech_data=tech, market_data=market, position_state=position ) ``` **Implementation**: ```python def forward(self, price_data=None, price_data_1m=None, ...): # Legacy support if price_data is not None and price_data_1m is None: price_data_1m = price_data # Use as 1m data # Continue with new logic ... ``` --- ## Implementation Steps ### Step 1: Update Model Architecture ✅ - [x] Add `price_projections` ModuleDict - [x] Add `btc_projection` - [x] Add `timeframe_weights` parameter - [x] Add `btc_next_candle_head` ### Step 2: Update Forward Method (TODO) - [ ] Accept multi-timeframe inputs - [ ] Handle missing timeframes - [ ] Combine embeddings with learned weights - [ ] Add BTC embedding - [ ] Maintain backward compatibility ### Step 3: Update Training Adapter (TODO) - [ ] Extract ALL timeframes from market_state - [ ] Normalize each timeframe independently - [ ] Create multi-timeframe batch dictionary - [ ] Add target next candles for each timeframe ### Step 4: Update Loss Calculation (TODO) - [ ] Add per-timeframe price losses - [ ] Add BTC prediction loss - [ ] Weight losses appropriately ### Step 5: Testing (TODO) - [ ] Unit test: multi-timeframe forward pass - [ ] Unit test: missing timeframes handling - [ ] Integration test: full training loop - [ ] Validation: check predictions for all timeframes --- ## Expected Benefits ### 1. Better Predictions - **Multi-scale patterns**: 1s for micro, 1d for macro trends - **Cross-timeframe validation**: Confirm signals across timeframes - **Market correlation**: BTC reference for market-wide moves ### 2. Flexible Usage - **Any timeframe**: Choose output timeframe in post-processing - **Scalping**: Use 1s predictions - **Swing trading**: Use 1h/1d predictions - **Same model**: No retraining needed ### 3. Robustness - **Missing data**: Handles missing 1s gracefully - **Degraded mode**: Works with just 1m if needed - **Backward compatible**: Old code still works --- ## Next Actions 1. **Finish forward() refactor** - Make it accept multi-timeframe inputs 2. **Update training adapter** - Pass all timeframes to model 3. **Test thoroughly** - Ensure no regressions 4. **Document usage** - Show how to use multi-timeframe predictions This is a significant improvement that will make the model much more powerful! 🚀