Files
gogo2/docs/HYBRID_MULTI_TIMEFRAME_ARCHITECTURE.md
Dobromir Popov 738c7cb854 Shared Pattern Encoder
fix T training
2025-11-06 14:27:52 +02:00

479 lines
13 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Hybrid Multi-Timeframe Transformer Architecture
## Overview
The transformer uses a **hybrid serial-parallel architecture** that:
1. **SERIAL**: Learns candle patterns ONCE (shared weights across all timeframes)
2. **PARALLEL**: Captures cross-timeframe dependencies simultaneously
This design ensures the model learns common patterns efficiently while understanding relationships between timeframes.
---
## Architecture Flow
```
Input: Multiple Timeframes
┌─────────────────────────────────────────┐
│ STEP 1: SERIAL PROCESSING │
│ (Shared Pattern Encoder) │
│ │
│ 1s data → Shared Encoder → Encoding_1s │
│ 1m data → Shared Encoder → Encoding_1m │
│ 1h data → Shared Encoder → Encoding_1h │
│ 1d data → Shared Encoder → Encoding_1d │
│ BTC data → Shared Encoder → Encoding_BTC│
│ │
│ Same weights learn patterns once! │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ STEP 2: PARALLEL PROCESSING │
│ (Cross-Timeframe Attention) │
│ │
│ Stack all encodings: │
│ [Enc_1s, Enc_1m, Enc_1h, Enc_1d, Enc_BTC]│
│ ↓ │
│ Cross-Timeframe Transformer Layers │
│ (Captures dependencies between TFs) │
│ ↓ │
│ Unified representation │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ STEP 3: PREDICTION │
│ │
│ → Action (BUY/SELL/HOLD) │
│ → Next candle for EACH timeframe │
│ → BTC next candle │
│ → Pivot points │
│ → Trend analysis │
└─────────────────────────────────────────┘
```
---
## Key Components
### 1. Shared Pattern Encoder (SERIAL)
**Purpose**: Learn candle patterns ONCE for all timeframes
```python
self.shared_pattern_encoder = nn.Sequential(
nn.Linear(5, 256), # OHLCV → 256
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 512), # 256 → 512
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, 1024) # 512 → 1024 (d_model)
)
```
**How it works**:
- Same network processes ALL timeframes
- Learns universal candle patterns:
- Doji, hammer, engulfing, etc.
- Support/resistance bounces
- Breakout patterns
- Volume spikes
- **Efficient**: Patterns learned once, not 5 times
**Example**:
```python
# All timeframes use SAME encoder
encoding_1s = shared_encoder(price_data_1s) # [batch, 600, 1024]
encoding_1m = shared_encoder(price_data_1m) # [batch, 600, 1024]
encoding_1h = shared_encoder(price_data_1h) # [batch, 600, 1024]
encoding_1d = shared_encoder(price_data_1d) # [batch, 600, 1024]
encoding_btc = shared_encoder(btc_data_1m) # [batch, 600, 1024]
# Same weights → learns patterns once!
```
### 2. Timeframe Embeddings
**Purpose**: Help model distinguish which timeframe it's looking at
```python
self.timeframe_embeddings = nn.Embedding(5, 1024)
# 5 timeframes: 1s, 1m, 1h, 1d, BTC
```
**How it works**:
```python
# Add timeframe identity to shared encoding
tf_embedding = timeframe_embeddings[tf_index] # [1024]
encoding = shared_encoding + tf_embedding
# Now model knows: "This is a 1h candle pattern"
```
### 3. Cross-Timeframe Attention (PARALLEL)
**Purpose**: Capture dependencies BETWEEN timeframes
```python
self.cross_timeframe_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=1024,
nhead=16,
dim_feedforward=4096,
dropout=0.1,
batch_first=True
) for _ in range(2) # 2 layers
])
```
**How it works**:
```python
# Stack all timeframes
stacked = torch.stack([enc_1s, enc_1m, enc_1h, enc_1d, enc_btc], dim=1)
# Shape: [batch, 5 timeframes, 600 seq_len, 1024 d_model]
# Reshape for attention
# [batch, 5, 600, 1024] → [batch, 3000, 1024]
cross_input = stacked.reshape(batch, 5*600, 1024)
# Apply cross-timeframe attention
# Each position can attend to ALL timeframes simultaneously
for layer in cross_timeframe_layers:
cross_input = layer(cross_input)
# Model learns:
# - "1s shows breakout, 1h confirms trend"
# - "1d resistance, but 1m shows accumulation"
# - "BTC dumping, ETH following"
```
**What it captures**:
- **Trend confirmation**: Signal on 1m confirmed by 1h
- **Divergences**: 1s bullish but 1d bearish
- **Correlation**: BTC moves predict ETH moves
- **Multi-scale patterns**: Fractal patterns across timeframes
---
## Benefits of Hybrid Architecture
### 1. Knowledge Sharing (SERIAL)
**Efficient Learning**
```
Traditional: 5 separate encoders × 656K params = 3.28M params
Hybrid: 1 shared encoder × 656K params = 656K params
Savings: 80% fewer parameters!
```
**Better Generalization**
- Patterns learned from ALL timeframes
- More training data per pattern
- Stronger pattern recognition
**Transfer Learning**
- Pattern learned on 1m helps 1h
- Pattern learned on 1d helps 1s
- Cross-timeframe knowledge transfer
### 2. Dependency Capture (PARALLEL)
**Cross-Timeframe Validation**
```python
# Example: Entry signal validation
1s: Bullish breakout (local signal)
1m: Uptrend confirmed (short-term)
1h: Above support (medium-term)
1d: Bullish trend (long-term)
BTC: Also bullish (market-wide)
High confidence entry!
```
**Divergence Detection**
```python
# Example: Warning signal
1s: Bullish (noise)
1m: Bullish (short-term)
1h: Bearish divergence (warning!)
1d: Downtrend (macro)
BTC: Dumping (market-wide)
Don't enter, wait for confirmation
```
**Market Correlation**
```python
# Example: BTC influence
BTC: Sharp drop detected
ETH 1s: Following BTC
ETH 1m: Correlation confirmed
ETH 1h: Likely to follow
ETH 1d: Macro trend affected
Exit positions, BTC leading
```
---
## Input/Output Specification
### Input Format
```python
model(
# Primary symbol (ETH/USDT) - all timeframes
price_data_1s=[batch, 600, 5], # 600 × 1s candles (10 min)
price_data_1m=[batch, 600, 5], # 600 × 1m candles (10 hours)
price_data_1h=[batch, 600, 5], # 600 × 1h candles (25 days)
price_data_1d=[batch, 600, 5], # 600 × 1d candles (~2 years)
# Reference symbol (BTC/USDT)
btc_data_1m=[batch, 600, 5], # 600 × 1m BTC candles
# Other features
cob_data=[batch, 600, 100], # Order book
tech_data=[batch, 40], # Technical indicators
market_data=[batch, 30], # Market features
position_state=[batch, 5] # Position state
)
```
**Notes**:
- All timeframes optional (handles missing data)
- Fixed sequence length: 600 candles
- OHLCV format: [open, high, low, close, volume]
### Output Format
```python
outputs = {
# Trading decision
'action_logits': [batch, 3], # BUY/SELL/HOLD logits
'action_probs': [batch, 3], # Softmax probabilities
'confidence': [batch, 1], # Prediction confidence
# Next candle predictions (ALL timeframes)
'next_candles': {
'1s': [batch, 5], # Next 1s candle OHLCV
'1m': [batch, 5], # Next 1m candle OHLCV
'1h': [batch, 5], # Next 1h candle OHLCV
'1d': [batch, 5] # Next 1d candle OHLCV
},
# BTC prediction
'btc_next_candle': [batch, 5], # Next BTC 1m candle
# Auxiliary predictions
'price_prediction': [batch, 1], # Price target
'volatility_prediction': [batch, 1], # Expected volatility
'trend_strength_prediction': [batch, 1], # Trend strength
# Pivot points (L1-L5)
'next_pivots': {...},
# Trend analysis
'trend_analysis': {...}
}
```
---
## Training Strategy
### Multi-Timeframe Loss
```python
# Action loss (primary)
action_loss = CrossEntropyLoss(action_logits, target_action)
# Next candle losses (auxiliary)
candle_losses = []
for tf in ['1s', '1m', '1h', '1d']:
if f'target_{tf}' in batch:
pred = outputs['next_candles'][tf]
target = batch[f'target_{tf}']
loss = MSELoss(pred, target)
candle_losses.append(loss)
# BTC loss
if 'target_btc' in batch:
btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
candle_losses.append(btc_loss)
# Combined loss
total_candle_loss = sum(candle_losses) / len(candle_losses)
total_loss = action_loss + 0.1 * total_candle_loss
```
**Why this works**:
- Action loss: Primary objective (trading decisions)
- Candle losses: Auxiliary tasks (improve representations)
- Multi-task learning: Better feature learning
---
## Usage Examples
### Example 1: All Timeframes Available
```python
outputs = model(
price_data_1s=eth_1s,
price_data_1m=eth_1m,
price_data_1h=eth_1h,
price_data_1d=eth_1d,
btc_data_1m=btc_1m,
position_state=position
)
# Get action
action = torch.argmax(outputs['action_probs'])
# Get next candle predictions for all timeframes
next_1s = outputs['next_candles']['1s']
next_1m = outputs['next_candles']['1m']
next_1h = outputs['next_candles']['1h']
next_1d = outputs['next_candles']['1d']
next_btc = outputs['btc_next_candle']
```
### Example 2: Missing 1s Data (Degraded Mode)
```python
# 1s data not available
outputs = model(
price_data_1m=eth_1m,
price_data_1h=eth_1h,
price_data_1d=eth_1d,
btc_data_1m=btc_1m,
position_state=position
)
# Still works! Model adapts to available timeframes
action = torch.argmax(outputs['action_probs'])
# 1s prediction still available (learned from other TFs)
next_1s = outputs['next_candles']['1s']
```
### Example 3: Legacy Single Timeframe
```python
# Old code still works
outputs = model(
price_data=eth_1m, # Legacy parameter
position_state=position
)
# Automatically uses as 1m data
action = torch.argmax(outputs['action_probs'])
```
---
## Performance Characteristics
### Memory Usage
**Per Sample**:
```
5 timeframes × 600 candles × 5 OHLCV = 15,000 values
15,000 × 4 bytes = 60 KB input
Shared encoder: 656K params
Cross-TF layers: ~8M params
Total: ~9M params for multi-TF processing
Batch of 5: 300 KB input, manageable
```
### Computational Cost
**Forward Pass**:
```
1. Shared encoder: 5 × (600 × 656K) = ~2B ops
2. Cross-TF attention: 2 layers × (3000 × 3000) = ~18M ops
3. Main transformer: 12 layers × (600 × 600) = ~4M ops
Total: ~2B ops (dominated by shared encoder)
```
**Compared to Separate Encoders**:
```
Traditional: 5 encoders × 2B ops = 10B ops
Hybrid: 1 encoder × 2B ops = 2B ops
Speedup: 5x faster!
```
---
## Key Insights
### 1. Pattern Universality
Candle patterns are universal across timeframes:
- Doji on 1s = Doji on 1d (same pattern, different scale)
- Hammer on 1m = Hammer on 1h (same reversal signal)
- Shared encoder exploits this universality
### 2. Scale Invariance
The model learns scale-invariant features:
- Normalized OHLCV removes absolute price scale
- Patterns recognized regardless of timeframe
- Timeframe embeddings add scale context
### 3. Cross-Scale Validation
Multi-timeframe attention enables validation:
- Micro signals (1s) validated by macro trends (1d)
- Reduces false signals
- Increases prediction confidence
### 4. Market Correlation
BTC reference captures market-wide moves:
- BTC leads, altcoins follow
- Market-wide sentiment
- Risk-on/risk-off detection
---
## Comparison to Alternatives
### vs. Separate Models per Timeframe
| Aspect | Separate Models | Hybrid Architecture |
|--------|----------------|---------------------|
| Parameters | 5 × 46M = 230M | 46M |
| Training Time | 5x longer | 1x |
| Pattern Learning | 5x redundant | Shared |
| Cross-TF Dependencies | ❌ None | ✅ Captured |
| Memory Usage | 5x higher | 1x |
| Inference Speed | 5x slower | 1x |
### vs. Single Concatenated Input
| Aspect | Concatenation | Hybrid Architecture |
|--------|--------------|---------------------|
| Pattern Sharing | ❌ No | ✅ Yes |
| Cross-TF Attention | ❌ No | ✅ Yes |
| Missing Data | ❌ Breaks | ✅ Handles |
| Interpretability | Low | High |
| Efficiency | Medium | High |
---
## Summary
The hybrid serial-parallel architecture provides:
**Efficient Pattern Learning**: Shared encoder learns once
**Cross-Timeframe Dependencies**: Parallel attention captures relationships
**Flexible Input**: Handles missing timeframes gracefully
**Multi-Scale Predictions**: Predicts next candle for ALL timeframes
**Market Correlation**: BTC reference for market-wide context
**Backward Compatible**: Legacy code still works
This design maximizes both efficiency and expressiveness! 🚀