#!/usr/bin/env python3 """ Advanced Transformer Models for High-Frequency Trading Optimized for COB data, technical indicators, and market microstructure """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from torch.utils.checkpoint import checkpoint import numpy as np import math import logging from typing import Dict, Any, Optional, Tuple, List, Callable from dataclasses import dataclass import os import json from datetime import datetime # Configure logging logger = logging.getLogger(__name__) @dataclass class TradingTransformerConfig: """Configuration for trading transformer models - OPTIMIZED FOR GPU (8-12M params)""" # Model architecture - REDUCED for efficient GPU training d_model: int = 256 # Model dimension (was 1024) n_heads: int = 8 # Number of attention heads (was 16) n_layers: int = 4 # Number of transformer layers (was 12) d_ff: int = 1024 # Feed-forward dimension (was 4096) dropout: float = 0.1 # Dropout rate # Input dimensions - OPTIMIZED seq_len: int = 200 # Sequence length for time series cob_features: int = 100 # COB feature dimension tech_features: int = 40 # Technical indicator features market_features: int = 30 # Market microstructure features # Output configuration n_actions: int = 3 # BUY, SELL, HOLD confidence_output: bool = True # Output confidence scores # Training configuration - OPTIMIZED FOR LARGER MODEL learning_rate: float = 5e-5 # Reduced for larger model weight_decay: float = 1e-4 # Increased regularization warmup_steps: int = 8000 # More warmup steps max_grad_norm: float = 0.5 # Tighter gradient clipping # Advanced features - ENHANCED use_relative_position: bool = True use_multi_scale_attention: bool = True use_market_regime_detection: bool = True use_uncertainty_estimation: bool = True # NEW: Additional scaling features use_deep_attention: bool = True # Deeper attention mechanisms use_residual_connections: bool = True # Enhanced residual connections use_layer_norm_variants: bool = True # Advanced normalization # Memory optimization use_gradient_checkpointing: bool = True # Trade compute for memory (saves ~30% memory) class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for transformer""" def __init__(self, d_model: int, max_len: int = 5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.pe[:x.size(0), :] class RelativePositionalEncoding(nn.Module): """Relative positional encoding for better temporal understanding""" def __init__(self, d_model: int, max_relative_position: int = 128): super().__init__() self.d_model = d_model self.max_relative_position = max_relative_position # Learnable relative position embeddings self.relative_position_embeddings = nn.Embedding( 2 * max_relative_position + 1, d_model ) def forward(self, seq_len: int) -> torch.Tensor: """Generate relative position encoding matrix""" range_vec = torch.arange(seq_len) range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1) distance_mat = range_mat - range_mat.transpose(0, 1) # Clip to max relative position distance_mat_clipped = torch.clamp( distance_mat, -self.max_relative_position, self.max_relative_position ) # Shift to positive indices final_mat = distance_mat_clipped + self.max_relative_position return self.relative_position_embeddings(final_mat) class DeepMultiScaleAttention(nn.Module): """Lightweight multi-scale attention optimized for 8-12M parameter model""" def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5]): super().__init__() self.d_model = d_model self.n_heads = n_heads self.scales = scales # Reduced from 6 scales to 3 self.head_dim = d_model // n_heads assert d_model % n_heads == 0, "d_model must be divisible by n_heads" # Lightweight multi-scale projections (single layer instead of deep) self.scale_projections = nn.ModuleList([ nn.ModuleDict({ 'query': nn.Linear(d_model, d_model), 'key': nn.Linear(d_model, d_model), 'value': nn.Linear(d_model, d_model), 'conv': nn.Conv1d(d_model, d_model, kernel_size=scale, padding=scale//2, groups=d_model//4) }) for scale in scales ]) # Lightweight output projection self.output_projection = nn.Linear(d_model * len(scales), d_model) self.dropout = nn.Dropout(0.1) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, seq_len, _ = x.size() scale_outputs = [] for scale_proj in self.scale_projections: # Apply enhanced temporal convolution for this scale x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2) # Enhanced attention computation with deeper projections Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim) # Transpose for attention computation Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim) K = K.transpose(1, 2) V = V.transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: scores.masked_fill_(mask == 0, -1e9) attention = F.softmax(scores, dim=-1) attention = self.dropout(attention) output = torch.matmul(attention, V) output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) scale_outputs.append(output) # Combine multi-scale outputs combined = torch.cat(scale_outputs, dim=-1) output = self.output_projection(combined) return output class MarketRegimeDetector(nn.Module): """Market regime detection module for adaptive behavior""" def __init__(self, d_model: int, n_regimes: int = 4): super().__init__() self.d_model = d_model self.n_regimes = n_regimes # Regime classification layers self.regime_classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_model // 2, n_regimes) ) # Regime-specific transformations self.regime_transforms = nn.ModuleList([ nn.Linear(d_model, d_model) for _ in range(n_regimes) ]) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Global pooling for regime detection pooled = torch.mean(x, dim=1) # (batch, d_model) # Classify market regime regime_logits = self.regime_classifier(pooled) regime_probs = F.softmax(regime_logits, dim=-1) # Apply regime-specific transformations regime_outputs = [] for i, transform in enumerate(self.regime_transforms): regime_output = transform(x) # (batch, seq_len, d_model) regime_outputs.append(regime_output) # Weighted combination based on regime probabilities regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model) regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes) regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1) # Weighted sum across regimes adapted_output = torch.sum(regime_stack * regime_weights, dim=0) return adapted_output, regime_probs class UncertaintyEstimation(nn.Module): """Uncertainty estimation using Monte Carlo Dropout""" def __init__(self, d_model: int, n_samples: int = 10): super().__init__() self.d_model = d_model self.n_samples = n_samples self.uncertainty_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.ReLU(), nn.Dropout(0.5), # Higher dropout for uncertainty estimation nn.Linear(d_model // 2, 1), nn.Sigmoid() ) def forward(self, x: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if training or not self.training: # Single forward pass during training or when not in MC mode uncertainty = self.uncertainty_head(x) return uncertainty, uncertainty # Monte Carlo sampling during inference uncertainties = [] for _ in range(self.n_samples): uncertainty = self.uncertainty_head(x) uncertainties.append(uncertainty) uncertainties = torch.stack(uncertainties, dim=0) mean_uncertainty = torch.mean(uncertainties, dim=0) std_uncertainty = torch.std(uncertainties, dim=0) return mean_uncertainty, std_uncertainty class TradingTransformerLayer(nn.Module): """Enhanced transformer layer for trading applications""" def __init__(self, config: TradingTransformerConfig): super().__init__() self.config = config # Enhanced multi-scale attention or standard attention if config.use_multi_scale_attention: self.attention = DeepMultiScaleAttention(config.d_model, config.n_heads) else: self.attention = nn.MultiheadAttention( config.d_model, config.n_heads, dropout=config.dropout, batch_first=True ) # Feed-forward network self.feed_forward = nn.Sequential( nn.Linear(config.d_model, config.d_ff), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_ff, config.d_model) ) # Layer normalization self.norm1 = nn.LayerNorm(config.d_model) self.norm2 = nn.LayerNorm(config.d_model) # Dropout self.dropout = nn.Dropout(config.dropout) # Market regime detection if config.use_market_regime_detection: self.regime_detector = MarketRegimeDetector(config.d_model) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: # Self-attention with residual connection if isinstance(self.attention, DeepMultiScaleAttention): attn_output = self.attention(x, mask) else: attn_output, _ = self.attention(x, x, x, attn_mask=mask) x = self.norm1(x + self.dropout(attn_output)) # Market regime adaptation regime_probs = None if hasattr(self, 'regime_detector'): x, regime_probs = self.regime_detector(x) # Feed-forward with residual connection ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return { 'output': x, 'regime_probs': regime_probs } class AdvancedTradingTransformer(nn.Module): """Advanced transformer model for high-frequency trading""" def __init__(self, config: TradingTransformerConfig): super().__init__() self.config = config # Timeframe configuration self.timeframes = ['1s', '1m', '1h', '1d'] self.num_timeframes = len(self.timeframes) + 1 # +1 for BTC # SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes) # This is applied to each timeframe independently but uses SAME weights # LIGHTWEIGHT: 2-layer encoder for efficiency self.shared_pattern_encoder = nn.Sequential( nn.Linear(5, config.d_model // 2), # 5 OHLCV -> 128 nn.LayerNorm(config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, config.d_model) # 128 -> 256 ) # Timeframe-specific embeddings (learnable, added to shared encoding) # These help the model distinguish which timeframe it's looking at self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model) # PARALLEL: Cross-timeframe attention layer (single layer for efficiency) # Processes all timeframes simultaneously to capture dependencies self.cross_timeframe_layer = nn.TransformerEncoderLayer( d_model=config.d_model, nhead=config.n_heads, dim_feedforward=config.d_ff, dropout=config.dropout, activation='gelu', batch_first=True ) # Other input projections self.cob_projection = nn.Linear(config.cob_features, config.d_model) self.tech_projection = nn.Linear(config.tech_features, config.d_model) self.market_projection = nn.Linear(config.market_features, config.d_model) # Position state projection self.position_projection = nn.Sequential( nn.Linear(5, config.d_model // 4), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 4, config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, config.d_model) ) # Positional encoding if config.use_relative_position: self.pos_encoding = RelativePositionalEncoding(config.d_model) else: self.pos_encoding = PositionalEncoding(config.d_model, config.seq_len) # Transformer layers self.layers = nn.ModuleList([ TradingTransformerLayer(config) for _ in range(config.n_layers) ]) # Lightweight output heads for 8-12M parameter model self.action_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, config.n_actions) ) if config.confidence_output: self.confidence_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, 1), nn.Sigmoid() ) # Enhanced uncertainty estimation if config.use_uncertainty_estimation: self.uncertainty_estimator = UncertaintyEstimation(config.d_model) # Lightweight price prediction head self.price_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model // 2, 1), nn.Tanh() # Constrain to [-1, 1] range for price change ratio ) # Lightweight volatility and trend heads self.volatility_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 4), nn.GELU(), nn.Linear(config.d_model // 4, 1), nn.Softplus() ) self.trend_strength_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 4), nn.GELU(), nn.Linear(config.d_model // 4, 1), nn.Tanh() ) # Lightweight next candle OHLCV prediction heads self.next_candle_heads = nn.ModuleDict({ tf: nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Linear(config.d_model // 2, 5), # OHLCV nn.Sigmoid() # Constrain to [0, 1] ) for tf in self.timeframes }) # BTC next candle prediction head self.btc_next_candle_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Linear(config.d_model // 2, 5), # OHLCV for BTC nn.Sigmoid() ) # Lightweight pivot point prediction heads (L1-L3 only for efficiency) self.pivot_levels = [1, 2, 3] # Reduced from L1-L5 to L1-L3 self.pivot_heads = nn.ModuleDict({ f'L{level}': nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Linear(config.d_model // 2, 4) # [price, type_prob_high, type_prob_low, confidence] ) for level in self.pivot_levels }) # Lightweight trend vector analysis head self.trend_analysis_head = nn.Sequential( nn.Linear(config.d_model, config.d_model // 2), nn.GELU(), nn.Linear(config.d_model // 2, 3) # [angle_radians, steepness, direction] ) # Initialize weights self._init_weights() def _init_weights(self): """Initialize model weights""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, # Multi-timeframe inputs price_data_1s: Optional[torch.Tensor] = None, price_data_1m: Optional[torch.Tensor] = None, price_data_1h: Optional[torch.Tensor] = None, price_data_1d: Optional[torch.Tensor] = None, btc_data_1m: Optional[torch.Tensor] = None, # Other inputs cob_data: Optional[torch.Tensor] = None, tech_data: Optional[torch.Tensor] = None, market_data: Optional[torch.Tensor] = None, position_state: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, # Legacy support price_data: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: """ Forward pass with hybrid serial-parallel multi-timeframe processing SERIAL: Shared pattern encoder learns candle patterns once (same weights for all timeframes) PARALLEL: Cross-timeframe attention captures dependencies between timeframes Args: price_data_1s: (batch, seq_len, 5) - 1-second OHLCV (optional) price_data_1m: (batch, seq_len, 5) - 1-minute OHLCV (optional) price_data_1h: (batch, seq_len, 5) - 1-hour OHLCV (optional) price_data_1d: (batch, seq_len, 5) - 1-day OHLCV (optional) btc_data_1m: (batch, seq_len, 5) - BTC 1-minute OHLCV (optional) cob_data: (batch, seq_len, cob_features) - COB features tech_data: (batch, tech_features) - Technical indicators market_data: (batch, market_features) - Market features position_state: (batch, 5) - Position state mask: Optional attention mask price_data: (batch, seq_len, 5) - Legacy single timeframe (defaults to 1m) Returns: Dictionary with predictions for ALL timeframes """ # Legacy support if price_data is not None and price_data_1m is None: price_data_1m = price_data # Collect available timeframes timeframe_data = { '1s': price_data_1s, '1m': price_data_1m, '1h': price_data_1h, '1d': price_data_1d, 'btc': btc_data_1m } # Filter to available timeframes available_tfs = [(tf, data) for tf, data in timeframe_data.items() if data is not None] if not available_tfs: raise ValueError("At least one timeframe must be provided") # Get dimensions from first available timeframe first_data = available_tfs[0][1] batch_size, seq_len = first_data.shape[:2] device = first_data.device # ============================================================ # STEP 1: SERIAL - Apply shared pattern encoder to each timeframe # This learns candle patterns ONCE (same weights for all) # ============================================================ timeframe_encodings = [] timeframe_indices = [] for idx, (tf_name, tf_data) in enumerate(available_tfs): # Ensure correct sequence length if tf_data.shape[1] != seq_len: if tf_data.shape[1] < seq_len: # Pad with last candle padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5) tf_data = torch.cat([tf_data, padding], dim=1) else: # Truncate to seq_len tf_data = tf_data[:, :seq_len, :] # Apply SHARED pattern encoder (learns patterns once for all timeframes) # Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model] tf_encoded = self.shared_pattern_encoder(tf_data) # Add timeframe-specific embedding (helps model know which timeframe) # Get timeframe index tf_idx = self.timeframes.index(tf_name) if tf_name in self.timeframes else len(self.timeframes) tf_embedding = self.timeframe_embeddings(torch.tensor([tf_idx], device=device)) tf_embedding = tf_embedding.unsqueeze(1).expand(batch_size, seq_len, -1) # Combine: shared pattern + timeframe identity tf_encoded = tf_encoded + tf_embedding timeframe_encodings.append(tf_encoded) timeframe_indices.append(tf_idx) # ============================================================ # STEP 2: PARALLEL - Cross-timeframe attention # Process all timeframes together to capture dependencies # ============================================================ # Stack timeframes: [batch, num_timeframes, seq_len, d_model] # Then reshape to: [batch, num_timeframes * seq_len, d_model] stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model] num_tfs = len(timeframe_encodings) # MEMORY EFFICIENT: Process timeframes with shared weights # Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model] # This avoids creating huge concatenated sequences while still processing efficiently batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model) # Apply single cross-timeframe attention layer batched_tfs = self.cross_timeframe_layer(batched_tfs) # Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model] cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model) # Average across timeframes to get unified representation # [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model] price_emb = cross_tf_output.mean(dim=1) # ============================================================ # STEP 3: Add other features (COB, tech, market, position) # ============================================================ # COB features if cob_data is not None: if cob_data.dim() == 2: cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1) cob_emb = self.cob_projection(cob_data) else: cob_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device) # Technical indicators if tech_data is not None: if tech_data.dim() == 2: tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1) tech_emb = self.tech_projection(tech_data) else: tech_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device) # Market features if market_data is not None: if market_data.dim() == 2: market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1) market_emb = self.market_projection(market_data) else: market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device) # Combine all embeddings x = price_emb + cob_emb + tech_emb + market_emb # Add position state if provided - critical for loss minimization and profit taking if position_state is not None: # Project position state through learned embedding network # Input: [batch, 5] -> Output: [batch, d_model] position_emb = self.position_projection(position_state) # [batch, d_model] # Expand to sequence length and add as bias to all positions # This conditions the entire sequence on current position state position_emb = position_emb.unsqueeze(1).expand(batch_size, seq_len, -1) # [batch, seq_len, d_model] # Add position embedding to the combined embeddings # This allows the model to modulate its predictions based on position state x = x + position_emb # Add positional encoding if isinstance(self.pos_encoding, RelativePositionalEncoding): # Relative position encoding is applied in attention pass else: x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1) # Apply transformer layers with optional gradient checkpointing regime_probs_history = [] for layer in self.layers: if self.training and self.config.use_gradient_checkpointing: # Use gradient checkpointing to save memory during training # Trades compute for memory (recomputes activations during backward pass) layer_output = checkpoint( layer, x, mask, use_reentrant=False ) else: layer_output = layer(x, mask) x = layer_output['output'] if layer_output['regime_probs'] is not None: regime_probs_history.append(layer_output['regime_probs']) # Global pooling for final prediction # Use attention-based pooling pooling_weights = F.softmax( torch.sum(x, dim=-1, keepdim=True), dim=1 ) pooled = torch.sum(x * pooling_weights, dim=1) # Generate outputs outputs = {} # Action prediction action_logits = self.action_head(pooled) outputs['action_logits'] = action_logits outputs['action_probs'] = F.softmax(action_logits, dim=-1) # Confidence prediction if self.config.confidence_output: confidence = self.confidence_head(pooled) outputs['confidence'] = confidence # Uncertainty estimation if self.config.use_uncertainty_estimation: uncertainty_mean, uncertainty_std = self.uncertainty_estimator(pooled) outputs['uncertainty_mean'] = uncertainty_mean outputs['uncertainty_std'] = uncertainty_std # Enhanced price prediction (auxiliary task) price_pred = self.price_head(pooled) outputs['price_prediction'] = price_pred # Additional specialized predictions for 46M model volatility_pred = self.volatility_head(pooled) outputs['volatility_prediction'] = volatility_pred trend_strength_pred = self.trend_strength_head(pooled) outputs['trend_strength_prediction'] = trend_strength_pred # NEW: Next candle OHLCV predictions for each timeframe next_candles = {} for tf in self.timeframes: candle_pred = self.next_candle_heads[tf](pooled) # (batch, 5) next_candles[tf] = candle_pred outputs['next_candles'] = next_candles # BTC next candle prediction btc_next_candle = self.btc_next_candle_head(pooled) # (batch, 5) outputs['btc_next_candle'] = btc_next_candle # NEW: Next pivot point predictions for L1-L5 next_pivots = {} for level in self.pivot_levels: pivot_pred = self.pivot_heads[f'L{level}'](pooled) # (batch, 4) # Extract components: [price, type_logit_high, type_logit_low, confidence] # Use softmax to ensure type probabilities sum to 1 type_logits = pivot_pred[:, 1:3] # (batch, 2) - [high, low] type_probs = F.softmax(type_logits, dim=-1) # (batch, 2) next_pivots[f'L{level}'] = { 'price': pivot_pred[:, 0:1], # Keep as (batch, 1) 'type_prob_high': type_probs[:, 0:1], # Probability of high pivot 'type_prob_low': type_probs[:, 1:2], # Probability of low pivot 'pivot_type': torch.argmax(type_probs, dim=-1, keepdim=True), # 0=high, 1=low 'confidence': torch.sigmoid(pivot_pred[:, 3:4]) # Prediction confidence } outputs['next_pivots'] = next_pivots # NEW: Trend vector analysis from pivot predictions trend_analysis = self.trend_analysis_head(pooled) # (batch, 3) outputs['trend_analysis'] = { 'angle_radians': trend_analysis[:, 0:1], # Trend angle in radians 'steepness': F.softplus(trend_analysis[:, 1:2]), # Always positive steepness 'direction': torch.tanh(trend_analysis[:, 2:3]) # -1 to 1 (down to up) } # NEW: Calculate trend vector from pivot predictions # Extract pivot prices and create trend vector pivot_prices = torch.stack([next_pivots[f'L{level}']['price'] for level in self.pivot_levels], dim=1) # (batch, 5, 1) pivot_prices = pivot_prices.squeeze(-1) # (batch, 5) # Calculate trend vector: (price_change, time_change) # Assume equal time spacing between pivot levels time_points = torch.arange(1, len(self.pivot_levels) + 1, dtype=torch.float32, device=pooled.device).unsqueeze(0) # (1, 5) # Calculate trend line slope using linear regression on pivot prices # Trend vector = (delta_price, delta_time) normalized if batch_size > 0: # For each sample, calculate trend from L1 to L5 price_deltas = pivot_prices[:, -1:] - pivot_prices[:, :1] # L5 - L1 price change time_deltas = time_points[:, -1:] - time_points[:, :1] # Time change (should be 4) # Calculate angle and steepness trend_angles = torch.atan2(price_deltas.squeeze(), time_deltas.squeeze()) # (batch,) trend_steepness = torch.sqrt(price_deltas.squeeze() ** 2 + time_deltas.squeeze() ** 2) # (batch,) trend_direction = torch.sign(price_deltas.squeeze()) # (batch,) outputs['trend_vector'] = { 'pivot_prices': pivot_prices, # (batch, 5) - prices for L1-L5 'price_delta': price_deltas.squeeze(), # (batch,) - price change from L1 to L5 'time_delta': time_deltas.squeeze(), # (batch,) - time change 'calculated_angle': trend_angles.unsqueeze(-1), # (batch, 1) 'calculated_steepness': trend_steepness.unsqueeze(-1), # (batch, 1) 'calculated_direction': trend_direction.unsqueeze(-1), # (batch, 1) 'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=0).unsqueeze(0) if batch_size == 1 else torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta] } else: outputs['trend_vector'] = { 'pivot_prices': pivot_prices, 'price_delta': torch.zeros(batch_size, device=pooled.device), 'time_delta': torch.zeros(batch_size, device=pooled.device), 'calculated_angle': torch.zeros(batch_size, 1, device=pooled.device), 'calculated_steepness': torch.zeros(batch_size, 1, device=pooled.device), 'calculated_direction': torch.zeros(batch_size, 1, device=pooled.device), 'vector': torch.zeros(batch_size, 2, device=pooled.device) } # NEW: Trade action based on trend steepness and angle # Combine predicted trend analysis with calculated trend vector predicted_angle = outputs['trend_analysis']['angle_radians'].squeeze() # (batch,) predicted_steepness = outputs['trend_analysis']['steepness'].squeeze() # (batch,) predicted_direction = outputs['trend_analysis']['direction'].squeeze() # (batch,) # Use calculated trend if available, otherwise use predicted if 'calculated_angle' in outputs['trend_vector']: trend_angle = outputs['trend_vector']['calculated_angle'].squeeze() # (batch,) trend_steepness_val = outputs['trend_vector']['calculated_steepness'].squeeze() # (batch,) else: trend_angle = predicted_angle trend_steepness_val = predicted_steepness # Trade action logic based on trend steepness and angle # Steep upward trend (> 45 degrees) -> BUY # Steep downward trend (< -45 degrees) -> SELL # Shallow trend -> HOLD angle_threshold = math.pi / 4 # 45 degrees # Determine action from trend angle trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD] # Calculate action probabilities based on trend for i in range(batch_size): # Handle both 0-dim and 1-dim tensors if trend_angle.dim() == 0: angle = trend_angle.item() steep = trend_steepness_val.item() else: angle = trend_angle[i].item() steep = trend_steepness_val[i].item() # Normalize steepness to [0, 1] range (assuming max steepness of 10 units) normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0 if angle > angle_threshold: # Steep upward trend trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD elif angle < -angle_threshold: # Steep downward trend trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD else: # Shallow trend trend_action_logits[i, 2] = 1.0 # HOLD # Combine trend-based action with main action prediction trend_action_probs = F.softmax(trend_action_logits, dim=-1) outputs['trend_based_action'] = { 'logits': trend_action_logits, 'probabilities': trend_action_probs, 'action_idx': torch.argmax(trend_action_probs, dim=-1), 'trend_angle_degrees': trend_angle * 180.0 / math.pi, # Convert to degrees 'trend_steepness': trend_steepness_val } # Market regime information if regime_probs_history: outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1) return outputs def extract_predictions(self, outputs: Dict[str, torch.Tensor], denormalize_prices: Optional[Callable] = None) -> Dict[str, Any]: """ Extract predictions from model outputs in a user-friendly format Args: outputs: Raw model outputs from forward() method denormalize_prices: Optional function to denormalize predicted prices Returns: Dictionary with formatted predictions including: - next_candles: Dict[str, Dict] - OHLCV predictions for each timeframe - next_pivots: Dict[str, Dict] - Pivot predictions for L1-L5 - trend_vector: Dict - Trend vector analysis - trend_based_action: Dict - Trading action based on trend """ self.eval() device = next(self.parameters()).device predictions = {} # Extract next candle predictions for each timeframe if 'next_candles' in outputs: next_candles = {} for tf in self.timeframes: candle_tensor = outputs['next_candles'][tf] if candle_tensor.dim() > 1: candle_tensor = candle_tensor[0] # Take first batch item candle_values = candle_tensor.cpu().detach().numpy() if hasattr(candle_tensor, 'cpu') else candle_tensor if isinstance(candle_values, np.ndarray): candle_values = candle_values.tolist() next_candles[tf] = { 'open': float(candle_values[0]) if len(candle_values) > 0 else 0.0, 'high': float(candle_values[1]) if len(candle_values) > 1 else 0.0, 'low': float(candle_values[2]) if len(candle_values) > 2 else 0.0, 'close': float(candle_values[3]) if len(candle_values) > 3 else 0.0, 'volume': float(candle_values[4]) if len(candle_values) > 4 else 0.0 } # Denormalize if function provided if denormalize_prices and callable(denormalize_prices): for key in ['open', 'high', 'low', 'close']: next_candles[tf][key] = denormalize_prices(next_candles[tf][key]) predictions['next_candles'] = next_candles # Extract pivot point predictions if 'next_pivots' in outputs: next_pivots = {} for level in self.pivot_levels: pivot_data = outputs['next_pivots'][f'L{level}'] # Extract values price = pivot_data['price'] if price.dim() > 1: price = price[0, 0] if price.shape[0] > 0 else torch.tensor(0.0, device=device) price_val = float(price.cpu().detach().item() if hasattr(price, 'cpu') else price) type_prob_high = pivot_data['type_prob_high'] if type_prob_high.dim() > 1: type_prob_high = type_prob_high[0, 0] if type_prob_high.shape[0] > 0 else torch.tensor(0.0, device=device) prob_high = float(type_prob_high.cpu().detach().item() if hasattr(type_prob_high, 'cpu') else type_prob_high) type_prob_low = pivot_data['type_prob_low'] if type_prob_low.dim() > 1: type_prob_low = type_prob_low[0, 0] if type_prob_low.shape[0] > 0 else torch.tensor(0.0, device=device) prob_low = float(type_prob_low.cpu().detach().item() if hasattr(type_prob_low, 'cpu') else type_prob_low) confidence = pivot_data['confidence'] if confidence.dim() > 1: confidence = confidence[0, 0] if confidence.shape[0] > 0 else torch.tensor(0.0, device=device) conf_val = float(confidence.cpu().detach().item() if hasattr(confidence, 'cpu') else confidence) pivot_type = pivot_data.get('pivot_type', torch.tensor(0)) if isinstance(pivot_type, torch.Tensor): if pivot_type.dim() > 1: pivot_type = pivot_type[0, 0] if pivot_type.shape[0] > 0 else torch.tensor(0, device=device) pivot_type_val = int(pivot_type.cpu().detach().item() if hasattr(pivot_type, 'cpu') else pivot_type) else: pivot_type_val = int(pivot_type) # Denormalize price if function provided if denormalize_prices and callable(denormalize_prices): price_val = denormalize_prices(price_val) next_pivots[f'L{level}'] = { 'price': price_val, 'type': 'high' if pivot_type_val == 0 else 'low', 'type_prob_high': prob_high, 'type_prob_low': prob_low, 'confidence': conf_val } predictions['next_pivots'] = next_pivots # Extract trend vector if 'trend_vector' in outputs: trend_vec = outputs['trend_vector'] # Extract pivot prices pivot_prices = trend_vec.get('pivot_prices', torch.zeros(5, device=device)) if isinstance(pivot_prices, torch.Tensor): if pivot_prices.dim() > 1: pivot_prices = pivot_prices[0] pivot_prices_list = pivot_prices.cpu().detach().numpy().tolist() if hasattr(pivot_prices, 'cpu') else pivot_prices.tolist() else: pivot_prices_list = pivot_prices # Denormalize pivot prices if function provided if denormalize_prices and callable(denormalize_prices): pivot_prices_list = [denormalize_prices(p) for p in pivot_prices_list] angle = trend_vec.get('calculated_angle', torch.tensor(0.0, device=device)) if isinstance(angle, torch.Tensor): if angle.dim() > 1: angle = angle[0, 0] if angle.shape[0] > 0 else torch.tensor(0.0, device=device) angle_val = float(angle.cpu().detach().item() if hasattr(angle, 'cpu') else angle) else: angle_val = float(angle) steepness = trend_vec.get('calculated_steepness', torch.tensor(0.0, device=device)) if isinstance(steepness, torch.Tensor): if steepness.dim() > 1: steepness = steepness[0, 0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device) steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness) else: steepness_val = float(steepness) direction = trend_vec.get('calculated_direction', torch.tensor(0.0, device=device)) if isinstance(direction, torch.Tensor): if direction.dim() > 1: direction = direction[0, 0] if direction.shape[0] > 0 else torch.tensor(0.0, device=device) direction_val = float(direction.cpu().detach().item() if hasattr(direction, 'cpu') else direction) else: direction_val = float(direction) price_delta = trend_vec.get('price_delta', torch.tensor(0.0, device=device)) if isinstance(price_delta, torch.Tensor): if price_delta.dim() > 0: price_delta = price_delta[0] if price_delta.shape[0] > 0 else torch.tensor(0.0, device=device) price_delta_val = float(price_delta.cpu().detach().item() if hasattr(price_delta, 'cpu') else price_delta) else: price_delta_val = float(price_delta) predictions['trend_vector'] = { 'pivot_prices': pivot_prices_list, # [L1, L2, L3, L4, L5] 'angle_radians': angle_val, 'angle_degrees': angle_val * 180.0 / math.pi, 'steepness': steepness_val, 'direction': 'up' if direction_val > 0 else 'down' if direction_val < 0 else 'sideways', 'price_delta': price_delta_val } # Extract trend-based action if 'trend_based_action' in outputs: trend_action = outputs['trend_based_action'] action_probs = trend_action.get('probabilities', torch.zeros(3, device=device)) if isinstance(action_probs, torch.Tensor): if action_probs.dim() > 1: action_probs = action_probs[0] action_probs_list = action_probs.cpu().detach().numpy().tolist() if hasattr(action_probs, 'cpu') else action_probs.tolist() else: action_probs_list = action_probs action_idx = trend_action.get('action_idx', torch.tensor(2, device=device)) if isinstance(action_idx, torch.Tensor): if action_idx.dim() > 0: action_idx = action_idx[0] if action_idx.shape[0] > 0 else torch.tensor(2, device=device) action_idx_val = int(action_idx.cpu().detach().item() if hasattr(action_idx, 'cpu') else action_idx) else: action_idx_val = int(action_idx) angle_degrees = trend_action.get('trend_angle_degrees', torch.tensor(0.0, device=device)) if isinstance(angle_degrees, torch.Tensor): if angle_degrees.dim() > 0: angle_degrees = angle_degrees[0] if angle_degrees.shape[0] > 0 else torch.tensor(0.0, device=device) angle_degrees_val = float(angle_degrees.cpu().detach().item() if hasattr(angle_degrees, 'cpu') else angle_degrees) else: angle_degrees_val = float(angle_degrees) steepness = trend_action.get('trend_steepness', torch.tensor(0.0, device=device)) if isinstance(steepness, torch.Tensor): if steepness.dim() > 0: steepness = steepness[0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device) steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness) else: steepness_val = float(steepness) action_names = ['BUY', 'SELL', 'HOLD'] predictions['trend_based_action'] = { 'action': action_names[action_idx_val] if 0 <= action_idx_val < len(action_names) else 'HOLD', 'action_idx': action_idx_val, 'probabilities': { 'BUY': float(action_probs_list[0]) if len(action_probs_list) > 0 else 0.0, 'SELL': float(action_probs_list[1]) if len(action_probs_list) > 1 else 0.0, 'HOLD': float(action_probs_list[2]) if len(action_probs_list) > 2 else 0.0 }, 'trend_angle_degrees': angle_degrees_val, 'trend_steepness': steepness_val } return predictions class TradingTransformerTrainer: """Trainer for the advanced trading transformer""" def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig): self.model = model self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Move model to device self.model.to(self.device) logger.info(f"✅ Model moved to device: {self.device}") # Log GPU info if available if torch.cuda.is_available(): logger.info(f" GPU: {torch.cuda.get_device_name(0)}") logger.info(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # MEMORY OPTIMIZATION: Enable gradient checkpointing if configured # This trades 20% compute for 30-40% memory savings if config.use_gradient_checkpointing: logger.info("Enabling gradient checkpointing for memory efficiency") self._enable_gradient_checkpointing() # Mixed precision training disabled - causes dtype mismatches # Can be re-enabled if needed, but requires careful dtype management self.use_amp = False self.scaler = None # GRADIENT ACCUMULATION: Track accumulation state self.gradient_accumulation_steps = 0 self.current_accumulation_step = 0 # Optimizer with warmup self.optimizer = optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay ) # Learning rate scheduler self.scheduler = optim.lr_scheduler.OneCycleLR( self.optimizer, max_lr=config.learning_rate, total_steps=10000, # Will be updated based on training data pct_start=0.1 ) # Loss functions self.action_criterion = nn.CrossEntropyLoss() self.price_criterion = nn.MSELoss() self.confidence_criterion = nn.BCELoss() # Training history self.training_history = { 'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': [], 'learning_rates': [] } def _enable_gradient_checkpointing(self): """Enable gradient checkpointing on transformer layers to save memory""" try: # Apply checkpointing to each transformer layer for layer in self.model.layers: if hasattr(layer, 'attention'): # Wrap attention in checkpoint original_forward = layer.attention.forward def checkpointed_attention_forward(*args, **kwargs): return checkpoint( original_forward, *args, **kwargs, use_reentrant=False ) layer.attention.forward = checkpointed_attention_forward if hasattr(layer, 'feed_forward'): # Wrap feed-forward in checkpoint original_ff_forward = layer.feed_forward.forward def checkpointed_ff_forward(*args, **kwargs): return checkpoint( original_ff_forward, *args, **kwargs, use_reentrant=False ) layer.feed_forward.forward = checkpointed_ff_forward logger.info("Gradient checkpointing enabled on all transformer layers") except Exception as e: logger.warning(f"Failed to enable gradient checkpointing: {e}") def set_gradient_accumulation_steps(self, steps: int): """ Set the number of gradient accumulation steps Args: steps: Number of batches to accumulate gradients over before optimizer step For example, steps=5 means process 5 batches, then update weights """ self.gradient_accumulation_steps = steps self.current_accumulation_step = 0 logger.info(f"Gradient accumulation enabled: {steps} steps") def reset_gradient_accumulation(self): """Reset gradient accumulation counter""" self.current_accumulation_step = 0 @staticmethod def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor: """ Denormalize price predictions back to real price space Args: normalized_values: Tensor of normalized values in [0, 1] range norm_params: Dict with 'price_min' and 'price_max' keys Returns: Denormalized tensor in original price space """ price_min = norm_params.get('price_min', 0.0) price_max = norm_params.get('price_max', 1.0) if price_max > price_min: return normalized_values * (price_max - price_min) + price_min else: return normalized_values @staticmethod def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor: """ Denormalize a full OHLCV candle back to real values Args: normalized_candle: Tensor of shape [..., 5] with normalized OHLCV norm_params: Dict with normalization parameters Returns: Denormalized OHLCV tensor """ denorm = normalized_candle.clone() # Denormalize OHLC (first 4 values) price_min = norm_params.get('price_min', 0.0) price_max = norm_params.get('price_max', 1.0) if price_max > price_min: denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min # Denormalize volume (5th value) volume_min = norm_params.get('volume_min', 0.0) volume_max = norm_params.get('volume_max', 1.0) if volume_max > volume_min: denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min return denorm def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]: """Single training step with optional gradient accumulation Args: batch: Training batch accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation) This is DEPRECATED - use gradient_accumulation_steps instead Returns: Dictionary with loss and accuracy metrics """ try: self.model.train() # GRADIENT ACCUMULATION: Determine if this is an accumulation step # If gradient_accumulation_steps is set, use automatic accumulation # Otherwise, fall back to manual accumulate_gradients flag if self.gradient_accumulation_steps > 0: is_accumulation_step = (self.current_accumulation_step < self.gradient_accumulation_steps - 1) self.current_accumulation_step += 1 # Reset counter after full accumulation cycle if self.current_accumulation_step >= self.gradient_accumulation_steps: self.current_accumulation_step = 0 else: is_accumulation_step = accumulate_gradients # Only zero gradients at the start of accumulation cycle # Use set_to_none=True for better memory efficiency (saves ~5% memory) if not is_accumulation_step or self.current_accumulation_step == 1: self.optimizer.zero_grad(set_to_none=True) # OPTIMIZATION: Only move batch to device if not already there # Check if first tensor is already on correct device needs_transfer = False for v in batch.values(): if isinstance(v, torch.Tensor): needs_transfer = (v.device != self.device) break if needs_transfer: # Move batch to device - iterate over copy of keys to avoid modification during iteration batch_gpu = {} for k in list(batch.keys()): # Create list copy to avoid modification during iteration v = batch[k] if isinstance(v, torch.Tensor): # Move to device (creates GPU copy) batch_gpu[k] = v.to(self.device, non_blocking=True) else: batch_gpu[k] = v # Replace batch with GPU version batch = batch_gpu del batch_gpu # else: batch is already on GPU, use it directly! # Use automatic mixed precision (FP16) for memory efficiency # Support both CUDA and ROCm (AMD) devices device_type = 'cuda' if self.device.type == 'cuda' else 'cpu' with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'): # Forward pass with multi-timeframe data outputs = self.model( price_data_1s=batch.get('price_data_1s'), price_data_1m=batch.get('price_data_1m'), price_data_1h=batch.get('price_data_1h'), price_data_1d=batch.get('price_data_1d'), btc_data_1m=batch.get('btc_data_1m'), cob_data=batch.get('cob_data'), # Use .get() to handle missing key tech_data=batch.get('tech_data'), market_data=batch.get('market_data'), position_state=batch.get('position_state'), price_data=batch.get('price_data') # Legacy fallback ) # Calculate losses action_loss = self.action_criterion(outputs['action_logits'], batch['actions']) # FIXED: Ensure shapes match for MSELoss price_pred = outputs['price_prediction'] price_target = batch['future_prices'] # Both should be [batch, 1], but ensure they match if price_pred.shape != price_target.shape: logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}") price_target = price_target.view(price_pred.shape) price_loss = self.price_criterion(price_pred, price_target) # NEW: Trend analysis loss (if trend_target provided) trend_loss = torch.tensor(0.0, device=self.device) if 'trend_target' in batch and 'trend_analysis' in outputs: trend_pred = torch.cat([ outputs['trend_analysis']['angle_radians'], outputs['trend_analysis']['steepness'], outputs['trend_analysis']['direction'] ], dim=1) # [batch, 3] trend_target = batch['trend_target'] if trend_pred.shape == trend_target.shape: trend_loss = self.price_criterion(trend_pred, trend_target) logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})") # NEW: Next candle prediction loss for each timeframe # This trains the model to predict full OHLCV for the next candle on each timeframe candle_loss = torch.tensor(0.0, device=self.device) candle_losses_detail = {} # Track per-timeframe losses (normalized space) candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space) if 'next_candles' in outputs: timeframe_losses = [] # Get normalization parameters if available # norm_params may be a dict or a list of dicts (one per sample in batch) norm_params_raw = batch.get('norm_params', {}) if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0: # If it's a list, use the first one (batch size is typically 1) norm_params = norm_params_raw[0] else: norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {} # Calculate loss for each timeframe that has target data for tf in ['1s', '1m', '1h', '1d']: future_key = f'future_candle_{tf}' if tf in outputs['next_candles'] and future_key in batch: pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized) target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized) if target_candle is not None and pred_candle.shape == target_candle.shape: # MSE loss on normalized values (used for backprop) tf_loss = self.price_criterion(pred_candle, target_candle) timeframe_losses.append(tf_loss) candle_losses_detail[tf] = tf_loss.item() # ALSO calculate denormalized loss for better interpretability # Use RMSE (Root Mean Square Error) instead of MSE for realistic values if tf in norm_params: with torch.no_grad(): pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf]) target_denorm = self.denormalize_candle(target_candle, norm_params[tf]) # Use RMSE instead of MSE to get interpretable dollar values mse = torch.mean((pred_denorm - target_denorm) ** 2) rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability candle_losses_denorm[tf] = rmse.item() # Average loss across available timeframes if timeframe_losses: candle_loss = torch.stack(timeframe_losses).mean() if candle_losses_denorm: logger.debug(f"Candle losses (normalized): {candle_losses_detail}") logger.debug(f"Candle losses (real prices): {candle_losses_denorm}") # Start with base losses - avoid inplace operations on computation graph # Weight: action=1.0, price=0.1, trend=0.05, candle=0.15 total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss # CRITICAL FIX: Scale loss for gradient accumulation # This prevents gradient explosion when accumulating over multiple batches # The loss is averaged over accumulation steps so gradients sum correctly if self.gradient_accumulation_steps > 0: total_loss = total_loss / self.gradient_accumulation_steps elif accumulate_gradients: # Legacy fallback - assume 5 steps if not specified total_loss = total_loss / 5.0 # Add confidence loss if available if 'confidence' in outputs and 'trade_success' in batch: # Both tensors should have shape [batch_size, 1] for BCELoss confidence_pred = outputs['confidence'] trade_target = batch['trade_success'].float() # FIXED: Ensure both are 2D tensors [batch_size, 1] # Handle different input shapes robustly if confidence_pred.dim() == 0: # Scalar -> [1, 1] confidence_pred = confidence_pred.unsqueeze(0).unsqueeze(0) elif confidence_pred.dim() == 1: # [batch_size] -> [batch_size, 1] confidence_pred = confidence_pred.unsqueeze(-1) elif confidence_pred.dim() == 3: # [batch_size, seq_len, 1] -> [batch_size, 1] (take last timestep) confidence_pred = confidence_pred[:, -1, :] if trade_target.dim() == 0: # Scalar -> [1, 1] trade_target = trade_target.unsqueeze(0).unsqueeze(0) elif trade_target.dim() == 1: # [batch_size] -> [batch_size, 1] trade_target = trade_target.unsqueeze(-1) # Ensure shapes match exactly - BCELoss requires exact match if confidence_pred.shape != trade_target.shape: # Reshape trade_target to match confidence_pred shape trade_target = trade_target.view(confidence_pred.shape) confidence_loss = self.confidence_criterion(confidence_pred, trade_target) # Use addition instead of += to avoid inplace operation total_loss = total_loss + 0.1 * confidence_loss # Backward pass with mixed precision scaling try: if self.use_amp: self.scaler.scale(total_loss).backward() else: total_loss.backward() except RuntimeError as e: if "inplace operation" in str(e): logger.error(f"Inplace operation error during backward pass: {e}") # Return zero loss to continue training return { 'total_loss': 0.0, 'action_loss': 0.0, 'price_loss': 0.0, 'accuracy': 0.0, 'learning_rate': self.scheduler.get_last_lr()[0] } else: raise # Only clip gradients and step optimizer at the end of accumulation cycle if not is_accumulation_step: if self.use_amp: # Unscale gradients before clipping self.scaler.unscale_(self.optimizer) # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) # Optimizer step with scaling self.scaler.step(self.optimizer) self.scaler.update() else: # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) # Optimizer step self.optimizer.step() self.scheduler.step() # Log gradient accumulation completion if self.gradient_accumulation_steps > 0: logger.debug(f"Gradient accumulation cycle complete ({self.gradient_accumulation_steps} steps)") # Calculate accuracy without gradients # PRIMARY: Next candle OHLCV prediction accuracy (realistic values) with torch.no_grad(): candle_accuracy = 0.0 candle_rmse = {} if 'next_candles' in outputs: # Use 1m timeframe as primary metric if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch: pred_candle = outputs['next_candles']['1m'] # [batch, 5] actual_candle = batch['future_candle_1m'] # [batch, 5] if actual_candle is not None and pred_candle.shape == actual_candle.shape: # Calculate RMSE for each OHLCV component rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8) rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8) rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8) rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8) # Average RMSE for OHLC (exclude volume) avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4 # Convert to accuracy: lower RMSE = higher accuracy # Normalize by price range price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8) candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item() candle_rmse = { 'open': rmse_open.item(), 'high': rmse_high.item(), 'low': rmse_low.item(), 'close': rmse_close.item(), 'avg': avg_rmse.item() } # SECONDARY: Trend vector prediction accuracy trend_accuracy = 0.0 if 'trend_analysis' in outputs and 'trend_target' in batch: pred_angle = outputs['trend_analysis']['angle_radians'] pred_steepness = outputs['trend_analysis']['steepness'] actual_angle = batch['trend_target'][:, 0:1] actual_steepness = batch['trend_target'][:, 1:2] # Angle error (degrees) angle_error_rad = torch.abs(pred_angle - actual_angle) angle_error_deg = angle_error_rad * 180.0 / 3.14159 angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean() # Steepness error (percentage) steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8) steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean() trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item() # LEGACY: Action accuracy (for comparison) action_predictions = torch.argmax(outputs['action_logits'], dim=-1) action_accuracy = (action_predictions == batch['actions']).float().mean().item() # Extract values and delete tensors to free memory result = { 'total_loss': total_loss.item(), 'action_loss': action_loss.item(), 'price_loss': price_loss.item(), 'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, 'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0, 'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe # NEW: Realistic accuracy metrics based on next candle prediction 'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy 'candle_accuracy': candle_accuracy, # Same as accuracy 'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component 'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness) 'action_accuracy': action_accuracy, # Legacy action accuracy 'learning_rate': self.scheduler.get_last_lr()[0] } # CRITICAL: Delete large tensors to free memory immediately # This prevents memory accumulation across batches del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions # Delete batch tensors (GPU copies) for key in list(batch.keys()): if isinstance(batch[key], torch.Tensor): del batch[key] del batch # Clear CUDA cache and log GPU memory usage if torch.cuda.is_available(): torch.cuda.empty_cache() # Log GPU memory usage periodically (every 10 steps) if not hasattr(self, '_step_counter'): self._step_counter = 0 self._step_counter += 1 if self._step_counter % 10 == 0: allocated = torch.cuda.memory_allocated() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2 logger.debug(f"GPU Memory: {allocated:.1f}MB allocated, {reserved:.1f}MB reserved") return result except torch.cuda.OutOfMemoryError as oom_error: logger.error(f"CUDA OOM in train_step: {oom_error}") # Aggressive cleanup on OOM if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Reset optimizer state to prevent corruption self.optimizer.zero_grad(set_to_none=True) # Return zero loss to continue training return { 'total_loss': 0.0, 'action_loss': 0.0, 'price_loss': 0.0, 'accuracy': 0.0, 'candle_accuracy': 0.0, 'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0 } except Exception as e: logger.error(f"Error in train_step: {e}", exc_info=True) # Clear any partial computations if torch.cuda.is_available(): torch.cuda.empty_cache() # Return a zero loss dict to prevent training from crashing # but log the error so we can debug return { 'total_loss': 0.0, 'action_loss': 0.0, 'price_loss': 0.0, 'accuracy': 0.0, 'candle_accuracy': 0.0, 'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0 } def validate(self, val_loader: DataLoader) -> Dict[str, float]: """Validation step""" self.model.eval() total_loss = 0 total_accuracy = 0 num_batches = 0 with torch.no_grad(): for batch in val_loader: batch = {k: v.to(self.device) for k, v in batch.items()} outputs = self.model( batch['price_data'], batch['cob_data'], batch['tech_data'], batch['market_data'] ) # Calculate losses action_loss = self.action_criterion(outputs['action_logits'], batch['actions']) price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices']) total_loss += action_loss.item() + 0.1 * price_loss.item() # Calculate accuracy predictions = torch.argmax(outputs['action_logits'], dim=-1) accuracy = (predictions == batch['actions']).float().mean() total_accuracy += accuracy.item() num_batches += 1 return { 'val_loss': total_loss / num_batches, 'val_accuracy': total_accuracy / num_batches } def train(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int, save_path: str = "NN/models/saved/"): """Full training loop""" best_val_loss = float('inf') for epoch in range(epochs): # Training epoch_losses = [] epoch_accuracies = [] for batch in train_loader: metrics = self.train_step(batch) epoch_losses.append(metrics['total_loss']) epoch_accuracies.append(metrics['accuracy']) # Validation val_metrics = self.validate(val_loader) # Update history avg_train_loss = np.mean(epoch_losses) avg_train_accuracy = np.mean(epoch_accuracies) self.training_history['train_loss'].append(avg_train_loss) self.training_history['val_loss'].append(val_metrics['val_loss']) self.training_history['train_accuracy'].append(avg_train_accuracy) self.training_history['val_accuracy'].append(val_metrics['val_accuracy']) self.training_history['learning_rates'].append(self.scheduler.get_last_lr()[0]) # Logging logger.info(f"Epoch {epoch+1}/{epochs}") logger.info(f" Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}") logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_accuracy']:.4f}") logger.info(f" LR: {self.scheduler.get_last_lr()[0]:.6f}") # Save best model if val_metrics['val_loss'] < best_val_loss: best_val_loss = val_metrics['val_loss'] self.save_model(os.path.join(save_path, 'best_transformer_model.pt')) logger.info(f" New best model saved (val_loss: {best_val_loss:.4f})") def save_model(self, path: str): """Save model and training state""" os.makedirs(os.path.dirname(path), exist_ok=True) torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'config': self.config, 'training_history': self.training_history }, path) logger.info(f"Model saved to {path}") def load_model(self, path: str): """Load model and training state""" checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.training_history = checkpoint.get('training_history', self.training_history) logger.info(f"Model loaded from {path}") def create_trading_transformer(config: Optional[TradingTransformerConfig] = None) -> Tuple[AdvancedTradingTransformer, TradingTransformerTrainer]: """Factory function to create trading transformer and trainer""" if config is None: config = TradingTransformerConfig() model = AdvancedTradingTransformer(config) trainer = TradingTransformerTrainer(model, config) return model, trainer # Example usage if __name__ == "__main__": # Create configuration config = TradingTransformerConfig( d_model=256, n_heads=8, n_layers=4, seq_len=50, n_actions=3, use_multi_scale_attention=True, use_market_regime_detection=True, use_uncertainty_estimation=True ) # Create model and trainer model, trainer = create_trading_transformer(config) logger.info(f"Created Advanced Trading Transformer with {sum(p.numel() for p in model.parameters())} parameters") logger.info("Model is ready for training on real market data!")