1706 lines
80 KiB
Python
1706 lines
80 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Advanced Transformer Models for High-Frequency Trading
|
|
Optimized for COB data, technical indicators, and market microstructure
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from torch.utils.checkpoint import checkpoint
|
|
import numpy as np
|
|
import math
|
|
import logging
|
|
from typing import Dict, Any, Optional, Tuple, List, Callable
|
|
from dataclasses import dataclass
|
|
import os
|
|
import json
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TradingTransformerConfig:
|
|
"""Configuration for trading transformer models - OPTIMIZED FOR GPU (8-12M params)"""
|
|
# Model architecture - REDUCED for efficient GPU training
|
|
d_model: int = 256 # Model dimension (was 1024)
|
|
n_heads: int = 8 # Number of attention heads (was 16)
|
|
n_layers: int = 4 # Number of transformer layers (was 12)
|
|
d_ff: int = 1024 # Feed-forward dimension (was 4096)
|
|
dropout: float = 0.1 # Dropout rate
|
|
|
|
# Input dimensions - OPTIMIZED
|
|
seq_len: int = 200 # Sequence length for time series
|
|
cob_features: int = 100 # COB feature dimension
|
|
tech_features: int = 40 # Technical indicator features
|
|
market_features: int = 30 # Market microstructure features
|
|
|
|
# Output configuration
|
|
n_actions: int = 3 # BUY, SELL, HOLD
|
|
confidence_output: bool = True # Output confidence scores
|
|
|
|
# Training configuration - OPTIMIZED FOR LARGER MODEL
|
|
learning_rate: float = 5e-5 # Reduced for larger model
|
|
weight_decay: float = 1e-4 # Increased regularization
|
|
warmup_steps: int = 8000 # More warmup steps
|
|
max_grad_norm: float = 0.5 # Tighter gradient clipping
|
|
|
|
# Advanced features - ENHANCED
|
|
use_relative_position: bool = True
|
|
use_multi_scale_attention: bool = True
|
|
use_market_regime_detection: bool = True
|
|
use_uncertainty_estimation: bool = True
|
|
|
|
# NEW: Additional scaling features
|
|
use_deep_attention: bool = True # Deeper attention mechanisms
|
|
use_residual_connections: bool = True # Enhanced residual connections
|
|
use_layer_norm_variants: bool = True # Advanced normalization
|
|
|
|
# Memory optimization
|
|
use_gradient_checkpointing: bool = True # Trade compute for memory (saves ~30% memory)
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""Sinusoidal positional encoding for transformer"""
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000):
|
|
super().__init__()
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
|
(-math.log(10000.0) / d_model))
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.pe[:x.size(0), :]
|
|
|
|
class RelativePositionalEncoding(nn.Module):
|
|
"""Relative positional encoding for better temporal understanding"""
|
|
|
|
def __init__(self, d_model: int, max_relative_position: int = 128):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.max_relative_position = max_relative_position
|
|
|
|
# Learnable relative position embeddings
|
|
self.relative_position_embeddings = nn.Embedding(
|
|
2 * max_relative_position + 1, d_model
|
|
)
|
|
|
|
def forward(self, seq_len: int) -> torch.Tensor:
|
|
"""Generate relative position encoding matrix"""
|
|
range_vec = torch.arange(seq_len)
|
|
range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
|
|
distance_mat = range_mat - range_mat.transpose(0, 1)
|
|
|
|
# Clip to max relative position
|
|
distance_mat_clipped = torch.clamp(
|
|
distance_mat, -self.max_relative_position, self.max_relative_position
|
|
)
|
|
|
|
# Shift to positive indices
|
|
final_mat = distance_mat_clipped + self.max_relative_position
|
|
|
|
return self.relative_position_embeddings(final_mat)
|
|
|
|
class DeepMultiScaleAttention(nn.Module):
|
|
"""Lightweight multi-scale attention optimized for 8-12M parameter model"""
|
|
|
|
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5]):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
self.scales = scales # Reduced from 6 scales to 3
|
|
self.head_dim = d_model // n_heads
|
|
|
|
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
|
|
# Lightweight multi-scale projections (single layer instead of deep)
|
|
self.scale_projections = nn.ModuleList([
|
|
nn.ModuleDict({
|
|
'query': nn.Linear(d_model, d_model),
|
|
'key': nn.Linear(d_model, d_model),
|
|
'value': nn.Linear(d_model, d_model),
|
|
'conv': nn.Conv1d(d_model, d_model, kernel_size=scale,
|
|
padding=scale//2, groups=d_model//4)
|
|
}) for scale in scales
|
|
])
|
|
|
|
# Lightweight output projection
|
|
self.output_projection = nn.Linear(d_model * len(scales), d_model)
|
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
batch_size, seq_len, _ = x.size()
|
|
scale_outputs = []
|
|
|
|
for scale_proj in self.scale_projections:
|
|
# Apply enhanced temporal convolution for this scale
|
|
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
# Enhanced attention computation with deeper projections
|
|
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
|
|
# Transpose for attention computation
|
|
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
|
|
K = K.transpose(1, 2)
|
|
V = V.transpose(1, 2)
|
|
|
|
# Scaled dot-product attention
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
if mask is not None:
|
|
scores.masked_fill_(mask == 0, -1e9)
|
|
|
|
attention = F.softmax(scores, dim=-1)
|
|
attention = self.dropout(attention)
|
|
|
|
output = torch.matmul(attention, V)
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
|
|
|
scale_outputs.append(output)
|
|
|
|
# Combine multi-scale outputs
|
|
combined = torch.cat(scale_outputs, dim=-1)
|
|
output = self.output_projection(combined)
|
|
|
|
return output
|
|
|
|
class MarketRegimeDetector(nn.Module):
|
|
"""Market regime detection module for adaptive behavior"""
|
|
|
|
def __init__(self, d_model: int, n_regimes: int = 4):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_regimes = n_regimes
|
|
|
|
# Regime classification layers
|
|
self.regime_classifier = nn.Sequential(
|
|
nn.Linear(d_model, d_model // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model // 2, n_regimes)
|
|
)
|
|
|
|
# Regime-specific transformations
|
|
self.regime_transforms = nn.ModuleList([
|
|
nn.Linear(d_model, d_model) for _ in range(n_regimes)
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Global pooling for regime detection
|
|
pooled = torch.mean(x, dim=1) # (batch, d_model)
|
|
|
|
# Classify market regime
|
|
regime_logits = self.regime_classifier(pooled)
|
|
regime_probs = F.softmax(regime_logits, dim=-1)
|
|
|
|
# Apply regime-specific transformations
|
|
regime_outputs = []
|
|
for i, transform in enumerate(self.regime_transforms):
|
|
regime_output = transform(x) # (batch, seq_len, d_model)
|
|
regime_outputs.append(regime_output)
|
|
|
|
# Weighted combination based on regime probabilities
|
|
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
|
|
regime_weights = regime_probs.unsqueeze(0).unsqueeze(2).unsqueeze(3) # (1, batch, 1, 1, n_regimes)
|
|
regime_weights = regime_weights.permute(4, 1, 2, 3, 0).squeeze(-1) # (n_regimes, batch, 1, 1)
|
|
|
|
# Weighted sum across regimes
|
|
adapted_output = torch.sum(regime_stack * regime_weights, dim=0)
|
|
|
|
return adapted_output, regime_probs
|
|
|
|
class UncertaintyEstimation(nn.Module):
|
|
"""Uncertainty estimation using Monte Carlo Dropout"""
|
|
|
|
def __init__(self, d_model: int, n_samples: int = 10):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_samples = n_samples
|
|
|
|
self.uncertainty_head = nn.Sequential(
|
|
nn.Linear(d_model, d_model // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5), # Higher dropout for uncertainty estimation
|
|
nn.Linear(d_model // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if training or not self.training:
|
|
# Single forward pass during training or when not in MC mode
|
|
uncertainty = self.uncertainty_head(x)
|
|
return uncertainty, uncertainty
|
|
|
|
# Monte Carlo sampling during inference
|
|
uncertainties = []
|
|
for _ in range(self.n_samples):
|
|
uncertainty = self.uncertainty_head(x)
|
|
uncertainties.append(uncertainty)
|
|
|
|
uncertainties = torch.stack(uncertainties, dim=0)
|
|
mean_uncertainty = torch.mean(uncertainties, dim=0)
|
|
std_uncertainty = torch.std(uncertainties, dim=0)
|
|
|
|
return mean_uncertainty, std_uncertainty
|
|
|
|
class TradingTransformerLayer(nn.Module):
|
|
"""Enhanced transformer layer for trading applications"""
|
|
|
|
def __init__(self, config: TradingTransformerConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Enhanced multi-scale attention or standard attention
|
|
if config.use_multi_scale_attention:
|
|
self.attention = DeepMultiScaleAttention(config.d_model, config.n_heads)
|
|
else:
|
|
self.attention = nn.MultiheadAttention(
|
|
config.d_model, config.n_heads, dropout=config.dropout, batch_first=True
|
|
)
|
|
|
|
# Feed-forward network
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_ff),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_ff, config.d_model)
|
|
)
|
|
|
|
# Layer normalization
|
|
self.norm1 = nn.LayerNorm(config.d_model)
|
|
self.norm2 = nn.LayerNorm(config.d_model)
|
|
|
|
# Dropout
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
# Market regime detection
|
|
if config.use_market_regime_detection:
|
|
self.regime_detector = MarketRegimeDetector(config.d_model)
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
# Self-attention with residual connection
|
|
if isinstance(self.attention, DeepMultiScaleAttention):
|
|
attn_output = self.attention(x, mask)
|
|
else:
|
|
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
|
|
|
x = self.norm1(x + self.dropout(attn_output))
|
|
|
|
# Market regime adaptation
|
|
regime_probs = None
|
|
if hasattr(self, 'regime_detector'):
|
|
x, regime_probs = self.regime_detector(x)
|
|
|
|
# Feed-forward with residual connection
|
|
ff_output = self.feed_forward(x)
|
|
x = self.norm2(x + self.dropout(ff_output))
|
|
|
|
return {
|
|
'output': x,
|
|
'regime_probs': regime_probs
|
|
}
|
|
|
|
class AdvancedTradingTransformer(nn.Module):
|
|
"""Advanced transformer model for high-frequency trading"""
|
|
|
|
def __init__(self, config: TradingTransformerConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Timeframe configuration
|
|
self.timeframes = ['1s', '1m', '1h', '1d']
|
|
self.num_timeframes = len(self.timeframes) + 1 # +1 for BTC
|
|
|
|
# SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes)
|
|
# This is applied to each timeframe independently but uses SAME weights
|
|
# LIGHTWEIGHT: 2-layer encoder for efficiency
|
|
self.shared_pattern_encoder = nn.Sequential(
|
|
nn.Linear(5, config.d_model // 2), # 5 OHLCV -> 128
|
|
nn.LayerNorm(config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.d_model) # 128 -> 256
|
|
)
|
|
|
|
# Timeframe-specific embeddings (learnable, added to shared encoding)
|
|
# These help the model distinguish which timeframe it's looking at
|
|
self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model)
|
|
|
|
# PARALLEL: Cross-timeframe attention layer (single layer for efficiency)
|
|
# Processes all timeframes simultaneously to capture dependencies
|
|
self.cross_timeframe_layer = nn.TransformerEncoderLayer(
|
|
d_model=config.d_model,
|
|
nhead=config.n_heads,
|
|
dim_feedforward=config.d_ff,
|
|
dropout=config.dropout,
|
|
activation='gelu',
|
|
batch_first=True
|
|
)
|
|
|
|
# Other input projections
|
|
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
|
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
|
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
|
|
|
# Position state projection
|
|
self.position_projection = nn.Sequential(
|
|
nn.Linear(5, config.d_model // 4),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 4, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.d_model)
|
|
)
|
|
|
|
# Positional encoding
|
|
if config.use_relative_position:
|
|
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
|
else:
|
|
self.pos_encoding = PositionalEncoding(config.d_model, config.seq_len)
|
|
|
|
# Transformer layers
|
|
self.layers = nn.ModuleList([
|
|
TradingTransformerLayer(config) for _ in range(config.n_layers)
|
|
])
|
|
|
|
# Lightweight output heads for 8-12M parameter model
|
|
self.action_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.n_actions)
|
|
)
|
|
|
|
if config.confidence_output:
|
|
self.confidence_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Enhanced uncertainty estimation
|
|
if config.use_uncertainty_estimation:
|
|
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
|
|
|
# Lightweight price prediction head
|
|
self.price_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, 1),
|
|
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
|
|
)
|
|
|
|
# Lightweight volatility and trend heads
|
|
self.volatility_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 4),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 4, 1),
|
|
nn.Softplus()
|
|
)
|
|
|
|
self.trend_strength_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 4),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 4, 1),
|
|
nn.Tanh()
|
|
)
|
|
|
|
# Lightweight next candle OHLCV prediction heads
|
|
self.next_candle_heads = nn.ModuleDict({
|
|
tf: nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 2, 5), # OHLCV
|
|
nn.Sigmoid() # Constrain to [0, 1]
|
|
) for tf in self.timeframes
|
|
})
|
|
|
|
# BTC next candle prediction head
|
|
self.btc_next_candle_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 2, 5), # OHLCV for BTC
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Lightweight pivot point prediction heads (L1-L3 only for efficiency)
|
|
self.pivot_levels = [1, 2, 3] # Reduced from L1-L5 to L1-L3
|
|
self.pivot_heads = nn.ModuleDict({
|
|
f'L{level}': nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 2, 4) # [price, type_prob_high, type_prob_low, confidence]
|
|
) for level in self.pivot_levels
|
|
})
|
|
|
|
# Lightweight trend vector analysis head
|
|
self.trend_analysis_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Linear(config.d_model // 2, 3) # [angle_radians, steepness, direction]
|
|
)
|
|
|
|
# Initialize weights
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
"""Initialize model weights"""
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
nn.init.ones_(module.weight)
|
|
nn.init.zeros_(module.bias)
|
|
|
|
def forward(self,
|
|
# Multi-timeframe inputs
|
|
price_data_1s: Optional[torch.Tensor] = None,
|
|
price_data_1m: Optional[torch.Tensor] = None,
|
|
price_data_1h: Optional[torch.Tensor] = None,
|
|
price_data_1d: Optional[torch.Tensor] = None,
|
|
btc_data_1m: Optional[torch.Tensor] = None,
|
|
# Other inputs
|
|
cob_data: Optional[torch.Tensor] = None,
|
|
tech_data: Optional[torch.Tensor] = None,
|
|
market_data: Optional[torch.Tensor] = None,
|
|
position_state: Optional[torch.Tensor] = None,
|
|
mask: Optional[torch.Tensor] = None,
|
|
# Legacy support
|
|
price_data: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass with hybrid serial-parallel multi-timeframe processing
|
|
|
|
SERIAL: Shared pattern encoder learns candle patterns once (same weights for all timeframes)
|
|
PARALLEL: Cross-timeframe attention captures dependencies between timeframes
|
|
|
|
Args:
|
|
price_data_1s: (batch, seq_len, 5) - 1-second OHLCV (optional)
|
|
price_data_1m: (batch, seq_len, 5) - 1-minute OHLCV (optional)
|
|
price_data_1h: (batch, seq_len, 5) - 1-hour OHLCV (optional)
|
|
price_data_1d: (batch, seq_len, 5) - 1-day OHLCV (optional)
|
|
btc_data_1m: (batch, seq_len, 5) - BTC 1-minute OHLCV (optional)
|
|
cob_data: (batch, seq_len, cob_features) - COB features
|
|
tech_data: (batch, tech_features) - Technical indicators
|
|
market_data: (batch, market_features) - Market features
|
|
position_state: (batch, 5) - Position state
|
|
mask: Optional attention mask
|
|
price_data: (batch, seq_len, 5) - Legacy single timeframe (defaults to 1m)
|
|
|
|
Returns:
|
|
Dictionary with predictions for ALL timeframes
|
|
"""
|
|
# Legacy support
|
|
if price_data is not None and price_data_1m is None:
|
|
price_data_1m = price_data
|
|
|
|
# Collect available timeframes
|
|
timeframe_data = {
|
|
'1s': price_data_1s,
|
|
'1m': price_data_1m,
|
|
'1h': price_data_1h,
|
|
'1d': price_data_1d,
|
|
'btc': btc_data_1m
|
|
}
|
|
|
|
# Filter to available timeframes
|
|
available_tfs = [(tf, data) for tf, data in timeframe_data.items() if data is not None]
|
|
|
|
if not available_tfs:
|
|
raise ValueError("At least one timeframe must be provided")
|
|
|
|
# Get dimensions from first available timeframe
|
|
first_data = available_tfs[0][1]
|
|
batch_size, seq_len = first_data.shape[:2]
|
|
device = first_data.device
|
|
|
|
# ============================================================
|
|
# STEP 1: SERIAL - Apply shared pattern encoder to each timeframe
|
|
# This learns candle patterns ONCE (same weights for all)
|
|
# ============================================================
|
|
timeframe_encodings = []
|
|
timeframe_indices = []
|
|
|
|
for idx, (tf_name, tf_data) in enumerate(available_tfs):
|
|
# Ensure correct sequence length
|
|
if tf_data.shape[1] != seq_len:
|
|
if tf_data.shape[1] < seq_len:
|
|
# Pad with last candle
|
|
padding = tf_data[:, -1:, :].expand(batch_size, seq_len - tf_data.shape[1], 5)
|
|
tf_data = torch.cat([tf_data, padding], dim=1)
|
|
else:
|
|
# Truncate to seq_len
|
|
tf_data = tf_data[:, :seq_len, :]
|
|
|
|
# Apply SHARED pattern encoder (learns patterns once for all timeframes)
|
|
# Shape: [batch, seq_len, 5] -> [batch, seq_len, d_model]
|
|
tf_encoded = self.shared_pattern_encoder(tf_data)
|
|
|
|
# Add timeframe-specific embedding (helps model know which timeframe)
|
|
# Get timeframe index
|
|
tf_idx = self.timeframes.index(tf_name) if tf_name in self.timeframes else len(self.timeframes)
|
|
tf_embedding = self.timeframe_embeddings(torch.tensor([tf_idx], device=device))
|
|
tf_embedding = tf_embedding.unsqueeze(1).expand(batch_size, seq_len, -1)
|
|
|
|
# Combine: shared pattern + timeframe identity
|
|
tf_encoded = tf_encoded + tf_embedding
|
|
|
|
timeframe_encodings.append(tf_encoded)
|
|
timeframe_indices.append(tf_idx)
|
|
|
|
# ============================================================
|
|
# STEP 2: PARALLEL - Cross-timeframe attention
|
|
# Process all timeframes together to capture dependencies
|
|
# ============================================================
|
|
|
|
# Stack timeframes: [batch, num_timeframes, seq_len, d_model]
|
|
# Then reshape to: [batch, num_timeframes * seq_len, d_model]
|
|
stacked_tfs = torch.stack(timeframe_encodings, dim=1) # [batch, num_tfs, seq_len, d_model]
|
|
num_tfs = len(timeframe_encodings)
|
|
|
|
# MEMORY EFFICIENT: Process timeframes with shared weights
|
|
# Reshape to process all timeframes in parallel: [batch*num_tfs, seq_len, d_model]
|
|
# This avoids creating huge concatenated sequences while still processing efficiently
|
|
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
|
|
|
# Apply single cross-timeframe attention layer
|
|
batched_tfs = self.cross_timeframe_layer(batched_tfs)
|
|
|
|
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
|
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
|
|
|
# Average across timeframes to get unified representation
|
|
# [batch, num_tfs, seq_len, d_model] -> [batch, seq_len, d_model]
|
|
price_emb = cross_tf_output.mean(dim=1)
|
|
|
|
# ============================================================
|
|
# STEP 3: Add other features (COB, tech, market, position)
|
|
# ============================================================
|
|
|
|
# COB features
|
|
if cob_data is not None:
|
|
if cob_data.dim() == 2:
|
|
cob_data = cob_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
|
cob_emb = self.cob_projection(cob_data)
|
|
else:
|
|
cob_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
|
|
|
# Technical indicators
|
|
if tech_data is not None:
|
|
if tech_data.dim() == 2:
|
|
tech_data = tech_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
|
tech_emb = self.tech_projection(tech_data)
|
|
else:
|
|
tech_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
|
|
|
# Market features
|
|
if market_data is not None:
|
|
if market_data.dim() == 2:
|
|
market_data = market_data.unsqueeze(1).expand(batch_size, seq_len, -1)
|
|
market_emb = self.market_projection(market_data)
|
|
else:
|
|
market_emb = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
|
|
|
|
# Combine all embeddings
|
|
x = price_emb + cob_emb + tech_emb + market_emb
|
|
|
|
# Add position state if provided - critical for loss minimization and profit taking
|
|
if position_state is not None:
|
|
# Project position state through learned embedding network
|
|
# Input: [batch, 5] -> Output: [batch, d_model]
|
|
position_emb = self.position_projection(position_state) # [batch, d_model]
|
|
|
|
# Expand to sequence length and add as bias to all positions
|
|
# This conditions the entire sequence on current position state
|
|
position_emb = position_emb.unsqueeze(1).expand(batch_size, seq_len, -1) # [batch, seq_len, d_model]
|
|
|
|
# Add position embedding to the combined embeddings
|
|
# This allows the model to modulate its predictions based on position state
|
|
x = x + position_emb
|
|
|
|
# Add positional encoding
|
|
if isinstance(self.pos_encoding, RelativePositionalEncoding):
|
|
# Relative position encoding is applied in attention
|
|
pass
|
|
else:
|
|
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
|
|
|
# Apply transformer layers with optional gradient checkpointing
|
|
regime_probs_history = []
|
|
for layer in self.layers:
|
|
if self.training and self.config.use_gradient_checkpointing:
|
|
# Use gradient checkpointing to save memory during training
|
|
# Trades compute for memory (recomputes activations during backward pass)
|
|
layer_output = checkpoint(
|
|
layer, x, mask, use_reentrant=False
|
|
)
|
|
else:
|
|
layer_output = layer(x, mask)
|
|
|
|
x = layer_output['output']
|
|
if layer_output['regime_probs'] is not None:
|
|
regime_probs_history.append(layer_output['regime_probs'])
|
|
|
|
# Global pooling for final prediction
|
|
# Use attention-based pooling
|
|
pooling_weights = F.softmax(
|
|
torch.sum(x, dim=-1, keepdim=True), dim=1
|
|
)
|
|
pooled = torch.sum(x * pooling_weights, dim=1)
|
|
|
|
# Generate outputs
|
|
outputs = {}
|
|
|
|
# Action prediction
|
|
action_logits = self.action_head(pooled)
|
|
outputs['action_logits'] = action_logits
|
|
outputs['action_probs'] = F.softmax(action_logits, dim=-1)
|
|
|
|
# Confidence prediction
|
|
if self.config.confidence_output:
|
|
confidence = self.confidence_head(pooled)
|
|
outputs['confidence'] = confidence
|
|
|
|
# Uncertainty estimation
|
|
if self.config.use_uncertainty_estimation:
|
|
uncertainty_mean, uncertainty_std = self.uncertainty_estimator(pooled)
|
|
outputs['uncertainty_mean'] = uncertainty_mean
|
|
outputs['uncertainty_std'] = uncertainty_std
|
|
|
|
# Enhanced price prediction (auxiliary task)
|
|
price_pred = self.price_head(pooled)
|
|
outputs['price_prediction'] = price_pred
|
|
|
|
# Additional specialized predictions for 46M model
|
|
volatility_pred = self.volatility_head(pooled)
|
|
outputs['volatility_prediction'] = volatility_pred
|
|
|
|
trend_strength_pred = self.trend_strength_head(pooled)
|
|
outputs['trend_strength_prediction'] = trend_strength_pred
|
|
|
|
# NEW: Next candle OHLCV predictions for each timeframe
|
|
next_candles = {}
|
|
for tf in self.timeframes:
|
|
candle_pred = self.next_candle_heads[tf](pooled) # (batch, 5)
|
|
next_candles[tf] = candle_pred
|
|
outputs['next_candles'] = next_candles
|
|
|
|
# BTC next candle prediction
|
|
btc_next_candle = self.btc_next_candle_head(pooled) # (batch, 5)
|
|
outputs['btc_next_candle'] = btc_next_candle
|
|
|
|
# NEW: Next pivot point predictions for L1-L5
|
|
next_pivots = {}
|
|
for level in self.pivot_levels:
|
|
pivot_pred = self.pivot_heads[f'L{level}'](pooled) # (batch, 4)
|
|
# Extract components: [price, type_logit_high, type_logit_low, confidence]
|
|
# Use softmax to ensure type probabilities sum to 1
|
|
type_logits = pivot_pred[:, 1:3] # (batch, 2) - [high, low]
|
|
type_probs = F.softmax(type_logits, dim=-1) # (batch, 2)
|
|
|
|
next_pivots[f'L{level}'] = {
|
|
'price': pivot_pred[:, 0:1], # Keep as (batch, 1)
|
|
'type_prob_high': type_probs[:, 0:1], # Probability of high pivot
|
|
'type_prob_low': type_probs[:, 1:2], # Probability of low pivot
|
|
'pivot_type': torch.argmax(type_probs, dim=-1, keepdim=True), # 0=high, 1=low
|
|
'confidence': torch.sigmoid(pivot_pred[:, 3:4]) # Prediction confidence
|
|
}
|
|
outputs['next_pivots'] = next_pivots
|
|
|
|
# NEW: Trend vector analysis from pivot predictions
|
|
trend_analysis = self.trend_analysis_head(pooled) # (batch, 3)
|
|
outputs['trend_analysis'] = {
|
|
'angle_radians': trend_analysis[:, 0:1], # Trend angle in radians
|
|
'steepness': F.softplus(trend_analysis[:, 1:2]), # Always positive steepness
|
|
'direction': torch.tanh(trend_analysis[:, 2:3]) # -1 to 1 (down to up)
|
|
}
|
|
|
|
# NEW: Calculate trend vector from pivot predictions
|
|
# Extract pivot prices and create trend vector
|
|
pivot_prices = torch.stack([next_pivots[f'L{level}']['price'] for level in self.pivot_levels], dim=1) # (batch, 5, 1)
|
|
pivot_prices = pivot_prices.squeeze(-1) # (batch, 5)
|
|
|
|
# Calculate trend vector: (price_change, time_change)
|
|
# Assume equal time spacing between pivot levels
|
|
time_points = torch.arange(1, len(self.pivot_levels) + 1, dtype=torch.float32, device=pooled.device).unsqueeze(0) # (1, 5)
|
|
|
|
# Calculate trend line slope using linear regression on pivot prices
|
|
# Trend vector = (delta_price, delta_time) normalized
|
|
if batch_size > 0:
|
|
# For each sample, calculate trend from L1 to L5
|
|
price_deltas = pivot_prices[:, -1:] - pivot_prices[:, :1] # L5 - L1 price change
|
|
time_deltas = time_points[:, -1:] - time_points[:, :1] # Time change (should be 4)
|
|
|
|
# Calculate angle and steepness
|
|
trend_angles = torch.atan2(price_deltas.squeeze(), time_deltas.squeeze()) # (batch,)
|
|
trend_steepness = torch.sqrt(price_deltas.squeeze() ** 2 + time_deltas.squeeze() ** 2) # (batch,)
|
|
trend_direction = torch.sign(price_deltas.squeeze()) # (batch,)
|
|
|
|
outputs['trend_vector'] = {
|
|
'pivot_prices': pivot_prices, # (batch, 5) - prices for L1-L5
|
|
'price_delta': price_deltas.squeeze(), # (batch,) - price change from L1 to L5
|
|
'time_delta': time_deltas.squeeze(), # (batch,) - time change
|
|
'calculated_angle': trend_angles.unsqueeze(-1), # (batch, 1)
|
|
'calculated_steepness': trend_steepness.unsqueeze(-1), # (batch, 1)
|
|
'calculated_direction': trend_direction.unsqueeze(-1), # (batch, 1)
|
|
'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=0).unsqueeze(0) if batch_size == 1 else torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta]
|
|
}
|
|
else:
|
|
outputs['trend_vector'] = {
|
|
'pivot_prices': pivot_prices,
|
|
'price_delta': torch.zeros(batch_size, device=pooled.device),
|
|
'time_delta': torch.zeros(batch_size, device=pooled.device),
|
|
'calculated_angle': torch.zeros(batch_size, 1, device=pooled.device),
|
|
'calculated_steepness': torch.zeros(batch_size, 1, device=pooled.device),
|
|
'calculated_direction': torch.zeros(batch_size, 1, device=pooled.device),
|
|
'vector': torch.zeros(batch_size, 2, device=pooled.device)
|
|
}
|
|
|
|
# NEW: Trade action based on trend steepness and angle
|
|
# Combine predicted trend analysis with calculated trend vector
|
|
predicted_angle = outputs['trend_analysis']['angle_radians'].squeeze() # (batch,)
|
|
predicted_steepness = outputs['trend_analysis']['steepness'].squeeze() # (batch,)
|
|
predicted_direction = outputs['trend_analysis']['direction'].squeeze() # (batch,)
|
|
|
|
# Use calculated trend if available, otherwise use predicted
|
|
if 'calculated_angle' in outputs['trend_vector']:
|
|
trend_angle = outputs['trend_vector']['calculated_angle'].squeeze() # (batch,)
|
|
trend_steepness_val = outputs['trend_vector']['calculated_steepness'].squeeze() # (batch,)
|
|
else:
|
|
trend_angle = predicted_angle
|
|
trend_steepness_val = predicted_steepness
|
|
|
|
# Trade action logic based on trend steepness and angle
|
|
# Steep upward trend (> 45 degrees) -> BUY
|
|
# Steep downward trend (< -45 degrees) -> SELL
|
|
# Shallow trend -> HOLD
|
|
angle_threshold = math.pi / 4 # 45 degrees
|
|
|
|
# Determine action from trend angle
|
|
trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD]
|
|
|
|
# Calculate action probabilities based on trend
|
|
for i in range(batch_size):
|
|
# Handle both 0-dim and 1-dim tensors
|
|
if trend_angle.dim() == 0:
|
|
angle = trend_angle.item()
|
|
steep = trend_steepness_val.item()
|
|
else:
|
|
angle = trend_angle[i].item()
|
|
steep = trend_steepness_val[i].item()
|
|
|
|
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
|
|
normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0
|
|
|
|
if angle > angle_threshold: # Steep upward trend
|
|
trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY
|
|
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
|
elif angle < -angle_threshold: # Steep downward trend
|
|
trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL
|
|
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
|
else: # Shallow trend
|
|
trend_action_logits[i, 2] = 1.0 # HOLD
|
|
|
|
# Combine trend-based action with main action prediction
|
|
trend_action_probs = F.softmax(trend_action_logits, dim=-1)
|
|
outputs['trend_based_action'] = {
|
|
'logits': trend_action_logits,
|
|
'probabilities': trend_action_probs,
|
|
'action_idx': torch.argmax(trend_action_probs, dim=-1),
|
|
'trend_angle_degrees': trend_angle * 180.0 / math.pi, # Convert to degrees
|
|
'trend_steepness': trend_steepness_val
|
|
}
|
|
|
|
# Market regime information
|
|
if regime_probs_history:
|
|
outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1)
|
|
|
|
return outputs
|
|
|
|
def extract_predictions(self, outputs: Dict[str, torch.Tensor], denormalize_prices: Optional[Callable] = None) -> Dict[str, Any]:
|
|
"""
|
|
Extract predictions from model outputs in a user-friendly format
|
|
|
|
Args:
|
|
outputs: Raw model outputs from forward() method
|
|
denormalize_prices: Optional function to denormalize predicted prices
|
|
|
|
Returns:
|
|
Dictionary with formatted predictions including:
|
|
- next_candles: Dict[str, Dict] - OHLCV predictions for each timeframe
|
|
- next_pivots: Dict[str, Dict] - Pivot predictions for L1-L5
|
|
- trend_vector: Dict - Trend vector analysis
|
|
- trend_based_action: Dict - Trading action based on trend
|
|
"""
|
|
self.eval()
|
|
device = next(self.parameters()).device
|
|
|
|
predictions = {}
|
|
|
|
# Extract next candle predictions for each timeframe
|
|
if 'next_candles' in outputs:
|
|
next_candles = {}
|
|
for tf in self.timeframes:
|
|
candle_tensor = outputs['next_candles'][tf]
|
|
if candle_tensor.dim() > 1:
|
|
candle_tensor = candle_tensor[0] # Take first batch item
|
|
|
|
candle_values = candle_tensor.cpu().detach().numpy() if hasattr(candle_tensor, 'cpu') else candle_tensor
|
|
if isinstance(candle_values, np.ndarray):
|
|
candle_values = candle_values.tolist()
|
|
|
|
next_candles[tf] = {
|
|
'open': float(candle_values[0]) if len(candle_values) > 0 else 0.0,
|
|
'high': float(candle_values[1]) if len(candle_values) > 1 else 0.0,
|
|
'low': float(candle_values[2]) if len(candle_values) > 2 else 0.0,
|
|
'close': float(candle_values[3]) if len(candle_values) > 3 else 0.0,
|
|
'volume': float(candle_values[4]) if len(candle_values) > 4 else 0.0
|
|
}
|
|
|
|
# Denormalize if function provided
|
|
if denormalize_prices and callable(denormalize_prices):
|
|
for key in ['open', 'high', 'low', 'close']:
|
|
next_candles[tf][key] = denormalize_prices(next_candles[tf][key])
|
|
|
|
predictions['next_candles'] = next_candles
|
|
|
|
# Extract pivot point predictions
|
|
if 'next_pivots' in outputs:
|
|
next_pivots = {}
|
|
for level in self.pivot_levels:
|
|
pivot_data = outputs['next_pivots'][f'L{level}']
|
|
|
|
# Extract values
|
|
price = pivot_data['price']
|
|
if price.dim() > 1:
|
|
price = price[0, 0] if price.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
price_val = float(price.cpu().detach().item() if hasattr(price, 'cpu') else price)
|
|
|
|
type_prob_high = pivot_data['type_prob_high']
|
|
if type_prob_high.dim() > 1:
|
|
type_prob_high = type_prob_high[0, 0] if type_prob_high.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
prob_high = float(type_prob_high.cpu().detach().item() if hasattr(type_prob_high, 'cpu') else type_prob_high)
|
|
|
|
type_prob_low = pivot_data['type_prob_low']
|
|
if type_prob_low.dim() > 1:
|
|
type_prob_low = type_prob_low[0, 0] if type_prob_low.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
prob_low = float(type_prob_low.cpu().detach().item() if hasattr(type_prob_low, 'cpu') else type_prob_low)
|
|
|
|
confidence = pivot_data['confidence']
|
|
if confidence.dim() > 1:
|
|
confidence = confidence[0, 0] if confidence.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
conf_val = float(confidence.cpu().detach().item() if hasattr(confidence, 'cpu') else confidence)
|
|
|
|
pivot_type = pivot_data.get('pivot_type', torch.tensor(0))
|
|
if isinstance(pivot_type, torch.Tensor):
|
|
if pivot_type.dim() > 1:
|
|
pivot_type = pivot_type[0, 0] if pivot_type.shape[0] > 0 else torch.tensor(0, device=device)
|
|
pivot_type_val = int(pivot_type.cpu().detach().item() if hasattr(pivot_type, 'cpu') else pivot_type)
|
|
else:
|
|
pivot_type_val = int(pivot_type)
|
|
|
|
# Denormalize price if function provided
|
|
if denormalize_prices and callable(denormalize_prices):
|
|
price_val = denormalize_prices(price_val)
|
|
|
|
next_pivots[f'L{level}'] = {
|
|
'price': price_val,
|
|
'type': 'high' if pivot_type_val == 0 else 'low',
|
|
'type_prob_high': prob_high,
|
|
'type_prob_low': prob_low,
|
|
'confidence': conf_val
|
|
}
|
|
|
|
predictions['next_pivots'] = next_pivots
|
|
|
|
# Extract trend vector
|
|
if 'trend_vector' in outputs:
|
|
trend_vec = outputs['trend_vector']
|
|
|
|
# Extract pivot prices
|
|
pivot_prices = trend_vec.get('pivot_prices', torch.zeros(5, device=device))
|
|
if isinstance(pivot_prices, torch.Tensor):
|
|
if pivot_prices.dim() > 1:
|
|
pivot_prices = pivot_prices[0]
|
|
pivot_prices_list = pivot_prices.cpu().detach().numpy().tolist() if hasattr(pivot_prices, 'cpu') else pivot_prices.tolist()
|
|
else:
|
|
pivot_prices_list = pivot_prices
|
|
|
|
# Denormalize pivot prices if function provided
|
|
if denormalize_prices and callable(denormalize_prices):
|
|
pivot_prices_list = [denormalize_prices(p) for p in pivot_prices_list]
|
|
|
|
angle = trend_vec.get('calculated_angle', torch.tensor(0.0, device=device))
|
|
if isinstance(angle, torch.Tensor):
|
|
if angle.dim() > 1:
|
|
angle = angle[0, 0] if angle.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
angle_val = float(angle.cpu().detach().item() if hasattr(angle, 'cpu') else angle)
|
|
else:
|
|
angle_val = float(angle)
|
|
|
|
steepness = trend_vec.get('calculated_steepness', torch.tensor(0.0, device=device))
|
|
if isinstance(steepness, torch.Tensor):
|
|
if steepness.dim() > 1:
|
|
steepness = steepness[0, 0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
|
|
else:
|
|
steepness_val = float(steepness)
|
|
|
|
direction = trend_vec.get('calculated_direction', torch.tensor(0.0, device=device))
|
|
if isinstance(direction, torch.Tensor):
|
|
if direction.dim() > 1:
|
|
direction = direction[0, 0] if direction.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
direction_val = float(direction.cpu().detach().item() if hasattr(direction, 'cpu') else direction)
|
|
else:
|
|
direction_val = float(direction)
|
|
|
|
price_delta = trend_vec.get('price_delta', torch.tensor(0.0, device=device))
|
|
if isinstance(price_delta, torch.Tensor):
|
|
if price_delta.dim() > 0:
|
|
price_delta = price_delta[0] if price_delta.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
price_delta_val = float(price_delta.cpu().detach().item() if hasattr(price_delta, 'cpu') else price_delta)
|
|
else:
|
|
price_delta_val = float(price_delta)
|
|
|
|
predictions['trend_vector'] = {
|
|
'pivot_prices': pivot_prices_list, # [L1, L2, L3, L4, L5]
|
|
'angle_radians': angle_val,
|
|
'angle_degrees': angle_val * 180.0 / math.pi,
|
|
'steepness': steepness_val,
|
|
'direction': 'up' if direction_val > 0 else 'down' if direction_val < 0 else 'sideways',
|
|
'price_delta': price_delta_val
|
|
}
|
|
|
|
# Extract trend-based action
|
|
if 'trend_based_action' in outputs:
|
|
trend_action = outputs['trend_based_action']
|
|
|
|
action_probs = trend_action.get('probabilities', torch.zeros(3, device=device))
|
|
if isinstance(action_probs, torch.Tensor):
|
|
if action_probs.dim() > 1:
|
|
action_probs = action_probs[0]
|
|
action_probs_list = action_probs.cpu().detach().numpy().tolist() if hasattr(action_probs, 'cpu') else action_probs.tolist()
|
|
else:
|
|
action_probs_list = action_probs
|
|
|
|
action_idx = trend_action.get('action_idx', torch.tensor(2, device=device))
|
|
if isinstance(action_idx, torch.Tensor):
|
|
if action_idx.dim() > 0:
|
|
action_idx = action_idx[0] if action_idx.shape[0] > 0 else torch.tensor(2, device=device)
|
|
action_idx_val = int(action_idx.cpu().detach().item() if hasattr(action_idx, 'cpu') else action_idx)
|
|
else:
|
|
action_idx_val = int(action_idx)
|
|
|
|
angle_degrees = trend_action.get('trend_angle_degrees', torch.tensor(0.0, device=device))
|
|
if isinstance(angle_degrees, torch.Tensor):
|
|
if angle_degrees.dim() > 0:
|
|
angle_degrees = angle_degrees[0] if angle_degrees.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
angle_degrees_val = float(angle_degrees.cpu().detach().item() if hasattr(angle_degrees, 'cpu') else angle_degrees)
|
|
else:
|
|
angle_degrees_val = float(angle_degrees)
|
|
|
|
steepness = trend_action.get('trend_steepness', torch.tensor(0.0, device=device))
|
|
if isinstance(steepness, torch.Tensor):
|
|
if steepness.dim() > 0:
|
|
steepness = steepness[0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
|
|
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
|
|
else:
|
|
steepness_val = float(steepness)
|
|
|
|
action_names = ['BUY', 'SELL', 'HOLD']
|
|
|
|
predictions['trend_based_action'] = {
|
|
'action': action_names[action_idx_val] if 0 <= action_idx_val < len(action_names) else 'HOLD',
|
|
'action_idx': action_idx_val,
|
|
'probabilities': {
|
|
'BUY': float(action_probs_list[0]) if len(action_probs_list) > 0 else 0.0,
|
|
'SELL': float(action_probs_list[1]) if len(action_probs_list) > 1 else 0.0,
|
|
'HOLD': float(action_probs_list[2]) if len(action_probs_list) > 2 else 0.0
|
|
},
|
|
'trend_angle_degrees': angle_degrees_val,
|
|
'trend_steepness': steepness_val
|
|
}
|
|
|
|
return predictions
|
|
|
|
class TradingTransformerTrainer:
|
|
"""Trainer for the advanced trading transformer"""
|
|
|
|
def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig):
|
|
self.model = model
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Move model to device
|
|
self.model.to(self.device)
|
|
logger.info(f"✅ Model moved to device: {self.device}")
|
|
|
|
# Log GPU info if available
|
|
if torch.cuda.is_available():
|
|
logger.info(f" GPU: {torch.cuda.get_device_name(0)}")
|
|
logger.info(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
|
|
|
# MEMORY OPTIMIZATION: Enable gradient checkpointing if configured
|
|
# This trades 20% compute for 30-40% memory savings
|
|
if config.use_gradient_checkpointing:
|
|
logger.info("Enabling gradient checkpointing for memory efficiency")
|
|
self._enable_gradient_checkpointing()
|
|
|
|
# Mixed precision training disabled - causes dtype mismatches
|
|
# Can be re-enabled if needed, but requires careful dtype management
|
|
self.use_amp = False
|
|
self.scaler = None
|
|
|
|
# GRADIENT ACCUMULATION: Track accumulation state
|
|
self.gradient_accumulation_steps = 0
|
|
self.current_accumulation_step = 0
|
|
|
|
# Optimizer with warmup
|
|
self.optimizer = optim.AdamW(
|
|
model.parameters(),
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
self.scheduler = optim.lr_scheduler.OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=config.learning_rate,
|
|
total_steps=10000, # Will be updated based on training data
|
|
pct_start=0.1
|
|
)
|
|
|
|
# Loss functions
|
|
self.action_criterion = nn.CrossEntropyLoss()
|
|
self.price_criterion = nn.MSELoss()
|
|
self.confidence_criterion = nn.BCELoss()
|
|
|
|
# Training history
|
|
self.training_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_accuracy': [],
|
|
'val_accuracy': [],
|
|
'learning_rates': []
|
|
}
|
|
|
|
def _enable_gradient_checkpointing(self):
|
|
"""Enable gradient checkpointing on transformer layers to save memory"""
|
|
try:
|
|
# Apply checkpointing to each transformer layer
|
|
for layer in self.model.layers:
|
|
if hasattr(layer, 'attention'):
|
|
# Wrap attention in checkpoint
|
|
original_forward = layer.attention.forward
|
|
|
|
def checkpointed_attention_forward(*args, **kwargs):
|
|
return checkpoint(
|
|
original_forward, *args, **kwargs, use_reentrant=False
|
|
)
|
|
|
|
layer.attention.forward = checkpointed_attention_forward
|
|
|
|
if hasattr(layer, 'feed_forward'):
|
|
# Wrap feed-forward in checkpoint
|
|
original_ff_forward = layer.feed_forward.forward
|
|
|
|
def checkpointed_ff_forward(*args, **kwargs):
|
|
return checkpoint(
|
|
original_ff_forward, *args, **kwargs, use_reentrant=False
|
|
)
|
|
|
|
layer.feed_forward.forward = checkpointed_ff_forward
|
|
|
|
logger.info("Gradient checkpointing enabled on all transformer layers")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to enable gradient checkpointing: {e}")
|
|
|
|
def set_gradient_accumulation_steps(self, steps: int):
|
|
"""
|
|
Set the number of gradient accumulation steps
|
|
|
|
Args:
|
|
steps: Number of batches to accumulate gradients over before optimizer step
|
|
For example, steps=5 means process 5 batches, then update weights
|
|
"""
|
|
self.gradient_accumulation_steps = steps
|
|
self.current_accumulation_step = 0
|
|
logger.info(f"Gradient accumulation enabled: {steps} steps")
|
|
|
|
def reset_gradient_accumulation(self):
|
|
"""Reset gradient accumulation counter"""
|
|
self.current_accumulation_step = 0
|
|
|
|
@staticmethod
|
|
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
|
"""
|
|
Denormalize price predictions back to real price space
|
|
|
|
Args:
|
|
normalized_values: Tensor of normalized values in [0, 1] range
|
|
norm_params: Dict with 'price_min' and 'price_max' keys
|
|
|
|
Returns:
|
|
Denormalized tensor in original price space
|
|
"""
|
|
price_min = norm_params.get('price_min', 0.0)
|
|
price_max = norm_params.get('price_max', 1.0)
|
|
|
|
if price_max > price_min:
|
|
return normalized_values * (price_max - price_min) + price_min
|
|
else:
|
|
return normalized_values
|
|
|
|
@staticmethod
|
|
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
|
"""
|
|
Denormalize a full OHLCV candle back to real values
|
|
|
|
Args:
|
|
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
|
|
norm_params: Dict with normalization parameters
|
|
|
|
Returns:
|
|
Denormalized OHLCV tensor
|
|
"""
|
|
denorm = normalized_candle.clone()
|
|
|
|
# Denormalize OHLC (first 4 values)
|
|
price_min = norm_params.get('price_min', 0.0)
|
|
price_max = norm_params.get('price_max', 1.0)
|
|
if price_max > price_min:
|
|
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
|
|
|
|
# Denormalize volume (5th value)
|
|
volume_min = norm_params.get('volume_min', 0.0)
|
|
volume_max = norm_params.get('volume_max', 1.0)
|
|
if volume_max > volume_min:
|
|
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
|
|
|
|
return denorm
|
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
|
"""Single training step with optional gradient accumulation
|
|
|
|
Args:
|
|
batch: Training batch
|
|
accumulate_gradients: If True, don't zero gradients or step optimizer (for gradient accumulation)
|
|
This is DEPRECATED - use gradient_accumulation_steps instead
|
|
|
|
Returns:
|
|
Dictionary with loss and accuracy metrics
|
|
"""
|
|
try:
|
|
self.model.train()
|
|
|
|
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
|
|
# If gradient_accumulation_steps is set, use automatic accumulation
|
|
# Otherwise, fall back to manual accumulate_gradients flag
|
|
if self.gradient_accumulation_steps > 0:
|
|
is_accumulation_step = (self.current_accumulation_step < self.gradient_accumulation_steps - 1)
|
|
self.current_accumulation_step += 1
|
|
|
|
# Reset counter after full accumulation cycle
|
|
if self.current_accumulation_step >= self.gradient_accumulation_steps:
|
|
self.current_accumulation_step = 0
|
|
else:
|
|
is_accumulation_step = accumulate_gradients
|
|
|
|
# Only zero gradients at the start of accumulation cycle
|
|
# Use set_to_none=True for better memory efficiency (saves ~5% memory)
|
|
if not is_accumulation_step or self.current_accumulation_step == 1:
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
# OPTIMIZATION: Only move batch to device if not already there
|
|
# Check if first tensor is already on correct device
|
|
needs_transfer = False
|
|
for v in batch.values():
|
|
if isinstance(v, torch.Tensor):
|
|
needs_transfer = (v.device != self.device)
|
|
break
|
|
|
|
if needs_transfer:
|
|
# Move batch to device - iterate over copy of keys to avoid modification during iteration
|
|
batch_gpu = {}
|
|
for k in list(batch.keys()): # Create list copy to avoid modification during iteration
|
|
v = batch[k]
|
|
if isinstance(v, torch.Tensor):
|
|
# Move to device (creates GPU copy)
|
|
batch_gpu[k] = v.to(self.device, non_blocking=True)
|
|
else:
|
|
batch_gpu[k] = v
|
|
|
|
# Replace batch with GPU version
|
|
batch = batch_gpu
|
|
del batch_gpu
|
|
# else: batch is already on GPU, use it directly!
|
|
|
|
# Use automatic mixed precision (FP16) for memory efficiency
|
|
# Support both CUDA and ROCm (AMD) devices
|
|
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
|
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
|
|
# Forward pass with multi-timeframe data
|
|
outputs = self.model(
|
|
price_data_1s=batch.get('price_data_1s'),
|
|
price_data_1m=batch.get('price_data_1m'),
|
|
price_data_1h=batch.get('price_data_1h'),
|
|
price_data_1d=batch.get('price_data_1d'),
|
|
btc_data_1m=batch.get('btc_data_1m'),
|
|
cob_data=batch.get('cob_data'), # Use .get() to handle missing key
|
|
tech_data=batch.get('tech_data'),
|
|
market_data=batch.get('market_data'),
|
|
position_state=batch.get('position_state'),
|
|
price_data=batch.get('price_data') # Legacy fallback
|
|
)
|
|
|
|
# Calculate losses
|
|
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
|
|
|
# FIXED: Ensure shapes match for MSELoss
|
|
price_pred = outputs['price_prediction']
|
|
price_target = batch['future_prices']
|
|
|
|
# Both should be [batch, 1], but ensure they match
|
|
if price_pred.shape != price_target.shape:
|
|
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
|
|
price_target = price_target.view(price_pred.shape)
|
|
|
|
price_loss = self.price_criterion(price_pred, price_target)
|
|
|
|
# NEW: Trend analysis loss (if trend_target provided)
|
|
trend_loss = torch.tensor(0.0, device=self.device)
|
|
if 'trend_target' in batch and 'trend_analysis' in outputs:
|
|
trend_pred = torch.cat([
|
|
outputs['trend_analysis']['angle_radians'],
|
|
outputs['trend_analysis']['steepness'],
|
|
outputs['trend_analysis']['direction']
|
|
], dim=1) # [batch, 3]
|
|
|
|
trend_target = batch['trend_target']
|
|
if trend_pred.shape == trend_target.shape:
|
|
trend_loss = self.price_criterion(trend_pred, trend_target)
|
|
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
|
|
|
|
# NEW: Next candle prediction loss for each timeframe
|
|
# This trains the model to predict full OHLCV for the next candle on each timeframe
|
|
candle_loss = torch.tensor(0.0, device=self.device)
|
|
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
|
|
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
|
|
|
|
if 'next_candles' in outputs:
|
|
timeframe_losses = []
|
|
|
|
# Get normalization parameters if available
|
|
# norm_params may be a dict or a list of dicts (one per sample in batch)
|
|
norm_params_raw = batch.get('norm_params', {})
|
|
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
|
|
# If it's a list, use the first one (batch size is typically 1)
|
|
norm_params = norm_params_raw[0]
|
|
else:
|
|
norm_params = norm_params_raw if isinstance(norm_params_raw, dict) else {}
|
|
|
|
# Calculate loss for each timeframe that has target data
|
|
for tf in ['1s', '1m', '1h', '1d']:
|
|
future_key = f'future_candle_{tf}'
|
|
|
|
if tf in outputs['next_candles'] and future_key in batch:
|
|
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
|
|
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
|
|
|
|
if target_candle is not None and pred_candle.shape == target_candle.shape:
|
|
# MSE loss on normalized values (used for backprop)
|
|
tf_loss = self.price_criterion(pred_candle, target_candle)
|
|
timeframe_losses.append(tf_loss)
|
|
candle_losses_detail[tf] = tf_loss.item()
|
|
|
|
# ALSO calculate denormalized loss for better interpretability
|
|
# Use RMSE (Root Mean Square Error) instead of MSE for realistic values
|
|
if tf in norm_params:
|
|
with torch.no_grad():
|
|
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
|
|
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
|
|
# Use RMSE instead of MSE to get interpretable dollar values
|
|
mse = torch.mean((pred_denorm - target_denorm) ** 2)
|
|
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
|
|
candle_losses_denorm[tf] = rmse.item()
|
|
|
|
# Average loss across available timeframes
|
|
if timeframe_losses:
|
|
candle_loss = torch.stack(timeframe_losses).mean()
|
|
if candle_losses_denorm:
|
|
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
|
|
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
|
|
|
|
# Start with base losses - avoid inplace operations on computation graph
|
|
# Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
|
|
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
|
|
|
|
# CRITICAL FIX: Scale loss for gradient accumulation
|
|
# This prevents gradient explosion when accumulating over multiple batches
|
|
# The loss is averaged over accumulation steps so gradients sum correctly
|
|
if self.gradient_accumulation_steps > 0:
|
|
total_loss = total_loss / self.gradient_accumulation_steps
|
|
elif accumulate_gradients:
|
|
# Legacy fallback - assume 5 steps if not specified
|
|
total_loss = total_loss / 5.0
|
|
|
|
# Add confidence loss if available
|
|
if 'confidence' in outputs and 'trade_success' in batch:
|
|
# Both tensors should have shape [batch_size, 1] for BCELoss
|
|
confidence_pred = outputs['confidence']
|
|
trade_target = batch['trade_success'].float()
|
|
|
|
# FIXED: Ensure both are 2D tensors [batch_size, 1]
|
|
# Handle different input shapes robustly
|
|
if confidence_pred.dim() == 0:
|
|
# Scalar -> [1, 1]
|
|
confidence_pred = confidence_pred.unsqueeze(0).unsqueeze(0)
|
|
elif confidence_pred.dim() == 1:
|
|
# [batch_size] -> [batch_size, 1]
|
|
confidence_pred = confidence_pred.unsqueeze(-1)
|
|
elif confidence_pred.dim() == 3:
|
|
# [batch_size, seq_len, 1] -> [batch_size, 1] (take last timestep)
|
|
confidence_pred = confidence_pred[:, -1, :]
|
|
|
|
if trade_target.dim() == 0:
|
|
# Scalar -> [1, 1]
|
|
trade_target = trade_target.unsqueeze(0).unsqueeze(0)
|
|
elif trade_target.dim() == 1:
|
|
# [batch_size] -> [batch_size, 1]
|
|
trade_target = trade_target.unsqueeze(-1)
|
|
|
|
# Ensure shapes match exactly - BCELoss requires exact match
|
|
if confidence_pred.shape != trade_target.shape:
|
|
# Reshape trade_target to match confidence_pred shape
|
|
trade_target = trade_target.view(confidence_pred.shape)
|
|
|
|
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
|
# Use addition instead of += to avoid inplace operation
|
|
total_loss = total_loss + 0.1 * confidence_loss
|
|
|
|
# Backward pass with mixed precision scaling
|
|
try:
|
|
if self.use_amp:
|
|
self.scaler.scale(total_loss).backward()
|
|
else:
|
|
total_loss.backward()
|
|
except RuntimeError as e:
|
|
if "inplace operation" in str(e):
|
|
logger.error(f"Inplace operation error during backward pass: {e}")
|
|
# Return zero loss to continue training
|
|
return {
|
|
'total_loss': 0.0,
|
|
'action_loss': 0.0,
|
|
'price_loss': 0.0,
|
|
'accuracy': 0.0,
|
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
|
}
|
|
else:
|
|
raise
|
|
|
|
# Only clip gradients and step optimizer at the end of accumulation cycle
|
|
if not is_accumulation_step:
|
|
if self.use_amp:
|
|
# Unscale gradients before clipping
|
|
self.scaler.unscale_(self.optimizer)
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
|
# Optimizer step with scaling
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
|
|
self.scheduler.step()
|
|
|
|
# Log gradient accumulation completion
|
|
if self.gradient_accumulation_steps > 0:
|
|
logger.debug(f"Gradient accumulation cycle complete ({self.gradient_accumulation_steps} steps)")
|
|
|
|
# Calculate accuracy without gradients
|
|
# PRIMARY: Next candle OHLCV prediction accuracy (realistic values)
|
|
with torch.no_grad():
|
|
candle_accuracy = 0.0
|
|
candle_rmse = {}
|
|
|
|
if 'next_candles' in outputs:
|
|
# Use 1m timeframe as primary metric
|
|
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
|
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
|
|
actual_candle = batch['future_candle_1m'] # [batch, 5]
|
|
|
|
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
|
|
# Calculate RMSE for each OHLCV component
|
|
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
|
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
|
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
|
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
|
|
|
# Average RMSE for OHLC (exclude volume)
|
|
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
|
|
|
# Convert to accuracy: lower RMSE = higher accuracy
|
|
# Normalize by price range
|
|
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
|
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
|
|
|
candle_rmse = {
|
|
'open': rmse_open.item(),
|
|
'high': rmse_high.item(),
|
|
'low': rmse_low.item(),
|
|
'close': rmse_close.item(),
|
|
'avg': avg_rmse.item()
|
|
}
|
|
|
|
# SECONDARY: Trend vector prediction accuracy
|
|
trend_accuracy = 0.0
|
|
if 'trend_analysis' in outputs and 'trend_target' in batch:
|
|
pred_angle = outputs['trend_analysis']['angle_radians']
|
|
pred_steepness = outputs['trend_analysis']['steepness']
|
|
|
|
actual_angle = batch['trend_target'][:, 0:1]
|
|
actual_steepness = batch['trend_target'][:, 1:2]
|
|
|
|
# Angle error (degrees)
|
|
angle_error_rad = torch.abs(pred_angle - actual_angle)
|
|
angle_error_deg = angle_error_rad * 180.0 / 3.14159
|
|
angle_accuracy = (1.0 - torch.clamp(angle_error_deg / 180.0, 0, 1)).mean()
|
|
|
|
# Steepness error (percentage)
|
|
steepness_error = torch.abs(pred_steepness - actual_steepness) / (actual_steepness + 1e-8)
|
|
steepness_accuracy = (1.0 - torch.clamp(steepness_error, 0, 1)).mean()
|
|
|
|
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
|
|
|
# LEGACY: Action accuracy (for comparison)
|
|
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
|
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
|
|
|
|
# Extract values and delete tensors to free memory
|
|
result = {
|
|
'total_loss': total_loss.item(),
|
|
'action_loss': action_loss.item(),
|
|
'price_loss': price_loss.item(),
|
|
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
|
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
|
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
|
|
|
# NEW: Realistic accuracy metrics based on next candle prediction
|
|
'accuracy': candle_accuracy, # PRIMARY: Next candle prediction accuracy
|
|
'candle_accuracy': candle_accuracy, # Same as accuracy
|
|
'candle_rmse': candle_rmse, # Detailed RMSE per OHLC component
|
|
'trend_accuracy': trend_accuracy, # Trend vector accuracy (angle + steepness)
|
|
'action_accuracy': action_accuracy, # Legacy action accuracy
|
|
|
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
|
}
|
|
|
|
# CRITICAL: Delete large tensors to free memory immediately
|
|
# This prevents memory accumulation across batches
|
|
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, action_predictions
|
|
|
|
# Delete batch tensors (GPU copies)
|
|
for key in list(batch.keys()):
|
|
if isinstance(batch[key], torch.Tensor):
|
|
del batch[key]
|
|
del batch
|
|
|
|
# Clear CUDA cache and log GPU memory usage
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
# Log GPU memory usage periodically (every 10 steps)
|
|
if not hasattr(self, '_step_counter'):
|
|
self._step_counter = 0
|
|
self._step_counter += 1
|
|
|
|
if self._step_counter % 10 == 0:
|
|
allocated = torch.cuda.memory_allocated() / 1024**2
|
|
reserved = torch.cuda.memory_reserved() / 1024**2
|
|
logger.debug(f"GPU Memory: {allocated:.1f}MB allocated, {reserved:.1f}MB reserved")
|
|
|
|
return result
|
|
|
|
except torch.cuda.OutOfMemoryError as oom_error:
|
|
logger.error(f"CUDA OOM in train_step: {oom_error}")
|
|
# Aggressive cleanup on OOM
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
# Reset optimizer state to prevent corruption
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
# Return zero loss to continue training
|
|
return {
|
|
'total_loss': 0.0,
|
|
'action_loss': 0.0,
|
|
'price_loss': 0.0,
|
|
'accuracy': 0.0,
|
|
'candle_accuracy': 0.0,
|
|
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in train_step: {e}", exc_info=True)
|
|
# Clear any partial computations
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
# Return a zero loss dict to prevent training from crashing
|
|
# but log the error so we can debug
|
|
return {
|
|
'total_loss': 0.0,
|
|
'action_loss': 0.0,
|
|
'price_loss': 0.0,
|
|
'accuracy': 0.0,
|
|
'candle_accuracy': 0.0,
|
|
'learning_rate': self.scheduler.get_last_lr()[0] if hasattr(self, 'scheduler') else 0.0
|
|
}
|
|
|
|
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
|
|
"""Validation step"""
|
|
self.model.eval()
|
|
total_loss = 0
|
|
total_accuracy = 0
|
|
num_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
outputs = self.model(
|
|
batch['price_data'],
|
|
batch['cob_data'],
|
|
batch['tech_data'],
|
|
batch['market_data']
|
|
)
|
|
|
|
# Calculate losses
|
|
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
|
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
|
total_loss += action_loss.item() + 0.1 * price_loss.item()
|
|
|
|
# Calculate accuracy
|
|
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
|
accuracy = (predictions == batch['actions']).float().mean()
|
|
total_accuracy += accuracy.item()
|
|
|
|
num_batches += 1
|
|
|
|
return {
|
|
'val_loss': total_loss / num_batches,
|
|
'val_accuracy': total_accuracy / num_batches
|
|
}
|
|
|
|
def train(self, train_loader: DataLoader, val_loader: DataLoader,
|
|
epochs: int, save_path: str = "NN/models/saved/"):
|
|
"""Full training loop"""
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(epochs):
|
|
# Training
|
|
epoch_losses = []
|
|
epoch_accuracies = []
|
|
|
|
for batch in train_loader:
|
|
metrics = self.train_step(batch)
|
|
epoch_losses.append(metrics['total_loss'])
|
|
epoch_accuracies.append(metrics['accuracy'])
|
|
|
|
# Validation
|
|
val_metrics = self.validate(val_loader)
|
|
|
|
# Update history
|
|
avg_train_loss = np.mean(epoch_losses)
|
|
avg_train_accuracy = np.mean(epoch_accuracies)
|
|
|
|
self.training_history['train_loss'].append(avg_train_loss)
|
|
self.training_history['val_loss'].append(val_metrics['val_loss'])
|
|
self.training_history['train_accuracy'].append(avg_train_accuracy)
|
|
self.training_history['val_accuracy'].append(val_metrics['val_accuracy'])
|
|
self.training_history['learning_rates'].append(self.scheduler.get_last_lr()[0])
|
|
|
|
# Logging
|
|
logger.info(f"Epoch {epoch+1}/{epochs}")
|
|
logger.info(f" Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}")
|
|
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_accuracy']:.4f}")
|
|
logger.info(f" LR: {self.scheduler.get_last_lr()[0]:.6f}")
|
|
|
|
# Save best model
|
|
if val_metrics['val_loss'] < best_val_loss:
|
|
best_val_loss = val_metrics['val_loss']
|
|
self.save_model(os.path.join(save_path, 'best_transformer_model.pt'))
|
|
logger.info(f" New best model saved (val_loss: {best_val_loss:.4f})")
|
|
|
|
def save_model(self, path: str):
|
|
"""Save model and training state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'config': self.config,
|
|
'training_history': self.training_history
|
|
}, path)
|
|
|
|
logger.info(f"Model saved to {path}")
|
|
|
|
def load_model(self, path: str):
|
|
"""Load model and training state"""
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
self.training_history = checkpoint.get('training_history', self.training_history)
|
|
|
|
logger.info(f"Model loaded from {path}")
|
|
|
|
def create_trading_transformer(config: Optional[TradingTransformerConfig] = None) -> Tuple[AdvancedTradingTransformer, TradingTransformerTrainer]:
|
|
"""Factory function to create trading transformer and trainer"""
|
|
if config is None:
|
|
config = TradingTransformerConfig()
|
|
|
|
model = AdvancedTradingTransformer(config)
|
|
trainer = TradingTransformerTrainer(model, config)
|
|
|
|
return model, trainer
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
# Create configuration
|
|
config = TradingTransformerConfig(
|
|
d_model=256,
|
|
n_heads=8,
|
|
n_layers=4,
|
|
seq_len=50,
|
|
n_actions=3,
|
|
use_multi_scale_attention=True,
|
|
use_market_regime_detection=True,
|
|
use_uncertainty_estimation=True
|
|
)
|
|
|
|
# Create model and trainer
|
|
model, trainer = create_trading_transformer(config)
|
|
|
|
logger.info(f"Created Advanced Trading Transformer with {sum(p.numel() for p in model.parameters())} parameters")
|
|
logger.info("Model is ready for training on real market data!") |