device tensor fix
This commit is contained in:
@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
|
||||
else:
|
||||
self._load_best_checkpoint()
|
||||
|
||||
# Final device check and move
|
||||
self._ensure_model_on_device()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
def _initialize_model(self):
|
||||
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
# Ensure model is moved to the correct device
|
||||
self.model.to(self.device)
|
||||
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
|
||||
if self.model and os.path.exists(checkpoint_path):
|
||||
success = self.model.load(checkpoint_path)
|
||||
if success:
|
||||
logger.info(f"Loaded model from {checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_model_on_device(self):
|
||||
"""Ensure model and all its components are on the correct device"""
|
||||
try:
|
||||
if self.model:
|
||||
self.model.to(self.device)
|
||||
# Also ensure the model's internal device is set correctly
|
||||
if hasattr(self.model, 'device'):
|
||||
self.model.device = self.device
|
||||
logger.debug(f"Model ensured on device {self.device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring model on device: {e}")
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default output when prediction fails"""
|
||||
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
|
||||
if features.dim() == 1:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
# Ensure model is on correct device before prediction
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# Ensure model is on correct device before training
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Prepare batch
|
||||
features = torch.stack([sample[0] for sample in batch])
|
||||
# Prepare batch - ensure all tensors are on the correct device
|
||||
features = torch.stack([sample[0].to(self.device) for sample in batch])
|
||||
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
|
Reference in New Issue
Block a user