reduce T model size to fit in GPU during training.
test model size
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import numpy as np
|
||||
import math
|
||||
import logging
|
||||
@@ -23,15 +24,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TradingTransformerConfig:
|
||||
"""Configuration for trading transformer models - WITH PROPER MEMORY MANAGEMENT"""
|
||||
# Model architecture - RESTORED to original size (memory leak fixed)
|
||||
d_model: int = 1024 # Model dimension
|
||||
n_heads: int = 16 # Number of attention heads
|
||||
n_layers: int = 12 # Number of transformer layers
|
||||
d_ff: int = 4096 # Feed-forward dimension
|
||||
"""Configuration for trading transformer models - OPTIMIZED FOR GPU (8-12M params)"""
|
||||
# Model architecture - REDUCED for efficient GPU training
|
||||
d_model: int = 256 # Model dimension (was 1024)
|
||||
n_heads: int = 8 # Number of attention heads (was 16)
|
||||
n_layers: int = 4 # Number of transformer layers (was 12)
|
||||
d_ff: int = 1024 # Feed-forward dimension (was 4096)
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
|
||||
# Input dimensions - RESTORED
|
||||
# Input dimensions - OPTIMIZED
|
||||
seq_len: int = 200 # Sequence length for time series
|
||||
cob_features: int = 100 # COB feature dimension
|
||||
tech_features: int = 40 # Technical indicator features
|
||||
@@ -111,59 +112,30 @@ class RelativePositionalEncoding(nn.Module):
|
||||
return self.relative_position_embeddings(final_mat)
|
||||
|
||||
class DeepMultiScaleAttention(nn.Module):
|
||||
"""Enhanced multi-scale attention with deeper mechanisms for 46M parameter model"""
|
||||
"""Lightweight multi-scale attention optimized for 8-12M parameter model"""
|
||||
|
||||
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5, 7, 11, 15]):
|
||||
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5]):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.scales = scales
|
||||
self.scales = scales # Reduced from 6 scales to 3
|
||||
self.head_dim = d_model // n_heads
|
||||
|
||||
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
||||
|
||||
# Enhanced multi-scale projections with deeper architecture
|
||||
# Lightweight multi-scale projections (single layer instead of deep)
|
||||
self.scale_projections = nn.ModuleList([
|
||||
nn.ModuleDict({
|
||||
'query': nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(d_model * 2, d_model)
|
||||
),
|
||||
'key': nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(d_model * 2, d_model)
|
||||
),
|
||||
'value': nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(d_model * 2, d_model)
|
||||
),
|
||||
'conv': nn.Sequential(
|
||||
nn.Conv1d(d_model, d_model * 2, kernel_size=scale,
|
||||
padding=scale//2, groups=d_model),
|
||||
nn.GELU(),
|
||||
nn.Conv1d(d_model * 2, d_model, kernel_size=1)
|
||||
)
|
||||
'query': nn.Linear(d_model, d_model),
|
||||
'key': nn.Linear(d_model, d_model),
|
||||
'value': nn.Linear(d_model, d_model),
|
||||
'conv': nn.Conv1d(d_model, d_model, kernel_size=scale,
|
||||
padding=scale//2, groups=d_model//4)
|
||||
}) for scale in scales
|
||||
])
|
||||
|
||||
# Enhanced output projection with residual connection
|
||||
self.output_projection = nn.Sequential(
|
||||
nn.Linear(d_model * len(scales), d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(d_model * 2, d_model)
|
||||
)
|
||||
|
||||
# Additional attention mechanisms
|
||||
self.cross_scale_attention = nn.MultiheadAttention(
|
||||
d_model, n_heads // 2, dropout=0.1, batch_first=True
|
||||
)
|
||||
# Lightweight output projection
|
||||
self.output_projection = nn.Linear(d_model * len(scales), d_model)
|
||||
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
@@ -199,15 +171,11 @@ class DeepMultiScaleAttention(nn.Module):
|
||||
|
||||
scale_outputs.append(output)
|
||||
|
||||
# Combine multi-scale outputs with enhanced projection
|
||||
# Combine multi-scale outputs
|
||||
combined = torch.cat(scale_outputs, dim=-1)
|
||||
output = self.output_projection(combined)
|
||||
|
||||
# Apply cross-scale attention for better integration
|
||||
cross_attended, _ = self.cross_scale_attention(output, output, output, attn_mask=mask)
|
||||
|
||||
# Residual connection
|
||||
return output + cross_attended
|
||||
return output
|
||||
|
||||
class MarketRegimeDetector(nn.Module):
|
||||
"""Market regime detection module for adaptive behavior"""
|
||||
@@ -358,35 +326,29 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
|
||||
# SERIAL: Shared pattern encoder (learns candle patterns ONCE for all timeframes)
|
||||
# This is applied to each timeframe independently but uses SAME weights
|
||||
# RESTORED: Original dimensions (memory leak fixed)
|
||||
# LIGHTWEIGHT: 2-layer encoder for efficiency
|
||||
self.shared_pattern_encoder = nn.Sequential(
|
||||
nn.Linear(5, config.d_model // 4), # 5 OHLCV -> 256
|
||||
nn.LayerNorm(config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, config.d_model // 2), # 256 -> 512
|
||||
nn.Linear(5, config.d_model // 2), # 5 OHLCV -> 128
|
||||
nn.LayerNorm(config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model) # 512 -> 1024
|
||||
nn.Linear(config.d_model // 2, config.d_model) # 128 -> 256
|
||||
)
|
||||
|
||||
# Timeframe-specific embeddings (learnable, added to shared encoding)
|
||||
# These help the model distinguish which timeframe it's looking at
|
||||
self.timeframe_embeddings = nn.Embedding(self.num_timeframes, config.d_model)
|
||||
|
||||
# PARALLEL: Cross-timeframe attention layers
|
||||
# These process all timeframes simultaneously to capture dependencies
|
||||
self.cross_timeframe_layers = nn.ModuleList([
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=config.d_model,
|
||||
nhead=config.n_heads,
|
||||
dim_feedforward=config.d_ff,
|
||||
dropout=config.dropout,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
) for _ in range(2) # 2 layers for cross-timeframe attention
|
||||
])
|
||||
# PARALLEL: Cross-timeframe attention layer (single layer for efficiency)
|
||||
# Processes all timeframes simultaneously to capture dependencies
|
||||
self.cross_timeframe_layer = nn.TransformerEncoderLayer(
|
||||
d_model=config.d_model,
|
||||
nhead=config.n_heads,
|
||||
dim_feedforward=config.d_ff,
|
||||
dropout=config.dropout,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Other input projections
|
||||
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
||||
@@ -415,11 +377,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
TradingTransformerLayer(config) for _ in range(config.n_layers)
|
||||
])
|
||||
|
||||
# Enhanced output heads for 46M parameter model
|
||||
# Lightweight output heads for 8-12M parameter model
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
@@ -431,10 +390,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Linear(config.d_model // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
@@ -442,92 +398,63 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
if config.use_uncertainty_estimation:
|
||||
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
||||
|
||||
# Enhanced price prediction head (auxiliary task)
|
||||
# Predicts price change ratio (future_price - current_price) / current_price
|
||||
# Use Tanh to constrain to [-1, 1] range (max 100% change up/down)
|
||||
# Lightweight price prediction head
|
||||
self.price_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Linear(config.d_model // 2, 1),
|
||||
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
|
||||
)
|
||||
|
||||
# Additional specialized heads for 46M model
|
||||
# Lightweight volatility and trend heads
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.Linear(config.d_model, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, 1),
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Softplus()
|
||||
)
|
||||
|
||||
self.trend_strength_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.Linear(config.d_model, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, 1),
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||
# Note: self.timeframes already defined above in input projections
|
||||
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||
# Lightweight next candle OHLCV prediction heads
|
||||
self.next_candle_heads = nn.ModuleDict({
|
||||
tf: nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5), # OHLCV: [open, high, low, close, volume]
|
||||
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||
nn.Linear(config.d_model // 2, 5), # OHLCV
|
||||
nn.Sigmoid() # Constrain to [0, 1]
|
||||
) for tf in self.timeframes
|
||||
})
|
||||
|
||||
# BTC next candle prediction head
|
||||
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||
self.btc_next_candle_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5), # OHLCV for BTC
|
||||
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||
nn.Linear(config.d_model // 2, 5), # OHLCV for BTC
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
|
||||
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity
|
||||
self.pivot_levels = [1, 2, 3, 4, 5] # L1 to L5
|
||||
# Lightweight pivot point prediction heads (L1-L3 only for efficiency)
|
||||
self.pivot_levels = [1, 2, 3] # Reduced from L1-L5 to L1-L3
|
||||
self.pivot_heads = nn.ModuleDict({
|
||||
f'L{level}': nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 4) # [price, type_prob_high, type_prob_low, confidence]
|
||||
nn.Linear(config.d_model // 2, 4) # [price, type_prob_high, type_prob_low, confidence]
|
||||
) for level in self.pivot_levels
|
||||
})
|
||||
|
||||
# NEW: Trend vector analysis head (calculates trend from pivot predictions)
|
||||
# Lightweight trend vector analysis head
|
||||
self.trend_analysis_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 3) # [angle_radians, steepness, direction]
|
||||
nn.Linear(config.d_model // 2, 3) # [angle_radians, steepness, direction]
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
@@ -654,9 +581,8 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
# This avoids creating huge concatenated sequences while still processing efficiently
|
||||
batched_tfs = stacked_tfs.reshape(batch_size * num_tfs, seq_len, self.config.d_model)
|
||||
|
||||
# Apply attention layers (shared across timeframes)
|
||||
for layer in self.cross_timeframe_layers:
|
||||
batched_tfs = layer(batched_tfs)
|
||||
# Apply single cross-timeframe attention layer
|
||||
batched_tfs = self.cross_timeframe_layer(batched_tfs)
|
||||
|
||||
# Reshape back: [batch*num_tfs, seq_len, d_model] -> [batch, num_tfs, seq_len, d_model]
|
||||
cross_tf_output = batched_tfs.reshape(batch_size, num_tfs, seq_len, self.config.d_model)
|
||||
@@ -723,7 +649,7 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
if self.training and self.config.use_gradient_checkpointing:
|
||||
# Use gradient checkpointing to save memory during training
|
||||
# Trades compute for memory (recomputes activations during backward pass)
|
||||
layer_output = torch.utils.checkpoint.checkpoint(
|
||||
layer_output = checkpoint(
|
||||
layer, x, mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
@@ -1180,7 +1106,7 @@ class TradingTransformerTrainer:
|
||||
original_forward = layer.attention.forward
|
||||
|
||||
def checkpointed_attention_forward(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
return checkpoint(
|
||||
original_forward, *args, **kwargs, use_reentrant=False
|
||||
)
|
||||
|
||||
@@ -1191,7 +1117,7 @@ class TradingTransformerTrainer:
|
||||
original_ff_forward = layer.feed_forward.forward
|
||||
|
||||
def checkpointed_ff_forward(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
return checkpoint(
|
||||
original_ff_forward, *args, **kwargs, use_reentrant=False
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user