742 lines
28 KiB
Python
742 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Advanced Transformer Models for High-Frequency Trading
|
|
Optimized for COB data, technical indicators, and market microstructure
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
import numpy as np
|
|
import math
|
|
import logging
|
|
from typing import Dict, Any, Optional, Tuple, List
|
|
from dataclasses import dataclass
|
|
import os
|
|
import json
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TradingTransformerConfig:
|
|
"""Configuration for trading transformer models - SCALED TO 46M PARAMETERS"""
|
|
# Model architecture - SCALED UP
|
|
d_model: int = 1024 # Model dimension (2x increase)
|
|
n_heads: int = 16 # Number of attention heads (2x increase)
|
|
n_layers: int = 12 # Number of transformer layers (2x increase)
|
|
d_ff: int = 4096 # Feed-forward dimension (2x increase)
|
|
dropout: float = 0.1 # Dropout rate
|
|
|
|
# Input dimensions - ENHANCED
|
|
seq_len: int = 150 # Sequence length for time series (1.5x increase)
|
|
cob_features: int = 100 # COB feature dimension (2x increase)
|
|
tech_features: int = 40 # Technical indicator features (2x increase)
|
|
market_features: int = 30 # Market microstructure features (2x increase)
|
|
|
|
# Output configuration
|
|
n_actions: int = 3 # BUY, SELL, HOLD
|
|
confidence_output: bool = True # Output confidence scores
|
|
|
|
# Training configuration - OPTIMIZED FOR LARGER MODEL
|
|
learning_rate: float = 5e-5 # Reduced for larger model
|
|
weight_decay: float = 1e-4 # Increased regularization
|
|
warmup_steps: int = 8000 # More warmup steps
|
|
max_grad_norm: float = 0.5 # Tighter gradient clipping
|
|
|
|
# Advanced features - ENHANCED
|
|
use_relative_position: bool = True
|
|
use_multi_scale_attention: bool = True
|
|
use_market_regime_detection: bool = True
|
|
use_uncertainty_estimation: bool = True
|
|
|
|
# NEW: Additional scaling features
|
|
use_deep_attention: bool = True # Deeper attention mechanisms
|
|
use_residual_connections: bool = True # Enhanced residual connections
|
|
use_layer_norm_variants: bool = True # Advanced normalization
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""Sinusoidal positional encoding for transformer"""
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000):
|
|
super().__init__()
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
|
(-math.log(10000.0) / d_model))
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.pe[:x.size(0), :]
|
|
|
|
class RelativePositionalEncoding(nn.Module):
|
|
"""Relative positional encoding for better temporal understanding"""
|
|
|
|
def __init__(self, d_model: int, max_relative_position: int = 128):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.max_relative_position = max_relative_position
|
|
|
|
# Learnable relative position embeddings
|
|
self.relative_position_embeddings = nn.Embedding(
|
|
2 * max_relative_position + 1, d_model
|
|
)
|
|
|
|
def forward(self, seq_len: int) -> torch.Tensor:
|
|
"""Generate relative position encoding matrix"""
|
|
range_vec = torch.arange(seq_len)
|
|
range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
|
|
distance_mat = range_mat - range_mat.transpose(0, 1)
|
|
|
|
# Clip to max relative position
|
|
distance_mat_clipped = torch.clamp(
|
|
distance_mat, -self.max_relative_position, self.max_relative_position
|
|
)
|
|
|
|
# Shift to positive indices
|
|
final_mat = distance_mat_clipped + self.max_relative_position
|
|
|
|
return self.relative_position_embeddings(final_mat)
|
|
|
|
class DeepMultiScaleAttention(nn.Module):
|
|
"""Enhanced multi-scale attention with deeper mechanisms for 46M parameter model"""
|
|
|
|
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5, 7, 11, 15]):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
self.scales = scales
|
|
self.head_dim = d_model // n_heads
|
|
|
|
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
|
|
# Enhanced multi-scale projections with deeper architecture
|
|
self.scale_projections = nn.ModuleList([
|
|
nn.ModuleDict({
|
|
'query': nn.Sequential(
|
|
nn.Linear(d_model, d_model * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model * 2, d_model)
|
|
),
|
|
'key': nn.Sequential(
|
|
nn.Linear(d_model, d_model * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model * 2, d_model)
|
|
),
|
|
'value': nn.Sequential(
|
|
nn.Linear(d_model, d_model * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model * 2, d_model)
|
|
),
|
|
'conv': nn.Sequential(
|
|
nn.Conv1d(d_model, d_model * 2, kernel_size=scale,
|
|
padding=scale//2, groups=d_model),
|
|
nn.GELU(),
|
|
nn.Conv1d(d_model * 2, d_model, kernel_size=1)
|
|
)
|
|
}) for scale in scales
|
|
])
|
|
|
|
# Enhanced output projection with residual connection
|
|
self.output_projection = nn.Sequential(
|
|
nn.Linear(d_model * len(scales), d_model * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model * 2, d_model)
|
|
)
|
|
|
|
# Additional attention mechanisms
|
|
self.cross_scale_attention = nn.MultiheadAttention(
|
|
d_model, n_heads // 2, dropout=0.1, batch_first=True
|
|
)
|
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
batch_size, seq_len, _ = x.size()
|
|
scale_outputs = []
|
|
|
|
for scale_proj in self.scale_projections:
|
|
# Apply enhanced temporal convolution for this scale
|
|
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
|
|
|
|
# Enhanced attention computation with deeper projections
|
|
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
|
|
|
# Transpose for attention computation
|
|
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
|
|
K = K.transpose(1, 2)
|
|
V = V.transpose(1, 2)
|
|
|
|
# Scaled dot-product attention
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
if mask is not None:
|
|
scores.masked_fill_(mask == 0, -1e9)
|
|
|
|
attention = F.softmax(scores, dim=-1)
|
|
attention = self.dropout(attention)
|
|
|
|
output = torch.matmul(attention, V)
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
|
|
|
scale_outputs.append(output)
|
|
|
|
# Combine multi-scale outputs with enhanced projection
|
|
combined = torch.cat(scale_outputs, dim=-1)
|
|
output = self.output_projection(combined)
|
|
|
|
# Apply cross-scale attention for better integration
|
|
cross_attended, _ = self.cross_scale_attention(output, output, output, attn_mask=mask)
|
|
|
|
# Residual connection
|
|
return output + cross_attended
|
|
|
|
class MarketRegimeDetector(nn.Module):
|
|
"""Market regime detection module for adaptive behavior"""
|
|
|
|
def __init__(self, d_model: int, n_regimes: int = 4):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_regimes = n_regimes
|
|
|
|
# Regime classification layers
|
|
self.regime_classifier = nn.Sequential(
|
|
nn.Linear(d_model, d_model // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(d_model // 2, n_regimes)
|
|
)
|
|
|
|
# Regime-specific transformations
|
|
self.regime_transforms = nn.ModuleList([
|
|
nn.Linear(d_model, d_model) for _ in range(n_regimes)
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Global pooling for regime detection
|
|
pooled = torch.mean(x, dim=1) # (batch, d_model)
|
|
|
|
# Classify market regime
|
|
regime_logits = self.regime_classifier(pooled)
|
|
regime_probs = F.softmax(regime_logits, dim=-1)
|
|
|
|
# Apply regime-specific transformations
|
|
regime_outputs = []
|
|
for i, transform in enumerate(self.regime_transforms):
|
|
regime_output = transform(x) # (batch, seq_len, d_model)
|
|
regime_outputs.append(regime_output)
|
|
|
|
# Weighted combination based on regime probabilities
|
|
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
|
|
regime_weights = regime_probs.unsqueeze(1).unsqueeze(3) # (batch, 1, 1, n_regimes)
|
|
|
|
# Weighted sum across regimes
|
|
adapted_output = torch.sum(regime_stack * regime_weights.transpose(0, 3), dim=0)
|
|
|
|
return adapted_output, regime_probs
|
|
|
|
class UncertaintyEstimation(nn.Module):
|
|
"""Uncertainty estimation using Monte Carlo Dropout"""
|
|
|
|
def __init__(self, d_model: int, n_samples: int = 10):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.n_samples = n_samples
|
|
|
|
self.uncertainty_head = nn.Sequential(
|
|
nn.Linear(d_model, d_model // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5), # Higher dropout for uncertainty estimation
|
|
nn.Linear(d_model // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if training or not self.training:
|
|
# Single forward pass during training or when not in MC mode
|
|
uncertainty = self.uncertainty_head(x)
|
|
return uncertainty, uncertainty
|
|
|
|
# Monte Carlo sampling during inference
|
|
uncertainties = []
|
|
for _ in range(self.n_samples):
|
|
uncertainty = self.uncertainty_head(x)
|
|
uncertainties.append(uncertainty)
|
|
|
|
uncertainties = torch.stack(uncertainties, dim=0)
|
|
mean_uncertainty = torch.mean(uncertainties, dim=0)
|
|
std_uncertainty = torch.std(uncertainties, dim=0)
|
|
|
|
return mean_uncertainty, std_uncertainty
|
|
|
|
class TradingTransformerLayer(nn.Module):
|
|
"""Enhanced transformer layer for trading applications"""
|
|
|
|
def __init__(self, config: TradingTransformerConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Enhanced multi-scale attention or standard attention
|
|
if config.use_multi_scale_attention:
|
|
self.attention = DeepMultiScaleAttention(config.d_model, config.n_heads)
|
|
else:
|
|
self.attention = nn.MultiheadAttention(
|
|
config.d_model, config.n_heads, dropout=config.dropout, batch_first=True
|
|
)
|
|
|
|
# Feed-forward network
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_ff),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_ff, config.d_model)
|
|
)
|
|
|
|
# Layer normalization
|
|
self.norm1 = nn.LayerNorm(config.d_model)
|
|
self.norm2 = nn.LayerNorm(config.d_model)
|
|
|
|
# Dropout
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
# Market regime detection
|
|
if config.use_market_regime_detection:
|
|
self.regime_detector = MarketRegimeDetector(config.d_model)
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
# Self-attention with residual connection
|
|
if isinstance(self.attention, DeepMultiScaleAttention):
|
|
attn_output = self.attention(x, mask)
|
|
else:
|
|
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
|
|
|
x = self.norm1(x + self.dropout(attn_output))
|
|
|
|
# Market regime adaptation
|
|
regime_probs = None
|
|
if hasattr(self, 'regime_detector'):
|
|
x, regime_probs = self.regime_detector(x)
|
|
|
|
# Feed-forward with residual connection
|
|
ff_output = self.feed_forward(x)
|
|
x = self.norm2(x + self.dropout(ff_output))
|
|
|
|
return {
|
|
'output': x,
|
|
'regime_probs': regime_probs
|
|
}
|
|
|
|
class AdvancedTradingTransformer(nn.Module):
|
|
"""Advanced transformer model for high-frequency trading"""
|
|
|
|
def __init__(self, config: TradingTransformerConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Input projections
|
|
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
|
|
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
|
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
|
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
|
|
|
# Positional encoding
|
|
if config.use_relative_position:
|
|
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
|
else:
|
|
self.pos_encoding = PositionalEncoding(config.d_model, config.seq_len)
|
|
|
|
# Transformer layers
|
|
self.layers = nn.ModuleList([
|
|
TradingTransformerLayer(config) for _ in range(config.n_layers)
|
|
])
|
|
|
|
# Enhanced output heads for 46M parameter model
|
|
self.action_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.n_actions)
|
|
)
|
|
|
|
if config.confidence_output:
|
|
self.confidence_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 4, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Enhanced uncertainty estimation
|
|
if config.use_uncertainty_estimation:
|
|
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
|
|
|
# Enhanced price prediction head (auxiliary task)
|
|
self.price_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 4, 1)
|
|
)
|
|
|
|
# Additional specialized heads for 46M model
|
|
self.volatility_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, 1),
|
|
nn.Softplus()
|
|
)
|
|
|
|
self.trend_strength_head = nn.Sequential(
|
|
nn.Linear(config.d_model, config.d_model // 2),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
nn.Linear(config.d_model // 2, 1),
|
|
nn.Tanh()
|
|
)
|
|
|
|
# Initialize weights
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
"""Initialize model weights"""
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
nn.init.ones_(module.weight)
|
|
nn.init.zeros_(module.bias)
|
|
|
|
def forward(self, price_data: torch.Tensor, cob_data: torch.Tensor,
|
|
tech_data: torch.Tensor, market_data: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass of the trading transformer
|
|
|
|
Args:
|
|
price_data: (batch, seq_len, 5) - OHLCV data
|
|
cob_data: (batch, seq_len, cob_features) - COB features
|
|
tech_data: (batch, seq_len, tech_features) - Technical indicators
|
|
market_data: (batch, seq_len, market_features) - Market microstructure
|
|
mask: Optional attention mask
|
|
|
|
Returns:
|
|
Dictionary containing model outputs
|
|
"""
|
|
batch_size, seq_len = price_data.shape[:2]
|
|
|
|
# Project inputs to model dimension
|
|
price_emb = self.price_projection(price_data)
|
|
cob_emb = self.cob_projection(cob_data)
|
|
tech_emb = self.tech_projection(tech_data)
|
|
market_emb = self.market_projection(market_data)
|
|
|
|
# Combine embeddings (could also use cross-attention)
|
|
x = price_emb + cob_emb + tech_emb + market_emb
|
|
|
|
# Add positional encoding
|
|
if isinstance(self.pos_encoding, RelativePositionalEncoding):
|
|
# Relative position encoding is applied in attention
|
|
pass
|
|
else:
|
|
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
|
|
|
# Apply transformer layers
|
|
regime_probs_history = []
|
|
for layer in self.layers:
|
|
layer_output = layer(x, mask)
|
|
x = layer_output['output']
|
|
if layer_output['regime_probs'] is not None:
|
|
regime_probs_history.append(layer_output['regime_probs'])
|
|
|
|
# Global pooling for final prediction
|
|
# Use attention-based pooling
|
|
pooling_weights = F.softmax(
|
|
torch.sum(x, dim=-1, keepdim=True), dim=1
|
|
)
|
|
pooled = torch.sum(x * pooling_weights, dim=1)
|
|
|
|
# Generate outputs
|
|
outputs = {}
|
|
|
|
# Action prediction
|
|
action_logits = self.action_head(pooled)
|
|
outputs['action_logits'] = action_logits
|
|
outputs['action_probs'] = F.softmax(action_logits, dim=-1)
|
|
|
|
# Confidence prediction
|
|
if self.config.confidence_output:
|
|
confidence = self.confidence_head(pooled)
|
|
outputs['confidence'] = confidence
|
|
|
|
# Uncertainty estimation
|
|
if self.config.use_uncertainty_estimation:
|
|
uncertainty_mean, uncertainty_std = self.uncertainty_estimator(pooled)
|
|
outputs['uncertainty_mean'] = uncertainty_mean
|
|
outputs['uncertainty_std'] = uncertainty_std
|
|
|
|
# Enhanced price prediction (auxiliary task)
|
|
price_pred = self.price_head(pooled)
|
|
outputs['price_prediction'] = price_pred
|
|
|
|
# Additional specialized predictions for 46M model
|
|
volatility_pred = self.volatility_head(pooled)
|
|
outputs['volatility_prediction'] = volatility_pred
|
|
|
|
trend_strength_pred = self.trend_strength_head(pooled)
|
|
outputs['trend_strength_prediction'] = trend_strength_pred
|
|
|
|
# Market regime information
|
|
if regime_probs_history:
|
|
outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1)
|
|
|
|
return outputs
|
|
|
|
class TradingTransformerTrainer:
|
|
"""Trainer for the advanced trading transformer"""
|
|
|
|
def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig):
|
|
self.model = model
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Move model to device
|
|
self.model.to(self.device)
|
|
|
|
# Optimizer with warmup
|
|
self.optimizer = optim.AdamW(
|
|
model.parameters(),
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
self.scheduler = optim.lr_scheduler.OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=config.learning_rate,
|
|
total_steps=10000, # Will be updated based on training data
|
|
pct_start=0.1
|
|
)
|
|
|
|
# Loss functions
|
|
self.action_criterion = nn.CrossEntropyLoss()
|
|
self.price_criterion = nn.MSELoss()
|
|
self.confidence_criterion = nn.BCELoss()
|
|
|
|
# Training history
|
|
self.training_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_accuracy': [],
|
|
'val_accuracy': [],
|
|
'learning_rates': []
|
|
}
|
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
"""Single training step"""
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
|
|
# Move batch to device
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
# Forward pass
|
|
outputs = self.model(
|
|
batch['price_data'],
|
|
batch['cob_data'],
|
|
batch['tech_data'],
|
|
batch['market_data']
|
|
)
|
|
|
|
# Calculate losses
|
|
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
|
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
|
|
|
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
|
|
|
# Add confidence loss if available
|
|
if 'confidence' in outputs and 'trade_success' in batch:
|
|
confidence_loss = self.confidence_criterion(
|
|
outputs['confidence'].squeeze(),
|
|
batch['trade_success'].float()
|
|
)
|
|
total_loss += 0.1 * confidence_loss
|
|
|
|
# Backward pass
|
|
total_loss.backward()
|
|
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
|
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
# Calculate accuracy
|
|
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
|
accuracy = (predictions == batch['actions']).float().mean()
|
|
|
|
return {
|
|
'total_loss': total_loss.item(),
|
|
'action_loss': action_loss.item(),
|
|
'price_loss': price_loss.item(),
|
|
'accuracy': accuracy.item(),
|
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
|
}
|
|
|
|
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
|
|
"""Validation step"""
|
|
self.model.eval()
|
|
total_loss = 0
|
|
total_accuracy = 0
|
|
num_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
outputs = self.model(
|
|
batch['price_data'],
|
|
batch['cob_data'],
|
|
batch['tech_data'],
|
|
batch['market_data']
|
|
)
|
|
|
|
# Calculate losses
|
|
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
|
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
|
total_loss += action_loss.item() + 0.1 * price_loss.item()
|
|
|
|
# Calculate accuracy
|
|
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
|
accuracy = (predictions == batch['actions']).float().mean()
|
|
total_accuracy += accuracy.item()
|
|
|
|
num_batches += 1
|
|
|
|
return {
|
|
'val_loss': total_loss / num_batches,
|
|
'val_accuracy': total_accuracy / num_batches
|
|
}
|
|
|
|
def train(self, train_loader: DataLoader, val_loader: DataLoader,
|
|
epochs: int, save_path: str = "NN/models/saved/"):
|
|
"""Full training loop"""
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(epochs):
|
|
# Training
|
|
epoch_losses = []
|
|
epoch_accuracies = []
|
|
|
|
for batch in train_loader:
|
|
metrics = self.train_step(batch)
|
|
epoch_losses.append(metrics['total_loss'])
|
|
epoch_accuracies.append(metrics['accuracy'])
|
|
|
|
# Validation
|
|
val_metrics = self.validate(val_loader)
|
|
|
|
# Update history
|
|
avg_train_loss = np.mean(epoch_losses)
|
|
avg_train_accuracy = np.mean(epoch_accuracies)
|
|
|
|
self.training_history['train_loss'].append(avg_train_loss)
|
|
self.training_history['val_loss'].append(val_metrics['val_loss'])
|
|
self.training_history['train_accuracy'].append(avg_train_accuracy)
|
|
self.training_history['val_accuracy'].append(val_metrics['val_accuracy'])
|
|
self.training_history['learning_rates'].append(self.scheduler.get_last_lr()[0])
|
|
|
|
# Logging
|
|
logger.info(f"Epoch {epoch+1}/{epochs}")
|
|
logger.info(f" Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}")
|
|
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_accuracy']:.4f}")
|
|
logger.info(f" LR: {self.scheduler.get_last_lr()[0]:.6f}")
|
|
|
|
# Save best model
|
|
if val_metrics['val_loss'] < best_val_loss:
|
|
best_val_loss = val_metrics['val_loss']
|
|
self.save_model(os.path.join(save_path, 'best_transformer_model.pt'))
|
|
logger.info(f" New best model saved (val_loss: {best_val_loss:.4f})")
|
|
|
|
def save_model(self, path: str):
|
|
"""Save model and training state"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'config': self.config,
|
|
'training_history': self.training_history
|
|
}, path)
|
|
|
|
logger.info(f"Model saved to {path}")
|
|
|
|
def load_model(self, path: str):
|
|
"""Load model and training state"""
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
self.training_history = checkpoint.get('training_history', self.training_history)
|
|
|
|
logger.info(f"Model loaded from {path}")
|
|
|
|
def create_trading_transformer(config: Optional[TradingTransformerConfig] = None) -> Tuple[AdvancedTradingTransformer, TradingTransformerTrainer]:
|
|
"""Factory function to create trading transformer and trainer"""
|
|
if config is None:
|
|
config = TradingTransformerConfig()
|
|
|
|
model = AdvancedTradingTransformer(config)
|
|
trainer = TradingTransformerTrainer(model, config)
|
|
|
|
return model, trainer
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
# Create configuration
|
|
config = TradingTransformerConfig(
|
|
d_model=256,
|
|
n_heads=8,
|
|
n_layers=4,
|
|
seq_len=50,
|
|
n_actions=3,
|
|
use_multi_scale_attention=True,
|
|
use_market_regime_detection=True,
|
|
use_uncertainty_estimation=True
|
|
)
|
|
|
|
# Create model and trainer
|
|
model, trainer = create_trading_transformer(config)
|
|
|
|
logger.info(f"Created Advanced Trading Transformer with {sum(p.numel() for p in model.parameters())} parameters")
|
|
logger.info("Model is ready for training on real market data!") |