gpu issues
This commit is contained in:
@@ -1227,6 +1227,13 @@ class TradingTransformerTrainer:
|
||||
try:
|
||||
self.model.train()
|
||||
|
||||
# Enable anomaly detection temporarily to debug inplace operation issues
|
||||
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
|
||||
# Set to False once the issue is resolved
|
||||
enable_anomaly_detection = False # Set to True to debug gradient issues
|
||||
if enable_anomaly_detection:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
|
||||
# If gradient_accumulation_steps is set, use automatic accumulation
|
||||
# Otherwise, fall back to manual accumulate_gradients flag
|
||||
@@ -1445,8 +1452,25 @@ class TradingTransformerTrainer:
|
||||
else:
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
# Optimizer step with error handling
|
||||
try:
|
||||
self.optimizer.step()
|
||||
except (KeyError, RuntimeError) as opt_error:
|
||||
logger.error(f"Optimizer step failed: {opt_error}. Resetting optimizer state.")
|
||||
# Zero gradients first to clear any stale gradients
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
# Reset optimizer to fix corrupted state
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
# Zero gradients again after recreating optimizer
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
# Retry optimizer step with fresh state
|
||||
# Note: We need to recompute loss and backward pass, but for now just skip this step
|
||||
logger.warning("Skipping optimizer step after reset - gradients need to be recomputed")
|
||||
# Don't raise - allow training to continue with next batch
|
||||
|
||||
self.scheduler.step()
|
||||
|
||||
@@ -1697,30 +1721,55 @@ class TradingTransformerTrainer:
|
||||
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
|
||||
|
||||
# Load optimizer state (handle mismatched states gracefully)
|
||||
# IMPORTANT: Always recreate optimizer if there's any issue to avoid corrupted state
|
||||
optimizer_state_loaded = False
|
||||
try:
|
||||
optimizer_state = checkpoint.get('optimizer_state_dict')
|
||||
if optimizer_state:
|
||||
try:
|
||||
# Try to load optimizer state
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
except (KeyError, ValueError, RuntimeError) as e:
|
||||
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||
# Recreate optimizer (same pattern as __init__)
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
# Validate optimizer state before loading
|
||||
# Check if state dict has the expected structure
|
||||
if 'state' in optimizer_state and 'param_groups' in optimizer_state:
|
||||
# Count parameters in saved state vs current model
|
||||
saved_param_count = len(optimizer_state.get('state', {}))
|
||||
current_param_count = sum(1 for _ in self.model.parameters() if _.requires_grad)
|
||||
|
||||
if saved_param_count == current_param_count:
|
||||
try:
|
||||
# Try to load optimizer state
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
optimizer_state_loaded = True
|
||||
logger.info("Optimizer state loaded successfully")
|
||||
except (KeyError, ValueError, RuntimeError, TypeError) as e:
|
||||
logger.warning(f"Error loading optimizer state: {e}. State will be reset.")
|
||||
optimizer_state_loaded = False
|
||||
else:
|
||||
logger.warning(f"Optimizer state mismatch: {saved_param_count} saved params vs {current_param_count} current params. Resetting optimizer.")
|
||||
optimizer_state_loaded = False
|
||||
else:
|
||||
logger.warning("Invalid optimizer state structure in checkpoint. Resetting optimizer.")
|
||||
optimizer_state_loaded = False
|
||||
else:
|
||||
logger.warning("No optimizer state found in checkpoint. Using fresh optimizer.")
|
||||
logger.info("No optimizer state found in checkpoint. Using fresh optimizer.")
|
||||
optimizer_state_loaded = False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||
# Recreate optimizer (same pattern as __init__)
|
||||
optimizer_state_loaded = False
|
||||
|
||||
# Always recreate optimizer if state loading failed
|
||||
if not optimizer_state_loaded:
|
||||
logger.info("Creating fresh optimizer (checkpoint state was invalid or missing)")
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
# Also recreate scheduler to match
|
||||
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=self.config.learning_rate,
|
||||
total_steps=10000,
|
||||
pct_start=0.1
|
||||
)
|
||||
|
||||
# Load scheduler state
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user