586 lines
22 KiB
Python
586 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Enhanced CNN Model for Trading - PyTorch Implementation
|
|
Much larger and more sophisticated architecture for better learning
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
import torch.nn.functional as F
|
|
from typing import Dict, Any, Optional, Tuple
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""Multi-head attention mechanism for sequence data"""
|
|
|
|
def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
|
|
super().__init__()
|
|
assert d_model % num_heads == 0
|
|
|
|
self.d_model = d_model
|
|
self.num_heads = num_heads
|
|
self.d_k = d_model // num_heads
|
|
|
|
self.w_q = nn.Linear(d_model, d_model)
|
|
self.w_k = nn.Linear(d_model, d_model)
|
|
self.w_v = nn.Linear(d_model, d_model)
|
|
self.w_o = nn.Linear(d_model, d_model)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.scale = math.sqrt(self.d_k)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
batch_size, seq_len, _ = x.size()
|
|
|
|
# Compute Q, K, V
|
|
Q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
K = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
V = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
|
|
|
|
# Attention weights
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
|
|
attention_weights = F.softmax(scores, dim=-1)
|
|
attention_weights = self.dropout(attention_weights)
|
|
|
|
# Apply attention
|
|
attention_output = torch.matmul(attention_weights, V)
|
|
attention_output = attention_output.transpose(1, 2).contiguous().view(
|
|
batch_size, seq_len, self.d_model
|
|
)
|
|
|
|
return self.w_o(attention_output)
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""Residual block with normalization and dropout"""
|
|
|
|
def __init__(self, channels: int, dropout: float = 0.1):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
|
self.norm1 = nn.BatchNorm1d(channels)
|
|
self.norm2 = nn.BatchNorm1d(channels)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
residual = x
|
|
|
|
out = F.relu(self.norm1(self.conv1(x)))
|
|
out = self.dropout(out)
|
|
out = self.norm2(self.conv2(out))
|
|
|
|
# Add residual connection
|
|
out += residual
|
|
return F.relu(out)
|
|
|
|
class SpatialAttentionBlock(nn.Module):
|
|
"""Spatial attention for feature maps"""
|
|
|
|
def __init__(self, channels: int):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(channels, 1, kernel_size=1)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Compute attention weights
|
|
attention = torch.sigmoid(self.conv(x))
|
|
return x * attention
|
|
|
|
class EnhancedCNNModel(nn.Module):
|
|
"""
|
|
Much larger and more sophisticated CNN architecture for trading
|
|
Features:
|
|
- Deep convolutional layers with residual connections
|
|
- Multi-head attention mechanisms
|
|
- Spatial attention blocks
|
|
- Multiple feature extraction paths
|
|
- Large capacity for complex pattern learning
|
|
"""
|
|
|
|
def __init__(self,
|
|
input_size: int = 60,
|
|
feature_dim: int = 50,
|
|
output_size: int = 2, # BUY/SELL for 2-action system
|
|
base_channels: int = 256, # Increased from 128 to 256
|
|
num_blocks: int = 12, # Increased from 6 to 12
|
|
num_attention_heads: int = 16, # Increased from 8 to 16
|
|
dropout_rate: float = 0.2):
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
self.feature_dim = feature_dim
|
|
self.output_size = output_size
|
|
self.base_channels = base_channels
|
|
|
|
# Much larger input embedding - project features to higher dimension
|
|
self.input_embedding = nn.Sequential(
|
|
nn.Linear(feature_dim, base_channels // 2),
|
|
nn.BatchNorm1d(base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels),
|
|
nn.BatchNorm1d(base_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Multi-scale convolutional feature extraction with more channels
|
|
self.conv_path1 = self._build_conv_path(base_channels, base_channels, 3)
|
|
self.conv_path2 = self._build_conv_path(base_channels, base_channels, 5)
|
|
self.conv_path3 = self._build_conv_path(base_channels, base_channels, 7)
|
|
self.conv_path4 = self._build_conv_path(base_channels, base_channels, 9) # Additional path
|
|
|
|
# Feature fusion with more capacity
|
|
self.feature_fusion = nn.Sequential(
|
|
nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now
|
|
nn.BatchNorm1d(base_channels * 3),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1),
|
|
nn.BatchNorm1d(base_channels * 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Much deeper residual blocks for complex pattern learning
|
|
self.residual_blocks = nn.ModuleList([
|
|
ResidualBlock(base_channels * 2, dropout_rate) for _ in range(num_blocks)
|
|
])
|
|
|
|
# More spatial attention blocks
|
|
self.spatial_attention = nn.ModuleList([
|
|
SpatialAttentionBlock(base_channels * 2) for _ in range(6) # Increased from 3 to 6
|
|
])
|
|
|
|
# Multiple temporal attention layers
|
|
self.temporal_attention1 = MultiHeadAttention(
|
|
d_model=base_channels * 2,
|
|
num_heads=num_attention_heads,
|
|
dropout=dropout_rate
|
|
)
|
|
self.temporal_attention2 = MultiHeadAttention(
|
|
d_model=base_channels * 2,
|
|
num_heads=num_attention_heads // 2,
|
|
dropout=dropout_rate
|
|
)
|
|
|
|
# Global feature aggregation
|
|
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
|
self.global_max_pool = nn.AdaptiveMaxPool1d(1)
|
|
|
|
# Much larger advanced feature processing
|
|
self.advanced_features = nn.Sequential(
|
|
nn.Linear(base_channels * 4, base_channels * 6), # Increased capacity
|
|
nn.BatchNorm1d(base_channels * 6),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 6, base_channels * 4),
|
|
nn.BatchNorm1d(base_channels * 4),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 4, base_channels * 3),
|
|
nn.BatchNorm1d(base_channels * 3),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 3, base_channels * 2),
|
|
nn.BatchNorm1d(base_channels * 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels * 2, base_channels),
|
|
nn.BatchNorm1d(base_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate)
|
|
)
|
|
|
|
# Enhanced market regime detection branch
|
|
self.regime_detector = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.BatchNorm1d(base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels // 4),
|
|
nn.BatchNorm1d(base_channels // 4),
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 4, 8), # 8 market regimes instead of 4
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
# Enhanced volatility prediction branch
|
|
self.volatility_predictor = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.BatchNorm1d(base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(base_channels // 2, base_channels // 4),
|
|
nn.BatchNorm1d(base_channels // 4),
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 4, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Main trading decision head
|
|
self.decision_head = nn.Sequential(
|
|
nn.Linear(base_channels + 8 + 1, base_channels), # 8 regime classes + 1 volatility
|
|
nn.BatchNorm1d(base_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.BatchNorm1d(base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
nn.Linear(base_channels // 2, output_size)
|
|
)
|
|
|
|
# Confidence estimation head
|
|
self.confidence_head = nn.Sequential(
|
|
nn.Linear(base_channels, base_channels // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(base_channels // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Initialize weights
|
|
self._initialize_weights()
|
|
|
|
def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
|
|
"""Build a convolutional path with multiple layers"""
|
|
return nn.Sequential(
|
|
nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.BatchNorm1d(out_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
|
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.BatchNorm1d(out_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
|
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
|
nn.BatchNorm1d(out_channels),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def _initialize_weights(self):
|
|
"""Initialize model weights"""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.xavier_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm1d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Forward pass with multiple outputs
|
|
Args:
|
|
x: Input tensor of shape [batch_size, sequence_length, features]
|
|
Returns:
|
|
Dictionary with predictions, confidence, regime, and volatility
|
|
"""
|
|
batch_size, seq_len, features = x.shape
|
|
|
|
# Reshape for processing: [batch, seq, features] -> [batch*seq, features]
|
|
x_reshaped = x.view(-1, features)
|
|
|
|
# Input embedding
|
|
embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels]
|
|
|
|
# Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq]
|
|
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2)
|
|
|
|
# Multi-scale feature extraction
|
|
path1 = self.conv_path1(embedded)
|
|
path2 = self.conv_path2(embedded)
|
|
path3 = self.conv_path3(embedded)
|
|
path4 = self.conv_path4(embedded)
|
|
|
|
# Feature fusion
|
|
fused_features = torch.cat([path1, path2, path3, path4], dim=1)
|
|
fused_features = self.feature_fusion(fused_features)
|
|
|
|
# Apply residual blocks with spatial attention
|
|
current_features = fused_features
|
|
for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)):
|
|
current_features = res_block(current_features)
|
|
if i % 2 == 0: # Apply attention every other block
|
|
current_features = attention(current_features)
|
|
|
|
# Apply remaining residual blocks
|
|
for res_block in self.residual_blocks[len(self.spatial_attention):]:
|
|
current_features = res_block(current_features)
|
|
|
|
# Temporal attention - apply both attention layers
|
|
# Reshape for attention: [batch, channels, seq] -> [batch, seq, channels]
|
|
attention_input = current_features.transpose(1, 2)
|
|
attended_features = self.temporal_attention1(attention_input)
|
|
attended_features = self.temporal_attention2(attended_features)
|
|
# Back to conv format: [batch, seq, channels] -> [batch, channels, seq]
|
|
attended_features = attended_features.transpose(1, 2)
|
|
|
|
# Global aggregation
|
|
avg_pooled = self.global_pool(attended_features).squeeze(-1) # [batch, channels]
|
|
max_pooled = self.global_max_pool(attended_features).squeeze(-1) # [batch, channels]
|
|
|
|
# Combine global features
|
|
global_features = torch.cat([avg_pooled, max_pooled], dim=1)
|
|
|
|
# Advanced feature processing
|
|
processed_features = self.advanced_features(global_features)
|
|
|
|
# Multi-task predictions
|
|
regime_probs = self.regime_detector(processed_features)
|
|
volatility_pred = self.volatility_predictor(processed_features)
|
|
confidence = self.confidence_head(processed_features)
|
|
|
|
# Combine all features for final decision (8 regime classes + 1 volatility)
|
|
combined_features = torch.cat([processed_features, regime_probs, volatility_pred], dim=1)
|
|
trading_logits = self.decision_head(combined_features)
|
|
|
|
# Apply temperature scaling for better calibration
|
|
temperature = 1.5
|
|
trading_probs = F.softmax(trading_logits / temperature, dim=1)
|
|
|
|
return {
|
|
'logits': trading_logits,
|
|
'probabilities': trading_probs,
|
|
'confidence': confidence.squeeze(-1),
|
|
'regime': regime_probs,
|
|
'volatility': volatility_pred.squeeze(-1),
|
|
'features': processed_features
|
|
}
|
|
|
|
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
|
|
"""
|
|
Make predictions on feature matrix
|
|
Args:
|
|
feature_matrix: numpy array of shape [sequence_length, features]
|
|
Returns:
|
|
Dictionary with prediction results
|
|
"""
|
|
self.eval()
|
|
|
|
with torch.no_grad():
|
|
# Convert to tensor and add batch dimension
|
|
if isinstance(feature_matrix, np.ndarray):
|
|
x = torch.FloatTensor(feature_matrix).unsqueeze(0) # Add batch dim
|
|
else:
|
|
x = feature_matrix.unsqueeze(0)
|
|
|
|
# Move to device
|
|
device = next(self.parameters()).device
|
|
x = x.to(device)
|
|
|
|
# Forward pass
|
|
outputs = self.forward(x)
|
|
|
|
# Extract results
|
|
probs = outputs['probabilities'].cpu().numpy()[0]
|
|
confidence = outputs['confidence'].cpu().numpy()[0]
|
|
regime = outputs['regime'].cpu().numpy()[0]
|
|
volatility = outputs['volatility'].cpu().numpy()[0]
|
|
|
|
# Determine action (0=BUY, 1=SELL for 2-action system)
|
|
action = int(np.argmax(probs))
|
|
action_confidence = float(probs[action])
|
|
|
|
return {
|
|
'action': action,
|
|
'action_name': 'BUY' if action == 0 else 'SELL',
|
|
'confidence': float(confidence),
|
|
'action_confidence': action_confidence,
|
|
'probabilities': probs.tolist(),
|
|
'regime_probabilities': regime.tolist(),
|
|
'volatility_prediction': float(volatility),
|
|
'raw_logits': outputs['logits'].cpu().numpy()[0].tolist()
|
|
}
|
|
|
|
def get_memory_usage(self) -> Dict[str, Any]:
|
|
"""Get model memory usage statistics"""
|
|
total_params = sum(p.numel() for p in self.parameters())
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
param_size = sum(p.numel() * p.element_size() for p in self.parameters())
|
|
buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
|
|
|
|
return {
|
|
'total_parameters': total_params,
|
|
'trainable_parameters': trainable_params,
|
|
'parameter_size_mb': param_size / (1024 * 1024),
|
|
'buffer_size_mb': buffer_size / (1024 * 1024),
|
|
'total_size_mb': (param_size + buffer_size) / (1024 * 1024)
|
|
}
|
|
|
|
def to_device(self, device: str):
|
|
"""Move model to specified device"""
|
|
return self.to(torch.device(device))
|
|
|
|
class CNNModelTrainer:
|
|
"""Enhanced trainer for the beefed-up CNN model"""
|
|
|
|
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda'):
|
|
self.model = model.to(device)
|
|
self.device = device
|
|
self.learning_rate = learning_rate
|
|
|
|
# Use AdamW optimizer with weight decay
|
|
self.optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=learning_rate,
|
|
weight_decay=0.01,
|
|
betas=(0.9, 0.999)
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=learning_rate * 10,
|
|
total_steps=10000, # Will be updated based on actual training
|
|
pct_start=0.1,
|
|
anneal_strategy='cos'
|
|
)
|
|
|
|
# Multi-task loss functions
|
|
self.main_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
self.confidence_criterion = nn.BCELoss()
|
|
self.regime_criterion = nn.CrossEntropyLoss()
|
|
self.volatility_criterion = nn.MSELoss()
|
|
|
|
self.training_history = []
|
|
|
|
def train_step(self, x: torch.Tensor, y: torch.Tensor,
|
|
confidence_targets: Optional[torch.Tensor] = None,
|
|
regime_targets: Optional[torch.Tensor] = None,
|
|
volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
|
"""Single training step with multi-task learning"""
|
|
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
outputs = self.model(x)
|
|
|
|
# Main trading loss
|
|
main_loss = self.main_criterion(outputs['logits'], y)
|
|
total_loss = main_loss
|
|
|
|
losses = {'main_loss': main_loss.item()}
|
|
|
|
# Confidence loss (if targets provided)
|
|
if confidence_targets is not None:
|
|
conf_loss = self.confidence_criterion(outputs['confidence'], confidence_targets)
|
|
total_loss += 0.1 * conf_loss
|
|
losses['confidence_loss'] = conf_loss.item()
|
|
|
|
# Regime classification loss (if targets provided)
|
|
if regime_targets is not None:
|
|
regime_loss = self.regime_criterion(outputs['regime'], regime_targets)
|
|
total_loss += 0.05 * regime_loss
|
|
losses['regime_loss'] = regime_loss.item()
|
|
|
|
# Volatility prediction loss (if targets provided)
|
|
if volatility_targets is not None:
|
|
vol_loss = self.volatility_criterion(outputs['volatility'], volatility_targets)
|
|
total_loss += 0.05 * vol_loss
|
|
losses['volatility_loss'] = vol_loss.item()
|
|
|
|
losses['total_loss'] = total_loss.item()
|
|
|
|
# Backward pass
|
|
total_loss.backward()
|
|
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
# Calculate accuracy
|
|
with torch.no_grad():
|
|
predictions = torch.argmax(outputs['probabilities'], dim=1)
|
|
accuracy = (predictions == y).float().mean().item()
|
|
losses['accuracy'] = accuracy
|
|
|
|
return losses
|
|
|
|
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
|
"""Save model with metadata"""
|
|
save_dict = {
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'training_history': self.training_history,
|
|
'model_config': {
|
|
'input_size': self.model.input_size,
|
|
'feature_dim': self.model.feature_dim,
|
|
'output_size': self.model.output_size,
|
|
'base_channels': self.model.base_channels
|
|
}
|
|
}
|
|
|
|
if metadata:
|
|
save_dict['metadata'] = metadata
|
|
|
|
torch.save(save_dict, filepath)
|
|
logger.info(f"Enhanced CNN model saved to {filepath}")
|
|
|
|
def load_model(self, filepath: str) -> Dict:
|
|
"""Load model from file"""
|
|
checkpoint = torch.load(filepath, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
if 'scheduler_state_dict' in checkpoint:
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
if 'training_history' in checkpoint:
|
|
self.training_history = checkpoint['training_history']
|
|
|
|
logger.info(f"Enhanced CNN model loaded from {filepath}")
|
|
return checkpoint.get('metadata', {})
|
|
|
|
def create_enhanced_cnn_model(input_size: int = 60,
|
|
feature_dim: int = 50,
|
|
output_size: int = 2,
|
|
base_channels: int = 256,
|
|
device: str = 'cuda') -> Tuple[EnhancedCNNModel, CNNModelTrainer]:
|
|
"""Create enhanced CNN model and trainer"""
|
|
|
|
model = EnhancedCNNModel(
|
|
input_size=input_size,
|
|
feature_dim=feature_dim,
|
|
output_size=output_size,
|
|
base_channels=base_channels,
|
|
num_blocks=12,
|
|
num_attention_heads=16,
|
|
dropout_rate=0.2
|
|
)
|
|
|
|
trainer = CNNModelTrainer(model, learning_rate=0.0001, device=device)
|
|
|
|
logger.info(f"Created enhanced CNN model with {model.get_memory_usage()['total_parameters']:,} parameters")
|
|
|
|
return model, trainer
|