gpu issues

This commit is contained in:
Dobromir Popov
2025-11-23 02:26:43 +02:00
parent 53ce4a355a
commit 1aded7325f
4 changed files with 69 additions and 31 deletions

View File

@@ -1227,6 +1227,13 @@ class TradingTransformerTrainer:
try:
self.model.train()
# Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to False once the issue is resolved
enable_anomaly_detection = False # Set to True to debug gradient issues
if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True)
# GRADIENT ACCUMULATION: Determine if this is an accumulation step
# If gradient_accumulation_steps is set, use automatic accumulation
# Otherwise, fall back to manual accumulate_gradients flag
@@ -1445,8 +1452,25 @@ class TradingTransformerTrainer:
else:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Optimizer step
self.optimizer.step()
# Optimizer step with error handling
try:
self.optimizer.step()
except (KeyError, RuntimeError) as opt_error:
logger.error(f"Optimizer step failed: {opt_error}. Resetting optimizer state.")
# Zero gradients first to clear any stale gradients
self.optimizer.zero_grad(set_to_none=True)
# Reset optimizer to fix corrupted state
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Zero gradients again after recreating optimizer
self.optimizer.zero_grad(set_to_none=True)
# Retry optimizer step with fresh state
# Note: We need to recompute loss and backward pass, but for now just skip this step
logger.warning("Skipping optimizer step after reset - gradients need to be recomputed")
# Don't raise - allow training to continue with next batch
self.scheduler.step()
@@ -1697,30 +1721,55 @@ class TradingTransformerTrainer:
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
# Load optimizer state (handle mismatched states gracefully)
# IMPORTANT: Always recreate optimizer if there's any issue to avoid corrupted state
optimizer_state_loaded = False
try:
optimizer_state = checkpoint.get('optimizer_state_dict')
if optimizer_state:
try:
# Try to load optimizer state
self.optimizer.load_state_dict(optimizer_state)
except (KeyError, ValueError, RuntimeError) as e:
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
# Recreate optimizer (same pattern as __init__)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Validate optimizer state before loading
# Check if state dict has the expected structure
if 'state' in optimizer_state and 'param_groups' in optimizer_state:
# Count parameters in saved state vs current model
saved_param_count = len(optimizer_state.get('state', {}))
current_param_count = sum(1 for _ in self.model.parameters() if _.requires_grad)
if saved_param_count == current_param_count:
try:
# Try to load optimizer state
self.optimizer.load_state_dict(optimizer_state)
optimizer_state_loaded = True
logger.info("Optimizer state loaded successfully")
except (KeyError, ValueError, RuntimeError, TypeError) as e:
logger.warning(f"Error loading optimizer state: {e}. State will be reset.")
optimizer_state_loaded = False
else:
logger.warning(f"Optimizer state mismatch: {saved_param_count} saved params vs {current_param_count} current params. Resetting optimizer.")
optimizer_state_loaded = False
else:
logger.warning("Invalid optimizer state structure in checkpoint. Resetting optimizer.")
optimizer_state_loaded = False
else:
logger.warning("No optimizer state found in checkpoint. Using fresh optimizer.")
logger.info("No optimizer state found in checkpoint. Using fresh optimizer.")
optimizer_state_loaded = False
except Exception as e:
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
# Recreate optimizer (same pattern as __init__)
optimizer_state_loaded = False
# Always recreate optimizer if state loading failed
if not optimizer_state_loaded:
logger.info("Creating fresh optimizer (checkpoint state was invalid or missing)")
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Also recreate scheduler to match
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=self.config.learning_rate,
total_steps=10000,
pct_start=0.1
)
# Load scheduler state
try: