Files
gogo2/_dev/multi_timeframe_transformer_plan.md
2025-11-06 00:03:19 +02:00

10 KiB

Multi-Timeframe Transformer - Implementation Plan

Current Problem

The transformer currently:

  • Only uses 1m timeframe (ignores 1s, 1h, 1d)
  • Doesn't use BTC reference data
  • Has price_projection (single) but we changed it to price_projections (dict)
  • Predicts next candles for all timeframes but only trains on 1m

Goal

Make the transformer:

  • Accept ALL timeframes (1s, 1m, 1h, 1d) as separate inputs
  • Accept BTC 1m as reference symbol
  • Learn timeframe-specific patterns with separate projections
  • Predict next candle for EACH timeframe
  • Handle missing timeframes gracefully (1s may not be available)
  • Maintain backward compatibility with existing code

Architecture Changes

1. Input Structure (NEW)

forward(
    price_data_1s=None,   # [batch, 600, 5] - 1-second candles (optional)
    price_data_1m=None,   # [batch, 600, 5] - 1-minute candles (optional)
    price_data_1h=None,   # [batch, 600, 5] - 1-hour candles (optional)
    price_data_1d=None,   # [batch, 600, 5] - 1-day candles (optional)
    btc_data_1m=None,     # [batch, 600, 5] - BTC 1-minute (optional)
    cob_data=None,        # [batch, 600, 100] - COB features
    tech_data=None,       # [batch, 40] - Technical indicators
    market_data=None,     # [batch, 30] - Market features
    position_state=None,  # [batch, 5] - Position state
    # Legacy support
    price_data=None       # [batch, 600, 5] - Falls back to 1m if provided
)

2. Input Projections (UPDATED)

# Separate projection for each timeframe
self.price_projections = nn.ModuleDict({
    '1s': nn.Linear(5, 1024),  # 1-second patterns
    '1m': nn.Linear(5, 1024),  # 1-minute patterns
    '1h': nn.Linear(5, 1024),  # 1-hour patterns
    '1d': nn.Linear(5, 1024)   # 1-day patterns
})

# BTC reference projection
self.btc_projection = nn.Linear(5, 1024)

# Learnable timeframe importance weights
self.timeframe_weights = nn.Parameter(torch.ones(5))  # 4 timeframes + BTC

3. Embedding Combination Strategy

Option A: Weighted Average (Current Plan)

# Project each timeframe
emb_1s = self.price_projections['1s'](price_data_1s) * weight_1s
emb_1m = self.price_projections['1m'](price_data_1m) * weight_1m
emb_1h = self.price_projections['1h'](price_data_1h) * weight_1h
emb_1d = self.price_projections['1d'](price_data_1d) * weight_1d
emb_btc = self.btc_projection(btc_data_1m) * weight_btc

# Combine
price_emb = (emb_1s + emb_1m + emb_1h + emb_1d + emb_btc) / num_available

Option B: Concatenation + Projection (Alternative)

# Concatenate all timeframes
combined = torch.cat([emb_1s, emb_1m, emb_1h, emb_1d, emb_btc], dim=-1)
# [batch, seq_len, 5*1024] = [batch, seq_len, 5120]

# Project back to d_model
price_emb = nn.Linear(5120, 1024)(combined)

Decision: Use Option A (Weighted Average)

  • Simpler
  • Handles missing timeframes naturally
  • Fewer parameters
  • Learned weights show timeframe importance

4. Output Structure (ALREADY EXISTS)

outputs = {
    'action_logits': [batch, 3],
    'action_probs': [batch, 3],
    'confidence': [batch, 1],
    'price_prediction': [batch, 1],
    
    # Next candle predictions for EACH timeframe
    'next_candles': {
        '1s': [batch, 5],  # [open, high, low, close, volume]
        '1m': [batch, 5],
        '1h': [batch, 5],
        '1d': [batch, 5]
    },
    
    # BTC next candle
    'btc_next_candle': [batch, 5],
    
    # Pivot predictions
    'next_pivots': {...},
    
    # Trend analysis
    'trend_analysis': {...}
}

Training Data Preparation

Current (WRONG)

# Only uses 1m timeframe
primary_data = timeframes['1m']
price_data = stack_ohlcv(primary_data)  # [batch, 150, 5]

batch = {
    'price_data': price_data,  # Only 1m!
    'cob_data': cob_data,
    'tech_data': tech_data,
    'market_data': market_data
}

New (CORRECT)

# Use ALL available timeframes
batch = {
    'price_data_1s': stack_ohlcv(timeframes.get('1s')),  # [batch, 600, 5] or None
    'price_data_1m': stack_ohlcv(timeframes['1m']),      # [batch, 600, 5]
    'price_data_1h': stack_ohlcv(timeframes.get('1h')),  # [batch, 600, 5] or None
    'price_data_1d': stack_ohlcv(timeframes.get('1d')),  # [batch, 600, 5] or None
    'btc_data_1m': stack_ohlcv(secondary_timeframes.get('BTC/USDT', {}).get('1m')),
    'cob_data': cob_data,
    'tech_data': tech_data,
    'market_data': market_data,
    'position_state': position_state
}

Sequence Length Strategy

Problem

Different timeframes have different natural sequence lengths:

  • 1s: 600 candles = 10 minutes
  • 1m: 600 candles = 10 hours
  • 1h: 600 candles = 25 days
  • 1d: 600 candles = ~2 years

Solution: Fixed Sequence Length with Padding

Use 600 candles for ALL timeframes:

seq_len = 600  # Fixed for all timeframes

# Each timeframe provides 600 candles
# - 1s: Last 600 seconds (10 min)
# - 1m: Last 600 minutes (10 hours)
# - 1h: Last 600 hours (25 days)
# - 1d: Last 600 days (~2 years)

# If less than 600 available, pad with last candle
if len(candles) < 600:
    padding = repeat_last_candle(600 - len(candles))
    candles = concat([candles, padding])

Why 600?

  • Captures sufficient history for all timeframes
  • Not too large (memory efficient)
  • Divisible by many numbers (batch processing)
  • Transformer can handle it (we have 46M params)

Normalization Strategy

Per-Timeframe Normalization

Problem: Different timeframes have different price ranges

  • 1s: Tight range (e.g., $2000-$2001)
  • 1d: Wide range (e.g., $1800-$2200)

Solution: Normalize each timeframe independently

def normalize_timeframe(ohlcv, timeframe):
    """Normalize OHLCV data for a specific timeframe"""
    # Use min/max from this timeframe's data
    price_min = min(ohlcv[:, [0,1,2,3]].min())  # Min of OHLC
    price_max = max(ohlcv[:, [0,1,2,3]].max())  # Max of OHLC
    volume_min = ohlcv[:, 4].min()
    volume_max = ohlcv[:, 4].max()
    
    # Normalize prices to [0, 1]
    ohlcv_norm = ohlcv.copy()
    ohlcv_norm[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
    ohlcv_norm[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
    
    # Store bounds for denormalization
    bounds = {
        'price_min': price_min,
        'price_max': price_max,
        'volume_min': volume_min,
        'volume_max': volume_max
    }
    
    return ohlcv_norm, bounds

Denormalization for Predictions

def denormalize_prediction(pred_norm, bounds):
    """Convert normalized prediction back to actual price"""
    pred = pred_norm * (bounds['price_max'] - bounds['price_min']) + bounds['price_min']
    return pred

Loss Function Updates

Current (Single Timeframe)

loss = action_loss + 0.1 * price_loss

New (Multi-Timeframe)

# Action loss (same)
action_loss = CrossEntropyLoss(action_logits, actions)

# Price loss for each timeframe
price_losses = []
for tf in ['1s', '1m', '1h', '1d']:
    if f'target_{tf}' in batch:
        pred = outputs['next_candles'][tf]
        target = batch[f'target_{tf}']
        loss_tf = MSELoss(pred, target)
        price_losses.append(loss_tf)

# BTC loss
if 'target_btc' in batch:
    btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
    price_losses.append(btc_loss)

# Combined loss
total_price_loss = sum(price_losses) / len(price_losses) if price_losses else 0
total_loss = action_loss + 0.1 * total_price_loss

Backward Compatibility

Legacy Code Support

# Old code (still works)
outputs = model(
    price_data=price_1m,  # Single timeframe
    cob_data=cob,
    tech_data=tech,
    market_data=market
)

# New code (multi-timeframe)
outputs = model(
    price_data_1m=price_1m,
    price_data_1h=price_1h,
    price_data_1d=price_1d,
    btc_data_1m=btc_1m,
    cob_data=cob,
    tech_data=tech,
    market_data=market,
    position_state=position
)

Implementation:

def forward(self, price_data=None, price_data_1m=None, ...):
    # Legacy support
    if price_data is not None and price_data_1m is None:
        price_data_1m = price_data  # Use as 1m data
    
    # Continue with new logic
    ...

Implementation Steps

Step 1: Update Model Architecture

  • Add price_projections ModuleDict
  • Add btc_projection
  • Add timeframe_weights parameter
  • Add btc_next_candle_head

Step 2: Update Forward Method (TODO)

  • Accept multi-timeframe inputs
  • Handle missing timeframes
  • Combine embeddings with learned weights
  • Add BTC embedding
  • Maintain backward compatibility

Step 3: Update Training Adapter (TODO)

  • Extract ALL timeframes from market_state
  • Normalize each timeframe independently
  • Create multi-timeframe batch dictionary
  • Add target next candles for each timeframe

Step 4: Update Loss Calculation (TODO)

  • Add per-timeframe price losses
  • Add BTC prediction loss
  • Weight losses appropriately

Step 5: Testing (TODO)

  • Unit test: multi-timeframe forward pass
  • Unit test: missing timeframes handling
  • Integration test: full training loop
  • Validation: check predictions for all timeframes

Expected Benefits

1. Better Predictions

  • Multi-scale patterns: 1s for micro, 1d for macro trends
  • Cross-timeframe validation: Confirm signals across timeframes
  • Market correlation: BTC reference for market-wide moves

2. Flexible Usage

  • Any timeframe: Choose output timeframe in post-processing
  • Scalping: Use 1s predictions
  • Swing trading: Use 1h/1d predictions
  • Same model: No retraining needed

3. Robustness

  • Missing data: Handles missing 1s gracefully
  • Degraded mode: Works with just 1m if needed
  • Backward compatible: Old code still works

Next Actions

  1. Finish forward() refactor - Make it accept multi-timeframe inputs
  2. Update training adapter - Pass all timeframes to model
  3. Test thoroughly - Ensure no regressions
  4. Document usage - Show how to use multi-timeframe predictions

This is a significant improvement that will make the model much more powerful! 🚀