device tensor fix

This commit is contained in:
Dobromir Popov
2025-07-25 13:59:33 +03:00
parent 78b4bb0f06
commit 1f60c80d67
6 changed files with 495 additions and 45 deletions

View File

@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
else:
self._load_best_checkpoint()
# Final device check and move
self._ensure_model_on_device()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _initialize_model(self):
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# Ensure model is moved to the correct device
self.model.to(self.device)
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path)
if success:
logger.info(f"Loaded model from {checkpoint_path}")
# Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
return True
else:
logger.warning(f"Failed to load model from {checkpoint_path}")
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
logger.error(f"Error loading best checkpoint: {e}")
return False
def _ensure_model_on_device(self):
"""Ensure model and all its components are on the correct device"""
try:
if self.model:
self.model.to(self.device)
# Also ensure the model's internal device is set correctly
if hasattr(self.model, 'device'):
self.model.device = self.device
logger.debug(f"Model ensured on device {self.device}")
except Exception as e:
logger.error(f"Error ensuring model on device: {e}")
def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails"""
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
if features.dim() == 1:
features = features.unsqueeze(0)
# Ensure model is on correct device before prediction
self._ensure_model_on_device()
# Set model to evaluation mode
self.model.eval()
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Ensure model is on correct device before training
self._ensure_model_on_device()
# Set model to training mode
self.model.train()
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
if len(batch) < 2:
continue
# Prepare batch
features = torch.stack([sample[0] for sample in batch])
# Prepare batch - ensure all tensors are on the correct device
features = torch.stack([sample[0].to(self.device) for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)