Shared Pattern Encoder
fix T training
This commit is contained in:
@@ -349,36 +349,57 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Multi-timeframe input projections
|
||||
# Each timeframe gets its own projection to learn timeframe-specific patterns
|
||||
# Timeframe configuration
|
||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||
self.price_projections = nn.ModuleDict({
|
||||
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
|
||||
})
|
||||
self.num_timeframes = len(self.timeframes) + 1 # +1 for BTC
|
||||
|
||||
# Reference symbol projection (BTC 1m)
|
||||
self.btc_projection = nn.Linear(5, config.d_model)
|
||||
# SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes)
|
||||
# This is applied to each timeframe independently but uses SAME weights
|
||||
self.shared_pattern_encoder = nn.Sequential(
|
||||
nn.Linear(5, config.d_model // 4), # 5 OHLCV -> 256
|
||||
nn.LayerNorm(config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, config.d_model // 2), # 256 -> 512
|
||||
nn.LayerNorm(config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
||||
)
|
||||
|
||||
# Timeframe-specific embeddings (learnable, added to shared encoding)
|
||||
# These help the model distinguish which timeframe it's looking at
|
||||
self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model)
|
||||
|
||||
# PARALLEL: Cross-timeframe attention layers
|
||||
# These process all timeframes simultaneously to capture dependencies
|
||||
self.cross_timeframe_layers = nn.ModuleList([
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=config.d_model,
|
||||
nhead=config.n_heads,
|
||||
dim_feedforward=config.d_ff,
|
||||
dropout=config.dropout,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
) for _ in range(2) # 2 layers for cross-timeframe attention
|
||||
])
|
||||
|
||||
# Other input projections
|
||||
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
||||
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
||||
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
||||
|
||||
# Position state projection - properly learns to embed position info
|
||||
# Input: [has_position, pnl, size, entry_price_norm, time_in_position] = 5 features
|
||||
# Position state projection
|
||||
self.position_projection = nn.Sequential(
|
||||
nn.Linear(5, config.d_model // 4), # 5 -> 256
|
||||
nn.Linear(5, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, config.d_model // 2), # 256 -> 512
|
||||
nn.Linear(config.d_model // 4, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
||||
nn.Linear(config.d_model // 2, config.d_model)
|
||||
)
|
||||
|
||||
# Timeframe importance weights (learnable)
|
||||
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
|
||||
|
||||
# Positional encoding
|
||||
if config.use_relative_position:
|
||||
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
||||
@@ -512,41 +533,156 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(self, price_data: torch.Tensor, cob_data: torch.Tensor,
|
||||
tech_data: torch.Tensor, market_data: torch.Tensor,
|
||||
def forward(self,
|
||||
# Multi-timeframe inputs
|
||||
price_data_1s: Optional[torch.Tensor] = None,
|
||||
price_data_1m: Optional[torch.Tensor] = None,
|
||||
price_data_1h: Optional[torch.Tensor] = None,
|
||||
price_data_1d: Optional[torch.Tensor] = None,
|
||||
btc_data_1m: Optional[torch.Tensor] = None,
|
||||
# Other inputs
|
||||
cob_data: Optional[torch.Tensor] = None,
|
||||
tech_data: Optional[torch.Tensor] = None,
|
||||
market_data: Optional[torch.Tensor] = None,
|
||||
position_state: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
position_state: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
# Legacy support
|
||||
price_data: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the trading transformer
|
||||
Forward pass with hybrid serial-parallel multi-timeframe processing
|
||||
|
||||
SERIAL: Shared pattern encoder learns candle patterns once (same weights for all timeframes)
|
||||
PARALLEL: Cross-timeframe attention captures dependencies between timeframes
|
||||
|
||||
Args:
|
||||
price_data: (batch, seq_len, 5) - OHLCV data
|
||||
price_data_1s: (batch, seq_len, 5) - 1-second OHLCV (optional)
|
||||
price_data_1m: (batch, seq_len, 5) - 1-minute OHLCV (optional)
|
||||
price_data_1h: (batch, seq_len, 5) - 1-hour OHLCV (optional)
|
||||
price_data_1d: (batch, seq_len, 5) - 1-day OHLCV (optional)
|
||||
btc_data_1m: (batch, seq_len, 5) - BTC 1-minute OHLCV (optional)
|
||||
cob_data: (batch, seq_len, cob_features) - COB features
|
||||
tech_data: (batch, seq_len, tech_features) - Technical indicators
|
||||
market_data: (batch, seq_len, market_features) - Market microstructure
|
||||
tech_data: (batch, tech_features) - Technical indicators
|
||||
market_data: (batch, market_features) - Market features
|
||||
position_state: (batch, 5) - Position state
|
||||
mask: Optional attention mask
|
||||
position_state: (batch, 5) - Position state [has_position, pnl, size, entry_price, time_in_position]
|
||||
price_data: (batch, seq_len, 5) - Legacy single timeframe (defaults to 1m)
|
||||
|
||||
Returns:
|
||||
Dictionary containing model outputs
|
||||
Dictionary with predictions for ALL timeframes
|
||||
"""
|
||||
batch_size, seq_len = price_data.shape[:2]
|
||||
# Legacy support
|
||||
if price_data is not None and price_data_1m is None:
|
||||
price_data_1m = price_data
|
||||
|
||||
# Handle different input dimensions - expand to sequence if needed
|
||||
if cob_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
|
||||
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
if tech_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
|
||||
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
if market_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
|
||||
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
# Collect available timeframes
|
||||
timeframe_data = {
|
||||
'1s': price_data_1s,
|
||||
'1m': price_data_1m,
|
||||
'1h': price_data_1h,
|
||||
'1d': price_data_1d,
|
||||
'btc': btc_data_1m
|
||||
}
|
||||
|
||||
# Project inputs to model dimension
|
||||
price_emb = self.price_projection(price_data)
|
||||
cob_emb = self.cob_projection(cob_data)
|
||||
tech_emb = self.tech_projection(tech_data)
|
||||
market_emb = self.market_projection(market_data)
|
||||
# Filter to available timeframes
|
||||
available_tfs = [(tf, data) for tf, data in timeframe_data.items() if data is not None]
|
||||
|
||||
# Combine embeddings (could also use cross-attention)
|
||||
if not available_tfs:
|
||||
raise ValueError("At least one timeframe must be provided")
|
||||
|
||||
# Get dimensions from first available timeframe
|
||||
first_data = available_tfs[0][1]
|
||||
batch_size, seq_len = first_data.shape[:2]
|
||||
device = first_data.device
|
||||
|
||||
# ============================================================
|
||||
# STEP 1: SERIAL - Apply shared pattern encoder to each timeframe
|
||||
# This learns candle patterns ONCE (same weights for all)
|
||||
# ============================================================
|
||||
timeframe_encodings = []
|
||||
timeframe_indices = []
|
||||
|
||||
for idx, (tf_name, tf_data) in enumerate(available_tfs):
|
||||
# Ensure correct sequence length
|
||||
if tf_data.shape[1] != seq_len:
|
||||
if tf_data.shape[1] < seq_len:
|
||||
# Pad with last candle
|
||||
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
|
||||
tf_data = torch.cat([tf_data, padding], dim=1)
|
||||
else:
|
||||
# Truncate to seq_len
|
||||
tf_data = tf_data[:, :seq_len, :]
|
||||
|
||||
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
|
||||
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
|
||||
tf_encoded = self.shared_pattern_encoder(tf_data)
|
||||
|
||||
# Add timeframe-specific embedding (helps model know which timeframe)
|
||||
# Get timeframe index
|
||||
tf_idx = self.timeframes.index(tf_name) if tf_name in self.timeframes else len(self.timeframes)
|
||||
tf_embedding = self.timeframe_embeddings(torch.tensor([tf_idx], device=device))
|
||||
tf_embedding = tf_embedding.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
|
||||
# Combine: shared pattern + timeframe identity
|
||||
tf_encoded = tf_encoded + tf_embedding
|
||||
|
||||
timeframe_encodings.append(tf_encoded)
|
||||
timeframe_indices.append(tf_idx)
|
||||
|
||||
# ============================================================
|
||||
# STEP 2: PARALLEL - Cross-timeframe attention
|
||||
# Process all timeframes together to capture dependencies
|
||||
# ============================================================
|
||||
|
||||
# Stack timeframes: [batch, num_timeframes, seq_len, d_model]
|
||||
# Then reshape to: [batch, num_timeframes * seq_len, d_model]
|
||||
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
|
||||
num_tfs = len(timeframe_encodings)
|
||||
|
||||
# Reshape for cross-timeframe attention
|
||||
# [batch, num_tfs, seq_len, d_model] -> [batch, num_tfs * seq_len, d_model]
|
||||
cross_tf_input = stacked_tfs.reshape(batch_size, num_tfs * seq_len, self.config.d_model)
|
||||
|
||||
# Apply cross-timeframe attention layers
|
||||
# This allows the model to see patterns ACROSS timeframes simultaneously
|
||||
for layer in self.cross_timeframe_layers:
|
||||
cross_tf_input = layer(cross_tf_input)
|
||||
|
||||
# Reshape back: [batch, num_tfs * seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||
cross_tf_output = cross_tf_input.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Average across timeframes to get unified representation
|
||||
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
||||
price_emb = cross_tf_output.mean(dim=1)
|
||||
|
||||
# ============================================================
|
||||
# STEP 3: Add other features (COB, tech, market, position)
|
||||
# ============================================================
|
||||
|
||||
# COB features
|
||||
if cob_data is not None:
|
||||
if cob_data.dim() == 2:
|
||||
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
cob_emb = self.cob_projection(cob_data)
|
||||
else:
|
||||
cob_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
||||
|
||||
# Technical indicators
|
||||
if tech_data is not None:
|
||||
if tech_data.dim() == 2:
|
||||
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
tech_emb = self.tech_projection(tech_data)
|
||||
else:
|
||||
tech_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
||||
|
||||
# Market features
|
||||
if market_data is not None:
|
||||
if market_data.dim() == 2:
|
||||
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
||||
market_emb = self.market_projection(market_data)
|
||||
else:
|
||||
market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
||||
|
||||
# Combine all embeddings
|
||||
x = price_emb + cob_emb + tech_emb + market_emb
|
||||
|
||||
# Add position state if provided - critical for loss minimization and profit taking
|
||||
@@ -622,6 +758,10 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
next_candles[tf] = candle_pred
|
||||
outputs['next_candles'] = next_candles
|
||||
|
||||
# BTC next candle prediction
|
||||
btc_next_candle = self.btc_next_candle_head(pooled) # (batch, 5)
|
||||
outputs['btc_next_candle'] = btc_next_candle
|
||||
|
||||
# NEW: Next pivot point predictions for L1-L5
|
||||
next_pivots = {}
|
||||
for level in self.pivot_levels:
|
||||
@@ -1007,13 +1147,18 @@ class TradingTransformerTrainer:
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Forward pass with position state for loss minimization
|
||||
# Forward pass with multi-timeframe data
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data'],
|
||||
position_state=batch.get('position_state', None) # Pass position state if available
|
||||
price_data_1s=batch.get('price_data_1s'),
|
||||
price_data_1m=batch.get('price_data_1m'),
|
||||
price_data_1h=batch.get('price_data_1h'),
|
||||
price_data_1d=batch.get('price_data_1d'),
|
||||
btc_data_1m=batch.get('btc_data_1m'),
|
||||
cob_data=batch['cob_data'],
|
||||
tech_data=batch['tech_data'],
|
||||
market_data=batch['market_data'],
|
||||
position_state=batch.get('position_state'),
|
||||
price_data=batch.get('price_data') # Legacy fallback
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
@@ -1078,19 +1223,30 @@ class TradingTransformerTrainer:
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
# Calculate accuracy without gradients
|
||||
with torch.no_grad():
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
return {
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
# Delete large tensors to free memory immediately
|
||||
del outputs, total_loss, action_loss, price_loss, predictions, accuracy
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in train_step: {e}", exc_info=True)
|
||||
# Clear any partial computations
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Return a zero loss dict to prevent training from crashing
|
||||
# but log the error so we can debug
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user