Files
gogo2/NN/models/advanced_transformer_trading.py
Dobromir Popov 1ab1c02889 listen to all IPs
2025-12-08 21:36:07 +02:00

1993 lines
95 KiB
Python

#!/usr/bin/env python3
"""
Advanced Transformer Models for High-Frequency Trading
Optimized for COB data, technical indicators, and market microstructure
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.checkpoint import checkpoint
import numpy as np
import math
import logging
from typing import Dict, Any, Optional, Tuple, List, Callable
from dataclasses import dataclass
import os
import json
from datetime import datetime
# Configure logging
from utils.safe_logger import get_logger
logger = get_logger(__name__)
@dataclass
class TradingTransformerConfig:
"""Configuration for trading transformer models - OPTIMIZED FOR GPU (8-12M params)"""
# Model architecture - REDUCED for efficient GPU training
d_model: int = 256 # Model dimension (was 1024)
n_heads: int = 8 # Number of attention heads (was 16)
n_layers: int = 4 # Number of transformer layers (was 12)
d_ff: int = 1024 # Feed-forward dimension (was 4096)
dropout: float = 0.1 # Dropout rate
# Input dimensions - OPTIMIZED
seq_len: int = 200 # Sequence length for time series
cob_features: int = 100 # COB feature dimension
tech_features: int = 40 # Technical indicator features
market_features: int = 30 # Market microstructure features
# Output configuration
n_actions: int = 3 # BUY, SELL, HOLD
confidence_output: bool = True # Output confidence scores
# Training configuration - OPTIMIZED FOR LARGER MODEL
learning_rate: float = 5e-5 # Reduced for larger model
weight_decay: float = 1e-4 # Increased regularization
warmup_steps: int = 8000 # More warmup steps
max_grad_norm: float = 0.5 # Tighter gradient clipping
# Advanced features - ENHANCED
use_relative_position: bool = True
use_multi_scale_attention: bool = True
use_market_regime_detection: bool = True
use_uncertainty_estimation: bool = True
# NEW: Additional scaling features
use_deep_attention: bool = True # Deeper attention mechanisms
use_residual_connections: bool = True # Enhanced residual connections
use_layer_norm_variants: bool = True # Advanced normalization
# Memory optimization
use_gradient_checkpointing: bool = False # DISABLED: Causes tensor shape mismatches during backward pass
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for transformer"""
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:x.size(0), :]
class RelativePositionalEncoding(nn.Module):
"""Relative positional encoding for better temporal understanding"""
def __init__(self, d_model: int, max_relative_position: int = 128):
super().__init__()
self.d_model = d_model
self.max_relative_position = max_relative_position
# Learnable relative position embeddings
self.relative_position_embeddings = nn.Embedding(
2 * max_relative_position + 1, d_model
)
def forward(self, seq_len: int) -> torch.Tensor:
"""Generate relative position encoding matrix"""
range_vec = torch.arange(seq_len)
range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
distance_mat = range_mat - range_mat.transpose(0, 1)
# Clip to max relative position
distance_mat_clipped = torch.clamp(
distance_mat, -self.max_relative_position, self.max_relative_position
)
# Shift to positive indices
final_mat = distance_mat_clipped + self.max_relative_position
return self.relative_position_embeddings(final_mat)
class DeepMultiScaleAttention(nn.Module):
"""Lightweight multi-scale attention optimized for 8-12M parameter model"""
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5]):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.scales = scales # Reduced from 6 scales to 3
self.head_dim = d_model // n_heads
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
# Lightweight multi-scale projections (single layer instead of deep)
self.scale_projections = nn.ModuleList([
nn.ModuleDict({
'query': nn.Linear(d_model, d_model),
'key': nn.Linear(d_model, d_model),
'value': nn.Linear(d_model, d_model),
'conv': nn.Conv1d(d_model, d_model, kernel_size=scale,
padding=scale//2, groups=d_model//4)
}) for scale in scales
])
# Lightweight output projection
self.output_projection = nn.Linear(d_model * len(scales), d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, _ = x.size()
scale_outputs = []
for scale_proj in self.scale_projections:
# Apply enhanced temporal convolution for this scale
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
# Enhanced attention computation with deeper projections
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
# Transpose for attention computation
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
# Use non-inplace version to avoid gradient computation issues
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
output = torch.matmul(attention, V)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
scale_outputs.append(output)
# Combine multi-scale outputs
combined = torch.cat(scale_outputs, dim=-1)
output = self.output_projection(combined)
return output
class MarketRegimeDetector(nn.Module):
"""Market regime detection module for adaptive behavior"""
def __init__(self, d_model: int, n_regimes: int = 4):
super().__init__()
self.d_model = d_model
self.n_regimes = n_regimes
# Regime classification layers
self.regime_classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_model // 2, n_regimes)
)
# Regime-specific transformations
self.regime_transforms = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(n_regimes)
])
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Global pooling for regime detection
pooled = torch.mean(x, dim=1) # (batch, d_model)
# Classify market regime
regime_logits = self.regime_classifier(pooled)
regime_probs = F.softmax(regime_logits, dim=-1)
# Apply regime-specific transformations
regime_outputs = []
for i, transform in enumerate(self.regime_transforms):
regime_output = transform(x) # (batch, seq_len, d_model)
regime_outputs.append(regime_output)
# Weighted combination based on regime probabilities
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
# Weighted sum across regimes
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
return adapted_output, regime_probs
class UncertaintyEstimation(nn.Module):
"""Uncertainty estimation using Monte Carlo Dropout"""
def __init__(self, d_model: int, n_samples: int = 10):
super().__init__()
self.d_model = d_model
self.n_samples = n_samples
self.uncertainty_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(0.5), # Higher dropout for uncertainty estimation
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
if training or not self.training:
# Single forward pass during training or when not in MC mode
uncertainty = self.uncertainty_head(x)
return uncertainty, uncertainty
# Monte Carlo sampling during inference
uncertainties = []
for _ in range(self.n_samples):
uncertainty = self.uncertainty_head(x)
uncertainties.append(uncertainty)
uncertainties = torch.stack(uncertainties, dim=0)
mean_uncertainty = torch.mean(uncertainties, dim=0)
std_uncertainty = torch.std(uncertainties, dim=0)
return mean_uncertainty, std_uncertainty
class TradingTransformerLayer(nn.Module):
"""Enhanced transformer layer for trading applications"""
def __init__(self, config: TradingTransformerConfig):
super().__init__()
self.config = config
# Enhanced multi-scale attention or standard attention
if config.use_multi_scale_attention:
self.attention = DeepMultiScaleAttention(config.d_model, config.n_heads)
else:
self.attention = nn.MultiheadAttention(
config.d_model, config.n_heads, dropout=config.dropout, batch_first=True
)
# Feed-forward network
self.feed_forward = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_ff, config.d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(config.d_model)
self.norm2 = nn.LayerNorm(config.d_model)
# Dropout
self.dropout = nn.Dropout(config.dropout)
# Market regime detection
if config.use_market_regime_detection:
self.regime_detector = MarketRegimeDetector(config.d_model)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Self-attention with residual connection
# Store residual before any operations to avoid version conflicts
if isinstance(self.attention, DeepMultiScaleAttention):
attn_output = self.attention(x, mask)
else:
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
# Create new tensor for residual to avoid inplace modification tracking
x_new = self.norm1(x + self.dropout(attn_output))
# Market regime adaptation
regime_probs = None
if hasattr(self, 'regime_detector'):
x_new, regime_probs = self.regime_detector(x_new)
# Feed-forward with residual connection
ff_output = self.feed_forward(x_new)
x_out = self.norm2(x_new + self.dropout(ff_output))
return {
'output': x_out,
'regime_probs': regime_probs
}
class AdvancedTradingTransformer(nn.Module):
"""Advanced transformer model for high-frequency trading"""
def __init__(self, config: TradingTransformerConfig):
super().__init__()
self.config = config
# Timeframe configuration
self.timeframes = ['1s', '1m', '1h', '1d']
self.num_timeframes = len(self.timeframes) + 1 # +1 for BTC
# SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes)
# This is applied to each timeframe independently but uses SAME weights
# LIGHTWEIGHT: 2-layer encoder for efficiency
self.shared_pattern_encoder = nn.Sequential(
nn.Linear(5, config.d_model // 2), # 5 OHLCV -> 128
nn.LayerNorm(config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model) # 128 -> 256
)
# Timeframe-specific embeddings (learnable, added to shared encoding)
# These help the model distinguish which timeframe it's looking at
self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model)
# PARALLEL: Cross-timeframe attention layer (single layer for efficiency)
# Processes all timeframes simultaneously to capture dependencies
self.cross_timeframe_layer = nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.n_heads,
dim_feedforward=config.d_ff,
dropout=config.dropout,
activation='gelu',
batch_first=True
)
# Other input projections
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
self.market_projection = nn.Linear(config.market_features, config.d_model)
# Position state projection
self.position_projection = nn.Sequential(
nn.Linear(5, config.d_model // 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.d_model)
)
# Positional encoding
if config.use_relative_position:
self.pos_encoding = RelativePositionalEncoding(config.d_model)
else:
self.pos_encoding = PositionalEncoding(config.d_model, config.seq_len)
# Transformer layers
self.layers = nn.ModuleList([
TradingTransformerLayer(config) for _ in range(config.n_layers)
])
# Lightweight output heads for 8-12M parameter model
self.action_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, config.n_actions)
)
if config.confidence_output:
self.confidence_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, 1),
nn.Sigmoid()
)
# Enhanced uncertainty estimation
if config.use_uncertainty_estimation:
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
# Lightweight price prediction head
self.price_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model // 2, 1),
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
)
# Lightweight volatility and trend heads
self.volatility_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 4),
nn.GELU(),
nn.Linear(config.d_model // 4, 1),
nn.Softplus()
)
self.trend_strength_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 4),
nn.GELU(),
nn.Linear(config.d_model // 4, 1),
nn.Tanh()
)
# Lightweight next candle OHLCV prediction heads
self.next_candle_heads = nn.ModuleDict({
tf: nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 5), # OHLCV
nn.Sigmoid() # Constrain to [0, 1]
) for tf in self.timeframes
})
# BTC next candle prediction head
self.btc_next_candle_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 5), # OHLCV for BTC
nn.Sigmoid()
)
# Lightweight pivot point prediction heads (L1-L3 only for efficiency)
self.pivot_levels = [1, 2, 3] # Reduced from L1-L5 to L1-L3
self.pivot_heads = nn.ModuleDict({
f'L{level}': nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 4) # [price, type_prob_high, type_prob_low, confidence]
) for level in self.pivot_levels
})
# Lightweight trend vector analysis head
self.trend_analysis_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(),
nn.Linear(config.d_model // 2, 3) # [angle_radians, steepness, direction]
)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize model weights"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self,
# Multi-timeframe inputs
price_data_1s: Optional[torch.Tensor] = None,
price_data_1m: Optional[torch.Tensor] = None,
price_data_1h: Optional[torch.Tensor] = None,
price_data_1d: Optional[torch.Tensor] = None,
btc_data_1m: Optional[torch.Tensor] = None,
# Other inputs
cob_data: Optional[torch.Tensor] = None,
tech_data: Optional[torch.Tensor] = None,
market_data: Optional[torch.Tensor] = None,
position_state: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
# Legacy support
price_data: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Forward pass with hybrid serial-parallel multi-timeframe processing
SERIAL: Shared pattern encoder learns candle patterns once (same weights for all timeframes)
PARALLEL: Cross-timeframe attention captures dependencies between timeframes
Args:
price_data_1s: (batch, seq_len, 5) - 1-second OHLCV (optional)
price_data_1m: (batch, seq_len, 5) - 1-minute OHLCV (optional)
price_data_1h: (batch, seq_len, 5) - 1-hour OHLCV (optional)
price_data_1d: (batch, seq_len, 5) - 1-day OHLCV (optional)
btc_data_1m: (batch, seq_len, 5) - BTC 1-minute OHLCV (optional)
cob_data: (batch, seq_len, cob_features) - COB features
tech_data: (batch, tech_features) - Technical indicators
market_data: (batch, market_features) - Market features
position_state: (batch, 5) - Position state
mask: Optional attention mask
price_data: (batch, seq_len, 5) - Legacy single timeframe (defaults to 1m)
Returns:
Dictionary with predictions for ALL timeframes
"""
# Legacy support
if price_data is not None and price_data_1m is None:
price_data_1m = price_data
# Collect available timeframes
timeframe_data = {
'1s': price_data_1s,
'1m': price_data_1m,
'1h': price_data_1h,
'1d': price_data_1d,
'btc': btc_data_1m
}
# Filter to available timeframes
available_tfs = [(tf, data) for tf, data in timeframe_data.items() if data is not None]
if not available_tfs:
# Handle case where no timeframes are available (e.g., missing training data)
# Return a default output that won't break training
logger.warning("No timeframe data available for transformer forward pass")
batch_size = 1 # Default batch size
device = torch.device('cpu') # Default to CPU when no data available
# Return default outputs with appropriate shapes
return {
'action_logits': torch.zeros(batch_size, 3, device=device), # 3 actions: HOLD, BUY, SELL
'trend_logits': torch.zeros(batch_size, 3, device=device), # 3 trends: DOWN, SIDEWAYS, UP
'candle_logits': torch.zeros(batch_size, 2, device=device), # 2 candle types: NORMAL, EXTREMA
'price_prediction': torch.zeros(batch_size, 1, device=device),
'confidence': torch.zeros(batch_size, 1, device=device),
'attention_weights': None
}
# Get dimensions from first available timeframe
first_data = available_tfs[0][1]
batch_size, seq_len = first_data.shape[:2]
device = first_data.device
# ============================================================
# STEP 1: SERIAL - Apply shared pattern encoder to each timeframe
# This learns candle patterns ONCE (same weights for all)
# ============================================================
timeframe_encodings = []
timeframe_indices = []
for idx, (tf_name, tf_data) in enumerate(available_tfs):
# Ensure correct sequence length
if tf_data.shape[1] != seq_len:
if tf_data.shape[1] < seq_len:
# Pad with last candle
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
tf_data = torch.cat([tf_data, padding], dim=1)
else:
# Truncate to seq_len
tf_data = tf_data[:, :seq_len, :]
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
tf_encoded = self.shared_pattern_encoder(tf_data)
# Add timeframe-specific embedding (helps model know which timeframe)
# Get timeframe index
tf_idx = self.timeframes.index(tf_name) if tf_name in self.timeframes else len(self.timeframes)
tf_embedding = self.timeframe_embeddings(torch.tensor([tf_idx], device=device))
tf_embedding = tf_embedding.unsqueeze(1).expand(batch_size, seq_len, -1)
# Combine: shared pattern + timeframe identity
tf_encoded = tf_encoded + tf_embedding
timeframe_encodings.append(tf_encoded)
timeframe_indices.append(tf_idx)
# ============================================================
# STEP 2: PARALLEL - Cross-timeframe attention
# Process all timeframes together to capture dependencies
# ============================================================
# Stack timeframes: [batch, num_timeframes, seq_len, d_model]
# Then reshape to: [batch, num_timeframes * seq_len, d_model]
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
num_tfs = len(timeframe_encodings)
# MEMORY EFFICIENT: Process timeframes with shared weights
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
# This avoids creating huge concatenated sequences while still processing efficiently
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
# Apply single cross-timeframe attention layer
# Use new variable to avoid inplace modification issues
cross_tf_encoded = self.cross_timeframe_layer(batched_tfs)
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
cross_tf_output = cross_tf_encoded.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
# Average across timeframes to get unified representation
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
price_emb = cross_tf_output.mean(dim=1)
# ============================================================
# STEP 3: Add other features (COB, tech, market, position)
# ============================================================
# COB features
if cob_data is not None:
if cob_data.dim() == 2:
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
cob_emb = self.cob_projection(cob_data)
else:
cob_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Technical indicators
if tech_data is not None:
if tech_data.dim() == 2:
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
tech_emb = self.tech_projection(tech_data)
else:
tech_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Market features
if market_data is not None:
if market_data.dim() == 2:
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
market_emb = self.market_projection(market_data)
else:
market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
# Combine all embeddings - use clone() to avoid inplace operation errors
x = price_emb.clone() + cob_emb + tech_emb + market_emb
# Add position state if provided - critical for loss minimization and profit taking
if position_state is not None:
# Project position state through learned embedding network
# Input: [batch, 5] -> Output: [batch, d_model]
position_emb = self.position_projection(position_state) # [batch, d_model]
# Expand to sequence length and add as bias to all positions
# This conditions the entire sequence on current position state
position_emb = position_emb.unsqueeze(1).expand(batch_size, seq_len, -1) # [batch, seq_len, d_model]
# Add position embedding to the combined embeddings - create new tensor to avoid inplace
x = x + position_emb
# Add positional encoding
if isinstance(self.pos_encoding, RelativePositionalEncoding):
# Relative position encoding is applied in attention
pass
else:
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
# Apply transformer layers with optional gradient checkpointing
regime_probs_history = []
for layer in self.layers:
if self.training and self.config.use_gradient_checkpointing:
# Use gradient checkpointing to save memory during training
# Trades compute for memory (recomputes activations during backward pass)
layer_output = checkpoint(
layer, x, mask, use_reentrant=False
)
else:
layer_output = layer(x, mask)
# Use output directly - no clone needed with proper variable naming
x = layer_output['output']
if layer_output['regime_probs'] is not None:
regime_probs_history.append(layer_output['regime_probs'])
# Global pooling for final prediction
# Use attention-based pooling
pooling_weights = F.softmax(
torch.sum(x, dim=-1, keepdim=True), dim=1
)
pooled = torch.sum(x * pooling_weights, dim=1)
# Generate outputs
outputs = {}
# Action prediction
action_logits = self.action_head(pooled)
outputs['action_logits'] = action_logits
outputs['action_probs'] = F.softmax(action_logits, dim=-1)
# Confidence prediction
if self.config.confidence_output:
confidence = self.confidence_head(pooled)
outputs['confidence'] = confidence
# Uncertainty estimation
if self.config.use_uncertainty_estimation:
uncertainty_mean, uncertainty_std = self.uncertainty_estimator(pooled)
outputs['uncertainty_mean'] = uncertainty_mean
outputs['uncertainty_std'] = uncertainty_std
# Enhanced price prediction (auxiliary task)
price_pred = self.price_head(pooled)
outputs['price_prediction'] = price_pred
# Additional specialized predictions for 46M model
volatility_pred = self.volatility_head(pooled)
outputs['volatility_prediction'] = volatility_pred
trend_strength_pred = self.trend_strength_head(pooled)
outputs['trend_strength_prediction'] = trend_strength_pred
# NEW: Next candle OHLCV predictions for each timeframe
next_candles = {}
for tf in self.timeframes:
candle_pred = self.next_candle_heads[tf](pooled) # (batch, 5)
next_candles[tf] = candle_pred
outputs['next_candles'] = next_candles
# BTC next candle prediction
btc_next_candle = self.btc_next_candle_head(pooled) # (batch, 5)
outputs['btc_next_candle'] = btc_next_candle
# NEW: Next pivot point predictions for L1-L5
next_pivots = {}
for level in self.pivot_levels:
pivot_pred = self.pivot_heads[f'L{level}'](pooled) # (batch, 4)
# Extract components: [price, type_logit_high, type_logit_low, confidence]
# Use softmax to ensure type probabilities sum to 1
type_logits = pivot_pred[:, 1:3] # (batch, 2) - [high, low]
type_probs = F.softmax(type_logits, dim=-1) # (batch, 2)
next_pivots[f'L{level}'] = {
'price': pivot_pred[:, 0:1], # Keep as (batch, 1)
'type_prob_high': type_probs[:, 0:1], # Probability of high pivot
'type_prob_low': type_probs[:, 1:2], # Probability of low pivot
'pivot_type': torch.argmax(type_probs, dim=-1, keepdim=True), # 0=high, 1=low
'confidence': torch.sigmoid(pivot_pred[:, 3:4]) # Prediction confidence
}
outputs['next_pivots'] = next_pivots
# NEW: Trend vector analysis from pivot predictions
trend_analysis = self.trend_analysis_head(pooled) # (batch, 3)
outputs['trend_analysis'] = {
'angle_radians': trend_analysis[:, 0:1], # Trend angle in radians
'steepness': F.softplus(trend_analysis[:, 1:2]), # Always positive steepness
'direction': torch.tanh(trend_analysis[:, 2:3]) # -1 to 1 (down to up)
}
# NEW: Calculate trend vector from pivot predictions
# Extract pivot prices and create trend vector
pivot_prices = torch.stack([next_pivots[f'L{level}']['price'] for level in self.pivot_levels], dim=1) # (batch, 5, 1)
pivot_prices = pivot_prices.squeeze(-1) # (batch, 5)
# Calculate trend vector: (price_change, time_change)
# Assume equal time spacing between pivot levels
time_points = torch.arange(1, len(self.pivot_levels) + 1, dtype=torch.float32, device=pooled.device).unsqueeze(0) # (1, 5)
# Calculate trend line slope using linear regression on pivot prices
# Trend vector = (delta_price, delta_time) normalized
if batch_size > 0:
# For each sample, calculate trend from L1 to L5
price_deltas = pivot_prices[:, -1:] - pivot_prices[:, :1] # L5 - L1 price change
time_deltas = time_points[:, -1:] - time_points[:, :1] # Time change (should be 4)
# Calculate angle and steepness
trend_angles = torch.atan2(price_deltas.squeeze(), time_deltas.squeeze()) # (batch,)
trend_steepness = torch.sqrt(price_deltas.squeeze() ** 2 + time_deltas.squeeze() ** 2) # (batch,)
trend_direction = torch.sign(price_deltas.squeeze()) # (batch,)
outputs['trend_vector'] = {
'pivot_prices': pivot_prices, # (batch, 5) - prices for L1-L5
'price_delta': price_deltas.squeeze(), # (batch,) - price change from L1 to L5
'time_delta': time_deltas.squeeze(), # (batch,) - time change
'calculated_angle': trend_angles.unsqueeze(-1), # (batch, 1)
'calculated_steepness': trend_steepness.unsqueeze(-1), # (batch, 1)
'calculated_direction': trend_direction.unsqueeze(-1), # (batch, 1)
'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=0).unsqueeze(0) if batch_size == 1 else torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta]
}
else:
outputs['trend_vector'] = {
'pivot_prices': pivot_prices,
'price_delta': torch.zeros(batch_size, device=pooled.device),
'time_delta': torch.zeros(batch_size, device=pooled.device),
'calculated_angle': torch.zeros(batch_size, 1, device=pooled.device),
'calculated_steepness': torch.zeros(batch_size, 1, device=pooled.device),
'calculated_direction': torch.zeros(batch_size, 1, device=pooled.device),
'vector': torch.zeros(batch_size, 2, device=pooled.device)
}
# NEW: Trade action based on trend steepness and angle
# Combine predicted trend analysis with calculated trend vector
predicted_angle = outputs['trend_analysis']['angle_radians'].squeeze() # (batch,)
predicted_steepness = outputs['trend_analysis']['steepness'].squeeze() # (batch,)
predicted_direction = outputs['trend_analysis']['direction'].squeeze() # (batch,)
# Use calculated trend if available, otherwise use predicted
if 'calculated_angle' in outputs['trend_vector']:
trend_angle = outputs['trend_vector']['calculated_angle'].squeeze() # (batch,)
trend_steepness_val = outputs['trend_vector']['calculated_steepness'].squeeze() # (batch,)
else:
trend_angle = predicted_angle
trend_steepness_val = predicted_steepness
# Trade action logic based on trend steepness and angle
# Steep upward trend (> 45 degrees) -> BUY
# Steep downward trend (< -45 degrees) -> SELL
# Shallow trend -> HOLD
angle_threshold = math.pi / 4 # 45 degrees
# Determine action from trend angle
trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD]
# Calculate action probabilities based on trend
for i in range(batch_size):
# Handle both 0-dim and 1-dim tensors
if trend_angle.dim() == 0:
angle = trend_angle.item()
steep = trend_steepness_val.item()
else:
angle = trend_angle[i].item()
steep = trend_steepness_val[i].item()
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0
if angle > angle_threshold: # Steep upward trend
trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
elif angle < -angle_threshold: # Steep downward trend
trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
else: # Shallow trend
trend_action_logits[i, 2] = 1.0 # HOLD
# Combine trend-based action with main action prediction
trend_action_probs = F.softmax(trend_action_logits, dim=-1)
outputs['trend_based_action'] = {
'logits': trend_action_logits,
'probabilities': trend_action_probs,
'action_idx': torch.argmax(trend_action_probs, dim=-1),
'trend_angle_degrees': trend_angle * 180.0 / math.pi, # Convert to degrees
'trend_steepness': trend_steepness_val
}
# Market regime information
if regime_probs_history:
outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1)
return outputs
def extract_predictions(self, outputs: Dict[str, torch.Tensor], denormalize_prices: Optional[Callable] = None) -> Dict[str, Any]:
"""
Extract predictions from model outputs in a user-friendly format
Args:
outputs: Raw model outputs from forward() method
denormalize_prices: Optional function to denormalize predicted prices
Returns:
Dictionary with formatted predictions including:
- next_candles: Dict[str, Dict] - OHLCV predictions for each timeframe
- next_pivots: Dict[str, Dict] - Pivot predictions for L1-L5
- trend_vector: Dict - Trend vector analysis
- trend_based_action: Dict - Trading action based on trend
"""
self.eval()
device = next(self.parameters()).device
predictions = {}
# Extract next candle predictions for each timeframe
if 'next_candles' in outputs:
next_candles = {}
for tf in self.timeframes:
candle_tensor = outputs['next_candles'][tf]
if candle_tensor.dim() > 1:
candle_tensor = candle_tensor[0] # Take first batch item
candle_values = candle_tensor.cpu().detach().numpy() if hasattr(candle_tensor, 'cpu') else candle_tensor
if isinstance(candle_values, np.ndarray):
candle_values = candle_values.tolist()
next_candles[tf] = {
'open': float(candle_values[0]) if len(candle_values) > 0 else 0.0,
'high': float(candle_values[1]) if len(candle_values) > 1 else 0.0,
'low': float(candle_values[2]) if len(candle_values) > 2 else 0.0,
'close': float(candle_values[3]) if len(candle_values) > 3 else 0.0,
'volume': float(candle_values[4]) if len(candle_values) > 4 else 0.0
}
# Denormalize if function provided
if denormalize_prices and callable(denormalize_prices):
for key in ['open', 'high', 'low', 'close']:
next_candles[tf][key] = denormalize_prices(next_candles[tf][key])
predictions['next_candles'] = next_candles
# Extract pivot point predictions
if 'next_pivots' in outputs:
next_pivots = {}
for level in self.pivot_levels:
pivot_data = outputs['next_pivots'][f'L{level}']
# Extract values
price = pivot_data['price']
if price.dim() > 1:
price = price[0, 0] if price.shape[0] > 0 else torch.tensor(0.0, device=device)
price_val = float(price.cpu().detach().item() if hasattr(price, 'cpu') else price)
type_prob_high = pivot_data['type_prob_high']
if type_prob_high.dim() > 1:
type_prob_high = type_prob_high[0, 0] if type_prob_high.shape[0] > 0 else torch.tensor(0.0, device=device)
prob_high = float(type_prob_high.cpu().detach().item() if hasattr(type_prob_high, 'cpu') else type_prob_high)
type_prob_low = pivot_data['type_prob_low']
if type_prob_low.dim() > 1:
type_prob_low = type_prob_low[0, 0] if type_prob_low.shape[0] > 0 else torch.tensor(0.0, device=device)
prob_low = float(type_prob_low.cpu().detach().item() if hasattr(type_prob_low, 'cpu') else type_prob_low)
confidence = pivot_data['confidence']
if confidence.dim() > 1:
confidence = confidence[0, 0] if confidence.shape[0] > 0 else torch.tensor(0.0, device=device)
conf_val = float(confidence.cpu().detach().item() if hasattr(confidence, 'cpu') else confidence)
pivot_type = pivot_data.get('pivot_type', torch.tensor(0))
if isinstance(pivot_type, torch.Tensor):
if pivot_type.dim() > 1:
pivot_type = pivot_type[0, 0] if pivot_type.shape[0] > 0 else torch.tensor(0, device=device)
pivot_type_val = int(pivot_type.cpu().detach().item() if hasattr(pivot_type, 'cpu') else pivot_type)
else:
pivot_type_val = int(pivot_type)
# Denormalize price if function provided
if denormalize_prices and callable(denormalize_prices):
price_val = denormalize_prices(price_val)
next_pivots[f'L{level}'] = {
'price': price_val,
'type': 'high' if pivot_type_val == 0 else 'low',
'type_prob_high': prob_high,
'type_prob_low': prob_low,
'confidence': conf_val
}
predictions['next_pivots'] = next_pivots
# Extract trend vector
if 'trend_vector' in outputs:
trend_vec = outputs['trend_vector']
# Extract pivot prices
pivot_prices = trend_vec.get('pivot_prices', torch.zeros(5, device=device))
if isinstance(pivot_prices, torch.Tensor):
if pivot_prices.dim() > 1:
pivot_prices = pivot_prices[0]
pivot_prices_list = pivot_prices.cpu().detach().numpy().tolist() if hasattr(pivot_prices, 'cpu') else pivot_prices.tolist()
else:
pivot_prices_list = pivot_prices
# Denormalize pivot prices if function provided
if denormalize_prices and callable(denormalize_prices):
pivot_prices_list = [denormalize_prices(p) for p in pivot_prices_list]
angle = trend_vec.get('calculated_angle', torch.tensor(0.0, device=device))
if isinstance(angle, torch.Tensor):
if angle.dim() > 1:
angle = angle[0, 0] if angle.shape[0] > 0 else torch.tensor(0.0, device=device)
angle_val = float(angle.cpu().detach().item() if hasattr(angle, 'cpu') else angle)
else:
angle_val = float(angle)
steepness = trend_vec.get('calculated_steepness', torch.tensor(0.0, device=device))
if isinstance(steepness, torch.Tensor):
if steepness.dim() > 1:
steepness = steepness[0, 0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
else:
steepness_val = float(steepness)
direction = trend_vec.get('calculated_direction', torch.tensor(0.0, device=device))
if isinstance(direction, torch.Tensor):
if direction.dim() > 1:
direction = direction[0, 0] if direction.shape[0] > 0 else torch.tensor(0.0, device=device)
direction_val = float(direction.cpu().detach().item() if hasattr(direction, 'cpu') else direction)
else:
direction_val = float(direction)
price_delta = trend_vec.get('price_delta', torch.tensor(0.0, device=device))
if isinstance(price_delta, torch.Tensor):
if price_delta.dim() > 0:
price_delta = price_delta[0] if price_delta.shape[0] > 0 else torch.tensor(0.0, device=device)
price_delta_val = float(price_delta.cpu().detach().item() if hasattr(price_delta, 'cpu') else price_delta)
else:
price_delta_val = float(price_delta)
predictions['trend_vector'] = {
'pivot_prices': pivot_prices_list, # [L1, L2, L3, L4, L5]
'angle_radians': angle_val,
'angle_degrees': angle_val * 180.0 / math.pi,
'steepness': steepness_val,
'direction': 'up' if direction_val > 0 else 'down' if direction_val < 0 else 'sideways',
'price_delta': price_delta_val
}
# Extract trend-based action
if 'trend_based_action' in outputs:
trend_action = outputs['trend_based_action']
action_probs = trend_action.get('probabilities', torch.zeros(3, device=device))
if isinstance(action_probs, torch.Tensor):
if action_probs.dim() > 1:
action_probs = action_probs[0]
action_probs_list = action_probs.cpu().detach().numpy().tolist() if hasattr(action_probs, 'cpu') else action_probs.tolist()
else:
action_probs_list = action_probs
action_idx = trend_action.get('action_idx', torch.tensor(2, device=device))
if isinstance(action_idx, torch.Tensor):
if action_idx.dim() > 0:
action_idx = action_idx[0] if action_idx.shape[0] > 0 else torch.tensor(2, device=device)
action_idx_val = int(action_idx.cpu().detach().item() if hasattr(action_idx, 'cpu') else action_idx)
else:
action_idx_val = int(action_idx)
angle_degrees = trend_action.get('trend_angle_degrees', torch.tensor(0.0, device=device))
if isinstance(angle_degrees, torch.Tensor):
if angle_degrees.dim() > 0:
angle_degrees = angle_degrees[0] if angle_degrees.shape[0] > 0 else torch.tensor(0.0, device=device)
angle_degrees_val = float(angle_degrees.cpu().detach().item() if hasattr(angle_degrees, 'cpu') else angle_degrees)
else:
angle_degrees_val = float(angle_degrees)
steepness = trend_action.get('trend_steepness', torch.tensor(0.0, device=device))
if isinstance(steepness, torch.Tensor):
if steepness.dim() > 0:
steepness = steepness[0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
else:
steepness_val = float(steepness)
action_names = ['BUY', 'SELL', 'HOLD']
predictions['trend_based_action'] = {
'action': action_names[action_idx_val] if 0 <= action_idx_val < len(action_names) else 'HOLD',
'action_idx': action_idx_val,
'probabilities': {
'BUY': float(action_probs_list[0]) if len(action_probs_list) > 0 else 0.0,
'SELL': float(action_probs_list[1]) if len(action_probs_list) > 1 else 0.0,
'HOLD': float(action_probs_list[2]) if len(action_probs_list) > 2 else 0.0
},
'trend_angle_degrees': angle_degrees_val,
'trend_steepness': steepness_val
}
return predictions
class TradingTransformerTrainer:
"""Trainer for the advanced trading transformer"""
def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig):
self.model = model
self.config = config
# Determine device from config or auto-detect
self.device = self._get_device_from_config()
# Move model to device
self.model.to(self.device)
logger.info(f"Model moved to device: {self.device}")
# Log GPU info if available
if self.device.type == 'cuda' and torch.cuda.is_available():
logger.info(f" GPU: {torch.cuda.get_device_name(0)}")
logger.info(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
elif self.device.type == 'cpu':
logger.info(" Using CPU (GPU disabled or unavailable)")
# MEMORY OPTIMIZATION: Enable gradient checkpointing if configured
# This trades 20% compute for 30-40% memory savings
if self.config.use_gradient_checkpointing:
logger.info("Enabling gradient checkpointing for memory efficiency")
self._enable_gradient_checkpointing()
# Mixed precision training disabled - causes dtype mismatches
# Can be re-enabled if needed, but requires careful dtype management
self.use_amp = False
self.scaler = None
# GRADIENT ACCUMULATION: Track accumulation state
self.gradient_accumulation_steps = 0
self.current_accumulation_step = 0
# Optimizer with warmup
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=self.config.learning_rate,
total_steps=10000, # Will be updated based on training data
pct_start=0.1
)
# Loss functions with class weights
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
self.price_criterion = nn.MSELoss()
self.confidence_criterion = nn.BCELoss()
# Training history
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_accuracy': [],
'val_accuracy': [],
'epochs': []
}
def _get_device_from_config(self) -> torch.device:
"""Get device from config.yaml or auto-detect"""
try:
# Try to load config
from core.config import get_config
config = get_config()
gpu_config = config._config.get('gpu', {})
device_setting = gpu_config.get('device', 'auto')
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
gpu_enabled = gpu_config.get('enabled', True)
# If GPU is disabled in config, use CPU
if not gpu_enabled:
logger.info("GPU disabled in config.yaml, using CPU")
return torch.device('cpu')
# Handle device selection
if device_setting == 'cpu':
logger.info("Device set to CPU in config.yaml")
return torch.device('cpu')
elif device_setting == 'cuda' or device_setting == 'auto':
# Try GPU first
if torch.cuda.is_available():
logger.info("Using GPU (CUDA available)")
return torch.device('cuda')
else:
if fallback_to_cpu:
logger.warning("CUDA not available, falling back to CPU")
return torch.device('cpu')
else:
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
else:
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except Exception as e:
logger.warning(f"Error reading device config: {e}, using auto-detection")
# Fallback to auto-detection
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
# This is handled by the model itself if use_gradient_checkpointing is True
pass
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=config.learning_rate,
total_steps=10000, # Will be updated based on training data
pct_start=0.1
)
# Loss functions with class weights
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
self.price_criterion = nn.MSELoss()
self.confidence_criterion = nn.BCELoss()
# Training history
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_accuracy': [],
'val_accuracy': [],
'learning_rates': []
}
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing on transformer layers to save memory"""
try:
# Apply checkpointing to each transformer layer
for layer in self.model.layers:
if hasattr(layer, 'attention'):
# Wrap attention in checkpoint
original_forward = layer.attention.forward
def checkpointed_attention_forward(*args, **kwargs):
return checkpoint(
original_forward, *args, **kwargs, use_reentrant=False
)
layer.attention.forward = checkpointed_attention_forward
if hasattr(layer, 'feed_forward'):
# Wrap feed-forward in checkpoint
original_ff_forward = layer.feed_forward.forward
def checkpointed_ff_forward(*args, **kwargs):
return checkpoint(
original_ff_forward, *args, **kwargs, use_reentrant=False
)
layer.feed_forward.forward = checkpointed_ff_forward
logger.info("Gradient checkpointing enabled on all transformer layers")
except Exception as e:
logger.warning(f"Failed to enable gradient checkpointing: {e}")
def set_gradient_accumulation_steps(self, steps: int):
"""
Set the number of gradient accumulation steps
Args:
steps: Number of batches to accumulate gradients over before optimizer step
For example, steps=5 means process 5 batches, then update weights
"""
self.gradient_accumulation_steps = steps
self.current_accumulation_step = 0
logger.info(f"Gradient accumulation enabled: {steps} steps")
def reset_gradient_accumulation(self):
"""Reset gradient accumulation counter"""
self.current_accumulation_step = 0
@staticmethod
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize price predictions back to real price space
Args:
normalized_values: Tensor of normalized values in [0, 1] range
norm_params: Dict with 'price_min' and 'price_max' keys
Returns:
Denormalized tensor in original price space
"""
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
return normalized_values * (price_max - price_min) + price_min
else:
return normalized_values
@staticmethod
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize a full OHLCV candle back to real values
Args:
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
norm_params: Dict with normalization parameters
Returns:
Denormalized OHLCV tensor
"""
# Avoid inplace operations by creating new tensors instead of slice assignment
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
# Denormalize OHLC (first 4 values) - create new tensor, no inplace operations
if price_max > price_min:
price_scale = (price_max - price_min)
price_offset = price_min
denorm_ohlc = normalized_candle[..., :4] * price_scale + price_offset
else:
denorm_ohlc = normalized_candle[..., :4]
# Denormalize volume (5th value) - create new tensor, no inplace operations
if volume_max > volume_min:
volume_scale = (volume_max - volume_min)
volume_offset = volume_min
denorm_volume = (normalized_candle[..., 4:5] * volume_scale + volume_offset)
else:
denorm_volume = normalized_candle[..., 4:5]
# Concatenate OHLC and Volume to create final tensor (no inplace operations)
denorm = torch.cat([denorm_ohlc, denorm_volume], dim=-1)
return denorm
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
"""Single training step with optional gradient accumulation
Args:
batch: Training batch
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
This is DEPRECATED - use gradient_accumulation_steps instead
Returns:
Dictionary with loss and accuracy metrics
"""
try:
self.model.train()
# Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to True to find exact inplace operation causing errors
enable_anomaly_detection = False # DISABLED - inplace operations fixed
if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True)
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
# If gradient_accumulation_steps is set, use automatic accumulation
# Otherwise, fall back to manual accumulate_gradients flag
if self.gradient_accumulation_steps > 0:
is_accumulation_step = (self.current_accumulation_step < self.gradient_accumulation_steps - 1)
self.current_accumulation_step += 1
# Reset counter after full accumulation cycle
if self.current_accumulation_step >= self.gradient_accumulation_steps:
self.current_accumulation_step = 0
else:
is_accumulation_step = accumulate_gradients
# Only zero gradients at the start of accumulation cycle
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
if not is_accumulation_step or self.current_accumulation_step == 1:
self.optimizer.zero_grad(set_to_none=True)
# Also clear any cached gradients in the model
for param in self.model.parameters():
if param.grad is not None:
param.grad = None
# Clear CUDA cache to prevent tensor version conflicts
if torch.cuda.is_available():
torch.cuda.empty_cache()
# OPTIMIZATION: Only move batch to device if not already there
# Check if first tensor is already on correct device
needs_transfer = False
for v in batch.values():
if isinstance(v, torch.Tensor):
needs_transfer = (v.device != self.device)
break
# Always create a new batch_on_device dict to avoid modifying the input batch
# This is critical for multi-epoch training where batches are reused
batch_on_device = {}
if needs_transfer:
# Move batch to device - create new tensors
for k in list(batch.keys()):
v = batch[k]
if isinstance(v, torch.Tensor):
# Move to device (creates GPU copy)
batch_on_device[k] = v.to(self.device, non_blocking=True)
else:
batch_on_device[k] = v
else:
# Batch is already on GPU, but still create a copy of the dict
# to avoid modifying the original batch dict
for k, v in batch.items():
batch_on_device[k] = v
# Ensure all batch tensors are on the same device as the model
# This is critical to avoid device mismatch errors
model_device = next(self.model.parameters()).device
for k, v in list(batch_on_device.items()):
if isinstance(v, torch.Tensor):
# Move tensor to model's device if it's not already there
if v.device != model_device:
batch_on_device[k] = v.to(model_device, non_blocking=True)
else:
batch_on_device[k] = v
else:
batch_on_device[k] = v
# Also ensure model is on the correct device (in case it was moved elsewhere)
if model_device != self.device:
logger.warning(f"Model device ({model_device}) doesn't match trainer device ({self.device}). Moving model to {self.device}")
self.model.to(self.device)
model_device = self.device
# Re-move batch to correct device
for k, v in batch_on_device.items():
if isinstance(v, torch.Tensor):
batch_on_device[k] = v.to(self.device, non_blocking=True)
# Use automatic mixed precision (FP16) for memory efficiency
# Support both CUDA and ROCm (AMD) devices
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
# Forward pass with multi-timeframe data
outputs = self.model(
price_data_1s=batch_on_device.get('price_data_1s'),
price_data_1m=batch_on_device.get('price_data_1m'),
price_data_1h=batch_on_device.get('price_data_1h'),
price_data_1d=batch_on_device.get('price_data_1d'),
btc_data_1m=batch_on_device.get('btc_data_1m'),
cob_data=batch_on_device.get('cob_data'), # Use .get() to handle missing key
tech_data=batch_on_device.get('tech_data'),
market_data=batch_on_device.get('market_data'),
position_state=batch_on_device.get('position_state'),
price_data=batch_on_device.get('price_data') # Legacy fallback
)
# Calculate losses (use batch_on_device for consistency)
# Handle case where actions key is missing (e.g., when no timeframe data available)
if 'actions' not in batch_on_device:
logger.warning("No 'actions' key in batch - skipping this training step")
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'candle_accuracy': 0.0,
'trend_accuracy': 0.0,
'action_accuracy': 0.0
}
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
# FIXED: Ensure shapes match for MSELoss
price_pred = outputs['price_prediction']
# Handle case where future_prices key is missing
if 'future_prices' not in batch_on_device:
logger.warning("No 'future_prices' key in batch - using zero loss for price prediction")
price_loss = torch.tensor(0.0, device=self.device)
else:
price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match
if price_pred.shape != price_target.shape:
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
price_target = price_target.view(price_pred.shape)
price_loss = self.price_criterion(price_pred, price_target)
# NEW: Trend analysis loss (if trend_target provided)
trend_loss = torch.tensor(0.0, device=self.device)
if 'trend_target' in batch_on_device and 'trend_analysis' in outputs:
trend_pred = torch.cat([
outputs['trend_analysis']['angle_radians'],
outputs['trend_analysis']['steepness'],
outputs['trend_analysis']['direction']
], dim=1) # [batch, 3]
trend_target = batch_on_device['trend_target']
if trend_pred.shape == trend_target.shape:
trend_loss = self.price_criterion(trend_pred, trend_target)
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
# NEW: Next candle prediction loss for each timeframe
# This trains the model to predict full OHLCV for the next candle on each timeframe
candle_loss = torch.tensor(0.0, device=self.device)
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
if 'next_candles' in outputs:
timeframe_losses = []
# Get normalization parameters if available
# norm_params may be a dict or a list of dicts (one per sample in batch)
norm_params_raw = batch_on_device.get('norm_params', {})
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
# If it's a list, use the first one (batch size is typically 1)
norm_params = norm_params_raw[0]
else:
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
# Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']:
future_key = f'future_candle_{tf}'
if tf in outputs['next_candles'] and future_key in batch_on_device:
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
target_candle = batch_on_device[future_key] # [batch, 5] - actual OHLCV (normalized)
if target_candle is not None and pred_candle.shape == target_candle.shape:
# MSE loss on normalized values (used for backprop)
tf_loss = self.price_criterion(pred_candle, target_candle)
timeframe_losses.append(tf_loss)
candle_losses_detail[tf] = tf_loss.item()
# ALSO calculate denormalized loss for better interpretability
# Use RMSE (Root Mean Square Error) instead of MSE for realistic values
if tf in norm_params:
with torch.no_grad():
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
# Use RMSE instead of MSE to get interpretable dollar values
mse = torch.mean((pred_denorm - target_denorm) ** 2)
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
candle_losses_denorm[tf] = rmse.item()
# Average loss across available timeframes
if timeframe_losses:
candle_loss = torch.stack(timeframe_losses).mean()
if candle_losses_denorm:
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
# Start with base losses - avoid inplace operations on computation graph
# Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
# CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches
# The loss is averaged over accumulation steps so gradients sum correctly
if self.gradient_accumulation_steps > 0:
total_loss = total_loss / self.gradient_accumulation_steps
elif accumulate_gradients:
# Legacy fallback - assume 5 steps if not specified
total_loss = total_loss / 5.0
# Add confidence loss if available
if 'confidence' in outputs and 'trade_success' in batch_on_device:
# Both tensors should have shape [batch_size, 1] for BCELoss
confidence_pred = outputs['confidence']
trade_target = batch_on_device['trade_success'].float()
# FIXED: Ensure both are 2D tensors [batch_size, 1]
# Handle different input shapes robustly
if confidence_pred.dim() == 0:
# Scalar -> [1, 1]
confidence_pred = confidence_pred.unsqueeze(0).unsqueeze(0)
elif confidence_pred.dim() == 1:
# [batch_size] -> [batch_size, 1]
confidence_pred = confidence_pred.unsqueeze(-1)
elif confidence_pred.dim() == 3:
# [batch_size, seq_len, 1] -> [batch_size, 1] (take last timestep)
confidence_pred = confidence_pred[:, -1, :]
if trade_target.dim() == 0:
# Scalar -> [1, 1]
trade_target = trade_target.unsqueeze(0).unsqueeze(0)
elif trade_target.dim() == 1:
# [batch_size] -> [batch_size, 1]
trade_target = trade_target.unsqueeze(-1)
# Ensure shapes match exactly - BCELoss requires exact match
if confidence_pred.shape != trade_target.shape:
# Reshape trade_target to match confidence_pred shape
trade_target = trade_target.view(confidence_pred.shape)
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
# Use addition instead of += to avoid inplace operation
total_loss = total_loss + 0.1 * confidence_loss
# Backward pass with mixed precision scaling
try:
if self.use_amp:
self.scaler.scale(total_loss).backward()
else:
total_loss.backward()
except RuntimeError as e:
error_msg = str(e)
if "inplace operation" in error_msg or "modified by an inplace operation" in error_msg:
logger.error(f"Inplace operation error during backward pass: {e}")
# Clear gradients to reset state
self.optimizer.zero_grad(set_to_none=True)
for param in self.model.parameters():
if param.grad is not None:
param.grad = None
# Return zero loss to continue training
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
}
else:
raise
# Only clip gradients and step optimizer at the end of accumulation cycle
if not is_accumulation_step:
if self.use_amp:
# Unscale gradients before clipping
self.scaler.unscale_(self.optimizer)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step with scaling
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step with error handling
try:
self.optimizer.step()
except (KeyError, RuntimeError) as opt_error:
logger.error(f"Optimizer step failed: {opt_error}. Resetting optimizer state.")
# Zero gradients first to clear any stale gradients
self.optimizer.zero_grad(set_to_none=True)
# Reset optimizer to fix corrupted state
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Zero gradients again after recreating optimizer
self.optimizer.zero_grad(set_to_none=True)
# Retry optimizer step with fresh state
# Note: We need to recompute loss and backward pass, but for now just skip this step
logger.warning("Skipping optimizer step after reset - gradients need to be recomputed")
# Don't raise - allow training to continue with next batch
except RuntimeError as gpu_error:
# Check if it's a GPU-related error and fallback to CPU if configured
if "cuda" in str(gpu_error).lower() or "gpu" in str(gpu_error).lower():
logger.error(f"GPU error during optimizer step: {gpu_error}")
# Try to fallback to CPU if configured
try:
from core.config import get_config
config = get_config()
fallback_to_cpu = config._config.get('gpu', {}).get('fallback_to_cpu', True)
if fallback_to_cpu and self.device.type == 'cuda':
logger.warning("Falling back to CPU due to GPU errors")
self.device = torch.device('cpu')
self.model.to(self.device)
# Recreate optimizer for CPU
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
logger.info("Model moved to CPU, training will continue on CPU")
# Skip this step, continue with next batch
return result
except Exception as fallback_error:
logger.error(f"Failed to fallback to CPU: {fallback_error}")
raise
self.scheduler.step()
# Log gradient accumulation completion
if self.gradient_accumulation_steps > 0:
logger.debug(f"Gradient accumulation cycle complete ({self.gradient_accumulation_steps} steps)")
# Calculate accuracy without gradients
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
with torch.no_grad():
candle_accuracy = 0.0
candle_rmse = {}
if 'next_candles' in outputs:
# Use 1s or 1m timeframe as primary metric (try 1s first)
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch_on_device:
pred_candle = outputs['next_candles']['1s'] # [batch, 5]
actual_candle = batch_on_device['future_candle_1s'] # [batch, 5]
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch_on_device:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch_on_device['future_candle_1m'] # [batch, 5]
else:
pred_candle = None
actual_candle = None
if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0
if 'trend_analysis' in outputs and 'trend_target' in batch_on_device:
pred_angle = outputs['trend_analysis']['angle_radians']
pred_steepness = outputs['trend_analysis']['steepness']
actual_angle = batch_on_device['trend_target'][:, 0:1]
actual_steepness = batch_on_device['trend_target'][:, 1:2]
# Angle error (degrees)
angle_error_rad = torch.abs(pred_angle - actual_angle)
angle_error_deg = angle_error_rad * 180.0 / 3.14159
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
# Steepness error (percentage)
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
# LEGACY: Action accuracy (for comparison)
if 'actions' in batch_on_device:
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
else:
action_accuracy = 0.0
# Extract values and delete tensors to free memory
result = {
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
# NEW: Realistic accuracy metrics based on next candle prediction
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
'candle_accuracy': candle_accuracy, # Same as accuracy
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
'action_accuracy': action_accuracy, # Legacy action accuracy
'learning_rate': self.scheduler.get_last_lr()[0]
}
# CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
# Delete batch tensors (GPU copies)
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# Clear CUDA cache and log GPU memory usage
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Log GPU memory usage periodically (every 10 steps)
if not hasattr(self, '_step_counter'):
self._step_counter = 0
self._step_counter += 1
if self._step_counter % 10 == 0:
allocated = torch.cuda.memory_allocated() / 1024**2
reserved = torch.cuda.memory_reserved() / 1024**2
logger.debug(f"GPU Memory: {allocated:.1f}MB allocated, {reserved:.1f}MB reserved")
return result
except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f"CUDA OOM in train_step: {oom_error}")
# Aggressive cleanup on OOM
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset optimizer state to prevent corruption
self.optimizer.zero_grad(set_to_none=True)
# Return zero loss to continue training
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'candle_accuracy': 0.0,
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
}
except Exception as e:
logger.error(f"Error in train_step: {e}", exc_info=True)
# Clear any partial computations
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Return a zero loss dict to prevent training from crashing
# but log the error so we can debug
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'candle_accuracy': 0.0,
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
}
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
"""Validation step"""
self.model.eval()
total_loss = 0
total_accuracy = 0
num_batches = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(
batch['price_data'],
batch['cob_data'],
batch['tech_data'],
batch['market_data']
)
# Calculate losses
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
total_loss += action_loss.item() + 0.1 * price_loss.item()
# Calculate accuracy
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
total_accuracy += accuracy.item()
num_batches += 1
return {
'val_loss': total_loss / num_batches,
'val_accuracy': total_accuracy / num_batches
}
def train(self, train_loader: DataLoader, val_loader: DataLoader,
epochs: int, save_path: str = "NN/models/saved/"):
"""Full training loop"""
best_val_loss = float('inf')
for epoch in range(epochs):
# Training
epoch_losses = []
epoch_accuracies = []
for batch in train_loader:
metrics = self.train_step(batch)
epoch_losses.append(metrics['total_loss'])
epoch_accuracies.append(metrics['accuracy'])
# Validation
val_metrics = self.validate(val_loader)
# Update history
avg_train_loss = np.mean(epoch_losses)
avg_train_accuracy = np.mean(epoch_accuracies)
self.training_history['train_loss'].append(avg_train_loss)
self.training_history['val_loss'].append(val_metrics['val_loss'])
self.training_history['train_accuracy'].append(avg_train_accuracy)
self.training_history['val_accuracy'].append(val_metrics['val_accuracy'])
self.training_history['learning_rates'].append(self.scheduler.get_last_lr()[0])
# Logging
logger.info(f"Epoch {epoch+1}/{epochs}")
logger.info(f" Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}")
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_accuracy']:.4f}")
logger.info(f" LR: {self.scheduler.get_last_lr()[0]:.6f}")
# Save best model
if val_metrics['val_loss'] < best_val_loss:
best_val_loss = val_metrics['val_loss']
self.save_model(os.path.join(save_path, 'best_transformer_model.pt'))
logger.info(f" New best model saved (val_loss: {best_val_loss:.4f})")
def save_model(self, path: str):
"""Save model and training state"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'config': self.config,
'training_history': self.training_history
}, path)
logger.info(f"Model saved to {path}")
def load_model(self, path: str):
"""Load model and training state"""
checkpoint = torch.load(path, map_location=self.device)
# Load model state (with strict=False to handle architecture changes)
try:
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
except Exception as e:
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
# Load optimizer state (handle mismatched states gracefully)
# IMPORTANT: Always recreate optimizer if there's any issue to avoid corrupted state
optimizer_state_loaded = False
try:
optimizer_state = checkpoint.get('optimizer_state_dict')
if optimizer_state:
# Validate optimizer state before loading
# Check if state dict has the expected structure
if 'state' in optimizer_state and 'param_groups' in optimizer_state:
# Count parameters in saved state vs current model
saved_param_count = len(optimizer_state.get('state', {}))
current_param_count = sum(1 for _ in self.model.parameters() if _.requires_grad)
if saved_param_count == current_param_count:
try:
# Try to load optimizer state
self.optimizer.load_state_dict(optimizer_state)
optimizer_state_loaded = True
logger.info("Optimizer state loaded successfully")
except (KeyError, ValueError, RuntimeError, TypeError) as e:
logger.warning(f"Error loading optimizer state: {e}. State will be reset.")
optimizer_state_loaded = False
else:
logger.warning(f"Optimizer state mismatch: {saved_param_count} saved params vs {current_param_count} current params. Resetting optimizer.")
optimizer_state_loaded = False
else:
logger.warning("Invalid optimizer state structure in checkpoint. Resetting optimizer.")
optimizer_state_loaded = False
else:
logger.info("No optimizer state found in checkpoint. Using fresh optimizer.")
optimizer_state_loaded = False
except Exception as e:
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
optimizer_state_loaded = False
# Always recreate optimizer if state loading failed
if not optimizer_state_loaded:
logger.info("Creating fresh optimizer (checkpoint state was invalid or missing)")
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Also recreate scheduler to match
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=self.config.learning_rate,
total_steps=10000,
pct_start=0.1
)
# Load scheduler state
try:
scheduler_state = checkpoint.get('scheduler_state_dict')
if scheduler_state:
self.scheduler.load_state_dict(scheduler_state)
except Exception as e:
logger.warning(f"Error loading scheduler state: {e}, continuing without scheduler state")
self.training_history = checkpoint.get('training_history', self.training_history)
logger.info(f"Model loaded from {path}")
def create_trading_transformer(config: Optional[TradingTransformerConfig] = None) -> Tuple[AdvancedTradingTransformer, TradingTransformerTrainer]:
"""Factory function to create trading transformer and trainer"""
if config is None:
config = TradingTransformerConfig()
model = AdvancedTradingTransformer(config)
trainer = TradingTransformerTrainer(model, config)
return model, trainer
# Example usage
if __name__ == "__main__":
# Create configuration
config = TradingTransformerConfig(
d_model=256,
n_heads=8,
n_layers=4,
seq_len=50,
n_actions=3,
use_multi_scale_attention=True,
use_market_regime_detection=True,
use_uncertainty_estimation=True
)
# Create model and trainer
model, trainer = create_trading_transformer(config)
logger.info(f"Created Advanced Trading Transformer with {sum(p.numel() for p in model.parameters())} parameters")
logger.info("Model is ready for training on real market data!")