T model was only using 1t data so far. fixing WIP
This commit is contained in:
@@ -349,8 +349,17 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Input projections
|
# Multi-timeframe input projections
|
||||||
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
|
# Each timeframe gets its own projection to learn timeframe-specific patterns
|
||||||
|
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||||
|
self.price_projections = nn.ModuleDict({
|
||||||
|
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
|
||||||
|
})
|
||||||
|
|
||||||
|
# Reference symbol projection (BTC 1m)
|
||||||
|
self.btc_projection = nn.Linear(5, config.d_model)
|
||||||
|
|
||||||
|
# Other input projections
|
||||||
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
||||||
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
||||||
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
||||||
@@ -367,6 +376,9 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Timeframe importance weights (learnable)
|
||||||
|
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
|
||||||
|
|
||||||
# Positional encoding
|
# Positional encoding
|
||||||
if config.use_relative_position:
|
if config.use_relative_position:
|
||||||
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
||||||
@@ -435,7 +447,7 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
|
|
||||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
# Note: self.timeframes already defined above in input projections
|
||||||
self.next_candle_heads = nn.ModuleDict({
|
self.next_candle_heads = nn.ModuleDict({
|
||||||
tf: nn.Sequential(
|
tf: nn.Sequential(
|
||||||
nn.Linear(config.d_model, config.d_model // 2),
|
nn.Linear(config.d_model, config.d_model // 2),
|
||||||
@@ -448,6 +460,17 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
) for tf in self.timeframes
|
) for tf in self.timeframes
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# BTC next candle prediction head
|
||||||
|
self.btc_next_candle_head = nn.Sequential(
|
||||||
|
nn.Linear(config.d_model, config.d_model // 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(config.dropout),
|
||||||
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(config.dropout),
|
||||||
|
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
|
||||||
|
)
|
||||||
|
|
||||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||||
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
|
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
|
||||||
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity
|
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity
|
||||||
|
|||||||
374
_dev/multi_timeframe_transformer_plan.md
Normal file
374
_dev/multi_timeframe_transformer_plan.md
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
# Multi-Timeframe Transformer - Implementation Plan
|
||||||
|
|
||||||
|
## Current Problem
|
||||||
|
|
||||||
|
The transformer currently:
|
||||||
|
- ❌ Only uses **1m timeframe** (ignores 1s, 1h, 1d)
|
||||||
|
- ❌ Doesn't use **BTC reference data**
|
||||||
|
- ❌ Has `price_projection` (single) but we changed it to `price_projections` (dict)
|
||||||
|
- ❌ Predicts next candles for all timeframes but only trains on 1m
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Make the transformer:
|
||||||
|
- ✅ Accept **ALL timeframes** (1s, 1m, 1h, 1d) as separate inputs
|
||||||
|
- ✅ Accept **BTC 1m** as reference symbol
|
||||||
|
- ✅ Learn **timeframe-specific patterns** with separate projections
|
||||||
|
- ✅ Predict **next candle for EACH timeframe**
|
||||||
|
- ✅ Handle **missing timeframes** gracefully (1s may not be available)
|
||||||
|
- ✅ Maintain **backward compatibility** with existing code
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture Changes
|
||||||
|
|
||||||
|
### 1. Input Structure (NEW)
|
||||||
|
|
||||||
|
```python
|
||||||
|
forward(
|
||||||
|
price_data_1s=None, # [batch, 600, 5] - 1-second candles (optional)
|
||||||
|
price_data_1m=None, # [batch, 600, 5] - 1-minute candles (optional)
|
||||||
|
price_data_1h=None, # [batch, 600, 5] - 1-hour candles (optional)
|
||||||
|
price_data_1d=None, # [batch, 600, 5] - 1-day candles (optional)
|
||||||
|
btc_data_1m=None, # [batch, 600, 5] - BTC 1-minute (optional)
|
||||||
|
cob_data=None, # [batch, 600, 100] - COB features
|
||||||
|
tech_data=None, # [batch, 40] - Technical indicators
|
||||||
|
market_data=None, # [batch, 30] - Market features
|
||||||
|
position_state=None, # [batch, 5] - Position state
|
||||||
|
# Legacy support
|
||||||
|
price_data=None # [batch, 600, 5] - Falls back to 1m if provided
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Input Projections (UPDATED)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Separate projection for each timeframe
|
||||||
|
self.price_projections = nn.ModuleDict({
|
||||||
|
'1s': nn.Linear(5, 1024), # 1-second patterns
|
||||||
|
'1m': nn.Linear(5, 1024), # 1-minute patterns
|
||||||
|
'1h': nn.Linear(5, 1024), # 1-hour patterns
|
||||||
|
'1d': nn.Linear(5, 1024) # 1-day patterns
|
||||||
|
})
|
||||||
|
|
||||||
|
# BTC reference projection
|
||||||
|
self.btc_projection = nn.Linear(5, 1024)
|
||||||
|
|
||||||
|
# Learnable timeframe importance weights
|
||||||
|
self.timeframe_weights = nn.Parameter(torch.ones(5)) # 4 timeframes + BTC
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Embedding Combination Strategy
|
||||||
|
|
||||||
|
**Option A: Weighted Average** (Current Plan)
|
||||||
|
```python
|
||||||
|
# Project each timeframe
|
||||||
|
emb_1s = self.price_projections['1s'](price_data_1s) * weight_1s
|
||||||
|
emb_1m = self.price_projections['1m'](price_data_1m) * weight_1m
|
||||||
|
emb_1h = self.price_projections['1h'](price_data_1h) * weight_1h
|
||||||
|
emb_1d = self.price_projections['1d'](price_data_1d) * weight_1d
|
||||||
|
emb_btc = self.btc_projection(btc_data_1m) * weight_btc
|
||||||
|
|
||||||
|
# Combine
|
||||||
|
price_emb = (emb_1s + emb_1m + emb_1h + emb_1d + emb_btc) / num_available
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option B: Concatenation + Projection** (Alternative)
|
||||||
|
```python
|
||||||
|
# Concatenate all timeframes
|
||||||
|
combined = torch.cat([emb_1s, emb_1m, emb_1h, emb_1d, emb_btc], dim=-1)
|
||||||
|
# [batch, seq_len, 5*1024] = [batch, seq_len, 5120]
|
||||||
|
|
||||||
|
# Project back to d_model
|
||||||
|
price_emb = nn.Linear(5120, 1024)(combined)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Decision**: Use **Option A** (Weighted Average)
|
||||||
|
- ✅ Simpler
|
||||||
|
- ✅ Handles missing timeframes naturally
|
||||||
|
- ✅ Fewer parameters
|
||||||
|
- ✅ Learned weights show timeframe importance
|
||||||
|
|
||||||
|
### 4. Output Structure (ALREADY EXISTS)
|
||||||
|
|
||||||
|
```python
|
||||||
|
outputs = {
|
||||||
|
'action_logits': [batch, 3],
|
||||||
|
'action_probs': [batch, 3],
|
||||||
|
'confidence': [batch, 1],
|
||||||
|
'price_prediction': [batch, 1],
|
||||||
|
|
||||||
|
# Next candle predictions for EACH timeframe
|
||||||
|
'next_candles': {
|
||||||
|
'1s': [batch, 5], # [open, high, low, close, volume]
|
||||||
|
'1m': [batch, 5],
|
||||||
|
'1h': [batch, 5],
|
||||||
|
'1d': [batch, 5]
|
||||||
|
},
|
||||||
|
|
||||||
|
# BTC next candle
|
||||||
|
'btc_next_candle': [batch, 5],
|
||||||
|
|
||||||
|
# Pivot predictions
|
||||||
|
'next_pivots': {...},
|
||||||
|
|
||||||
|
# Trend analysis
|
||||||
|
'trend_analysis': {...}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training Data Preparation
|
||||||
|
|
||||||
|
### Current (WRONG)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Only uses 1m timeframe
|
||||||
|
primary_data = timeframes['1m']
|
||||||
|
price_data = stack_ohlcv(primary_data) # [batch, 150, 5]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
'price_data': price_data, # Only 1m!
|
||||||
|
'cob_data': cob_data,
|
||||||
|
'tech_data': tech_data,
|
||||||
|
'market_data': market_data
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### New (CORRECT)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use ALL available timeframes
|
||||||
|
batch = {
|
||||||
|
'price_data_1s': stack_ohlcv(timeframes.get('1s')), # [batch, 600, 5] or None
|
||||||
|
'price_data_1m': stack_ohlcv(timeframes['1m']), # [batch, 600, 5]
|
||||||
|
'price_data_1h': stack_ohlcv(timeframes.get('1h')), # [batch, 600, 5] or None
|
||||||
|
'price_data_1d': stack_ohlcv(timeframes.get('1d')), # [batch, 600, 5] or None
|
||||||
|
'btc_data_1m': stack_ohlcv(secondary_timeframes.get('BTC/USDT', {}).get('1m')),
|
||||||
|
'cob_data': cob_data,
|
||||||
|
'tech_data': tech_data,
|
||||||
|
'market_data': market_data,
|
||||||
|
'position_state': position_state
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Sequence Length Strategy
|
||||||
|
|
||||||
|
### Problem
|
||||||
|
Different timeframes have different natural sequence lengths:
|
||||||
|
- 1s: 600 candles = 10 minutes
|
||||||
|
- 1m: 600 candles = 10 hours
|
||||||
|
- 1h: 600 candles = 25 days
|
||||||
|
- 1d: 600 candles = ~2 years
|
||||||
|
|
||||||
|
### Solution: Fixed Sequence Length with Padding
|
||||||
|
|
||||||
|
**Use 600 candles for ALL timeframes**:
|
||||||
|
```python
|
||||||
|
seq_len = 600 # Fixed for all timeframes
|
||||||
|
|
||||||
|
# Each timeframe provides 600 candles
|
||||||
|
# - 1s: Last 600 seconds (10 min)
|
||||||
|
# - 1m: Last 600 minutes (10 hours)
|
||||||
|
# - 1h: Last 600 hours (25 days)
|
||||||
|
# - 1d: Last 600 days (~2 years)
|
||||||
|
|
||||||
|
# If less than 600 available, pad with last candle
|
||||||
|
if len(candles) < 600:
|
||||||
|
padding = repeat_last_candle(600 - len(candles))
|
||||||
|
candles = concat([candles, padding])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why 600?**
|
||||||
|
- ✅ Captures sufficient history for all timeframes
|
||||||
|
- ✅ Not too large (memory efficient)
|
||||||
|
- ✅ Divisible by many numbers (batch processing)
|
||||||
|
- ✅ Transformer can handle it (we have 46M params)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Normalization Strategy
|
||||||
|
|
||||||
|
### Per-Timeframe Normalization
|
||||||
|
|
||||||
|
**Problem**: Different timeframes have different price ranges
|
||||||
|
- 1s: Tight range (e.g., $2000-$2001)
|
||||||
|
- 1d: Wide range (e.g., $1800-$2200)
|
||||||
|
|
||||||
|
**Solution**: Normalize each timeframe independently
|
||||||
|
|
||||||
|
```python
|
||||||
|
def normalize_timeframe(ohlcv, timeframe):
|
||||||
|
"""Normalize OHLCV data for a specific timeframe"""
|
||||||
|
# Use min/max from this timeframe's data
|
||||||
|
price_min = min(ohlcv[:, [0,1,2,3]].min()) # Min of OHLC
|
||||||
|
price_max = max(ohlcv[:, [0,1,2,3]].max()) # Max of OHLC
|
||||||
|
volume_min = ohlcv[:, 4].min()
|
||||||
|
volume_max = ohlcv[:, 4].max()
|
||||||
|
|
||||||
|
# Normalize prices to [0, 1]
|
||||||
|
ohlcv_norm = ohlcv.copy()
|
||||||
|
ohlcv_norm[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
||||||
|
ohlcv_norm[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
||||||
|
|
||||||
|
# Store bounds for denormalization
|
||||||
|
bounds = {
|
||||||
|
'price_min': price_min,
|
||||||
|
'price_max': price_max,
|
||||||
|
'volume_min': volume_min,
|
||||||
|
'volume_max': volume_max
|
||||||
|
}
|
||||||
|
|
||||||
|
return ohlcv_norm, bounds
|
||||||
|
```
|
||||||
|
|
||||||
|
### Denormalization for Predictions
|
||||||
|
|
||||||
|
```python
|
||||||
|
def denormalize_prediction(pred_norm, bounds):
|
||||||
|
"""Convert normalized prediction back to actual price"""
|
||||||
|
pred = pred_norm * (bounds['price_max'] - bounds['price_min']) + bounds['price_min']
|
||||||
|
return pred
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Loss Function Updates
|
||||||
|
|
||||||
|
### Current (Single Timeframe)
|
||||||
|
|
||||||
|
```python
|
||||||
|
loss = action_loss + 0.1 * price_loss
|
||||||
|
```
|
||||||
|
|
||||||
|
### New (Multi-Timeframe)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Action loss (same)
|
||||||
|
action_loss = CrossEntropyLoss(action_logits, actions)
|
||||||
|
|
||||||
|
# Price loss for each timeframe
|
||||||
|
price_losses = []
|
||||||
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
|
if f'target_{tf}' in batch:
|
||||||
|
pred = outputs['next_candles'][tf]
|
||||||
|
target = batch[f'target_{tf}']
|
||||||
|
loss_tf = MSELoss(pred, target)
|
||||||
|
price_losses.append(loss_tf)
|
||||||
|
|
||||||
|
# BTC loss
|
||||||
|
if 'target_btc' in batch:
|
||||||
|
btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
|
||||||
|
price_losses.append(btc_loss)
|
||||||
|
|
||||||
|
# Combined loss
|
||||||
|
total_price_loss = sum(price_losses) / len(price_losses) if price_losses else 0
|
||||||
|
total_loss = action_loss + 0.1 * total_price_loss
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Backward Compatibility
|
||||||
|
|
||||||
|
### Legacy Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Old code (still works)
|
||||||
|
outputs = model(
|
||||||
|
price_data=price_1m, # Single timeframe
|
||||||
|
cob_data=cob,
|
||||||
|
tech_data=tech,
|
||||||
|
market_data=market
|
||||||
|
)
|
||||||
|
|
||||||
|
# New code (multi-timeframe)
|
||||||
|
outputs = model(
|
||||||
|
price_data_1m=price_1m,
|
||||||
|
price_data_1h=price_1h,
|
||||||
|
price_data_1d=price_1d,
|
||||||
|
btc_data_1m=btc_1m,
|
||||||
|
cob_data=cob,
|
||||||
|
tech_data=tech,
|
||||||
|
market_data=market,
|
||||||
|
position_state=position
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation**:
|
||||||
|
```python
|
||||||
|
def forward(self, price_data=None, price_data_1m=None, ...):
|
||||||
|
# Legacy support
|
||||||
|
if price_data is not None and price_data_1m is None:
|
||||||
|
price_data_1m = price_data # Use as 1m data
|
||||||
|
|
||||||
|
# Continue with new logic
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Steps
|
||||||
|
|
||||||
|
### Step 1: Update Model Architecture ✅
|
||||||
|
- [x] Add `price_projections` ModuleDict
|
||||||
|
- [x] Add `btc_projection`
|
||||||
|
- [x] Add `timeframe_weights` parameter
|
||||||
|
- [x] Add `btc_next_candle_head`
|
||||||
|
|
||||||
|
### Step 2: Update Forward Method (TODO)
|
||||||
|
- [ ] Accept multi-timeframe inputs
|
||||||
|
- [ ] Handle missing timeframes
|
||||||
|
- [ ] Combine embeddings with learned weights
|
||||||
|
- [ ] Add BTC embedding
|
||||||
|
- [ ] Maintain backward compatibility
|
||||||
|
|
||||||
|
### Step 3: Update Training Adapter (TODO)
|
||||||
|
- [ ] Extract ALL timeframes from market_state
|
||||||
|
- [ ] Normalize each timeframe independently
|
||||||
|
- [ ] Create multi-timeframe batch dictionary
|
||||||
|
- [ ] Add target next candles for each timeframe
|
||||||
|
|
||||||
|
### Step 4: Update Loss Calculation (TODO)
|
||||||
|
- [ ] Add per-timeframe price losses
|
||||||
|
- [ ] Add BTC prediction loss
|
||||||
|
- [ ] Weight losses appropriately
|
||||||
|
|
||||||
|
### Step 5: Testing (TODO)
|
||||||
|
- [ ] Unit test: multi-timeframe forward pass
|
||||||
|
- [ ] Unit test: missing timeframes handling
|
||||||
|
- [ ] Integration test: full training loop
|
||||||
|
- [ ] Validation: check predictions for all timeframes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Expected Benefits
|
||||||
|
|
||||||
|
### 1. Better Predictions
|
||||||
|
- **Multi-scale patterns**: 1s for micro, 1d for macro trends
|
||||||
|
- **Cross-timeframe validation**: Confirm signals across timeframes
|
||||||
|
- **Market correlation**: BTC reference for market-wide moves
|
||||||
|
|
||||||
|
### 2. Flexible Usage
|
||||||
|
- **Any timeframe**: Choose output timeframe in post-processing
|
||||||
|
- **Scalping**: Use 1s predictions
|
||||||
|
- **Swing trading**: Use 1h/1d predictions
|
||||||
|
- **Same model**: No retraining needed
|
||||||
|
|
||||||
|
### 3. Robustness
|
||||||
|
- **Missing data**: Handles missing 1s gracefully
|
||||||
|
- **Degraded mode**: Works with just 1m if needed
|
||||||
|
- **Backward compatible**: Old code still works
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Actions
|
||||||
|
|
||||||
|
1. **Finish forward() refactor** - Make it accept multi-timeframe inputs
|
||||||
|
2. **Update training adapter** - Pass all timeframes to model
|
||||||
|
3. **Test thoroughly** - Ensure no regressions
|
||||||
|
4. **Document usage** - Show how to use multi-timeframe predictions
|
||||||
|
|
||||||
|
This is a significant improvement that will make the model much more powerful! 🚀
|
||||||
Reference in New Issue
Block a user