From bf55ba5b51d8127521638742cfde69ad818a97f3 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Thu, 19 Jun 2025 16:07:05 +0300 Subject: [PATCH] restart script --- NN/models/cnn_model.py | 313 +++++++++++++++++++++++++------------ restart_main_overnight.ps1 | 90 +++++++++++ restart_main_overnight.py | 188 ++++++++++++++++++++++ 3 files changed, 488 insertions(+), 103 deletions(-) create mode 100644 restart_main_overnight.ps1 create mode 100644 restart_main_overnight.py diff --git a/NN/models/cnn_model.py b/NN/models/cnn_model.py index 83b6f1f..d06c209 100644 --- a/NN/models/cnn_model.py +++ b/NN/models/cnn_model.py @@ -69,20 +69,30 @@ class ResidualBlock(nn.Module): super().__init__() self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1) - self.norm1 = nn.BatchNorm1d(channels) - self.norm2 = nn.BatchNorm1d(channels) + self.norm1 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm + self.norm2 = nn.GroupNorm(1, channels) # Changed from BatchNorm1d to GroupNorm self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x + # Create completely independent copy for residual connection + residual = x.detach().clone() - out = F.relu(self.norm1(self.conv1(x))) + # First convolution branch - ensure no memory sharing + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) out = self.dropout(out) - out = self.norm2(self.conv2(out)) - # Add residual connection (avoid in-place operation) - out = out + residual - return F.relu(out) + # Second convolution branch + out = self.conv2(out) + out = self.norm2(out) + + # Residual connection - create completely new tensor + # Avoid any potential in-place operations or memory sharing + combined = residual + out + result = F.relu(combined) + + return result class SpatialAttentionBlock(nn.Module): """Spatial attention for feature maps""" @@ -144,11 +154,11 @@ class EnhancedCNNModel(nn.Module): # Feature fusion with more capacity self.feature_fusion = nn.Sequential( nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now - nn.BatchNorm1d(base_channels * 3), + nn.GroupNorm(1, base_channels * 3), # Changed from BatchNorm1d to GroupNorm nn.ReLU(), nn.Dropout(dropout_rate), nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1), - nn.BatchNorm1d(base_channels * 2), + nn.GroupNorm(1, base_channels * 2), # Changed from BatchNorm1d to GroupNorm nn.ReLU(), nn.Dropout(dropout_rate) ) @@ -258,22 +268,22 @@ class EnhancedCNNModel(nn.Module): # Initialize weights self._initialize_weights() - + def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module: """Build a convolutional path with multiple layers""" return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2), - nn.BatchNorm1d(out_channels), + nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2), - nn.BatchNorm1d(out_channels), + nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2), - nn.BatchNorm1d(out_channels), + nn.GroupNorm(1, out_channels), # Changed from BatchNorm1d to GroupNorm nn.ReLU() ) @@ -288,19 +298,28 @@ class EnhancedCNNModel(nn.Module): nn.init.xavier_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm1d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm, nn.LayerNorm)): + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight, 1) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor: + """Create a memory barrier to prevent in-place operation issues""" + return tensor.detach().clone().requires_grad_(tensor.requires_grad) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ - Forward pass with multiple outputs + Forward pass with multiple outputs - completely avoiding in-place operations Args: x: Input tensor of shape [batch_size, sequence_length, features] Returns: Dictionary with predictions, confidence, regime, and volatility """ - # Handle input shapes flexibly + # Apply memory barrier to input + x = self._memory_barrier(x) + + # Handle input shapes flexibly - create new tensors to avoid memory sharing if len(x.shape) == 2: # Input is [seq_len, features] - add batch dimension x = x.unsqueeze(0) @@ -308,76 +327,96 @@ class EnhancedCNNModel(nn.Module): # Input has extra dimensions - flatten to [batch, seq, features] x = x.view(x.shape[0], -1, x.shape[-1]) + x = self._memory_barrier(x) # Apply barrier after shape changes batch_size, seq_len, features = x.shape # Reshape for processing: [batch, seq, features] -> [batch*seq, features] x_reshaped = x.view(-1, features) + x_reshaped = self._memory_barrier(x_reshaped) # Input embedding embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels] + embedded = self._memory_barrier(embedded) # Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq] - embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2) + embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2).contiguous() + embedded = self._memory_barrier(embedded) - # Multi-scale feature extraction - path1 = self.conv_path1(embedded) - path2 = self.conv_path2(embedded) - path3 = self.conv_path3(embedded) - path4 = self.conv_path4(embedded) + # Multi-scale feature extraction - ensure each path creates independent tensors + path1 = self._memory_barrier(self.conv_path1(embedded)) + path2 = self._memory_barrier(self.conv_path2(embedded)) + path3 = self._memory_barrier(self.conv_path3(embedded)) + path4 = self._memory_barrier(self.conv_path4(embedded)) - # Feature fusion + # Feature fusion - create new tensor fused_features = torch.cat([path1, path2, path3, path4], dim=1) - fused_features = self.feature_fusion(fused_features) + fused_features = self._memory_barrier(self.feature_fusion(fused_features)) # Apply residual blocks with spatial attention - current_features = fused_features + current_features = self._memory_barrier(fused_features) for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)): - current_features = res_block(current_features) + current_features = self._memory_barrier(res_block(current_features)) if i % 2 == 0: # Apply attention every other block - current_features = attention(current_features) + current_features = self._memory_barrier(attention(current_features)) # Apply remaining residual blocks for res_block in self.residual_blocks[len(self.spatial_attention):]: - current_features = res_block(current_features) + current_features = self._memory_barrier(res_block(current_features)) # Temporal attention - apply both attention layers # Reshape for attention: [batch, channels, seq] -> [batch, seq, channels] - attention_input = current_features.transpose(1, 2) - attended_features = self.temporal_attention1(attention_input) - attended_features = self.temporal_attention2(attended_features) + attention_input = current_features.transpose(1, 2).contiguous() + attention_input = self._memory_barrier(attention_input) + + attended_features = self._memory_barrier(self.temporal_attention1(attention_input)) + attended_features = self._memory_barrier(self.temporal_attention2(attended_features)) # Back to conv format: [batch, seq, channels] -> [batch, channels, seq] - attended_features = attended_features.transpose(1, 2) + attended_features = attended_features.transpose(1, 2).contiguous() + attended_features = self._memory_barrier(attended_features) - # Global aggregation - avg_pooled = self.global_pool(attended_features).squeeze(-1) # [batch, channels] - max_pooled = self.global_max_pool(attended_features).squeeze(-1) # [batch, channels] + # Global aggregation - create independent tensors + avg_pooled = self.global_pool(attended_features) + avg_pooled = self._memory_barrier(avg_pooled.view(avg_pooled.shape[0], -1)) # Flatten instead of squeeze - # Combine global features + max_pooled = self.global_max_pool(attended_features) + max_pooled = self._memory_barrier(max_pooled.view(max_pooled.shape[0], -1)) # Flatten instead of squeeze + + # Combine global features - create new tensor global_features = torch.cat([avg_pooled, max_pooled], dim=1) + global_features = self._memory_barrier(global_features) # Advanced feature processing - processed_features = self.advanced_features(global_features) + processed_features = self._memory_barrier(self.advanced_features(global_features)) - # Multi-task predictions - regime_probs = self.regime_detector(processed_features) - volatility_pred = self.volatility_predictor(processed_features) - confidence = self.confidence_head(processed_features) + # Multi-task predictions - ensure each creates independent tensors + regime_probs = self._memory_barrier(self.regime_detector(processed_features)) + volatility_pred = self._memory_barrier(self.volatility_predictor(processed_features)) + confidence = self._memory_barrier(self.confidence_head(processed_features)) # Combine all features for final decision (8 regime classes + 1 volatility) - combined_features = torch.cat([processed_features, regime_probs, volatility_pred], dim=1) - trading_logits = self.decision_head(combined_features) + # Create completely independent tensors for concatenation + vol_pred_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1)) # Flatten instead of squeeze + combined_features = torch.cat([processed_features, regime_probs, vol_pred_flat], dim=1) + combined_features = self._memory_barrier(combined_features) - # Apply temperature scaling for better calibration + trading_logits = self._memory_barrier(self.decision_head(combined_features)) + + # Apply temperature scaling for better calibration - create new tensor temperature = 1.5 - trading_probs = F.softmax(trading_logits / temperature, dim=1) + scaled_logits = trading_logits / temperature + trading_probs = self._memory_barrier(F.softmax(scaled_logits, dim=1)) + + # Flatten confidence to ensure consistent shape + confidence_flat = self._memory_barrier(confidence.view(confidence.shape[0], -1)) + volatility_flat = self._memory_barrier(volatility_pred.view(volatility_pred.shape[0], -1)) return { - 'logits': trading_logits, - 'probabilities': trading_probs, - 'confidence': confidence.squeeze(-1), - 'regime': regime_probs, - 'volatility': volatility_pred.squeeze(-1), - 'features': processed_features + 'logits': self._memory_barrier(trading_logits), + 'probabilities': self._memory_barrier(trading_probs), + 'confidence': confidence_flat[:, 0] if confidence_flat.shape[1] > 0 else confidence_flat.view(-1)[0], + 'regime': self._memory_barrier(regime_probs), + 'volatility': volatility_flat[:, 0] if volatility_flat.shape[1] > 0 else volatility_flat.view(-1)[0], + 'features': self._memory_barrier(processed_features) } def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]: @@ -478,60 +517,128 @@ class CNNModelTrainer: self.training_history = [] + def reset_computational_graph(self): + """Reset the computational graph to prevent in-place operation issues""" + try: + # Clear all gradients + for param in self.model.parameters(): + param.grad = None + + # Force garbage collection + import gc + gc.collect() + + # Clear CUDA cache if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Reset optimizer state if needed + for group in self.optimizer.param_groups: + for param in group['params']: + if param in self.optimizer.state: + # Clear momentum buffers that might have stale references + self.optimizer.state[param] = {} + + except Exception as e: + logger.warning(f"Error during computational graph reset: {e}") + def train_step(self, x: torch.Tensor, y: torch.Tensor, confidence_targets: Optional[torch.Tensor] = None, regime_targets: Optional[torch.Tensor] = None, volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]: - """Single training step with multi-task learning""" + """Single training step with multi-task learning and robust error handling""" - self.model.train() - self.optimizer.zero_grad() + # Reset computational graph before each training step + self.reset_computational_graph() - # Forward pass - outputs = self.model(x) - - # Main trading loss - main_loss = self.main_criterion(outputs['logits'], y) - total_loss = main_loss - - losses = {'main_loss': main_loss.item()} - - # Confidence loss (if targets provided) - if confidence_targets is not None: - conf_loss = self.confidence_criterion(outputs['confidence'], confidence_targets) - total_loss += 0.1 * conf_loss - losses['confidence_loss'] = conf_loss.item() - - # Regime classification loss (if targets provided) - if regime_targets is not None: - regime_loss = self.regime_criterion(outputs['regime'], regime_targets) - total_loss += 0.05 * regime_loss - losses['regime_loss'] = regime_loss.item() - - # Volatility prediction loss (if targets provided) - if volatility_targets is not None: - vol_loss = self.volatility_criterion(outputs['volatility'], volatility_targets) - total_loss += 0.05 * vol_loss - losses['volatility_loss'] = vol_loss.item() - - losses['total_loss'] = total_loss.item() - - # Backward pass - total_loss.backward() - - # Gradient clipping - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - self.optimizer.step() - self.scheduler.step() - - # Calculate accuracy - with torch.no_grad(): - predictions = torch.argmax(outputs['probabilities'], dim=1) - accuracy = (predictions == y).float().mean().item() - losses['accuracy'] = accuracy - - return losses + try: + self.model.train() + + # Ensure inputs are completely independent from original tensors + x_train = x.detach().clone().requires_grad_(False).to(self.device) + y_train = y.detach().clone().requires_grad_(False).to(self.device) + + # Forward pass with error handling + try: + outputs = self.model(x_train) + except RuntimeError as forward_error: + if "modified by an inplace operation" in str(forward_error): + logger.error(f"In-place operation in forward pass: {forward_error}") + self.reset_computational_graph() + return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5} + else: + raise forward_error + + # Calculate main loss with detached outputs to prevent memory sharing + main_loss = self.main_criterion(outputs['logits'], y_train) + total_loss = main_loss + + losses = {'main_loss': main_loss.item()} + + # Add auxiliary losses if targets provided + if confidence_targets is not None: + conf_targets = confidence_targets.detach().clone().to(self.device) + conf_loss = self.confidence_criterion(outputs['confidence'], conf_targets) + total_loss = total_loss + 0.1 * conf_loss + losses['confidence_loss'] = conf_loss.item() + + if regime_targets is not None: + regime_targets_clean = regime_targets.detach().clone().to(self.device) + regime_loss = self.regime_criterion(outputs['regime'], regime_targets_clean) + total_loss = total_loss + 0.05 * regime_loss + losses['regime_loss'] = regime_loss.item() + + if volatility_targets is not None: + vol_targets = volatility_targets.detach().clone().to(self.device) + vol_loss = self.volatility_criterion(outputs['volatility'], vol_targets) + total_loss = total_loss + 0.05 * vol_loss + losses['volatility_loss'] = vol_loss.item() + + losses['total_loss'] = total_loss.item() + + # Backward pass with comprehensive error handling + try: + total_loss.backward() + + except RuntimeError as backward_error: + if "modified by an inplace operation" in str(backward_error): + logger.error(f"In-place operation during backward pass: {backward_error}") + logger.error("Attempting to continue training with gradient reset...") + + # Comprehensive cleanup + self.reset_computational_graph() + + return {'main_loss': losses.get('main_loss', 0.0), 'total_loss': losses.get('total_loss', 0.0), 'accuracy': 0.5} + else: + raise backward_error + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # Optimizer step + self.optimizer.step() + self.scheduler.step() + + # Calculate accuracy with detached tensors + with torch.no_grad(): + predictions = torch.argmax(outputs['probabilities'], dim=1) + accuracy = (predictions == y_train).float().mean().item() + losses['accuracy'] = accuracy + + return losses + + except Exception as e: + logger.error(f"Training step failed with unexpected error: {e}") + logger.error(f"Error type: {type(e).__name__}") + import traceback + logger.error(f"Full traceback: {traceback.format_exc()}") + + # Comprehensive cleanup on any error + self.reset_computational_graph() + + # Return safe dummy values to continue training + return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5} def save_model(self, filepath: str, metadata: Optional[Dict] = None): """Save model with metadata""" @@ -610,7 +717,7 @@ class CNNModel: feature_dim=input_shape[1], output_size=output_size ) - self.trainer = CNNModelTrainer(self.model, device=self.device) + self.trainer = CNNModelTrainer(self.model, device=str(self.device)) logger.info(f"CNN Model wrapper initialized: input_shape={input_shape}, output_size={output_size}") diff --git a/restart_main_overnight.ps1 b/restart_main_overnight.ps1 new file mode 100644 index 0000000..3ebbf86 --- /dev/null +++ b/restart_main_overnight.ps1 @@ -0,0 +1,90 @@ +# Overnight Training Restart Script (PowerShell) +# Keeps main.py running continuously, restarting it if it crashes. +# Usage: .\restart_main_overnight.ps1 + +Write-Host "=" * 60 +Write-Host "OVERNIGHT TRAINING RESTART SCRIPT (PowerShell)" +Write-Host "=" * 60 +Write-Host "Press Ctrl+C to stop the restart loop" +Write-Host "Main script: main.py" +Write-Host "Restart delay on crash: 10 seconds" +Write-Host "=" * 60 + +$restartCount = 0 +$startTime = Get-Date + +# Create logs directory if it doesn't exist +if (!(Test-Path "logs")) { + New-Item -ItemType Directory -Path "logs" +} + +# Setup log file +$timestamp = Get-Date -Format "yyyyMMdd_HHmmss" +$logFile = "logs\restart_main_ps_$timestamp.log" + +function Write-Log { + param($Message) + $timestamp = Get-Date -Format "yyyy-MM-dd HH:mm:ss" + $logMessage = "$timestamp - $Message" + Write-Host $logMessage + Add-Content -Path $logFile -Value $logMessage +} + +Write-Log "Restart script started, logging to: $logFile" + +# Kill any existing Python processes +try { + Get-Process python* -ErrorAction SilentlyContinue | Stop-Process -Force -ErrorAction SilentlyContinue + Start-Sleep -Seconds 2 + Write-Log "Killed existing Python processes" +} catch { + Write-Log "Could not kill existing processes: $_" +} + +try { + while ($true) { + $restartCount++ + $runStartTime = Get-Date + + Write-Log "[RESTART #$restartCount] Starting main.py at $(Get-Date -Format 'HH:mm:ss')" + + # Start main.py + try { + $process = Start-Process -FilePath "python" -ArgumentList "main.py" -PassThru -Wait + $exitCode = $process.ExitCode + $runEndTime = Get-Date + $runDuration = ($runEndTime - $runStartTime).TotalSeconds + + Write-Log "[EXIT] main.py exited with code $exitCode" + Write-Log "[DURATION] Process ran for $([math]::Round($runDuration, 1)) seconds" + + # Check for fast exits + if ($runDuration -lt 30) { + Write-Log "[FAST EXIT] Process exited quickly, waiting 30 seconds..." + Start-Sleep -Seconds 30 + } else { + Write-Log "[DELAY] Waiting 10 seconds before restart..." + Start-Sleep -Seconds 10 + } + + # Log stats every 10 restarts + if ($restartCount % 10 -eq 0) { + $totalDuration = (Get-Date) - $startTime + Write-Log "[STATS] Session: $restartCount restarts in $([math]::Round($totalDuration.TotalHours, 1)) hours" + } + + } catch { + Write-Log "[ERROR] Error starting main.py: $_" + Start-Sleep -Seconds 10 + } + } +} catch { + Write-Log "[INTERRUPT] Restart loop interrupted: $_" +} finally { + $totalDuration = (Get-Date) - $startTime + Write-Log "=" * 60 + Write-Log "OVERNIGHT TRAINING SESSION COMPLETE" + Write-Log "Total restarts: $restartCount" + Write-Log "Total session time: $([math]::Round($totalDuration.TotalHours, 1)) hours" + Write-Log "=" * 60 +} \ No newline at end of file diff --git a/restart_main_overnight.py b/restart_main_overnight.py new file mode 100644 index 0000000..a80c6b3 --- /dev/null +++ b/restart_main_overnight.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +""" +Overnight Training Restart Script +Keeps main.py running continuously, restarting it if it crashes. +Designed for overnight training sessions with unstable code. + +Usage: + python restart_main_overnight.py + +Press Ctrl+C to stop the restart loop. +""" + +import subprocess +import sys +import time +import logging +from datetime import datetime +from pathlib import Path +import signal +import os + +# Setup logging for the restart script +def setup_restart_logging(): + """Setup logging for restart events""" + log_dir = Path("logs") + log_dir.mkdir(exist_ok=True) + + # Create restart log file with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = log_dir / f"restart_main_{timestamp}.log" + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file, encoding='utf-8'), + logging.StreamHandler(sys.stdout) + ] + ) + + logger = logging.getLogger(__name__) + logger.info(f"Restart script logging to: {log_file}") + return logger + +def kill_existing_processes(logger): + """Kill any existing main.py processes to avoid conflicts""" + try: + if os.name == 'nt': # Windows + # Kill any existing Python processes running main.py + subprocess.run(['taskkill', '/f', '/im', 'python.exe'], + capture_output=True, check=False) + subprocess.run(['taskkill', '/f', '/im', 'pythonw.exe'], + capture_output=True, check=False) + time.sleep(2) + except Exception as e: + logger.warning(f"Could not kill existing processes: {e}") + +def run_main_with_restart(logger): + """Main restart loop""" + restart_count = 0 + consecutive_fast_exits = 0 + start_time = datetime.now() + + logger.info("=" * 60) + logger.info("OVERNIGHT TRAINING RESTART SCRIPT STARTED") + logger.info("=" * 60) + logger.info("Press Ctrl+C to stop the restart loop") + logger.info("Main script: main.py") + logger.info("Restart delay on crash: 10 seconds") + logger.info("Fast exit protection: Enabled") + logger.info("=" * 60) + + # Kill any existing processes + kill_existing_processes(logger) + + while True: + try: + restart_count += 1 + run_start_time = datetime.now() + + logger.info(f"[RESTART #{restart_count}] Starting main.py at {run_start_time.strftime('%H:%M:%S')}") + + # Start main.py as subprocess + process = subprocess.Popen([ + sys.executable, "main.py" + ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + universal_newlines=True, bufsize=1) + + logger.info(f"[PROCESS] main.py started with PID: {process.pid}") + + # Stream output from main.py + try: + if process.stdout: + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + # Forward output from main.py (remove extra newlines) + print(f"[MAIN] {output.rstrip()}") + else: + # If no stdout, just wait for process to complete + process.wait() + except KeyboardInterrupt: + logger.info("[INTERRUPT] Ctrl+C received, stopping main.py...") + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + logger.warning("[FORCE KILL] Process didn't terminate, force killing...") + process.kill() + raise + + # Process has exited + exit_code = process.poll() + run_end_time = datetime.now() + run_duration = (run_end_time - run_start_time).total_seconds() + + logger.info(f"[EXIT] main.py exited with code {exit_code}") + logger.info(f"[DURATION] Process ran for {run_duration:.1f} seconds") + + # Check for fast exits (potential configuration issues) + if run_duration < 30: # Less than 30 seconds + consecutive_fast_exits += 1 + logger.warning(f"[FAST EXIT] Process exited quickly ({consecutive_fast_exits} consecutive)") + + if consecutive_fast_exits >= 5: + logger.error("[ABORT] Too many consecutive fast exits (5+)") + logger.error("This indicates a configuration or startup problem") + logger.error("Please check the main.py script manually") + break + + # Longer delay for fast exits + delay = min(60, 10 * consecutive_fast_exits) + logger.info(f"[DELAY] Waiting {delay} seconds before restart due to fast exit...") + time.sleep(delay) + else: + consecutive_fast_exits = 0 # Reset counter + logger.info("[DELAY] Waiting 10 seconds before restart...") + time.sleep(10) + + # Log session statistics every 10 restarts + if restart_count % 10 == 0: + total_duration = (datetime.now() - start_time).total_seconds() + logger.info(f"[STATS] Session: {restart_count} restarts in {total_duration/3600:.1f} hours") + + except KeyboardInterrupt: + logger.info("[SHUTDOWN] Restart loop interrupted by user") + break + except Exception as e: + logger.error(f"[ERROR] Unexpected error in restart loop: {e}") + logger.error("Continuing restart loop after 30 second delay...") + time.sleep(30) + + total_duration = (datetime.now() - start_time).total_seconds() + logger.info("=" * 60) + logger.info("OVERNIGHT TRAINING SESSION COMPLETE") + logger.info(f"Total restarts: {restart_count}") + logger.info(f"Total session time: {total_duration/3600:.1f} hours") + logger.info("=" * 60) + +def main(): + """Main entry point""" + # Setup signal handlers for clean shutdown + def signal_handler(signum, frame): + logger.info(f"[SIGNAL] Received signal {signum}, shutting down...") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) + + # Setup logging + global logger + logger = setup_restart_logging() + + try: + run_main_with_restart(logger) + except Exception as e: + logger.error(f"[FATAL] Fatal error in restart script: {e}") + import traceback + logger.error(traceback.format_exc()) + return 1 + + return 0 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file