Files
gogo2/docs/HYBRID_MULTI_TIMEFRAME_ARCHITECTURE.md
Dobromir Popov 738c7cb854 Shared Pattern Encoder
fix T training
2025-11-06 14:27:52 +02:00

13 KiB
Raw Blame History

Hybrid Multi-Timeframe Transformer Architecture

Overview

The transformer uses a hybrid serial-parallel architecture that:

  1. SERIAL: Learns candle patterns ONCE (shared weights across all timeframes)
  2. PARALLEL: Captures cross-timeframe dependencies simultaneously

This design ensures the model learns common patterns efficiently while understanding relationships between timeframes.


Architecture Flow

Input: Multiple Timeframes
    ↓
┌─────────────────────────────────────────┐
│ STEP 1: SERIAL PROCESSING               │
│ (Shared Pattern Encoder)                │
│                                          │
│ 1s data → Shared Encoder → Encoding_1s  │
│ 1m data → Shared Encoder → Encoding_1m  │
│ 1h data → Shared Encoder → Encoding_1h  │
│ 1d data → Shared Encoder → Encoding_1d  │
│ BTC data → Shared Encoder → Encoding_BTC│
│                                          │
│ Same weights learn patterns once!       │
└─────────────────────────────────────────┘
    ↓
┌─────────────────────────────────────────┐
│ STEP 2: PARALLEL PROCESSING             │
│ (Cross-Timeframe Attention)             │
│                                          │
│ Stack all encodings:                    │
│ [Enc_1s, Enc_1m, Enc_1h, Enc_1d, Enc_BTC]│
│         ↓                                │
│ Cross-Timeframe Transformer Layers      │
│ (Captures dependencies between TFs)     │
│         ↓                                │
│ Unified representation                  │
└─────────────────────────────────────────┘
    ↓
┌─────────────────────────────────────────┐
│ STEP 3: PREDICTION                      │
│                                          │
│ → Action (BUY/SELL/HOLD)                │
│ → Next candle for EACH timeframe        │
│ → BTC next candle                       │
│ → Pivot points                          │
│ → Trend analysis                        │
└─────────────────────────────────────────┘

Key Components

1. Shared Pattern Encoder (SERIAL)

Purpose: Learn candle patterns ONCE for all timeframes

self.shared_pattern_encoder = nn.Sequential(
    nn.Linear(5, 256),      # OHLCV → 256
    nn.LayerNorm(256),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(256, 512),    # 256 → 512
    nn.LayerNorm(512),
    nn.GELU(),
    nn.Dropout(0.1),
    nn.Linear(512, 1024)    # 512 → 1024 (d_model)
)

How it works:

  • Same network processes ALL timeframes
  • Learns universal candle patterns:
    • Doji, hammer, engulfing, etc.
    • Support/resistance bounces
    • Breakout patterns
    • Volume spikes
  • Efficient: Patterns learned once, not 5 times

Example:

# All timeframes use SAME encoder
encoding_1s = shared_encoder(price_data_1s)   # [batch, 600, 1024]
encoding_1m = shared_encoder(price_data_1m)   # [batch, 600, 1024]
encoding_1h = shared_encoder(price_data_1h)   # [batch, 600, 1024]
encoding_1d = shared_encoder(price_data_1d)   # [batch, 600, 1024]
encoding_btc = shared_encoder(btc_data_1m)    # [batch, 600, 1024]

# Same weights → learns patterns once!

2. Timeframe Embeddings

Purpose: Help model distinguish which timeframe it's looking at

self.timeframe_embeddings = nn.Embedding(5, 1024)
# 5 timeframes: 1s, 1m, 1h, 1d, BTC

How it works:

# Add timeframe identity to shared encoding
tf_embedding = timeframe_embeddings[tf_index]  # [1024]
encoding = shared_encoding + tf_embedding

# Now model knows: "This is a 1h candle pattern"

3. Cross-Timeframe Attention (PARALLEL)

Purpose: Capture dependencies BETWEEN timeframes

self.cross_timeframe_layers = nn.ModuleList([
    nn.TransformerEncoderLayer(
        d_model=1024,
        nhead=16,
        dim_feedforward=4096,
        dropout=0.1,
        batch_first=True
    ) for _ in range(2)  # 2 layers
])

How it works:

# Stack all timeframes
stacked = torch.stack([enc_1s, enc_1m, enc_1h, enc_1d, enc_btc], dim=1)
# Shape: [batch, 5 timeframes, 600 seq_len, 1024 d_model]

# Reshape for attention
# [batch, 5, 600, 1024] → [batch, 3000, 1024]
cross_input = stacked.reshape(batch, 5*600, 1024)

# Apply cross-timeframe attention
# Each position can attend to ALL timeframes simultaneously
for layer in cross_timeframe_layers:
    cross_input = layer(cross_input)

# Model learns:
# - "1s shows breakout, 1h confirms trend"
# - "1d resistance, but 1m shows accumulation"
# - "BTC dumping, ETH following"

What it captures:

  • Trend confirmation: Signal on 1m confirmed by 1h
  • Divergences: 1s bullish but 1d bearish
  • Correlation: BTC moves predict ETH moves
  • Multi-scale patterns: Fractal patterns across timeframes

Benefits of Hybrid Architecture

1. Knowledge Sharing (SERIAL)

Efficient Learning

Traditional: 5 separate encoders × 656K params = 3.28M params
Hybrid: 1 shared encoder × 656K params = 656K params
Savings: 80% fewer parameters!

Better Generalization

  • Patterns learned from ALL timeframes
  • More training data per pattern
  • Stronger pattern recognition

Transfer Learning

  • Pattern learned on 1m helps 1h
  • Pattern learned on 1d helps 1s
  • Cross-timeframe knowledge transfer

2. Dependency Capture (PARALLEL)

Cross-Timeframe Validation

# Example: Entry signal validation
1s: Bullish breakout (local signal)
1m: Uptrend confirmed (short-term)
1h: Above support (medium-term)
1d: Bullish trend (long-term)
BTC: Also bullish (market-wide)

 High confidence entry!

Divergence Detection

# Example: Warning signal
1s: Bullish (noise)
1m: Bullish (short-term)
1h: Bearish divergence (warning!)
1d: Downtrend (macro)
BTC: Dumping (market-wide)

 Don't enter, wait for confirmation

Market Correlation

# Example: BTC influence
BTC: Sharp drop detected
ETH 1s: Following BTC
ETH 1m: Correlation confirmed
ETH 1h: Likely to follow
ETH 1d: Macro trend affected

 Exit positions, BTC leading

Input/Output Specification

Input Format

model(
    # Primary symbol (ETH/USDT) - all timeframes
    price_data_1s=[batch, 600, 5],   # 600 × 1s candles (10 min)
    price_data_1m=[batch, 600, 5],   # 600 × 1m candles (10 hours)
    price_data_1h=[batch, 600, 5],   # 600 × 1h candles (25 days)
    price_data_1d=[batch, 600, 5],   # 600 × 1d candles (~2 years)
    
    # Reference symbol (BTC/USDT)
    btc_data_1m=[batch, 600, 5],     # 600 × 1m BTC candles
    
    # Other features
    cob_data=[batch, 600, 100],      # Order book
    tech_data=[batch, 40],           # Technical indicators
    market_data=[batch, 30],         # Market features
    position_state=[batch, 5]        # Position state
)

Notes:

  • All timeframes optional (handles missing data)
  • Fixed sequence length: 600 candles
  • OHLCV format: [open, high, low, close, volume]

Output Format

outputs = {
    # Trading decision
    'action_logits': [batch, 3],        # BUY/SELL/HOLD logits
    'action_probs': [batch, 3],         # Softmax probabilities
    'confidence': [batch, 1],           # Prediction confidence
    
    # Next candle predictions (ALL timeframes)
    'next_candles': {
        '1s': [batch, 5],  # Next 1s candle OHLCV
        '1m': [batch, 5],  # Next 1m candle OHLCV
        '1h': [batch, 5],  # Next 1h candle OHLCV
        '1d': [batch, 5]   # Next 1d candle OHLCV
    },
    
    # BTC prediction
    'btc_next_candle': [batch, 5],      # Next BTC 1m candle
    
    # Auxiliary predictions
    'price_prediction': [batch, 1],     # Price target
    'volatility_prediction': [batch, 1], # Expected volatility
    'trend_strength_prediction': [batch, 1], # Trend strength
    
    # Pivot points (L1-L5)
    'next_pivots': {...},
    
    # Trend analysis
    'trend_analysis': {...}
}

Training Strategy

Multi-Timeframe Loss

# Action loss (primary)
action_loss = CrossEntropyLoss(action_logits, target_action)

# Next candle losses (auxiliary)
candle_losses = []
for tf in ['1s', '1m', '1h', '1d']:
    if f'target_{tf}' in batch:
        pred = outputs['next_candles'][tf]
        target = batch[f'target_{tf}']
        loss = MSELoss(pred, target)
        candle_losses.append(loss)

# BTC loss
if 'target_btc' in batch:
    btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
    candle_losses.append(btc_loss)

# Combined loss
total_candle_loss = sum(candle_losses) / len(candle_losses)
total_loss = action_loss + 0.1 * total_candle_loss

Why this works:

  • Action loss: Primary objective (trading decisions)
  • Candle losses: Auxiliary tasks (improve representations)
  • Multi-task learning: Better feature learning

Usage Examples

Example 1: All Timeframes Available

outputs = model(
    price_data_1s=eth_1s,
    price_data_1m=eth_1m,
    price_data_1h=eth_1h,
    price_data_1d=eth_1d,
    btc_data_1m=btc_1m,
    position_state=position
)

# Get action
action = torch.argmax(outputs['action_probs'])

# Get next candle predictions for all timeframes
next_1s = outputs['next_candles']['1s']
next_1m = outputs['next_candles']['1m']
next_1h = outputs['next_candles']['1h']
next_1d = outputs['next_candles']['1d']
next_btc = outputs['btc_next_candle']

Example 2: Missing 1s Data (Degraded Mode)

# 1s data not available
outputs = model(
    price_data_1m=eth_1m,
    price_data_1h=eth_1h,
    price_data_1d=eth_1d,
    btc_data_1m=btc_1m,
    position_state=position
)

# Still works! Model adapts to available timeframes
action = torch.argmax(outputs['action_probs'])

# 1s prediction still available (learned from other TFs)
next_1s = outputs['next_candles']['1s']

Example 3: Legacy Single Timeframe

# Old code still works
outputs = model(
    price_data=eth_1m,  # Legacy parameter
    position_state=position
)

# Automatically uses as 1m data
action = torch.argmax(outputs['action_probs'])

Performance Characteristics

Memory Usage

Per Sample:

5 timeframes × 600 candles × 5 OHLCV = 15,000 values
15,000 × 4 bytes = 60 KB input

Shared encoder: 656K params
Cross-TF layers: ~8M params
Total: ~9M params for multi-TF processing

Batch of 5: 300 KB input, manageable

Computational Cost

Forward Pass:

1. Shared encoder: 5 × (600 × 656K) = ~2B ops
2. Cross-TF attention: 2 layers × (3000 × 3000) = ~18M ops
3. Main transformer: 12 layers × (600 × 600) = ~4M ops

Total: ~2B ops (dominated by shared encoder)

Compared to Separate Encoders:

Traditional: 5 encoders × 2B ops = 10B ops
Hybrid: 1 encoder × 2B ops = 2B ops
Speedup: 5x faster!

Key Insights

1. Pattern Universality

Candle patterns are universal across timeframes:

  • Doji on 1s = Doji on 1d (same pattern, different scale)
  • Hammer on 1m = Hammer on 1h (same reversal signal)
  • Shared encoder exploits this universality

2. Scale Invariance

The model learns scale-invariant features:

  • Normalized OHLCV removes absolute price scale
  • Patterns recognized regardless of timeframe
  • Timeframe embeddings add scale context

3. Cross-Scale Validation

Multi-timeframe attention enables validation:

  • Micro signals (1s) validated by macro trends (1d)
  • Reduces false signals
  • Increases prediction confidence

4. Market Correlation

BTC reference captures market-wide moves:

  • BTC leads, altcoins follow
  • Market-wide sentiment
  • Risk-on/risk-off detection

Comparison to Alternatives

vs. Separate Models per Timeframe

Aspect Separate Models Hybrid Architecture
Parameters 5 × 46M = 230M 46M
Training Time 5x longer 1x
Pattern Learning 5x redundant Shared
Cross-TF Dependencies None Captured
Memory Usage 5x higher 1x
Inference Speed 5x slower 1x

vs. Single Concatenated Input

Aspect Concatenation Hybrid Architecture
Pattern Sharing No Yes
Cross-TF Attention No Yes
Missing Data Breaks Handles
Interpretability Low High
Efficiency Medium High

Summary

The hybrid serial-parallel architecture provides:

Efficient Pattern Learning: Shared encoder learns once Cross-Timeframe Dependencies: Parallel attention captures relationships Flexible Input: Handles missing timeframes gracefully Multi-Scale Predictions: Predicts next candle for ALL timeframes Market Correlation: BTC reference for market-wide context Backward Compatible: Legacy code still works

This design maximizes both efficiency and expressiveness! 🚀