375 lines
10 KiB
Markdown
375 lines
10 KiB
Markdown
# Multi-Timeframe Transformer - Implementation Plan
|
|
|
|
## Current Problem
|
|
|
|
The transformer currently:
|
|
- ❌ Only uses **1m timeframe** (ignores 1s, 1h, 1d)
|
|
- ❌ Doesn't use **BTC reference data**
|
|
- ❌ Has `price_projection` (single) but we changed it to `price_projections` (dict)
|
|
- ❌ Predicts next candles for all timeframes but only trains on 1m
|
|
|
|
## Goal
|
|
|
|
Make the transformer:
|
|
- ✅ Accept **ALL timeframes** (1s, 1m, 1h, 1d) as separate inputs
|
|
- ✅ Accept **BTC 1m** as reference symbol
|
|
- ✅ Learn **timeframe-specific patterns** with separate projections
|
|
- ✅ Predict **next candle for EACH timeframe**
|
|
- ✅ Handle **missing timeframes** gracefully (1s may not be available)
|
|
- ✅ Maintain **backward compatibility** with existing code
|
|
|
|
---
|
|
|
|
## Architecture Changes
|
|
|
|
### 1. Input Structure (NEW)
|
|
|
|
```python
|
|
forward(
|
|
price_data_1s=None, # [batch, 600, 5] - 1-second candles (optional)
|
|
price_data_1m=None, # [batch, 600, 5] - 1-minute candles (optional)
|
|
price_data_1h=None, # [batch, 600, 5] - 1-hour candles (optional)
|
|
price_data_1d=None, # [batch, 600, 5] - 1-day candles (optional)
|
|
btc_data_1m=None, # [batch, 600, 5] - BTC 1-minute (optional)
|
|
cob_data=None, # [batch, 600, 100] - COB features
|
|
tech_data=None, # [batch, 40] - Technical indicators
|
|
market_data=None, # [batch, 30] - Market features
|
|
position_state=None, # [batch, 5] - Position state
|
|
# Legacy support
|
|
price_data=None # [batch, 600, 5] - Falls back to 1m if provided
|
|
)
|
|
```
|
|
|
|
### 2. Input Projections (UPDATED)
|
|
|
|
```python
|
|
# Separate projection for each timeframe
|
|
self.price_projections = nn.ModuleDict({
|
|
'1s': nn.Linear(5, 1024), # 1-second patterns
|
|
'1m': nn.Linear(5, 1024), # 1-minute patterns
|
|
'1h': nn.Linear(5, 1024), # 1-hour patterns
|
|
'1d': nn.Linear(5, 1024) # 1-day patterns
|
|
})
|
|
|
|
# BTC reference projection
|
|
self.btc_projection = nn.Linear(5, 1024)
|
|
|
|
# Learnable timeframe importance weights
|
|
self.timeframe_weights = nn.Parameter(torch.ones(5)) # 4 timeframes + BTC
|
|
```
|
|
|
|
### 3. Embedding Combination Strategy
|
|
|
|
**Option A: Weighted Average** (Current Plan)
|
|
```python
|
|
# Project each timeframe
|
|
emb_1s = self.price_projections['1s'](price_data_1s) * weight_1s
|
|
emb_1m = self.price_projections['1m'](price_data_1m) * weight_1m
|
|
emb_1h = self.price_projections['1h'](price_data_1h) * weight_1h
|
|
emb_1d = self.price_projections['1d'](price_data_1d) * weight_1d
|
|
emb_btc = self.btc_projection(btc_data_1m) * weight_btc
|
|
|
|
# Combine
|
|
price_emb = (emb_1s + emb_1m + emb_1h + emb_1d + emb_btc) / num_available
|
|
```
|
|
|
|
**Option B: Concatenation + Projection** (Alternative)
|
|
```python
|
|
# Concatenate all timeframes
|
|
combined = torch.cat([emb_1s, emb_1m, emb_1h, emb_1d, emb_btc], dim=-1)
|
|
# [batch, seq_len, 5*1024] = [batch, seq_len, 5120]
|
|
|
|
# Project back to d_model
|
|
price_emb = nn.Linear(5120, 1024)(combined)
|
|
```
|
|
|
|
**Decision**: Use **Option A** (Weighted Average)
|
|
- ✅ Simpler
|
|
- ✅ Handles missing timeframes naturally
|
|
- ✅ Fewer parameters
|
|
- ✅ Learned weights show timeframe importance
|
|
|
|
### 4. Output Structure (ALREADY EXISTS)
|
|
|
|
```python
|
|
outputs = {
|
|
'action_logits': [batch, 3],
|
|
'action_probs': [batch, 3],
|
|
'confidence': [batch, 1],
|
|
'price_prediction': [batch, 1],
|
|
|
|
# Next candle predictions for EACH timeframe
|
|
'next_candles': {
|
|
'1s': [batch, 5], # [open, high, low, close, volume]
|
|
'1m': [batch, 5],
|
|
'1h': [batch, 5],
|
|
'1d': [batch, 5]
|
|
},
|
|
|
|
# BTC next candle
|
|
'btc_next_candle': [batch, 5],
|
|
|
|
# Pivot predictions
|
|
'next_pivots': {...},
|
|
|
|
# Trend analysis
|
|
'trend_analysis': {...}
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
## Training Data Preparation
|
|
|
|
### Current (WRONG)
|
|
|
|
```python
|
|
# Only uses 1m timeframe
|
|
primary_data = timeframes['1m']
|
|
price_data = stack_ohlcv(primary_data) # [batch, 150, 5]
|
|
|
|
batch = {
|
|
'price_data': price_data, # Only 1m!
|
|
'cob_data': cob_data,
|
|
'tech_data': tech_data,
|
|
'market_data': market_data
|
|
}
|
|
```
|
|
|
|
### New (CORRECT)
|
|
|
|
```python
|
|
# Use ALL available timeframes
|
|
batch = {
|
|
'price_data_1s': stack_ohlcv(timeframes.get('1s')), # [batch, 600, 5] or None
|
|
'price_data_1m': stack_ohlcv(timeframes['1m']), # [batch, 600, 5]
|
|
'price_data_1h': stack_ohlcv(timeframes.get('1h')), # [batch, 600, 5] or None
|
|
'price_data_1d': stack_ohlcv(timeframes.get('1d')), # [batch, 600, 5] or None
|
|
'btc_data_1m': stack_ohlcv(secondary_timeframes.get('BTC/USDT', {}).get('1m')),
|
|
'cob_data': cob_data,
|
|
'tech_data': tech_data,
|
|
'market_data': market_data,
|
|
'position_state': position_state
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
## Sequence Length Strategy
|
|
|
|
### Problem
|
|
Different timeframes have different natural sequence lengths:
|
|
- 1s: 600 candles = 10 minutes
|
|
- 1m: 600 candles = 10 hours
|
|
- 1h: 600 candles = 25 days
|
|
- 1d: 600 candles = ~2 years
|
|
|
|
### Solution: Fixed Sequence Length with Padding
|
|
|
|
**Use 600 candles for ALL timeframes**:
|
|
```python
|
|
seq_len = 600 # Fixed for all timeframes
|
|
|
|
# Each timeframe provides 600 candles
|
|
# - 1s: Last 600 seconds (10 min)
|
|
# - 1m: Last 600 minutes (10 hours)
|
|
# - 1h: Last 600 hours (25 days)
|
|
# - 1d: Last 600 days (~2 years)
|
|
|
|
# If less than 600 available, pad with last candle
|
|
if len(candles) < 600:
|
|
padding = repeat_last_candle(600 - len(candles))
|
|
candles = concat([candles, padding])
|
|
```
|
|
|
|
**Why 600?**
|
|
- ✅ Captures sufficient history for all timeframes
|
|
- ✅ Not too large (memory efficient)
|
|
- ✅ Divisible by many numbers (batch processing)
|
|
- ✅ Transformer can handle it (we have 46M params)
|
|
|
|
---
|
|
|
|
## Normalization Strategy
|
|
|
|
### Per-Timeframe Normalization
|
|
|
|
**Problem**: Different timeframes have different price ranges
|
|
- 1s: Tight range (e.g., $2000-$2001)
|
|
- 1d: Wide range (e.g., $1800-$2200)
|
|
|
|
**Solution**: Normalize each timeframe independently
|
|
|
|
```python
|
|
def normalize_timeframe(ohlcv, timeframe):
|
|
"""Normalize OHLCV data for a specific timeframe"""
|
|
# Use min/max from this timeframe's data
|
|
price_min = min(ohlcv[:, [0,1,2,3]].min()) # Min of OHLC
|
|
price_max = max(ohlcv[:, [0,1,2,3]].max()) # Max of OHLC
|
|
volume_min = ohlcv[:, 4].min()
|
|
volume_max = ohlcv[:, 4].max()
|
|
|
|
# Normalize prices to [0, 1]
|
|
ohlcv_norm = ohlcv.copy()
|
|
ohlcv_norm[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
|
ohlcv_norm[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
|
|
|
# Store bounds for denormalization
|
|
bounds = {
|
|
'price_min': price_min,
|
|
'price_max': price_max,
|
|
'volume_min': volume_min,
|
|
'volume_max': volume_max
|
|
}
|
|
|
|
return ohlcv_norm, bounds
|
|
```
|
|
|
|
### Denormalization for Predictions
|
|
|
|
```python
|
|
def denormalize_prediction(pred_norm, bounds):
|
|
"""Convert normalized prediction back to actual price"""
|
|
pred = pred_norm * (bounds['price_max'] - bounds['price_min']) + bounds['price_min']
|
|
return pred
|
|
```
|
|
|
|
---
|
|
|
|
## Loss Function Updates
|
|
|
|
### Current (Single Timeframe)
|
|
|
|
```python
|
|
loss = action_loss + 0.1 * price_loss
|
|
```
|
|
|
|
### New (Multi-Timeframe)
|
|
|
|
```python
|
|
# Action loss (same)
|
|
action_loss = CrossEntropyLoss(action_logits, actions)
|
|
|
|
# Price loss for each timeframe
|
|
price_losses = []
|
|
for tf in ['1s', '1m', '1h', '1d']:
|
|
if f'target_{tf}' in batch:
|
|
pred = outputs['next_candles'][tf]
|
|
target = batch[f'target_{tf}']
|
|
loss_tf = MSELoss(pred, target)
|
|
price_losses.append(loss_tf)
|
|
|
|
# BTC loss
|
|
if 'target_btc' in batch:
|
|
btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
|
|
price_losses.append(btc_loss)
|
|
|
|
# Combined loss
|
|
total_price_loss = sum(price_losses) / len(price_losses) if price_losses else 0
|
|
total_loss = action_loss + 0.1 * total_price_loss
|
|
```
|
|
|
|
---
|
|
|
|
## Backward Compatibility
|
|
|
|
### Legacy Code Support
|
|
|
|
```python
|
|
# Old code (still works)
|
|
outputs = model(
|
|
price_data=price_1m, # Single timeframe
|
|
cob_data=cob,
|
|
tech_data=tech,
|
|
market_data=market
|
|
)
|
|
|
|
# New code (multi-timeframe)
|
|
outputs = model(
|
|
price_data_1m=price_1m,
|
|
price_data_1h=price_1h,
|
|
price_data_1d=price_1d,
|
|
btc_data_1m=btc_1m,
|
|
cob_data=cob,
|
|
tech_data=tech,
|
|
market_data=market,
|
|
position_state=position
|
|
)
|
|
```
|
|
|
|
**Implementation**:
|
|
```python
|
|
def forward(self, price_data=None, price_data_1m=None, ...):
|
|
# Legacy support
|
|
if price_data is not None and price_data_1m is None:
|
|
price_data_1m = price_data # Use as 1m data
|
|
|
|
# Continue with new logic
|
|
...
|
|
```
|
|
|
|
---
|
|
|
|
## Implementation Steps
|
|
|
|
### Step 1: Update Model Architecture ✅
|
|
- [x] Add `price_projections` ModuleDict
|
|
- [x] Add `btc_projection`
|
|
- [x] Add `timeframe_weights` parameter
|
|
- [x] Add `btc_next_candle_head`
|
|
|
|
### Step 2: Update Forward Method (TODO)
|
|
- [ ] Accept multi-timeframe inputs
|
|
- [ ] Handle missing timeframes
|
|
- [ ] Combine embeddings with learned weights
|
|
- [ ] Add BTC embedding
|
|
- [ ] Maintain backward compatibility
|
|
|
|
### Step 3: Update Training Adapter (TODO)
|
|
- [ ] Extract ALL timeframes from market_state
|
|
- [ ] Normalize each timeframe independently
|
|
- [ ] Create multi-timeframe batch dictionary
|
|
- [ ] Add target next candles for each timeframe
|
|
|
|
### Step 4: Update Loss Calculation (TODO)
|
|
- [ ] Add per-timeframe price losses
|
|
- [ ] Add BTC prediction loss
|
|
- [ ] Weight losses appropriately
|
|
|
|
### Step 5: Testing (TODO)
|
|
- [ ] Unit test: multi-timeframe forward pass
|
|
- [ ] Unit test: missing timeframes handling
|
|
- [ ] Integration test: full training loop
|
|
- [ ] Validation: check predictions for all timeframes
|
|
|
|
---
|
|
|
|
## Expected Benefits
|
|
|
|
### 1. Better Predictions
|
|
- **Multi-scale patterns**: 1s for micro, 1d for macro trends
|
|
- **Cross-timeframe validation**: Confirm signals across timeframes
|
|
- **Market correlation**: BTC reference for market-wide moves
|
|
|
|
### 2. Flexible Usage
|
|
- **Any timeframe**: Choose output timeframe in post-processing
|
|
- **Scalping**: Use 1s predictions
|
|
- **Swing trading**: Use 1h/1d predictions
|
|
- **Same model**: No retraining needed
|
|
|
|
### 3. Robustness
|
|
- **Missing data**: Handles missing 1s gracefully
|
|
- **Degraded mode**: Works with just 1m if needed
|
|
- **Backward compatible**: Old code still works
|
|
|
|
---
|
|
|
|
## Next Actions
|
|
|
|
1. **Finish forward() refactor** - Make it accept multi-timeframe inputs
|
|
2. **Update training adapter** - Pass all timeframes to model
|
|
3. **Test thoroughly** - Ensure no regressions
|
|
4. **Document usage** - Show how to use multi-timeframe predictions
|
|
|
|
This is a significant improvement that will make the model much more powerful! 🚀
|