trading works!
This commit is contained in:
@ -221,7 +221,7 @@ class DQNAgent:
|
||||
# Check if mixed precision training should be used
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.use_mixed_precision = False
|
||||
@ -1083,6 +1083,11 @@ class DQNAgent:
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
if not total_loss.requires_grad:
|
||||
logger.warning("Total loss tensor does not require gradients, skipping backward pass")
|
||||
return 0.0
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
@ -1263,6 +1268,11 @@ class DQNAgent:
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
if not loss.requires_grad:
|
||||
logger.warning("Loss tensor does not require gradients, skipping backward pass")
|
||||
return 0.0
|
||||
|
||||
# Backward pass with scaled gradients
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
|
Reference in New Issue
Block a user