added transfformer model to the mix
This commit is contained in:
667
NN/models/advanced_transformer_trading.py
Normal file
667
NN/models/advanced_transformer_trading.py
Normal file
@ -0,0 +1,667 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Advanced Transformer Models for High-Frequency Trading
|
||||
Optimized for COB data, technical indicators, and market microstructure
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import numpy as np
|
||||
import math
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TradingTransformerConfig:
|
||||
"""Configuration for trading transformer models"""
|
||||
# Model architecture
|
||||
d_model: int = 512 # Model dimension
|
||||
n_heads: int = 8 # Number of attention heads
|
||||
n_layers: int = 6 # Number of transformer layers
|
||||
d_ff: int = 2048 # Feed-forward dimension
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
|
||||
# Input dimensions
|
||||
seq_len: int = 100 # Sequence length for time series
|
||||
cob_features: int = 50 # COB feature dimension
|
||||
tech_features: int = 20 # Technical indicator features
|
||||
market_features: int = 15 # Market microstructure features
|
||||
|
||||
# Output configuration
|
||||
n_actions: int = 3 # BUY, SELL, HOLD
|
||||
confidence_output: bool = True # Output confidence scores
|
||||
|
||||
# Training configuration
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 1e-5
|
||||
warmup_steps: int = 4000
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
# Advanced features
|
||||
use_relative_position: bool = True
|
||||
use_multi_scale_attention: bool = True
|
||||
use_market_regime_detection: bool = True
|
||||
use_uncertainty_estimation: bool = True
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Sinusoidal positional encoding for transformer"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5000):
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
||||
(-math.log(10000.0) / d_model))
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.pe[:x.size(0), :]
|
||||
|
||||
class RelativePositionalEncoding(nn.Module):
|
||||
"""Relative positional encoding for better temporal understanding"""
|
||||
|
||||
def __init__(self, d_model: int, max_relative_position: int = 128):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_relative_position = max_relative_position
|
||||
|
||||
# Learnable relative position embeddings
|
||||
self.relative_position_embeddings = nn.Embedding(
|
||||
2 * max_relative_position + 1, d_model
|
||||
)
|
||||
|
||||
def forward(self, seq_len: int) -> torch.Tensor:
|
||||
"""Generate relative position encoding matrix"""
|
||||
range_vec = torch.arange(seq_len)
|
||||
range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)
|
||||
distance_mat = range_mat - range_mat.transpose(0, 1)
|
||||
|
||||
# Clip to max relative position
|
||||
distance_mat_clipped = torch.clamp(
|
||||
distance_mat, -self.max_relative_position, self.max_relative_position
|
||||
)
|
||||
|
||||
# Shift to positive indices
|
||||
final_mat = distance_mat_clipped + self.max_relative_position
|
||||
|
||||
return self.relative_position_embeddings(final_mat)
|
||||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
"""Multi-scale attention for capturing different time horizons"""
|
||||
|
||||
def __init__(self, d_model: int, n_heads: int, scales: List[int] = [1, 3, 5, 7]):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.scales = scales
|
||||
self.head_dim = d_model // n_heads
|
||||
|
||||
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
||||
|
||||
# Multi-scale projections
|
||||
self.scale_projections = nn.ModuleList([
|
||||
nn.ModuleDict({
|
||||
'query': nn.Linear(d_model, d_model),
|
||||
'key': nn.Linear(d_model, d_model),
|
||||
'value': nn.Linear(d_model, d_model),
|
||||
'conv': nn.Conv1d(d_model, d_model, kernel_size=scale,
|
||||
padding=scale//2, groups=d_model)
|
||||
}) for scale in scales
|
||||
])
|
||||
|
||||
self.output_projection = nn.Linear(d_model * len(scales), d_model)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = x.size()
|
||||
scale_outputs = []
|
||||
|
||||
for scale_proj in self.scale_projections:
|
||||
# Apply temporal convolution for this scale
|
||||
x_conv = scale_proj['conv'](x.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
# Standard attention computation
|
||||
Q = scale_proj['query'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
K = scale_proj['key'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
V = scale_proj['value'](x_conv).view(batch_size, seq_len, self.n_heads, self.head_dim)
|
||||
|
||||
# Transpose for attention computation
|
||||
Q = Q.transpose(1, 2) # (batch, n_heads, seq_len, head_dim)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
||||
# Scaled dot-product attention
|
||||
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if mask is not None:
|
||||
scores.masked_fill_(mask == 0, -1e9)
|
||||
|
||||
attention = F.softmax(scores, dim=-1)
|
||||
attention = self.dropout(attention)
|
||||
|
||||
output = torch.matmul(attention, V)
|
||||
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||
|
||||
scale_outputs.append(output)
|
||||
|
||||
# Combine multi-scale outputs
|
||||
combined = torch.cat(scale_outputs, dim=-1)
|
||||
return self.output_projection(combined)
|
||||
|
||||
class MarketRegimeDetector(nn.Module):
|
||||
"""Market regime detection module for adaptive behavior"""
|
||||
|
||||
def __init__(self, d_model: int, n_regimes: int = 4):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.n_regimes = n_regimes
|
||||
|
||||
# Regime classification layers
|
||||
self.regime_classifier = nn.Sequential(
|
||||
nn.Linear(d_model, d_model // 2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(d_model // 2, n_regimes)
|
||||
)
|
||||
|
||||
# Regime-specific transformations
|
||||
self.regime_transforms = nn.ModuleList([
|
||||
nn.Linear(d_model, d_model) for _ in range(n_regimes)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Global pooling for regime detection
|
||||
pooled = torch.mean(x, dim=1) # (batch, d_model)
|
||||
|
||||
# Classify market regime
|
||||
regime_logits = self.regime_classifier(pooled)
|
||||
regime_probs = F.softmax(regime_logits, dim=-1)
|
||||
|
||||
# Apply regime-specific transformations
|
||||
regime_outputs = []
|
||||
for i, transform in enumerate(self.regime_transforms):
|
||||
regime_output = transform(x) # (batch, seq_len, d_model)
|
||||
regime_outputs.append(regime_output)
|
||||
|
||||
# Weighted combination based on regime probabilities
|
||||
regime_stack = torch.stack(regime_outputs, dim=0) # (n_regimes, batch, seq_len, d_model)
|
||||
regime_weights = regime_probs.unsqueeze(1).unsqueeze(3) # (batch, 1, 1, n_regimes)
|
||||
|
||||
# Weighted sum across regimes
|
||||
adapted_output = torch.sum(regime_stack * regime_weights.transpose(0, 3), dim=0)
|
||||
|
||||
return adapted_output, regime_probs
|
||||
|
||||
class UncertaintyEstimation(nn.Module):
|
||||
"""Uncertainty estimation using Monte Carlo Dropout"""
|
||||
|
||||
def __init__(self, d_model: int, n_samples: int = 10):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.n_samples = n_samples
|
||||
|
||||
self.uncertainty_head = nn.Sequential(
|
||||
nn.Linear(d_model, d_model // 2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5), # Higher dropout for uncertainty estimation
|
||||
nn.Linear(d_model // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if training or not self.training:
|
||||
# Single forward pass during training or when not in MC mode
|
||||
uncertainty = self.uncertainty_head(x)
|
||||
return uncertainty, uncertainty
|
||||
|
||||
# Monte Carlo sampling during inference
|
||||
uncertainties = []
|
||||
for _ in range(self.n_samples):
|
||||
uncertainty = self.uncertainty_head(x)
|
||||
uncertainties.append(uncertainty)
|
||||
|
||||
uncertainties = torch.stack(uncertainties, dim=0)
|
||||
mean_uncertainty = torch.mean(uncertainties, dim=0)
|
||||
std_uncertainty = torch.std(uncertainties, dim=0)
|
||||
|
||||
return mean_uncertainty, std_uncertainty
|
||||
|
||||
class TradingTransformerLayer(nn.Module):
|
||||
"""Enhanced transformer layer for trading applications"""
|
||||
|
||||
def __init__(self, config: TradingTransformerConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Multi-scale attention or standard attention
|
||||
if config.use_multi_scale_attention:
|
||||
self.attention = MultiScaleAttention(config.d_model, config.n_heads)
|
||||
else:
|
||||
self.attention = nn.MultiheadAttention(
|
||||
config.d_model, config.n_heads, dropout=config.dropout, batch_first=True
|
||||
)
|
||||
|
||||
# Feed-forward network
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_ff),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_ff, config.d_model)
|
||||
)
|
||||
|
||||
# Layer normalization
|
||||
self.norm1 = nn.LayerNorm(config.d_model)
|
||||
self.norm2 = nn.LayerNorm(config.d_model)
|
||||
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
# Market regime detection
|
||||
if config.use_market_regime_detection:
|
||||
self.regime_detector = MarketRegimeDetector(config.d_model)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
# Self-attention with residual connection
|
||||
if isinstance(self.attention, MultiScaleAttention):
|
||||
attn_output = self.attention(x, mask)
|
||||
else:
|
||||
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
|
||||
|
||||
x = self.norm1(x + self.dropout(attn_output))
|
||||
|
||||
# Market regime adaptation
|
||||
regime_probs = None
|
||||
if hasattr(self, 'regime_detector'):
|
||||
x, regime_probs = self.regime_detector(x)
|
||||
|
||||
# Feed-forward with residual connection
|
||||
ff_output = self.feed_forward(x)
|
||||
x = self.norm2(x + self.dropout(ff_output))
|
||||
|
||||
return {
|
||||
'output': x,
|
||||
'regime_probs': regime_probs
|
||||
}
|
||||
|
||||
class AdvancedTradingTransformer(nn.Module):
|
||||
"""Advanced transformer model for high-frequency trading"""
|
||||
|
||||
def __init__(self, config: TradingTransformerConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Input projections
|
||||
self.price_projection = nn.Linear(5, config.d_model) # OHLCV
|
||||
self.cob_projection = nn.Linear(config.cob_features, config.d_model)
|
||||
self.tech_projection = nn.Linear(config.tech_features, config.d_model)
|
||||
self.market_projection = nn.Linear(config.market_features, config.d_model)
|
||||
|
||||
# Positional encoding
|
||||
if config.use_relative_position:
|
||||
self.pos_encoding = RelativePositionalEncoding(config.d_model)
|
||||
else:
|
||||
self.pos_encoding = PositionalEncoding(config.d_model, config.seq_len)
|
||||
|
||||
# Transformer layers
|
||||
self.layers = nn.ModuleList([
|
||||
TradingTransformerLayer(config) for _ in range(config.n_layers)
|
||||
])
|
||||
|
||||
# Output heads
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.n_actions)
|
||||
)
|
||||
|
||||
if config.confidence_output:
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Uncertainty estimation
|
||||
if config.use_uncertainty_estimation:
|
||||
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
||||
|
||||
# Price prediction head (auxiliary task)
|
||||
self.price_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 1)
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize model weights"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(self, price_data: torch.Tensor, cob_data: torch.Tensor,
|
||||
tech_data: torch.Tensor, market_data: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the trading transformer
|
||||
|
||||
Args:
|
||||
price_data: (batch, seq_len, 5) - OHLCV data
|
||||
cob_data: (batch, seq_len, cob_features) - COB features
|
||||
tech_data: (batch, seq_len, tech_features) - Technical indicators
|
||||
market_data: (batch, seq_len, market_features) - Market microstructure
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
Dictionary containing model outputs
|
||||
"""
|
||||
batch_size, seq_len = price_data.shape[:2]
|
||||
|
||||
# Project inputs to model dimension
|
||||
price_emb = self.price_projection(price_data)
|
||||
cob_emb = self.cob_projection(cob_data)
|
||||
tech_emb = self.tech_projection(tech_data)
|
||||
market_emb = self.market_projection(market_data)
|
||||
|
||||
# Combine embeddings (could also use cross-attention)
|
||||
x = price_emb + cob_emb + tech_emb + market_emb
|
||||
|
||||
# Add positional encoding
|
||||
if isinstance(self.pos_encoding, RelativePositionalEncoding):
|
||||
# Relative position encoding is applied in attention
|
||||
pass
|
||||
else:
|
||||
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
|
||||
|
||||
# Apply transformer layers
|
||||
regime_probs_history = []
|
||||
for layer in self.layers:
|
||||
layer_output = layer(x, mask)
|
||||
x = layer_output['output']
|
||||
if layer_output['regime_probs'] is not None:
|
||||
regime_probs_history.append(layer_output['regime_probs'])
|
||||
|
||||
# Global pooling for final prediction
|
||||
# Use attention-based pooling
|
||||
pooling_weights = F.softmax(
|
||||
torch.sum(x, dim=-1, keepdim=True), dim=1
|
||||
)
|
||||
pooled = torch.sum(x * pooling_weights, dim=1)
|
||||
|
||||
# Generate outputs
|
||||
outputs = {}
|
||||
|
||||
# Action prediction
|
||||
action_logits = self.action_head(pooled)
|
||||
outputs['action_logits'] = action_logits
|
||||
outputs['action_probs'] = F.softmax(action_logits, dim=-1)
|
||||
|
||||
# Confidence prediction
|
||||
if self.config.confidence_output:
|
||||
confidence = self.confidence_head(pooled)
|
||||
outputs['confidence'] = confidence
|
||||
|
||||
# Uncertainty estimation
|
||||
if self.config.use_uncertainty_estimation:
|
||||
uncertainty_mean, uncertainty_std = self.uncertainty_estimator(pooled)
|
||||
outputs['uncertainty_mean'] = uncertainty_mean
|
||||
outputs['uncertainty_std'] = uncertainty_std
|
||||
|
||||
# Price prediction (auxiliary task)
|
||||
price_pred = self.price_head(pooled)
|
||||
outputs['price_prediction'] = price_pred
|
||||
|
||||
# Market regime information
|
||||
if regime_probs_history:
|
||||
outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1)
|
||||
|
||||
return outputs
|
||||
|
||||
class TradingTransformerTrainer:
|
||||
"""Trainer for the advanced trading transformer"""
|
||||
|
||||
def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig):
|
||||
self.model = model
|
||||
self.config = config
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Optimizer with warmup
|
||||
self.optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
self.scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=config.learning_rate,
|
||||
total_steps=10000, # Will be updated based on training data
|
||||
pct_start=0.1
|
||||
)
|
||||
|
||||
# Loss functions
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
self.price_criterion = nn.MSELoss()
|
||||
self.confidence_criterion = nn.BCELoss()
|
||||
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Single training step"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Move batch to device
|
||||
batch = {k: v.to(self.device) for k, v in batch.items()}
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data']
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
|
||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||
|
||||
# Add confidence loss if available
|
||||
if 'confidence' in outputs and 'trade_success' in batch:
|
||||
confidence_loss = self.confidence_criterion(
|
||||
outputs['confidence'].squeeze(),
|
||||
batch['trade_success'].float()
|
||||
)
|
||||
total_loss += 0.1 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
# Calculate accuracy
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
return {
|
||||
'total_loss': total_loss.item(),
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
def validate(self, val_loader: DataLoader) -> Dict[str, float]:
|
||||
"""Validation step"""
|
||||
self.model.eval()
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
batch = {k: v.to(self.device) for k, v in batch.items()}
|
||||
|
||||
outputs = self.model(
|
||||
batch['price_data'],
|
||||
batch['cob_data'],
|
||||
batch['tech_data'],
|
||||
batch['market_data']
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||
total_loss += action_loss.item() + 0.1 * price_loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
total_accuracy += accuracy.item()
|
||||
|
||||
num_batches += 1
|
||||
|
||||
return {
|
||||
'val_loss': total_loss / num_batches,
|
||||
'val_accuracy': total_accuracy / num_batches
|
||||
}
|
||||
|
||||
def train(self, train_loader: DataLoader, val_loader: DataLoader,
|
||||
epochs: int, save_path: str = "NN/models/saved/"):
|
||||
"""Full training loop"""
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Training
|
||||
epoch_losses = []
|
||||
epoch_accuracies = []
|
||||
|
||||
for batch in train_loader:
|
||||
metrics = self.train_step(batch)
|
||||
epoch_losses.append(metrics['total_loss'])
|
||||
epoch_accuracies.append(metrics['accuracy'])
|
||||
|
||||
# Validation
|
||||
val_metrics = self.validate(val_loader)
|
||||
|
||||
# Update history
|
||||
avg_train_loss = np.mean(epoch_losses)
|
||||
avg_train_accuracy = np.mean(epoch_accuracies)
|
||||
|
||||
self.training_history['train_loss'].append(avg_train_loss)
|
||||
self.training_history['val_loss'].append(val_metrics['val_loss'])
|
||||
self.training_history['train_accuracy'].append(avg_train_accuracy)
|
||||
self.training_history['val_accuracy'].append(val_metrics['val_accuracy'])
|
||||
self.training_history['learning_rates'].append(self.scheduler.get_last_lr()[0])
|
||||
|
||||
# Logging
|
||||
logger.info(f"Epoch {epoch+1}/{epochs}")
|
||||
logger.info(f" Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}")
|
||||
logger.info(f" Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_accuracy']:.4f}")
|
||||
logger.info(f" LR: {self.scheduler.get_last_lr()[0]:.6f}")
|
||||
|
||||
# Save best model
|
||||
if val_metrics['val_loss'] < best_val_loss:
|
||||
best_val_loss = val_metrics['val_loss']
|
||||
self.save_model(os.path.join(save_path, 'best_transformer_model.pt'))
|
||||
logger.info(f" New best model saved (val_loss: {best_val_loss:.4f})")
|
||||
|
||||
def save_model(self, path: str):
|
||||
"""Save model and training state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'config': self.config,
|
||||
'training_history': self.training_history
|
||||
}, path)
|
||||
|
||||
logger.info(f"Model saved to {path}")
|
||||
|
||||
def load_model(self, path: str):
|
||||
"""Load model and training state"""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.training_history = checkpoint.get('training_history', self.training_history)
|
||||
|
||||
logger.info(f"Model loaded from {path}")
|
||||
|
||||
def create_trading_transformer(config: Optional[TradingTransformerConfig] = None) -> Tuple[AdvancedTradingTransformer, TradingTransformerTrainer]:
|
||||
"""Factory function to create trading transformer and trainer"""
|
||||
if config is None:
|
||||
config = TradingTransformerConfig()
|
||||
|
||||
model = AdvancedTradingTransformer(config)
|
||||
trainer = TradingTransformerTrainer(model, config)
|
||||
|
||||
return model, trainer
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create configuration
|
||||
config = TradingTransformerConfig(
|
||||
d_model=256,
|
||||
n_heads=8,
|
||||
n_layers=4,
|
||||
seq_len=50,
|
||||
n_actions=3,
|
||||
use_multi_scale_attention=True,
|
||||
use_market_regime_detection=True,
|
||||
use_uncertainty_estimation=True
|
||||
)
|
||||
|
||||
# Create model and trainer
|
||||
model, trainer = create_trading_transformer(config)
|
||||
|
||||
logger.info(f"Created Advanced Trading Transformer with {sum(p.numel() for p in model.parameters())} parameters")
|
||||
logger.info("Model is ready for training on real market data!")
|
Reference in New Issue
Block a user