T model was only using 1t data so far. fixing WIP

This commit is contained in:
Dobromir Popov
2025-11-06 00:03:19 +02:00
parent 907a7d6224
commit 07d97100c0
2 changed files with 400 additions and 3 deletions

View File

@@ -349,8 +349,17 @@ class AdvancedTradingTransformer(nn.Module):
super().__init__()
self.config = config
# Input projections
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
# Multi-timeframe input projections
# Each timeframe gets its own projection to learn timeframe-specific patterns
self.timeframes = ['1s', '1m', '1h', '1d']
self.price_projections = nn.ModuleDict({
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
})
# Reference symbol projection (BTC 1m)
self.btc_projection = nn.Linear(5, config.d_model)
# Other input projections
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
self.market_projection = nn.Linear(config.market_features, config.d_model)
@@ -367,6 +376,9 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
)
# Timeframe importance weights (learnable)
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
# Positional encoding
if config.use_relative_position:
self.pos_encoding = RelativePositionalEncoding(config.d_model)
@@ -435,7 +447,7 @@ class AdvancedTradingTransformer(nn.Module):
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
self.timeframes = ['1s', '1m', '1h', '1d']
# Note: self.timeframes already defined above in input projections
self.next_candle_heads = nn.ModuleDict({
tf: nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
@@ -448,6 +460,17 @@ class AdvancedTradingTransformer(nn.Module):
) for tf in self.timeframes
})
# BTC next candle prediction head
self.btc_next_candle_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
)
# NEW: Next pivot point prediction heads for L1-L5 levels
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity

View File

@@ -0,0 +1,374 @@
# Multi-Timeframe Transformer - Implementation Plan
## Current Problem
The transformer currently:
- ❌ Only uses **1m timeframe** (ignores 1s, 1h, 1d)
- ❌ Doesn't use **BTC reference data**
- ❌ Has `price_projection` (single) but we changed it to `price_projections` (dict)
- ❌ Predicts next candles for all timeframes but only trains on 1m
## Goal
Make the transformer:
- ✅ Accept **ALL timeframes** (1s, 1m, 1h, 1d) as separate inputs
- ✅ Accept **BTC 1m** as reference symbol
- ✅ Learn **timeframe-specific patterns** with separate projections
- ✅ Predict **next candle for EACH timeframe**
- ✅ Handle **missing timeframes** gracefully (1s may not be available)
- ✅ Maintain **backward compatibility** with existing code
---
## Architecture Changes
### 1. Input Structure (NEW)
```python
forward(
price_data_1s=None, # [batch, 600, 5] - 1-second candles (optional)
price_data_1m=None, # [batch, 600, 5] - 1-minute candles (optional)
price_data_1h=None, # [batch, 600, 5] - 1-hour candles (optional)
price_data_1d=None, # [batch, 600, 5] - 1-day candles (optional)
btc_data_1m=None, # [batch, 600, 5] - BTC 1-minute (optional)
cob_data=None, # [batch, 600, 100] - COB features
tech_data=None, # [batch, 40] - Technical indicators
market_data=None, # [batch, 30] - Market features
position_state=None, # [batch, 5] - Position state
# Legacy support
price_data=None # [batch, 600, 5] - Falls back to 1m if provided
)
```
### 2. Input Projections (UPDATED)
```python
# Separate projection for each timeframe
self.price_projections = nn.ModuleDict({
'1s': nn.Linear(5, 1024), # 1-second patterns
'1m': nn.Linear(5, 1024), # 1-minute patterns
'1h': nn.Linear(5, 1024), # 1-hour patterns
'1d': nn.Linear(5, 1024) # 1-day patterns
})
# BTC reference projection
self.btc_projection = nn.Linear(5, 1024)
# Learnable timeframe importance weights
self.timeframe_weights = nn.Parameter(torch.ones(5)) # 4 timeframes + BTC
```
### 3. Embedding Combination Strategy
**Option A: Weighted Average** (Current Plan)
```python
# Project each timeframe
emb_1s = self.price_projections['1s'](price_data_1s) * weight_1s
emb_1m = self.price_projections['1m'](price_data_1m) * weight_1m
emb_1h = self.price_projections['1h'](price_data_1h) * weight_1h
emb_1d = self.price_projections['1d'](price_data_1d) * weight_1d
emb_btc = self.btc_projection(btc_data_1m) * weight_btc
# Combine
price_emb = (emb_1s + emb_1m + emb_1h + emb_1d + emb_btc) / num_available
```
**Option B: Concatenation + Projection** (Alternative)
```python
# Concatenate all timeframes
combined = torch.cat([emb_1s, emb_1m, emb_1h, emb_1d, emb_btc], dim=-1)
# [batch, seq_len, 5*1024] = [batch, seq_len, 5120]
# Project back to d_model
price_emb = nn.Linear(5120, 1024)(combined)
```
**Decision**: Use **Option A** (Weighted Average)
- ✅ Simpler
- ✅ Handles missing timeframes naturally
- ✅ Fewer parameters
- ✅ Learned weights show timeframe importance
### 4. Output Structure (ALREADY EXISTS)
```python
outputs = {
'action_logits': [batch, 3],
'action_probs': [batch, 3],
'confidence': [batch, 1],
'price_prediction': [batch, 1],
# Next candle predictions for EACH timeframe
'next_candles': {
'1s': [batch, 5], # [open, high, low, close, volume]
'1m': [batch, 5],
'1h': [batch, 5],
'1d': [batch, 5]
},
# BTC next candle
'btc_next_candle': [batch, 5],
# Pivot predictions
'next_pivots': {...},
# Trend analysis
'trend_analysis': {...}
}
```
---
## Training Data Preparation
### Current (WRONG)
```python
# Only uses 1m timeframe
primary_data = timeframes['1m']
price_data = stack_ohlcv(primary_data) # [batch, 150, 5]
batch = {
'price_data': price_data, # Only 1m!
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data
}
```
### New (CORRECT)
```python
# Use ALL available timeframes
batch = {
'price_data_1s': stack_ohlcv(timeframes.get('1s')), # [batch, 600, 5] or None
'price_data_1m': stack_ohlcv(timeframes['1m']), # [batch, 600, 5]
'price_data_1h': stack_ohlcv(timeframes.get('1h')), # [batch, 600, 5] or None
'price_data_1d': stack_ohlcv(timeframes.get('1d')), # [batch, 600, 5] or None
'btc_data_1m': stack_ohlcv(secondary_timeframes.get('BTC/USDT', {}).get('1m')),
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data,
'position_state': position_state
}
```
---
## Sequence Length Strategy
### Problem
Different timeframes have different natural sequence lengths:
- 1s: 600 candles = 10 minutes
- 1m: 600 candles = 10 hours
- 1h: 600 candles = 25 days
- 1d: 600 candles = ~2 years
### Solution: Fixed Sequence Length with Padding
**Use 600 candles for ALL timeframes**:
```python
seq_len = 600 # Fixed for all timeframes
# Each timeframe provides 600 candles
# - 1s: Last 600 seconds (10 min)
# - 1m: Last 600 minutes (10 hours)
# - 1h: Last 600 hours (25 days)
# - 1d: Last 600 days (~2 years)
# If less than 600 available, pad with last candle
if len(candles) < 600:
padding = repeat_last_candle(600 - len(candles))
candles = concat([candles, padding])
```
**Why 600?**
- ✅ Captures sufficient history for all timeframes
- ✅ Not too large (memory efficient)
- ✅ Divisible by many numbers (batch processing)
- ✅ Transformer can handle it (we have 46M params)
---
## Normalization Strategy
### Per-Timeframe Normalization
**Problem**: Different timeframes have different price ranges
- 1s: Tight range (e.g., $2000-$2001)
- 1d: Wide range (e.g., $1800-$2200)
**Solution**: Normalize each timeframe independently
```python
def normalize_timeframe(ohlcv, timeframe):
"""Normalize OHLCV data for a specific timeframe"""
# Use min/max from this timeframe's data
price_min = min(ohlcv[:, [0,1,2,3]].min()) # Min of OHLC
price_max = max(ohlcv[:, [0,1,2,3]].max()) # Max of OHLC
volume_min = ohlcv[:, 4].min()
volume_max = ohlcv[:, 4].max()
# Normalize prices to [0, 1]
ohlcv_norm = ohlcv.copy()
ohlcv_norm[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
ohlcv_norm[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
# Store bounds for denormalization
bounds = {
'price_min': price_min,
'price_max': price_max,
'volume_min': volume_min,
'volume_max': volume_max
}
return ohlcv_norm, bounds
```
### Denormalization for Predictions
```python
def denormalize_prediction(pred_norm, bounds):
"""Convert normalized prediction back to actual price"""
pred = pred_norm * (bounds['price_max'] - bounds['price_min']) + bounds['price_min']
return pred
```
---
## Loss Function Updates
### Current (Single Timeframe)
```python
loss = action_loss + 0.1 * price_loss
```
### New (Multi-Timeframe)
```python
# Action loss (same)
action_loss = CrossEntropyLoss(action_logits, actions)
# Price loss for each timeframe
price_losses = []
for tf in ['1s', '1m', '1h', '1d']:
if f'target_{tf}' in batch:
pred = outputs['next_candles'][tf]
target = batch[f'target_{tf}']
loss_tf = MSELoss(pred, target)
price_losses.append(loss_tf)
# BTC loss
if 'target_btc' in batch:
btc_loss = MSELoss(outputs['btc_next_candle'], batch['target_btc'])
price_losses.append(btc_loss)
# Combined loss
total_price_loss = sum(price_losses) / len(price_losses) if price_losses else 0
total_loss = action_loss + 0.1 * total_price_loss
```
---
## Backward Compatibility
### Legacy Code Support
```python
# Old code (still works)
outputs = model(
price_data=price_1m, # Single timeframe
cob_data=cob,
tech_data=tech,
market_data=market
)
# New code (multi-timeframe)
outputs = model(
price_data_1m=price_1m,
price_data_1h=price_1h,
price_data_1d=price_1d,
btc_data_1m=btc_1m,
cob_data=cob,
tech_data=tech,
market_data=market,
position_state=position
)
```
**Implementation**:
```python
def forward(self, price_data=None, price_data_1m=None, ...):
# Legacy support
if price_data is not None and price_data_1m is None:
price_data_1m = price_data # Use as 1m data
# Continue with new logic
...
```
---
## Implementation Steps
### Step 1: Update Model Architecture ✅
- [x] Add `price_projections` ModuleDict
- [x] Add `btc_projection`
- [x] Add `timeframe_weights` parameter
- [x] Add `btc_next_candle_head`
### Step 2: Update Forward Method (TODO)
- [ ] Accept multi-timeframe inputs
- [ ] Handle missing timeframes
- [ ] Combine embeddings with learned weights
- [ ] Add BTC embedding
- [ ] Maintain backward compatibility
### Step 3: Update Training Adapter (TODO)
- [ ] Extract ALL timeframes from market_state
- [ ] Normalize each timeframe independently
- [ ] Create multi-timeframe batch dictionary
- [ ] Add target next candles for each timeframe
### Step 4: Update Loss Calculation (TODO)
- [ ] Add per-timeframe price losses
- [ ] Add BTC prediction loss
- [ ] Weight losses appropriately
### Step 5: Testing (TODO)
- [ ] Unit test: multi-timeframe forward pass
- [ ] Unit test: missing timeframes handling
- [ ] Integration test: full training loop
- [ ] Validation: check predictions for all timeframes
---
## Expected Benefits
### 1. Better Predictions
- **Multi-scale patterns**: 1s for micro, 1d for macro trends
- **Cross-timeframe validation**: Confirm signals across timeframes
- **Market correlation**: BTC reference for market-wide moves
### 2. Flexible Usage
- **Any timeframe**: Choose output timeframe in post-processing
- **Scalping**: Use 1s predictions
- **Swing trading**: Use 1h/1d predictions
- **Same model**: No retraining needed
### 3. Robustness
- **Missing data**: Handles missing 1s gracefully
- **Degraded mode**: Works with just 1m if needed
- **Backward compatible**: Old code still works
---
## Next Actions
1. **Finish forward() refactor** - Make it accept multi-timeframe inputs
2. **Update training adapter** - Pass all timeframes to model
3. **Test thoroughly** - Ensure no regressions
4. **Document usage** - Show how to use multi-timeframe predictions
This is a significant improvement that will make the model much more powerful! 🚀