merge training system
This commit is contained in:
@@ -78,6 +78,11 @@ class StandardizedCNN(nn.Module):
|
||||
# Device management
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
try:
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
||||
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
||||
|
Reference in New Issue
Block a user