968 lines
41 KiB
Python
968 lines
41 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Enhanced CNN Model for Trading - PyTorch Implementation
|
|
Much larger and more sophisticated architecture for better learning
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
import torch.nn.functional as F
|
|
from typing import Dict, Any, Optional, Tuple
|
|
|
|
# Import checkpoint management
|
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
|
from utils.training_integration import get_training_integration
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""Multi-head attention mechanism for sequence data"""
|
|
|
|
def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
|
|
super().__init__()
|
|
assert d_model % num_heads == 0
|
|
|
|
self.d_model = d_model
|
|
self.num_heads = num_heads
|
|
self.d_k = d_model // num_heads
|
|
|
|
self.w_q = nn.Linear(d_model, d_model)
|
|
self.w_k = nn.Linear(d_model, d_model)
|
|
self.w_v = nn.Linear(d_model, d_model)
|
|
self.w_o = nn.Linear(d_model, d_model)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.scale = math.sqrt(self.d_k)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
batch_size, seq_len, _ = x.size()
|
|
|
|
# Compute Q, K, V
|
|
Q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
K = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
V = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
|
|
# Attention weights
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
|
|
attention_weights = F.softmax(scores, dim=-1)
|
|
attention_weights = self.dropout(attention_weights)
|
|
|
|
# Apply attention
|
|
attention_output = torch.matmul(attention_weights, V)
|
|
attention_output = attention_output.transpose(1, 2).contiguous().view(
|
|
batch_size, seq_len, self.d_model
|
|
)
|
|
|
|
return self.w_o(attention_output)
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""Residual block with normalization and dropout"""
|
|
|
|
def __init__(self, channels: int, dropout: float = 0.1):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
|
self.norm1 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm
|
|
self.norm2 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Create completely independent copy for residual connection
|
|
residual = x.detach().clone()
|
|
|
|
# First convolution branch - ensure no memory sharing
|
|
out = self.conv1(x)
|
|
out = self.norm1(out)
|
|
out = F.relu(out)
|
|
out = self.dropout(out)
|
|
|
|
# Second convolution branch
|
|
out = self.conv2(out)
|
|
out = self.norm2(out)
|
|
|
|
# Residual connection - create completely new tensor
|
|
# Avoid any potential in-place operations or memory sharing
|
|
combined = residual + out
|
|
result = F.relu(combined)
|
|
|
|
return result
|
|
|
|
class SpatialAttentionBlock(nn.Module):
|
|
"""Spatial attention for feature maps"""
|
|
|
|
def __init__(self, channels: int):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(channels, 1, kernel_size=1)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Compute attention weights
|
|
attention = torch.sigmoid(self.conv(x))
|
|
# Avoid in-place operation by creating new tensor
|
|
return torch.mul(x, attention)
|
|
|
|
class EnhancedCNNModel(nn.Module):
|
|
"""
|
|
Much larger and more sophisticated CNN architecture for trading
|
|
Features:
|
|
- Deep convolutional layers with residual connections
|
|
- Multi-head attention mechanisms
|
|
- Spatial attention blocks
|
|
- Multiple feature extraction paths
|
|
- Large capacity for complex pattern learning
|
|
"""
|
|
|
|
def __init__(self,
|
|
input_size: int = 60,
|
|
feature_dim: int = 50,
|
|
output_size: int = 2, # BUY/SELL for 2-action system
|
|
base_channels: int = 256, # Increased from 128 to 256
|
|
num_blocks: int = 12, # Increased from 6 to 12
|
|
num_attention_heads: int = 16, # Increased from 8 to 16
|
|
dropout_rate: float = 0.2):
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
self.feature_dim = feature_dim
|
|
self.output_size = output_size
|
|
self.base_channels = base_channels
|
|
|
|
# Much larger input embedding - project features to higher dimension
|
|
self.input_embedding = nn.Sequential(
|
|
nn.Linear(feature_dim, base_channels // 2),
|
|
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d for batch_size=1 compatibility
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels),
|
|
nn.LayerNorm(base_channels), # Changed from BatchNorm1d for batch_size=1 compatibility
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Multi-scale convolutional feature extraction with more channels
|
|
self.conv_path1 = self._build_conv_path(base_channels, base_channels, 3)
|
|
self.conv_path2 = self._build_conv_path(base_channels, base_channels, 5)
|
|
self.conv_path3 = self._build_conv_path(base_channels, base_channels, 7)
|
|
self.conv_path4 = self._build_conv_path(base_channels, base_channels, 9) # Additional path
|
|
|
|
# Feature fusion with more capacity
|
|
self.feature_fusion = nn.Sequential(
|
|
nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now
|
|
nn.GroupNorm(1, base_channels * 3), # Changed from BatchNorm1d to GroupNorm
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1),
|
|
nn.GroupNorm(1, base_channels * 2), # Changed from BatchNorm1d to GroupNorm
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Much deeper residual blocks for complex pattern learning
|
|
self.residual_blocks = nn.ModuleList([
|
|
ResidualBlock(base_channels * 2, dropout_rate) for _ in range(num_blocks)
|
|
])
|
|
|
|
# More spatial attention blocks
|
|
self.spatial_attention = nn.ModuleList([
|
|
SpatialAttentionBlock(base_channels * 2) for _ in range(6) # Increased from 3 to 6
|
|
])
|
|
|
|
# Multiple temporal attention layers
|
|
self.temporal_attention1 = MultiHeadAttention(
|
|
d_model=base_channels * 2,
|
|
num_heads=num_attention_heads,
|
|
dropout=dropout_rate
|
|
)
|
|
self.temporal_attention2 = MultiHeadAttention(
|
|
d_model=base_channels * 2,
|
|
num_heads=num_attention_heads // 2,
|
|
dropout=dropout_rate
|
|
)
|
|
|
|
# Global feature aggregation
|
|
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
|
self.global_max_pool = nn.AdaptiveMaxPool1d(1)
|
|
|
|
# Much larger advanced feature processing (using LayerNorm for batch_size=1 compatibility)
|
|
self.advanced_features = nn.Sequential(
|
|
nn.Linear(base_channels * 4, base_channels * 6), # Increased capacity
|
|
nn.LayerNorm(base_channels * 6), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 6, base_channels * 4),
|
|
nn.LayerNorm(base_channels * 4), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 4, base_channels * 3),
|
|
nn.LayerNorm(base_channels * 3), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 3, base_channels * 2),
|
|
nn.LayerNorm(base_channels * 2), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 2, base_channels),
|
|
nn.LayerNorm(base_channels), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Enhanced market regime detection branch (using LayerNorm for batch_size=1 compatibility)
|
|
self.regime_detector = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels // 4),
|
|
nn.LayerNorm(base_channels // 4), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 4, 8), # 8 market regimes instead of 4
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
# Enhanced volatility prediction branch (using LayerNorm for batch_size=1 compatibility)
|
|
self.volatility_predictor = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels // 4),
|
|
nn.LayerNorm(base_channels // 4), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 4, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Main trading decision head (using LayerNorm for batch_size=1 compatibility)
|
|
self.decision_head = nn.Sequential(
|
|
nn.Linear(base_channels + 8 + 1, base_channels), # 8 regime classes + 1 volatility
|
|
nn.LayerNorm(base_channels), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels // 2, output_size)
|
|
)
|
|
|
|
# Confidence estimation head
|
|
self.confidence_head = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Initialize weights
|
|
self._initialize_weights()
|
|
|
|
def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
|
|
"""Build a convolutional path with multiple layers"""
|
|
return nn.Sequential(
|
|
nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
|
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
|
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
|
nn.ReLU()
|
|
)
|
|
|
|
def _initialize_weights(self):
|
|
"""Initialize model weights"""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.xavier_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm, nn.LayerNorm)):
|
|
if hasattr(m, 'weight') and m.weight is not None:
|
|
nn.init.constant_(m.weight, 1)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Create a memory barrier to prevent in-place operation issues"""
|
|
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
|
|
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass with multiple outputs - completely avoiding in-place operations
|
|
Args:
|
|
x: Input tensor of shape [batch_size, sequence_length, features]
|
|
Returns:
|
|
Dictionary with predictions, confidence, regime, and volatility
|
|
"""
|
|
# Apply memory barrier to input
|
|
x = self._memory_barrier(x)
|
|
|
|
# Handle input shapes flexibly - create new tensors to avoid memory sharing
|
|
if len(x.shape) == 2:
|
|
# Input is [seq_len, features] - add batch dimension
|
|
x = x.unsqueeze(0)
|
|
elif len(x.shape) > 3:
|
|
# Input has extra dimensions - flatten to [batch, seq, features]
|
|
x = x.view(x.shape[0], -1, x.shape[-1])
|
|
|
|
x = self._memory_barrier(x) # Apply barrier after shape changes
|
|
batch_size, seq_len, features = x.shape
|
|
|
|
# Reshape for processing: [batch, seq, features] -> [batch*seq, features]
|
|
x_reshaped = x.view(-1, features)
|
|
x_reshaped = self._memory_barrier(x_reshaped)
|
|
|
|
# Input embedding
|
|
embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels]
|
|
embedded = self._memory_barrier(embedded)
|
|
|
|
# Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq]
|
|
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2).contiguous()
|
|
embedded = self._memory_barrier(embedded)
|
|
|
|
# Multi-scale feature extraction - ensure each path creates independent tensors
|
|
path1 = self._memory_barrier(self.conv_path1(embedded))
|
|
path2 = self._memory_barrier(self.conv_path2(embedded))
|
|
path3 = self._memory_barrier(self.conv_path3(embedded))
|
|
path4 = self._memory_barrier(self.conv_path4(embedded))
|
|
|
|
# Feature fusion - create new tensor
|
|
fused_features = torch.cat([path1, path2, path3, path4], dim=1)
|
|
fused_features = self._memory_barrier(self.feature_fusion(fused_features))
|
|
|
|
# Apply residual blocks with spatial attention
|
|
current_features = self._memory_barrier(fused_features)
|
|
for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)):
|
|
current_features = self._memory_barrier(res_block(current_features))
|
|
if i % 2 == 0: # Apply attention every other block
|
|
current_features = self._memory_barrier(attention(current_features))
|
|
|
|
# Apply remaining residual blocks
|
|
for res_block in self.residual_blocks[len(self.spatial_attention):]:
|
|
current_features = self._memory_barrier(res_block(current_features))
|
|
|
|
# Temporal attention - apply both attention layers
|
|
# Reshape for attention: [batch, channels, seq] -> [batch, seq, channels]
|
|
attention_input = current_features.transpose(1, 2).contiguous()
|
|
attention_input = self._memory_barrier(attention_input)
|
|
|
|
attended_features = self._memory_barrier(self.temporal_attention1(attention_input))
|
|
attended_features = self._memory_barrier(self.temporal_attention2(attended_features))
|
|
# Back to conv format: [batch, seq, channels] -> [batch, channels, seq]
|
|
attended_features = attended_features.transpose(1, 2).contiguous()
|
|
attended_features = self._memory_barrier(attended_features)
|
|
|
|
# Global aggregation - create independent tensors
|
|
avg_pooled = self.global_pool(attended_features)
|
|
avg_pooled = self._memory_barrier(avg_pooled.view(avg_pooled.shape[0], -1)) # Flatten instead of squeeze
|
|
|
|
max_pooled = self.global_max_pool(attended_features)
|
|
max_pooled = self._memory_barrier(max_pooled.view(max_pooled.shape[0], -1)) # Flatten instead of squeeze
|
|
|
|
# Combine global features - create new tensor
|
|
global_features = torch.cat([avg_pooled, max_pooled], dim=1)
|
|
global_features = self._memory_barrier(global_features)
|
|
|
|
# Advanced feature processing
|
|
processed_features = self._memory_barrier(self.advanced_features(global_features))
|
|
|
|
# Multi-task predictions - ensure each creates independent tensors
|
|
regime_probs = self._memory_barrier(self.regime_detector(processed_features))
|
|
volatility_pred = self._memory_barrier(self.volatility_predictor(processed_features))
|
|
confidence = self._memory_barrier(self.confidence_head(processed_features))
|
|
|
|
# Combine all features for final decision (8 regime classes + 1 volatility)
|
|
# Create completely independent tensors for concatenation
|
|
vol_pred_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1)) # Flatten instead of squeeze
|
|
combined_features = torch.cat([processed_features, regime_probs, vol_pred_flat], dim=1)
|
|
combined_features = self._memory_barrier(combined_features)
|
|
|
|
trading_logits = self._memory_barrier(self.decision_head(combined_features))
|
|
|
|
# Apply temperature scaling for better calibration - create new tensor
|
|
temperature = 1.5
|
|
scaled_logits = trading_logits / temperature
|
|
trading_probs = self._memory_barrier(F.softmax(scaled_logits, dim=1))
|
|
|
|
# Flatten confidence to ensure consistent shape
|
|
confidence_flat = self._memory_barrier(confidence.view(confidence.shape[0], -1))
|
|
volatility_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1))
|
|
|
|
return {
|
|
'logits': self._memory_barrier(trading_logits),
|
|
'probabilities': self._memory_barrier(trading_probs),
|
|
'confidence': confidence_flat[:, 0] if confidence_flat.shape[1] > 0 else confidence_flat.view(-1)[0],
|
|
'regime': self._memory_barrier(regime_probs),
|
|
'volatility': volatility_flat[:, 0] if volatility_flat.shape[1] > 0 else volatility_flat.view(-1)[0],
|
|
'features': self._memory_barrier(processed_features)
|
|
}
|
|
|
|
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
|
|
"""
|
|
Make predictions on feature matrix
|
|
Args:
|
|
feature_matrix: numpy array of shape [sequence_length, features]
|
|
Returns:
|
|
Dictionary with prediction results
|
|
"""
|
|
self.eval()
|
|
|
|
with torch.no_grad():
|
|
# Convert to tensor and add batch dimension
|
|
if isinstance(feature_matrix, np.ndarray):
|
|
x = torch.FloatTensor(feature_matrix).unsqueeze(0) # Add batch dim
|
|
else:
|
|
x = feature_matrix.unsqueeze(0)
|
|
|
|
# Move to device
|
|
device = next(self.parameters()).device
|
|
x = x.to(device)
|
|
|
|
# Forward pass
|
|
outputs = self.forward(x)
|
|
|
|
# Extract results with proper shape handling
|
|
probs = outputs['probabilities'].cpu().numpy()[0]
|
|
confidence_tensor = outputs['confidence'].cpu().numpy()
|
|
regime = outputs['regime'].cpu().numpy()[0]
|
|
volatility = outputs['volatility'].cpu().numpy()
|
|
|
|
# Handle confidence shape properly
|
|
if isinstance(confidence_tensor, np.ndarray):
|
|
if confidence_tensor.ndim == 0:
|
|
confidence = float(confidence_tensor.item())
|
|
elif confidence_tensor.size == 1:
|
|
confidence = float(confidence_tensor.flatten()[0])
|
|
else:
|
|
confidence = float(confidence_tensor[0] if len(confidence_tensor) > 0 else 0.7)
|
|
else:
|
|
confidence = float(confidence_tensor)
|
|
|
|
# Handle volatility shape properly
|
|
if isinstance(volatility, np.ndarray):
|
|
if volatility.ndim == 0:
|
|
volatility = float(volatility.item())
|
|
elif volatility.size == 1:
|
|
volatility = float(volatility.flatten()[0])
|
|
else:
|
|
volatility = float(volatility[0] if len(volatility) > 0 else 0.0)
|
|
else:
|
|
volatility = float(volatility)
|
|
|
|
# Determine action (0=BUY, 1=SELL for 2-action system)
|
|
action = int(np.argmax(probs))
|
|
action_confidence = float(probs[action])
|
|
|
|
return {
|
|
'action': action,
|
|
'action_name': 'BUY' if action == 0 else 'SELL',
|
|
'confidence': float(confidence),
|
|
'action_confidence': action_confidence,
|
|
'probabilities': probs.tolist(),
|
|
'regime_probabilities': regime.tolist(),
|
|
'volatility_prediction': float(volatility),
|
|
'raw_logits': outputs['logits'].cpu().numpy()[0].tolist()
|
|
}
|
|
|
|
def get_memory_usage(self) -> Dict[str, Any]:
|
|
"""Get model memory usage statistics"""
|
|
total_params = sum(p.numel() for p in self.parameters())
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
param_size = sum(p.numel() * p.element_size() for p in self.parameters())
|
|
buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
|
|
|
|
return {
|
|
'total_parameters': total_params,
|
|
'trainable_parameters': trainable_params,
|
|
'parameter_size_mb': param_size / (1024 * 1024),
|
|
'buffer_size_mb': buffer_size / (1024 * 1024),
|
|
'total_size_mb': (param_size + buffer_size) / (1024 * 1024)
|
|
}
|
|
|
|
def to_device(self, device: str):
|
|
"""Move model to specified device"""
|
|
return self.to(torch.device(device))
|
|
|
|
class CNNModelTrainer:
|
|
"""Enhanced CNN trainer with checkpoint management integration"""
|
|
|
|
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda',
|
|
model_name: str = "enhanced_cnn", enable_checkpoints: bool = True):
|
|
self.model = model
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
self.model.to(self.device)
|
|
|
|
# Checkpoint management
|
|
self.model_name = model_name
|
|
self.enable_checkpoints = enable_checkpoints
|
|
self.training_integration = get_training_integration() if enable_checkpoints else None
|
|
self.epoch_count = 0
|
|
self.best_val_accuracy = 0.0
|
|
self.best_val_loss = float('inf')
|
|
self.checkpoint_frequency = 10 # Save checkpoint every 10 epochs
|
|
|
|
# Optimizers and criteria
|
|
self.optimizer = optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=learning_rate,
|
|
weight_decay=0.01,
|
|
betas=(0.9, 0.999)
|
|
)
|
|
|
|
self.scheduler = optim.lr_scheduler.OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=learning_rate * 10,
|
|
total_steps=1000,
|
|
pct_start=0.1,
|
|
anneal_strategy='cos'
|
|
)
|
|
|
|
# Loss functions
|
|
self.main_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
self.confidence_criterion = nn.MSELoss()
|
|
self.regime_criterion = nn.CrossEntropyLoss()
|
|
self.volatility_criterion = nn.MSELoss()
|
|
|
|
# Training history
|
|
self.training_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_accuracy': [],
|
|
'val_accuracy': [],
|
|
'learning_rates': []
|
|
}
|
|
|
|
# Load best checkpoint if available
|
|
if self.enable_checkpoints:
|
|
self.load_best_checkpoint()
|
|
|
|
logger.info(f"CNN Trainer initialized with checkpoint management: {enable_checkpoints}")
|
|
if enable_checkpoints:
|
|
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
|
|
|
|
def load_best_checkpoint(self):
|
|
"""Load the best checkpoint for this CNN model"""
|
|
try:
|
|
if not self.enable_checkpoints:
|
|
return
|
|
|
|
result = load_best_checkpoint(self.model_name)
|
|
if result:
|
|
file_path, metadata = result
|
|
checkpoint = torch.load(file_path, map_location=self.device)
|
|
|
|
# Load model state
|
|
if 'model_state_dict' in checkpoint:
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
if 'optimizer_state_dict' in checkpoint:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
if 'scheduler_state_dict' in checkpoint:
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
# Load training state
|
|
if 'epoch_count' in checkpoint:
|
|
self.epoch_count = checkpoint['epoch_count']
|
|
if 'best_val_accuracy' in checkpoint:
|
|
self.best_val_accuracy = checkpoint['best_val_accuracy']
|
|
if 'best_val_loss' in checkpoint:
|
|
self.best_val_loss = checkpoint['best_val_loss']
|
|
if 'training_history' in checkpoint:
|
|
self.training_history = checkpoint['training_history']
|
|
|
|
logger.info(f"Loaded CNN checkpoint: {metadata.checkpoint_id}")
|
|
logger.info(f"Epoch: {self.epoch_count}, Best val accuracy: {self.best_val_accuracy:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
|
|
|
def save_checkpoint(self, train_accuracy: float, val_accuracy: float,
|
|
train_loss: float, val_loss: float, force_save: bool = False):
|
|
"""Save checkpoint if performance improved or forced"""
|
|
try:
|
|
if not self.enable_checkpoints:
|
|
return False
|
|
|
|
self.epoch_count += 1
|
|
|
|
# Update best metrics
|
|
improved = False
|
|
if val_accuracy > self.best_val_accuracy:
|
|
self.best_val_accuracy = val_accuracy
|
|
improved = True
|
|
if val_loss < self.best_val_loss:
|
|
self.best_val_loss = val_loss
|
|
improved = True
|
|
|
|
# Save checkpoint if improved, forced, or at regular intervals
|
|
should_save = (
|
|
force_save or
|
|
improved or
|
|
self.epoch_count % self.checkpoint_frequency == 0
|
|
)
|
|
|
|
if should_save and self.training_integration:
|
|
return self.training_integration.save_cnn_checkpoint(
|
|
cnn_model=self.model,
|
|
model_name=self.model_name,
|
|
epoch=self.epoch_count,
|
|
train_accuracy=train_accuracy,
|
|
val_accuracy=val_accuracy,
|
|
train_loss=train_loss,
|
|
val_loss=val_loss,
|
|
training_time_hours=0.0 # Can be calculated by calling code
|
|
)
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving CNN checkpoint: {e}")
|
|
return False
|
|
|
|
def reset_computational_graph(self):
|
|
"""Reset the computational graph to prevent in-place operation issues"""
|
|
try:
|
|
# Clear all gradients
|
|
for param in self.model.parameters():
|
|
param.grad = None
|
|
|
|
# Force garbage collection
|
|
import gc
|
|
gc.collect()
|
|
|
|
# Clear CUDA cache if available
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Reset optimizer state if needed
|
|
for group in self.optimizer.param_groups:
|
|
for param in group['params']:
|
|
if param in self.optimizer.state:
|
|
# Clear momentum buffers that might have stale references
|
|
self.optimizer.state[param] = {}
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error during computational graph reset: {e}")
|
|
|
|
def train_step(self, x: torch.Tensor, y: torch.Tensor,
|
|
confidence_targets: Optional[torch.Tensor] = None,
|
|
regime_targets: Optional[torch.Tensor] = None,
|
|
volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
|
"""Single training step with multi-task learning and robust error handling"""
|
|
|
|
# Reset computational graph before each training step
|
|
self.reset_computational_graph()
|
|
|
|
try:
|
|
self.model.train()
|
|
|
|
# Ensure inputs are completely independent from original tensors
|
|
x_train = x.detach().clone().requires_grad_(False).to(self.device)
|
|
y_train = y.detach().clone().requires_grad_(False).to(self.device)
|
|
|
|
# Forward pass with error handling
|
|
try:
|
|
outputs = self.model(x_train)
|
|
except RuntimeError as forward_error:
|
|
if "modified by an inplace operation" in str(forward_error):
|
|
logger.error(f"In-place operation in forward pass: {forward_error}")
|
|
self.reset_computational_graph()
|
|
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
|
|
else:
|
|
raise forward_error
|
|
|
|
# Calculate main loss with detached outputs to prevent memory sharing
|
|
main_loss = self.main_criterion(outputs['logits'], y_train)
|
|
total_loss = main_loss
|
|
|
|
losses = {'main_loss': main_loss.item()}
|
|
|
|
# Add auxiliary losses if targets provided
|
|
if confidence_targets is not None:
|
|
conf_targets = confidence_targets.detach().clone().to(self.device)
|
|
conf_loss = self.confidence_criterion(outputs['confidence'], conf_targets)
|
|
total_loss = total_loss + 0.1 * conf_loss
|
|
losses['confidence_loss'] = conf_loss.item()
|
|
|
|
if regime_targets is not None:
|
|
regime_targets_clean = regime_targets.detach().clone().to(self.device)
|
|
regime_loss = self.regime_criterion(outputs['regime'], regime_targets_clean)
|
|
total_loss = total_loss + 0.05 * regime_loss
|
|
losses['regime_loss'] = regime_loss.item()
|
|
|
|
if volatility_targets is not None:
|
|
vol_targets = volatility_targets.detach().clone().to(self.device)
|
|
vol_loss = self.volatility_criterion(outputs['volatility'], vol_targets)
|
|
total_loss = total_loss + 0.05 * vol_loss
|
|
losses['volatility_loss'] = vol_loss.item()
|
|
|
|
losses['total_loss'] = total_loss.item()
|
|
|
|
# Backward pass with comprehensive error handling
|
|
try:
|
|
total_loss.backward()
|
|
|
|
except RuntimeError as backward_error:
|
|
if "modified by an inplace operation" in str(backward_error):
|
|
logger.error(f"In-place operation during backward pass: {backward_error}")
|
|
logger.error("Attempting to continue training with gradient reset...")
|
|
|
|
# Comprehensive cleanup
|
|
self.reset_computational_graph()
|
|
|
|
return {'main_loss': losses.get('main_loss', 0.0), 'total_loss': losses.get('total_loss', 0.0), 'accuracy': 0.5}
|
|
else:
|
|
raise backward_error
|
|
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
# Calculate accuracy with detached tensors
|
|
with torch.no_grad():
|
|
predictions = torch.argmax(outputs['probabilities'], dim=1)
|
|
accuracy = (predictions == y_train).float().mean().item()
|
|
losses['accuracy'] = accuracy
|
|
|
|
# Update training history
|
|
if 'train_loss' in self.training_history:
|
|
self.training_history['train_loss'].append(losses['total_loss'])
|
|
self.training_history['train_accuracy'].append(accuracy)
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
self.training_history['learning_rates'].append(current_lr)
|
|
|
|
return losses
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training step failed with unexpected error: {e}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
import traceback
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
|
|
# Comprehensive cleanup on any error
|
|
self.reset_computational_graph()
|
|
|
|
# Return safe dummy values to continue training
|
|
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
|
|
|
|
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
|
"""Save model with metadata"""
|
|
save_dict = {
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'training_history': self.training_history,
|
|
'model_config': {
|
|
'input_size': self.model.input_size,
|
|
'feature_dim': self.model.feature_dim,
|
|
'output_size': self.model.output_size,
|
|
'base_channels': self.model.base_channels
|
|
}
|
|
}
|
|
|
|
if metadata:
|
|
save_dict['metadata'] = metadata
|
|
|
|
torch.save(save_dict, filepath)
|
|
logger.info(f"Enhanced CNN model saved to {filepath}")
|
|
|
|
def load_model(self, filepath: str) -> Dict:
|
|
"""Load model from file"""
|
|
checkpoint = torch.load(filepath, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
if 'scheduler_state_dict' in checkpoint:
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
if 'training_history' in checkpoint:
|
|
self.training_history = checkpoint['training_history']
|
|
|
|
logger.info(f"Enhanced CNN model loaded from {filepath}")
|
|
return checkpoint.get('metadata', {})
|
|
|
|
def create_enhanced_cnn_model(input_size: int = 60,
|
|
feature_dim: int = 50,
|
|
output_size: int = 2,
|
|
base_channels: int = 256,
|
|
device: str = 'cuda') -> Tuple[EnhancedCNNModel, CNNModelTrainer]:
|
|
"""Create enhanced CNN model and trainer"""
|
|
|
|
model = EnhancedCNNModel(
|
|
input_size=input_size,
|
|
feature_dim=feature_dim,
|
|
output_size=output_size,
|
|
base_channels=base_channels,
|
|
num_blocks=12,
|
|
num_attention_heads=16,
|
|
dropout_rate=0.2
|
|
)
|
|
|
|
trainer = CNNModelTrainer(model, learning_rate=0.0001, device=device)
|
|
|
|
logger.info(f"Created enhanced CNN model with {model.get_memory_usage()['total_parameters']:,} parameters")
|
|
|
|
return model, trainer
|
|
|
|
# Compatibility wrapper for williams_market_structure.py
|
|
class CNNModel:
|
|
"""
|
|
Compatibility wrapper for the enhanced CNN model
|
|
"""
|
|
|
|
def __init__(self, input_shape=(900, 50), output_size=10, model_path=None):
|
|
self.input_shape = input_shape
|
|
self.output_size = output_size
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Create the enhanced model
|
|
self.model = EnhancedCNNModel(
|
|
input_size=input_shape[0],
|
|
feature_dim=input_shape[1],
|
|
output_size=output_size
|
|
)
|
|
self.trainer = CNNModelTrainer(self.model, device=str(self.device))
|
|
|
|
logger.info(f"CNN Model wrapper initialized: input_shape={input_shape}, output_size={output_size}")
|
|
|
|
if model_path and os.path.exists(model_path):
|
|
self.load(model_path)
|
|
|
|
def build_model(self, **kwargs):
|
|
"""Build/configure the model"""
|
|
logger.info("CNN Model build_model called")
|
|
return self
|
|
|
|
def predict(self, X):
|
|
"""Make predictions on input data"""
|
|
try:
|
|
if isinstance(X, np.ndarray):
|
|
result = self.model.predict(X)
|
|
pred_class = np.array([result['action']])
|
|
pred_proba = np.array([result['probabilities']])
|
|
else:
|
|
# Handle tensor input
|
|
result = self.model.predict(X.cpu().numpy() if hasattr(X, 'cpu') else X)
|
|
pred_class = np.array([result['action']])
|
|
pred_proba = np.array([result['probabilities']])
|
|
|
|
logger.debug(f"CNN prediction: class={pred_class}, proba_shape={pred_proba.shape}")
|
|
return pred_class, pred_proba
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN prediction: {e}")
|
|
import traceback
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
# Return dummy prediction
|
|
pred_class = np.array([0])
|
|
pred_proba = np.array([[0.1] * self.output_size])
|
|
return pred_class, pred_proba
|
|
|
|
def fit(self, X, y, **kwargs):
|
|
"""Train the model on input data"""
|
|
try:
|
|
# Convert to tensors if needed (create new tensors to avoid in-place modifications)
|
|
if isinstance(X, np.ndarray):
|
|
X = torch.FloatTensor(X.copy()) # Use copy to avoid in-place modifications
|
|
elif isinstance(X, torch.Tensor):
|
|
X = X.clone().detach() # Clone to avoid in-place modifications
|
|
|
|
if isinstance(y, np.ndarray):
|
|
y = torch.LongTensor(y.copy()) # Use copy to avoid in-place modifications
|
|
elif isinstance(y, torch.Tensor):
|
|
y = y.clone().detach().long() # Clone to avoid in-place modifications
|
|
|
|
# Ensure proper shapes and consistent batch sizes
|
|
if len(X.shape) == 2:
|
|
X = X.unsqueeze(0) # [seq, features] -> [1, seq, features]
|
|
|
|
# Handle target tensor - ensure it matches batch size (avoid in-place operations)
|
|
if len(y.shape) == 0:
|
|
y = y.unsqueeze(0) # scalar -> [1]
|
|
elif len(y.shape) == 2 and y.shape[0] == 1:
|
|
# Already correct shape [1, num_classes] -> get class index
|
|
y = torch.argmax(y, dim=1) # [1, num_classes] -> [1]
|
|
elif len(y.shape) == 1 and len(y) > 1:
|
|
# Multi-class probabilities -> get class index, ensure batch size 1
|
|
y = torch.argmax(y).unsqueeze(0) # [num_classes] -> [1]
|
|
elif len(y.shape) == 1 and len(y) == 1:
|
|
pass # Already correct [1]
|
|
else:
|
|
# Fallback: take first element and ensure batch size 1
|
|
y = y.view(-1)[:1] # Take only first element
|
|
|
|
# Move to device (create new tensors on device, don't modify in-place)
|
|
X = X.to(self.device, non_blocking=True)
|
|
y = y.to(self.device, non_blocking=True)
|
|
|
|
# Use trainer's train_step
|
|
loss_dict = self.trainer.train_step(X, y)
|
|
logger.info(f"CNN training: X_shape={X.shape}, y_shape={y.shape}, loss={loss_dict.get('total_loss', 0):.4f}")
|
|
|
|
return self
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training: {e}")
|
|
return self
|
|
|
|
def save(self, filepath: str):
|
|
"""Save the model"""
|
|
try:
|
|
self.trainer.save_model(filepath)
|
|
logger.info(f"CNN model saved to {filepath}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving CNN model: {e}")
|
|
|
|
def load(self, filepath: str):
|
|
"""Load the model"""
|
|
try:
|
|
self.trainer.load_model(filepath)
|
|
logger.info(f"CNN model loaded from {filepath}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading CNN model: {e}")
|
|
|
|
def to_device(self, device):
|
|
"""Move model to device"""
|
|
self.device = device
|
|
self.model.to(device)
|
|
return self
|
|
|
|
def get_memory_usage(self):
|
|
"""Get model memory usage"""
|
|
try:
|
|
return self.model.get_memory_usage()
|
|
except Exception as e:
|
|
logger.error(f"Error getting memory usage: {e}")
|
|
return {'total_parameters': 0, 'memory_mb': 0}
|