Shared Pattern Encoder

fix T training
This commit is contained in:
Dobromir Popov
2025-11-06 14:27:52 +02:00
parent 07d97100c0
commit 738c7cb854
5 changed files with 1276 additions and 180 deletions

View File

@@ -0,0 +1,478 @@
# Hybrid Multi-Timeframe Transformer Architecture
## Overview
The transformer uses a **hybrid serial-parallel architecture** that:
1. **SERIAL**: Learns candle patterns ONCE (shared weights across all timeframes)
2. **PARALLEL**: Captures cross-timeframe dependencies simultaneously
This design ensures the model learns common patterns efficiently while understanding relationships between timeframes.
---
## Architecture Flow
```
Input: Multiple Timeframes
┌─────────────────────────────────────────┐
│ STEP 1: SERIAL PROCESSING │
│ (Shared Pattern Encoder) │
│ │
│ 1s data → Shared Encoder → Encoding_1s │
│ 1m data → Shared Encoder → Encoding_1m │
│ 1h data → Shared Encoder → Encoding_1h │
│ 1d data → Shared Encoder → Encoding_1d │
│ BTC data → Shared Encoder → Encoding_BTC│
│ │
│ Same weights learn patterns once! │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ STEP 2: PARALLEL PROCESSING │
│ (Cross-Timeframe Attention) │
│ │
│ Stack all encodings: │
│ [Enc_1s, Enc_1m, Enc_1h, Enc_1d, Enc_BTC]│
│ ↓ │
│ Cross-Timeframe Transformer Layers │
│ (Captures dependencies between TFs) │
│ ↓ │
│ Unified representation │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ STEP 3: PREDICTION │
│ │
│ → Action (BUY/SELL/HOLD) │
│ → Next candle for EACH timeframe │
│ → BTC next candle │
│ → Pivot points │
│ → Trend analysis │
└─────────────────────────────────────────┘
```
---
## Key Components
### 1. Shared Pattern Encoder (SERIAL)
**Purpose**: Learn candle patterns ONCE for all timeframes
```python
self.shared_pattern_encoder = nn.Sequential(
nn.Linear(5, 256), # OHLCV → 256
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 512), # 256 → 512
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, 1024) # 512 → 1024 (d_model)
)
```
**How it works**:
- Same network processes ALL timeframes
- Learns universal candle patterns:
- Doji, hammer, engulfing, etc.
- Support/resistance bounces
- Breakout patterns
- Volume spikes
- **Efficient**: Patterns learned once, not 5 times
**Example**:
```python
# All timeframes use SAME encoder
encoding_1s = shared_encoder(price_data_1s) # [batch, 600, 1024]
encoding_1m = shared_encoder(price_data_1m) # [batch, 600, 1024]
encoding_1h = shared_encoder(price_data_1h) # [batch, 600, 1024]
encoding_1d = shared_encoder(price_data_1d) # [batch, 600, 1024]
encoding_btc = shared_encoder(btc_data_1m) # [batch, 600, 1024]
# Same weights → learns patterns once!
```
### 2. Timeframe Embeddings
**Purpose**: Help model distinguish which timeframe it's looking at
```python
self.timeframe_embeddings = nn.Embedding(5, 1024)
# 5 timeframes: 1s, 1m, 1h, 1d, BTC
```
**How it works**:
```python
# Add timeframe identity to shared encoding
tf_embedding = timeframe_embeddings[tf_index] # [1024]
encoding = shared_encoding + tf_embedding
# Now model knows: "This is a 1h candle pattern"
```
### 3. Cross-Timeframe Attention (PARALLEL)
**Purpose**: Capture dependencies BETWEEN timeframes
```python
self.cross_timeframe_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=1024,
nhead=16,
dim_feedforward=4096,
dropout=0.1,
batch_first=True
) for _ in range(2) # 2 layers
])
```
**How it works**:
```python
# Stack all timeframes
stacked = torch.stack([enc_1s, enc_1m, enc_1h, enc_1d, enc_btc], dim=1)
# Shape: [batch, 5 timeframes, 600 seq_len, 1024 d_model]
# Reshape for attention
# [batch, 5, 600, 1024] → [batch, 3000, 1024]
cross_input = stacked.reshape(batch, 5*600, 1024)
# Apply cross-timeframe attention
# Each position can attend to ALL timeframes simultaneously
for layer in cross_timeframe_layers:
cross_input = layer(cross_input)
# Model learns:
# - "1s shows breakout, 1h confirms trend"
# - "1d resistance, but 1m shows accumulation"
# - "BTC dumping, ETH following"
```
**What it captures**:
- **Trend confirmation**: Signal on 1m confirmed by 1h
- **Divergences**: 1s bullish but 1d bearish
- **Correlation**: BTC moves predict ETH moves
- **Multi-scale patterns**: Fractal patterns across timeframes
---
## Benefits of Hybrid Architecture
### 1. Knowledge Sharing (SERIAL)
**Efficient Learning**
```
Traditional: 5 separate encoders × 656K params = 3.28M params
Hybrid: 1 shared encoder × 656K params = 656K params
Savings: 80% fewer parameters!
```
**Better Generalization**
- Patterns learned from ALL timeframes
- More training data per pattern
- Stronger pattern recognition
**Transfer Learning**
- Pattern learned on 1m helps 1h
- Pattern learned on 1d helps 1s
- Cross-timeframe knowledge transfer
### 2. Dependency Capture (PARALLEL)
**Cross-Timeframe Validation**
```python
# Example: Entry signal validation
1s: Bullish breakout (local signal)
1m: Uptrend confirmed (short-term)
1h: Above support (medium-term)
1d: Bullish trend (long-term)
BTC: Also bullish (market-wide)
High confidence entry!
```
**Divergence Detection**
```python
# Example: Warning signal
1s: Bullish (noise)
1m: Bullish (short-term)
1h: Bearish divergence (warning!)
1d: Downtrend (macro)
BTC: Dumping (market-wide)
Don't enter, wait for confirmation
```
**Market Correlation**
```python
# Example: BTC influence
BTC: Sharp drop detected
ETH 1s: Following BTC
ETH 1m: Correlation confirmed
ETH 1h: Likely to follow
ETH 1d: Macro trend affected
Exit positions, BTC leading
```
---
## Input/Output Specification
### Input Format
```python
model(
# Primary symbol (ETH/USDT) - all timeframes
price_data_1s=[batch, 600, 5], # 600 × 1s candles (10 min)
price_data_1m=[batch, 600, 5], # 600 × 1m candles (10 hours)
price_data_1h=[batch, 600, 5], # 600 × 1h candles (25 days)
price_data_1d=[batch, 600, 5], # 600 × 1d candles (~2 years)
# Reference symbol (BTC/USDT)
btc_data_1m=[batch, 600, 5], # 600 × 1m BTC candles
# Other features
cob_data=[batch, 600, 100], # Order book
tech_data=[batch, 40], # Technical indicators
market_data=[batch, 30], # Market features
position_state=[batch, 5] # Position state
)
```
**Notes**:
- All timeframes optional (handles missing data)
- Fixed sequence length: 600 candles
- OHLCV format: [open, high, low, close, volume]
### Output Format
```python
outputs = {
# Trading decision
'action_logits': [batch, 3], # BUY/SELL/HOLD logits
'action_probs': [batch, 3], # Softmax probabilities
'confidence': [batch, 1], # Prediction confidence
# Next candle predictions (ALL timeframes)
'next_candles': {
'1s': [batch, 5], # Next 1s candle OHLCV
'1m': [batch, 5], # Next 1m candle OHLCV
'1h': [batch, 5], # Next 1h candle OHLCV
'1d': [batch, 5] # Next 1d candle OHLCV
},
# BTC prediction
'btc_next_candle': [batch, 5], # Next BTC 1m candle
# Auxiliary predictions
'price_prediction': [batch, 1], # Price target
'volatility_prediction': [batch, 1], # Expected volatility
'trend_strength_prediction': [batch, 1], # Trend strength
# Pivot points (L1-L5)
'next_pivots': {...},
# Trend analysis
'trend_analysis': {...}
}
```
---
## Training Strategy
### Multi-Timeframe Loss
```python
# Action loss (primary)
action_loss = CrossEntropyLoss(action_logits, target_action)
# Next candle losses (auxiliary)
candle_losses = []
for tf in ['1s', '1m', '1h', '1d']:
if f'target_{tf}' in batch:
pred = outputs['next_candles'][tf]
target = batch[f'target_{tf}']
loss = MSELoss(pred, target)
candle_losses.append(loss)
# BTC loss
if 'target_btc' in batch:
btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
candle_losses.append(btc_loss)
# Combined loss
total_candle_loss = sum(candle_losses) / len(candle_losses)
total_loss = action_loss + 0.1 * total_candle_loss
```
**Why this works**:
- Action loss: Primary objective (trading decisions)
- Candle losses: Auxiliary tasks (improve representations)
- Multi-task learning: Better feature learning
---
## Usage Examples
### Example 1: All Timeframes Available
```python
outputs = model(
price_data_1s=eth_1s,
price_data_1m=eth_1m,
price_data_1h=eth_1h,
price_data_1d=eth_1d,
btc_data_1m=btc_1m,
position_state=position
)
# Get action
action = torch.argmax(outputs['action_probs'])
# Get next candle predictions for all timeframes
next_1s = outputs['next_candles']['1s']
next_1m = outputs['next_candles']['1m']
next_1h = outputs['next_candles']['1h']
next_1d = outputs['next_candles']['1d']
next_btc = outputs['btc_next_candle']
```
### Example 2: Missing 1s Data (Degraded Mode)
```python
# 1s data not available
outputs = model(
price_data_1m=eth_1m,
price_data_1h=eth_1h,
price_data_1d=eth_1d,
btc_data_1m=btc_1m,
position_state=position
)
# Still works! Model adapts to available timeframes
action = torch.argmax(outputs['action_probs'])
# 1s prediction still available (learned from other TFs)
next_1s = outputs['next_candles']['1s']
```
### Example 3: Legacy Single Timeframe
```python
# Old code still works
outputs = model(
price_data=eth_1m, # Legacy parameter
position_state=position
)
# Automatically uses as 1m data
action = torch.argmax(outputs['action_probs'])
```
---
## Performance Characteristics
### Memory Usage
**Per Sample**:
```
5 timeframes × 600 candles × 5 OHLCV = 15,000 values
15,000 × 4 bytes = 60 KB input
Shared encoder: 656K params
Cross-TF layers: ~8M params
Total: ~9M params for multi-TF processing
Batch of 5: 300 KB input, manageable
```
### Computational Cost
**Forward Pass**:
```
1. Shared encoder: 5 × (600 × 656K) = ~2B ops
2. Cross-TF attention: 2 layers × (3000 × 3000) = ~18M ops
3. Main transformer: 12 layers × (600 × 600) = ~4M ops
Total: ~2B ops (dominated by shared encoder)
```
**Compared to Separate Encoders**:
```
Traditional: 5 encoders × 2B ops = 10B ops
Hybrid: 1 encoder × 2B ops = 2B ops
Speedup: 5x faster!
```
---
## Key Insights
### 1. Pattern Universality
Candle patterns are universal across timeframes:
- Doji on 1s = Doji on 1d (same pattern, different scale)
- Hammer on 1m = Hammer on 1h (same reversal signal)
- Shared encoder exploits this universality
### 2. Scale Invariance
The model learns scale-invariant features:
- Normalized OHLCV removes absolute price scale
- Patterns recognized regardless of timeframe
- Timeframe embeddings add scale context
### 3. Cross-Scale Validation
Multi-timeframe attention enables validation:
- Micro signals (1s) validated by macro trends (1d)
- Reduces false signals
- Increases prediction confidence
### 4. Market Correlation
BTC reference captures market-wide moves:
- BTC leads, altcoins follow
- Market-wide sentiment
- Risk-on/risk-off detection
---
## Comparison to Alternatives
### vs. Separate Models per Timeframe
| Aspect | Separate Models | Hybrid Architecture |
|--------|----------------|---------------------|
| Parameters | 5 × 46M = 230M | 46M |
| Training Time | 5x longer | 1x |
| Pattern Learning | 5x redundant | Shared |
| Cross-TF Dependencies | ❌ None | ✅ Captured |
| Memory Usage | 5x higher | 1x |
| Inference Speed | 5x slower | 1x |
### vs. Single Concatenated Input
| Aspect | Concatenation | Hybrid Architecture |
|--------|--------------|---------------------|
| Pattern Sharing | ❌ No | ✅ Yes |
| Cross-TF Attention | ❌ No | ✅ Yes |
| Missing Data | ❌ Breaks | ✅ Handles |
| Interpretability | Low | High |
| Efficiency | Medium | High |
---
## Summary
The hybrid serial-parallel architecture provides:
**Efficient Pattern Learning**: Shared encoder learns once
**Cross-Timeframe Dependencies**: Parallel attention captures relationships
**Flexible Input**: Handles missing timeframes gracefully
**Multi-Scale Predictions**: Predicts next candle for ALL timeframes
**Market Correlation**: BTC reference for market-wide context
**Backward Compatible**: Legacy code still works
This design maximizes both efficiency and expressiveness! 🚀