Shared Pattern Encoder

fix T training
This commit is contained in:
Dobromir Popov
2025-11-06 14:27:52 +02:00
parent 07d97100c0
commit 738c7cb854
5 changed files with 1276 additions and 180 deletions

View File

@@ -349,36 +349,57 @@ class AdvancedTradingTransformer(nn.Module):
super().__init__()
self.config = config
# Multi-timeframe input projections
# Each timeframe gets its own projection to learn timeframe-specific patterns
# Timeframe configuration
self.timeframes = ['1s', '1m', '1h', '1d']
self.price_projections = nn.ModuleDict({
tf: nn.Linear(5, config.d_model) for tf in self.timeframes # OHLCV per timeframe
})
self.num_timeframes = len(self.timeframes) + 1 # +1 for BTC
# Reference symbol projection (BTC 1m)
self.btc_projection = nn.Linear(5, config.d_model)
# SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes)
# This is applied to each timeframe independently but uses SAME weights
self.shared_pattern_encoder = nn.Sequential(
nn.Linear(5, config.d_model // 4), # 5 OHLCV -> 256
nn.LayerNorm(config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, config.d_model // 2), # 256 -> 512
nn.LayerNorm(config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
)
# Timeframe-specific embeddings (learnable, added to shared encoding)
# These help the model distinguish which timeframe it's looking at
self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model)
# PARALLEL: Cross-timeframe attention layers
# These process all timeframes simultaneously to capture dependencies
self.cross_timeframe_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.n_heads,
dim_feedforward=config.d_ff,
dropout=config.dropout,
activation='gelu',
batch_first=True
) for _ in range(2) # 2 layers for cross-timeframe attention
])
# Other input projections
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
self.market_projection = nn.Linear(config.market_features, config.d_model)
# Position state projection - properly learns to embed position info
# Input: [has_position, pnl, size, entry_price_norm, time_in_position] = 5 features
# Position state projection
self.position_projection = nn.Sequential(
nn.Linear(5, config.d_model // 4), # 5 -> 256
nn.Linear(5, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, config.d_model // 2), # 256 -> 512
nn.Linear(config.d_model // 4, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
nn.Linear(config.d_model // 2, config.d_model)
)
# Timeframe importance weights (learnable)
self.timeframe_weights = nn.Parameter(torch.ones(len(self.timeframes) + 1)) # +1 for BTC
# Positional encoding
if config.use_relative_position:
self.pos_encoding = RelativePositionalEncoding(config.d_model)
@@ -512,41 +533,156 @@ class AdvancedTradingTransformer(nn.Module):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, price_data: torch.Tensor, cob_data: torch.Tensor,
tech_data: torch.Tensor, market_data: torch.Tensor,
def forward(self,
# Multi-timeframe inputs
price_data_1s: Optional[torch.Tensor] = None,
price_data_1m: Optional[torch.Tensor] = None,
price_data_1h: Optional[torch.Tensor] = None,
price_data_1d: Optional[torch.Tensor] = None,
btc_data_1m: Optional[torch.Tensor] = None,
# Other inputs
cob_data: Optional[torch.Tensor] = None,
tech_data: Optional[torch.Tensor] = None,
market_data: Optional[torch.Tensor] = None,
position_state: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
position_state: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Legacy support
price_data: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Forward pass of the trading transformer
Forward pass with hybrid serial-parallel multi-timeframe processing
SERIAL: Shared pattern encoder learns candle patterns once (same weights for all timeframes)
PARALLEL: Cross-timeframe attention captures dependencies between timeframes
Args:
price_data: (batch, seq_len, 5) - OHLCV data
price_data_1s: (batch, seq_len, 5) - 1-second OHLCV (optional)
price_data_1m: (batch, seq_len, 5) - 1-minute OHLCV (optional)
price_data_1h: (batch, seq_len, 5) - 1-hour OHLCV (optional)
price_data_1d: (batch, seq_len, 5) - 1-day OHLCV (optional)
btc_data_1m: (batch, seq_len, 5) - BTC 1-minute OHLCV (optional)
cob_data: (batch, seq_len, cob_features) - COB features
tech_data: (batch, seq_len, tech_features) - Technical indicators
market_data: (batch, seq_len, market_features) - Market microstructure
tech_data: (batch, tech_features) - Technical indicators
market_data: (batch, market_features) - Market features
position_state: (batch, 5) - Position state
mask: Optional attention mask
position_state: (batch, 5) - Position state [has_position, pnl, size, entry_price, time_in_position]
price_data: (batch, seq_len, 5) - Legacy single timeframe (defaults to 1m)
Returns:
Dictionary containing model outputs
Dictionary with predictions for ALL timeframes
"""
batch_size, seq_len = price_data.shape[:2]
# Legacy support
if price_data is not None and price_data_1m is None:
price_data_1m = price_data
# Handle different input dimensions - expand to sequence if needed
if cob_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
if tech_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
if market_data.dim() == 2: # (batch, features) -> (batch, seq_len, features)
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
# Collect available timeframes
timeframe_data = {
'1s': price_data_1s,
'1m': price_data_1m,
'1h': price_data_1h,
'1d': price_data_1d,
'btc': btc_data_1m
}
# Project inputs to model dimension
price_emb = self.price_projection(price_data)
cob_emb = self.cob_projection(cob_data)
tech_emb = self.tech_projection(tech_data)
market_emb = self.market_projection(market_data)
# Filter to available timeframes
available_tfs = [(tf, data) for tf, data in timeframe_data.items() if data is not None]
# Combine embeddings (could also use cross-attention)
if not available_tfs:
raise ValueError("At least one timeframe must be provided")
# Get dimensions from first available timeframe
first_data = available_tfs[0][1]
batch_size, seq_len = first_data.shape[:2]
device = first_data.device
# ============================================================
# STEP 1: SERIAL - Apply shared pattern encoder to each timeframe
# This learns candle patterns ONCE (same weights for all)
# ============================================================
timeframe_encodings = []
timeframe_indices = []
for idx, (tf_name, tf_data) in enumerate(available_tfs):
# Ensure correct sequence length
if tf_data.shape[1] != seq_len:
if tf_data.shape[1] < seq_len:
# Pad with last candle
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
tf_data = torch.cat([tf_data, padding], dim=1)
else:
# Truncate to seq_len
tf_data = tf_data[:, :seq_len, :]
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
tf_encoded = self.shared_pattern_encoder(tf_data)
# Add timeframe-specific embedding (helps model know which timeframe)
# Get timeframe index
tf_idx = self.timeframes.index(tf_name) if tf_name in self.timeframes else len(self.timeframes)
tf_embedding = self.timeframe_embeddings(torch.tensor([tf_idx], device=device))
tf_embedding = tf_embedding.unsqueeze(1).expand(batch_size, seq_len, -1)
# Combine: shared pattern + timeframe identity
tf_encoded = tf_encoded + tf_embedding
timeframe_encodings.append(tf_encoded)
timeframe_indices.append(tf_idx)
# ============================================================
# STEP 2: PARALLEL - Cross-timeframe attention
# Process all timeframes together to capture dependencies
# ============================================================
# Stack timeframes: [batch, num_timeframes, seq_len, d_model]
# Then reshape to: [batch, num_timeframes * seq_len, d_model]
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
num_tfs = len(timeframe_encodings)
# Reshape for cross-timeframe attention
# [batch, num_tfs, seq_len, d_model] -> [batch, num_tfs * seq_len, d_model]
cross_tf_input = stacked_tfs.reshape(batch_size, num_tfs * seq_len, self.config.d_model)
# Apply cross-timeframe attention layers
# This allows the model to see patterns ACROSS timeframes simultaneously
for layer in self.cross_timeframe_layers:
cross_tf_input = layer(cross_tf_input)
# Reshape back: [batch, num_tfs * seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
cross_tf_output = cross_tf_input.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
# Average across timeframes to get unified representation
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
price_emb = cross_tf_output.mean(dim=1)
# ============================================================
# STEP 3: Add other features (COB, tech, market, position)
# ============================================================
# COB features
if cob_data is not None:
if cob_data.dim() == 2:
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
cob_emb = self.cob_projection(cob_data)
else:
cob_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Technical indicators
if tech_data is not None:
if tech_data.dim() == 2:
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
tech_emb = self.tech_projection(tech_data)
else:
tech_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Market features
if market_data is not None:
if market_data.dim() == 2:
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
market_emb = self.market_projection(market_data)
else:
market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Combine all embeddings
x = price_emb + cob_emb + tech_emb + market_emb
# Add position state if provided - critical for loss minimization and profit taking
@@ -622,6 +758,10 @@ class AdvancedTradingTransformer(nn.Module):
next_candles[tf] = candle_pred
outputs['next_candles'] = next_candles
# BTC next candle prediction
btc_next_candle = self.btc_next_candle_head(pooled) # (batch, 5)
outputs['btc_next_candle'] = btc_next_candle
# NEW: Next pivot point predictions for L1-L5
next_pivots = {}
for level in self.pivot_levels:
@@ -1007,13 +1147,18 @@ class TradingTransformerTrainer:
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Forward pass with position state for loss minimization
# Forward pass with multi-timeframe data
outputs = self.model(
batch['price_data'],
batch['cob_data'],
batch['tech_data'],
batch['market_data'],
position_state=batch.get('position_state', None) # Pass position state if available
price_data_1s=batch.get('price_data_1s'),
price_data_1m=batch.get('price_data_1m'),
price_data_1h=batch.get('price_data_1h'),
price_data_1d=batch.get('price_data_1d'),
btc_data_1m=batch.get('btc_data_1m'),
cob_data=batch['cob_data'],
tech_data=batch['tech_data'],
market_data=batch['market_data'],
position_state=batch.get('position_state'),
price_data=batch.get('price_data') # Legacy fallback
)
# Calculate losses
@@ -1078,19 +1223,30 @@ class TradingTransformerTrainer:
self.optimizer.step()
self.scheduler.step()
# Calculate accuracy
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
# Calculate accuracy without gradients
with torch.no_grad():
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
return {
# Extract values and delete tensors to free memory
result = {
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'accuracy': accuracy.item(),
'learning_rate': self.scheduler.get_last_lr()[0]
}
# Delete large tensors to free memory immediately
del outputs, total_loss, action_loss, price_loss, predictions, accuracy
return result
except Exception as e:
logger.error(f"Error in train_step: {e}", exc_info=True)
# Clear any partial computations
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Return a zero loss dict to prevent training from crashing
# but log the error so we can debug
return {