improved model loading and training

This commit is contained in:
Dobromir Popov
2025-10-31 01:22:49 +02:00
parent 7ddf98bf18
commit ba91740e4c
7 changed files with 745 additions and 186 deletions

View File

@@ -407,7 +407,7 @@ class DQNAgent:
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = torch.amp.GradScaler('cuda')
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
@@ -577,7 +577,7 @@ class DQNAgent:
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = torch.amp.GradScaler('cuda')
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False