restart script
This commit is contained in:
@ -69,20 +69,30 @@ class ResidualBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
|
||||||
self.norm1 = nn.BatchNorm1d(channels)
|
self.norm1 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm
|
||||||
self.norm2 = nn.BatchNorm1d(channels)
|
self.norm2 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
residual = x
|
# Create completely independent copy for residual connection
|
||||||
|
residual = x.detach().clone()
|
||||||
|
|
||||||
out = F.relu(self.norm1(self.conv1(x)))
|
# First convolution branch - ensure no memory sharing
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.norm1(out)
|
||||||
|
out = F.relu(out)
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.norm2(self.conv2(out))
|
|
||||||
|
|
||||||
# Add residual connection (avoid in-place operation)
|
# Second convolution branch
|
||||||
out = out + residual
|
out = self.conv2(out)
|
||||||
return F.relu(out)
|
out = self.norm2(out)
|
||||||
|
|
||||||
|
# Residual connection - create completely new tensor
|
||||||
|
# Avoid any potential in-place operations or memory sharing
|
||||||
|
combined = residual + out
|
||||||
|
result = F.relu(combined)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
class SpatialAttentionBlock(nn.Module):
|
class SpatialAttentionBlock(nn.Module):
|
||||||
"""Spatial attention for feature maps"""
|
"""Spatial attention for feature maps"""
|
||||||
@ -144,11 +154,11 @@ class EnhancedCNNModel(nn.Module):
|
|||||||
# Feature fusion with more capacity
|
# Feature fusion with more capacity
|
||||||
self.feature_fusion = nn.Sequential(
|
self.feature_fusion = nn.Sequential(
|
||||||
nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now
|
nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now
|
||||||
nn.BatchNorm1d(base_channels * 3),
|
nn.GroupNorm(1, base_channels * 3), # Changed from BatchNorm1d to GroupNorm
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout_rate),
|
nn.Dropout(dropout_rate),
|
||||||
nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1),
|
nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1),
|
||||||
nn.BatchNorm1d(base_channels * 2),
|
nn.GroupNorm(1, base_channels * 2), # Changed from BatchNorm1d to GroupNorm
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout_rate)
|
nn.Dropout(dropout_rate)
|
||||||
)
|
)
|
||||||
@ -258,22 +268,22 @@ class EnhancedCNNModel(nn.Module):
|
|||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
|
def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
|
||||||
"""Build a convolutional path with multiple layers"""
|
"""Build a convolutional path with multiple layers"""
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
|
nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -288,19 +298,28 @@ class EnhancedCNNModel(nn.Module):
|
|||||||
nn.init.xavier_normal_(m.weight)
|
nn.init.xavier_normal_(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.BatchNorm1d):
|
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm, nn.LayerNorm)):
|
||||||
nn.init.constant_(m.weight, 1)
|
if hasattr(m, 'weight') and m.weight is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if hasattr(m, 'bias') and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create a memory barrier to prevent in-place operation issues"""
|
||||||
|
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward pass with multiple outputs
|
Forward pass with multiple outputs - completely avoiding in-place operations
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor of shape [batch_size, sequence_length, features]
|
x: Input tensor of shape [batch_size, sequence_length, features]
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with predictions, confidence, regime, and volatility
|
Dictionary with predictions, confidence, regime, and volatility
|
||||||
"""
|
"""
|
||||||
# Handle input shapes flexibly
|
# Apply memory barrier to input
|
||||||
|
x = self._memory_barrier(x)
|
||||||
|
|
||||||
|
# Handle input shapes flexibly - create new tensors to avoid memory sharing
|
||||||
if len(x.shape) == 2:
|
if len(x.shape) == 2:
|
||||||
# Input is [seq_len, features] - add batch dimension
|
# Input is [seq_len, features] - add batch dimension
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
@ -308,76 +327,96 @@ class EnhancedCNNModel(nn.Module):
|
|||||||
# Input has extra dimensions - flatten to [batch, seq, features]
|
# Input has extra dimensions - flatten to [batch, seq, features]
|
||||||
x = x.view(x.shape[0], -1, x.shape[-1])
|
x = x.view(x.shape[0], -1, x.shape[-1])
|
||||||
|
|
||||||
|
x = self._memory_barrier(x) # Apply barrier after shape changes
|
||||||
batch_size, seq_len, features = x.shape
|
batch_size, seq_len, features = x.shape
|
||||||
|
|
||||||
# Reshape for processing: [batch, seq, features] -> [batch*seq, features]
|
# Reshape for processing: [batch, seq, features] -> [batch*seq, features]
|
||||||
x_reshaped = x.view(-1, features)
|
x_reshaped = x.view(-1, features)
|
||||||
|
x_reshaped = self._memory_barrier(x_reshaped)
|
||||||
|
|
||||||
# Input embedding
|
# Input embedding
|
||||||
embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels]
|
embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels]
|
||||||
|
embedded = self._memory_barrier(embedded)
|
||||||
|
|
||||||
# Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq]
|
# Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq]
|
||||||
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2)
|
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2).contiguous()
|
||||||
|
embedded = self._memory_barrier(embedded)
|
||||||
|
|
||||||
# Multi-scale feature extraction
|
# Multi-scale feature extraction - ensure each path creates independent tensors
|
||||||
path1 = self.conv_path1(embedded)
|
path1 = self._memory_barrier(self.conv_path1(embedded))
|
||||||
path2 = self.conv_path2(embedded)
|
path2 = self._memory_barrier(self.conv_path2(embedded))
|
||||||
path3 = self.conv_path3(embedded)
|
path3 = self._memory_barrier(self.conv_path3(embedded))
|
||||||
path4 = self.conv_path4(embedded)
|
path4 = self._memory_barrier(self.conv_path4(embedded))
|
||||||
|
|
||||||
# Feature fusion
|
# Feature fusion - create new tensor
|
||||||
fused_features = torch.cat([path1, path2, path3, path4], dim=1)
|
fused_features = torch.cat([path1, path2, path3, path4], dim=1)
|
||||||
fused_features = self.feature_fusion(fused_features)
|
fused_features = self._memory_barrier(self.feature_fusion(fused_features))
|
||||||
|
|
||||||
# Apply residual blocks with spatial attention
|
# Apply residual blocks with spatial attention
|
||||||
current_features = fused_features
|
current_features = self._memory_barrier(fused_features)
|
||||||
for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)):
|
for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)):
|
||||||
current_features = res_block(current_features)
|
current_features = self._memory_barrier(res_block(current_features))
|
||||||
if i % 2 == 0: # Apply attention every other block
|
if i % 2 == 0: # Apply attention every other block
|
||||||
current_features = attention(current_features)
|
current_features = self._memory_barrier(attention(current_features))
|
||||||
|
|
||||||
# Apply remaining residual blocks
|
# Apply remaining residual blocks
|
||||||
for res_block in self.residual_blocks[len(self.spatial_attention):]:
|
for res_block in self.residual_blocks[len(self.spatial_attention):]:
|
||||||
current_features = res_block(current_features)
|
current_features = self._memory_barrier(res_block(current_features))
|
||||||
|
|
||||||
# Temporal attention - apply both attention layers
|
# Temporal attention - apply both attention layers
|
||||||
# Reshape for attention: [batch, channels, seq] -> [batch, seq, channels]
|
# Reshape for attention: [batch, channels, seq] -> [batch, seq, channels]
|
||||||
attention_input = current_features.transpose(1, 2)
|
attention_input = current_features.transpose(1, 2).contiguous()
|
||||||
attended_features = self.temporal_attention1(attention_input)
|
attention_input = self._memory_barrier(attention_input)
|
||||||
attended_features = self.temporal_attention2(attended_features)
|
|
||||||
|
attended_features = self._memory_barrier(self.temporal_attention1(attention_input))
|
||||||
|
attended_features = self._memory_barrier(self.temporal_attention2(attended_features))
|
||||||
# Back to conv format: [batch, seq, channels] -> [batch, channels, seq]
|
# Back to conv format: [batch, seq, channels] -> [batch, channels, seq]
|
||||||
attended_features = attended_features.transpose(1, 2)
|
attended_features = attended_features.transpose(1, 2).contiguous()
|
||||||
|
attended_features = self._memory_barrier(attended_features)
|
||||||
|
|
||||||
# Global aggregation
|
# Global aggregation - create independent tensors
|
||||||
avg_pooled = self.global_pool(attended_features).squeeze(-1) # [batch, channels]
|
avg_pooled = self.global_pool(attended_features)
|
||||||
max_pooled = self.global_max_pool(attended_features).squeeze(-1) # [batch, channels]
|
avg_pooled = self._memory_barrier(avg_pooled.view(avg_pooled.shape[0], -1)) # Flatten instead of squeeze
|
||||||
|
|
||||||
# Combine global features
|
max_pooled = self.global_max_pool(attended_features)
|
||||||
|
max_pooled = self._memory_barrier(max_pooled.view(max_pooled.shape[0], -1)) # Flatten instead of squeeze
|
||||||
|
|
||||||
|
# Combine global features - create new tensor
|
||||||
global_features = torch.cat([avg_pooled, max_pooled], dim=1)
|
global_features = torch.cat([avg_pooled, max_pooled], dim=1)
|
||||||
|
global_features = self._memory_barrier(global_features)
|
||||||
|
|
||||||
# Advanced feature processing
|
# Advanced feature processing
|
||||||
processed_features = self.advanced_features(global_features)
|
processed_features = self._memory_barrier(self.advanced_features(global_features))
|
||||||
|
|
||||||
# Multi-task predictions
|
# Multi-task predictions - ensure each creates independent tensors
|
||||||
regime_probs = self.regime_detector(processed_features)
|
regime_probs = self._memory_barrier(self.regime_detector(processed_features))
|
||||||
volatility_pred = self.volatility_predictor(processed_features)
|
volatility_pred = self._memory_barrier(self.volatility_predictor(processed_features))
|
||||||
confidence = self.confidence_head(processed_features)
|
confidence = self._memory_barrier(self.confidence_head(processed_features))
|
||||||
|
|
||||||
# Combine all features for final decision (8 regime classes + 1 volatility)
|
# Combine all features for final decision (8 regime classes + 1 volatility)
|
||||||
combined_features = torch.cat([processed_features, regime_probs, volatility_pred], dim=1)
|
# Create completely independent tensors for concatenation
|
||||||
trading_logits = self.decision_head(combined_features)
|
vol_pred_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1)) # Flatten instead of squeeze
|
||||||
|
combined_features = torch.cat([processed_features, regime_probs, vol_pred_flat], dim=1)
|
||||||
|
combined_features = self._memory_barrier(combined_features)
|
||||||
|
|
||||||
# Apply temperature scaling for better calibration
|
trading_logits = self._memory_barrier(self.decision_head(combined_features))
|
||||||
|
|
||||||
|
# Apply temperature scaling for better calibration - create new tensor
|
||||||
temperature = 1.5
|
temperature = 1.5
|
||||||
trading_probs = F.softmax(trading_logits / temperature, dim=1)
|
scaled_logits = trading_logits / temperature
|
||||||
|
trading_probs = self._memory_barrier(F.softmax(scaled_logits, dim=1))
|
||||||
|
|
||||||
|
# Flatten confidence to ensure consistent shape
|
||||||
|
confidence_flat = self._memory_barrier(confidence.view(confidence.shape[0], -1))
|
||||||
|
volatility_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'logits': trading_logits,
|
'logits': self._memory_barrier(trading_logits),
|
||||||
'probabilities': trading_probs,
|
'probabilities': self._memory_barrier(trading_probs),
|
||||||
'confidence': confidence.squeeze(-1),
|
'confidence': confidence_flat[:, 0] if confidence_flat.shape[1] > 0 else confidence_flat.view(-1)[0],
|
||||||
'regime': regime_probs,
|
'regime': self._memory_barrier(regime_probs),
|
||||||
'volatility': volatility_pred.squeeze(-1),
|
'volatility': volatility_flat[:, 0] if volatility_flat.shape[1] > 0 else volatility_flat.view(-1)[0],
|
||||||
'features': processed_features
|
'features': self._memory_barrier(processed_features)
|
||||||
}
|
}
|
||||||
|
|
||||||
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
|
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
|
||||||
@ -478,60 +517,128 @@ class CNNModelTrainer:
|
|||||||
|
|
||||||
self.training_history = []
|
self.training_history = []
|
||||||
|
|
||||||
|
def reset_computational_graph(self):
|
||||||
|
"""Reset the computational graph to prevent in-place operation issues"""
|
||||||
|
try:
|
||||||
|
# Clear all gradients
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Clear CUDA cache if available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Reset optimizer state if needed
|
||||||
|
for group in self.optimizer.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
if param in self.optimizer.state:
|
||||||
|
# Clear momentum buffers that might have stale references
|
||||||
|
self.optimizer.state[param] = {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during computational graph reset: {e}")
|
||||||
|
|
||||||
def train_step(self, x: torch.Tensor, y: torch.Tensor,
|
def train_step(self, x: torch.Tensor, y: torch.Tensor,
|
||||||
confidence_targets: Optional[torch.Tensor] = None,
|
confidence_targets: Optional[torch.Tensor] = None,
|
||||||
regime_targets: Optional[torch.Tensor] = None,
|
regime_targets: Optional[torch.Tensor] = None,
|
||||||
volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]:
|
||||||
"""Single training step with multi-task learning"""
|
"""Single training step with multi-task learning and robust error handling"""
|
||||||
|
|
||||||
self.model.train()
|
# Reset computational graph before each training step
|
||||||
self.optimizer.zero_grad()
|
self.reset_computational_graph()
|
||||||
|
|
||||||
# Forward pass
|
try:
|
||||||
outputs = self.model(x)
|
self.model.train()
|
||||||
|
|
||||||
# Main trading loss
|
# Ensure inputs are completely independent from original tensors
|
||||||
main_loss = self.main_criterion(outputs['logits'], y)
|
x_train = x.detach().clone().requires_grad_(False).to(self.device)
|
||||||
total_loss = main_loss
|
y_train = y.detach().clone().requires_grad_(False).to(self.device)
|
||||||
|
|
||||||
losses = {'main_loss': main_loss.item()}
|
# Forward pass with error handling
|
||||||
|
try:
|
||||||
# Confidence loss (if targets provided)
|
outputs = self.model(x_train)
|
||||||
if confidence_targets is not None:
|
except RuntimeError as forward_error:
|
||||||
conf_loss = self.confidence_criterion(outputs['confidence'], confidence_targets)
|
if "modified by an inplace operation" in str(forward_error):
|
||||||
total_loss += 0.1 * conf_loss
|
logger.error(f"In-place operation in forward pass: {forward_error}")
|
||||||
losses['confidence_loss'] = conf_loss.item()
|
self.reset_computational_graph()
|
||||||
|
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
|
||||||
# Regime classification loss (if targets provided)
|
else:
|
||||||
if regime_targets is not None:
|
raise forward_error
|
||||||
regime_loss = self.regime_criterion(outputs['regime'], regime_targets)
|
|
||||||
total_loss += 0.05 * regime_loss
|
# Calculate main loss with detached outputs to prevent memory sharing
|
||||||
losses['regime_loss'] = regime_loss.item()
|
main_loss = self.main_criterion(outputs['logits'], y_train)
|
||||||
|
total_loss = main_loss
|
||||||
# Volatility prediction loss (if targets provided)
|
|
||||||
if volatility_targets is not None:
|
losses = {'main_loss': main_loss.item()}
|
||||||
vol_loss = self.volatility_criterion(outputs['volatility'], volatility_targets)
|
|
||||||
total_loss += 0.05 * vol_loss
|
# Add auxiliary losses if targets provided
|
||||||
losses['volatility_loss'] = vol_loss.item()
|
if confidence_targets is not None:
|
||||||
|
conf_targets = confidence_targets.detach().clone().to(self.device)
|
||||||
losses['total_loss'] = total_loss.item()
|
conf_loss = self.confidence_criterion(outputs['confidence'], conf_targets)
|
||||||
|
total_loss = total_loss + 0.1 * conf_loss
|
||||||
# Backward pass
|
losses['confidence_loss'] = conf_loss.item()
|
||||||
total_loss.backward()
|
|
||||||
|
if regime_targets is not None:
|
||||||
# Gradient clipping
|
regime_targets_clean = regime_targets.detach().clone().to(self.device)
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
regime_loss = self.regime_criterion(outputs['regime'], regime_targets_clean)
|
||||||
|
total_loss = total_loss + 0.05 * regime_loss
|
||||||
self.optimizer.step()
|
losses['regime_loss'] = regime_loss.item()
|
||||||
self.scheduler.step()
|
|
||||||
|
if volatility_targets is not None:
|
||||||
# Calculate accuracy
|
vol_targets = volatility_targets.detach().clone().to(self.device)
|
||||||
with torch.no_grad():
|
vol_loss = self.volatility_criterion(outputs['volatility'], vol_targets)
|
||||||
predictions = torch.argmax(outputs['probabilities'], dim=1)
|
total_loss = total_loss + 0.05 * vol_loss
|
||||||
accuracy = (predictions == y).float().mean().item()
|
losses['volatility_loss'] = vol_loss.item()
|
||||||
losses['accuracy'] = accuracy
|
|
||||||
|
losses['total_loss'] = total_loss.item()
|
||||||
return losses
|
|
||||||
|
# Backward pass with comprehensive error handling
|
||||||
|
try:
|
||||||
|
total_loss.backward()
|
||||||
|
|
||||||
|
except RuntimeError as backward_error:
|
||||||
|
if "modified by an inplace operation" in str(backward_error):
|
||||||
|
logger.error(f"In-place operation during backward pass: {backward_error}")
|
||||||
|
logger.error("Attempting to continue training with gradient reset...")
|
||||||
|
|
||||||
|
# Comprehensive cleanup
|
||||||
|
self.reset_computational_graph()
|
||||||
|
|
||||||
|
return {'main_loss': losses.get('main_loss', 0.0), 'total_loss': losses.get('total_loss', 0.0), 'accuracy': 0.5}
|
||||||
|
else:
|
||||||
|
raise backward_error
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
|
self.optimizer.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# Calculate accuracy with detached tensors
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = torch.argmax(outputs['probabilities'], dim=1)
|
||||||
|
accuracy = (predictions == y_train).float().mean().item()
|
||||||
|
losses['accuracy'] = accuracy
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training step failed with unexpected error: {e}")
|
||||||
|
logger.error(f"Error type: {type(e).__name__}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
# Comprehensive cleanup on any error
|
||||||
|
self.reset_computational_graph()
|
||||||
|
|
||||||
|
# Return safe dummy values to continue training
|
||||||
|
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
|
||||||
|
|
||||||
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
||||||
"""Save model with metadata"""
|
"""Save model with metadata"""
|
||||||
@ -610,7 +717,7 @@ class CNNModel:
|
|||||||
feature_dim=input_shape[1],
|
feature_dim=input_shape[1],
|
||||||
output_size=output_size
|
output_size=output_size
|
||||||
)
|
)
|
||||||
self.trainer = CNNModelTrainer(self.model, device=self.device)
|
self.trainer = CNNModelTrainer(self.model, device=str(self.device))
|
||||||
|
|
||||||
logger.info(f"CNN Model wrapper initialized: input_shape={input_shape}, output_size={output_size}")
|
logger.info(f"CNN Model wrapper initialized: input_shape={input_shape}, output_size={output_size}")
|
||||||
|
|
||||||
|
90
restart_main_overnight.ps1
Normal file
90
restart_main_overnight.ps1
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Overnight Training Restart Script (PowerShell)
|
||||||
|
# Keeps main.py running continuously, restarting it if it crashes.
|
||||||
|
# Usage: .\restart_main_overnight.ps1
|
||||||
|
|
||||||
|
Write-Host "=" * 60
|
||||||
|
Write-Host "OVERNIGHT TRAINING RESTART SCRIPT (PowerShell)"
|
||||||
|
Write-Host "=" * 60
|
||||||
|
Write-Host "Press Ctrl+C to stop the restart loop"
|
||||||
|
Write-Host "Main script: main.py"
|
||||||
|
Write-Host "Restart delay on crash: 10 seconds"
|
||||||
|
Write-Host "=" * 60
|
||||||
|
|
||||||
|
$restartCount = 0
|
||||||
|
$startTime = Get-Date
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
if (!(Test-Path "logs")) {
|
||||||
|
New-Item -ItemType Directory -Path "logs"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup log file
|
||||||
|
$timestamp = Get-Date -Format "yyyyMMdd_HHmmss"
|
||||||
|
$logFile = "logs\restart_main_ps_$timestamp.log"
|
||||||
|
|
||||||
|
function Write-Log {
|
||||||
|
param($Message)
|
||||||
|
$timestamp = Get-Date -Format "yyyy-MM-dd HH:mm:ss"
|
||||||
|
$logMessage = "$timestamp - $Message"
|
||||||
|
Write-Host $logMessage
|
||||||
|
Add-Content -Path $logFile -Value $logMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
Write-Log "Restart script started, logging to: $logFile"
|
||||||
|
|
||||||
|
# Kill any existing Python processes
|
||||||
|
try {
|
||||||
|
Get-Process python* -ErrorAction SilentlyContinue | Stop-Process -Force -ErrorAction SilentlyContinue
|
||||||
|
Start-Sleep -Seconds 2
|
||||||
|
Write-Log "Killed existing Python processes"
|
||||||
|
} catch {
|
||||||
|
Write-Log "Could not kill existing processes: $_"
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
while ($true) {
|
||||||
|
$restartCount++
|
||||||
|
$runStartTime = Get-Date
|
||||||
|
|
||||||
|
Write-Log "[RESTART #$restartCount] Starting main.py at $(Get-Date -Format 'HH:mm:ss')"
|
||||||
|
|
||||||
|
# Start main.py
|
||||||
|
try {
|
||||||
|
$process = Start-Process -FilePath "python" -ArgumentList "main.py" -PassThru -Wait
|
||||||
|
$exitCode = $process.ExitCode
|
||||||
|
$runEndTime = Get-Date
|
||||||
|
$runDuration = ($runEndTime - $runStartTime).TotalSeconds
|
||||||
|
|
||||||
|
Write-Log "[EXIT] main.py exited with code $exitCode"
|
||||||
|
Write-Log "[DURATION] Process ran for $([math]::Round($runDuration, 1)) seconds"
|
||||||
|
|
||||||
|
# Check for fast exits
|
||||||
|
if ($runDuration -lt 30) {
|
||||||
|
Write-Log "[FAST EXIT] Process exited quickly, waiting 30 seconds..."
|
||||||
|
Start-Sleep -Seconds 30
|
||||||
|
} else {
|
||||||
|
Write-Log "[DELAY] Waiting 10 seconds before restart..."
|
||||||
|
Start-Sleep -Seconds 10
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log stats every 10 restarts
|
||||||
|
if ($restartCount % 10 -eq 0) {
|
||||||
|
$totalDuration = (Get-Date) - $startTime
|
||||||
|
Write-Log "[STATS] Session: $restartCount restarts in $([math]::Round($totalDuration.TotalHours, 1)) hours"
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch {
|
||||||
|
Write-Log "[ERROR] Error starting main.py: $_"
|
||||||
|
Start-Sleep -Seconds 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
Write-Log "[INTERRUPT] Restart loop interrupted: $_"
|
||||||
|
} finally {
|
||||||
|
$totalDuration = (Get-Date) - $startTime
|
||||||
|
Write-Log "=" * 60
|
||||||
|
Write-Log "OVERNIGHT TRAINING SESSION COMPLETE"
|
||||||
|
Write-Log "Total restarts: $restartCount"
|
||||||
|
Write-Log "Total session time: $([math]::Round($totalDuration.TotalHours, 1)) hours"
|
||||||
|
Write-Log "=" * 60
|
||||||
|
}
|
188
restart_main_overnight.py
Normal file
188
restart_main_overnight.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Overnight Training Restart Script
|
||||||
|
Keeps main.py running continuously, restarting it if it crashes.
|
||||||
|
Designed for overnight training sessions with unstable code.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python restart_main_overnight.py
|
||||||
|
|
||||||
|
Press Ctrl+C to stop the restart loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import signal
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Setup logging for the restart script
|
||||||
|
def setup_restart_logging():
|
||||||
|
"""Setup logging for restart events"""
|
||||||
|
log_dir = Path("logs")
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Create restart log file with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
log_file = log_dir / f"restart_main_{timestamp}.log"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file, encoding='utf-8'),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Restart script logging to: {log_file}")
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def kill_existing_processes(logger):
|
||||||
|
"""Kill any existing main.py processes to avoid conflicts"""
|
||||||
|
try:
|
||||||
|
if os.name == 'nt': # Windows
|
||||||
|
# Kill any existing Python processes running main.py
|
||||||
|
subprocess.run(['taskkill', '/f', '/im', 'python.exe'],
|
||||||
|
capture_output=True, check=False)
|
||||||
|
subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'],
|
||||||
|
capture_output=True, check=False)
|
||||||
|
time.sleep(2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not kill existing processes: {e}")
|
||||||
|
|
||||||
|
def run_main_with_restart(logger):
|
||||||
|
"""Main restart loop"""
|
||||||
|
restart_count = 0
|
||||||
|
consecutive_fast_exits = 0
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("Press Ctrl+C to stop the restart loop")
|
||||||
|
logger.info("Main script: main.py")
|
||||||
|
logger.info("Restart delay on crash: 10 seconds")
|
||||||
|
logger.info("Fast exit protection: Enabled")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Kill any existing processes
|
||||||
|
kill_existing_processes(logger)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
restart_count += 1
|
||||||
|
run_start_time = datetime.now()
|
||||||
|
|
||||||
|
logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}")
|
||||||
|
|
||||||
|
# Start main.py as subprocess
|
||||||
|
process = subprocess.Popen([
|
||||||
|
sys.executable, "main.py"
|
||||||
|
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||||
|
universal_newlines=True, bufsize=1)
|
||||||
|
|
||||||
|
logger.info(f"[PROCESS] main.py started with PID: {process.pid}")
|
||||||
|
|
||||||
|
# Stream output from main.py
|
||||||
|
try:
|
||||||
|
if process.stdout:
|
||||||
|
while True:
|
||||||
|
output = process.stdout.readline()
|
||||||
|
if output == '' and process.poll() is not None:
|
||||||
|
break
|
||||||
|
if output:
|
||||||
|
# Forward output from main.py (remove extra newlines)
|
||||||
|
print(f"[MAIN] {output.rstrip()}")
|
||||||
|
else:
|
||||||
|
# If no stdout, just wait for process to complete
|
||||||
|
process.wait()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...")
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning("[FORCE KILL] Process didn't terminate, force killing...")
|
||||||
|
process.kill()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Process has exited
|
||||||
|
exit_code = process.poll()
|
||||||
|
run_end_time = datetime.now()
|
||||||
|
run_duration = (run_end_time - run_start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(f"[EXIT] main.py exited with code {exit_code}")
|
||||||
|
logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds")
|
||||||
|
|
||||||
|
# Check for fast exits (potential configuration issues)
|
||||||
|
if run_duration < 30: # Less than 30 seconds
|
||||||
|
consecutive_fast_exits += 1
|
||||||
|
logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)")
|
||||||
|
|
||||||
|
if consecutive_fast_exits >= 5:
|
||||||
|
logger.error("[ABORT] Too many consecutive fast exits (5+)")
|
||||||
|
logger.error("This indicates a configuration or startup problem")
|
||||||
|
logger.error("Please check the main.py script manually")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Longer delay for fast exits
|
||||||
|
delay = min(60, 10 * consecutive_fast_exits)
|
||||||
|
logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...")
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
consecutive_fast_exits = 0 # Reset counter
|
||||||
|
logger.info("[DELAY] Waiting 10 seconds before restart...")
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# Log session statistics every 10 restarts
|
||||||
|
if restart_count % 10 == 0:
|
||||||
|
total_duration = (datetime.now() - start_time).total_seconds()
|
||||||
|
logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("[SHUTDOWN] Restart loop interrupted by user")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ERROR] Unexpected error in restart loop: {e}")
|
||||||
|
logger.error("Continuing restart loop after 30 second delay...")
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
total_duration = (datetime.now() - start_time).total_seconds()
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("OVERNIGHT TRAINING SESSION COMPLETE")
|
||||||
|
logger.info(f"Total restarts: {restart_count}")
|
||||||
|
logger.info(f"Total session time: {total_duration/3600:.1f} hours")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
# Setup signal handlers for clean shutdown
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logger.info(f"[SIGNAL] Received signal {signum}, shutting down...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
if hasattr(signal, 'SIGTERM'):
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
global logger
|
||||||
|
logger = setup_restart_logging()
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_main_with_restart(logger)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[FATAL] Fatal error in restart script: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
Reference in New Issue
Block a user