fixes around pivot points and BOM matrix
This commit is contained in:
@ -125,8 +125,16 @@ except ImportError:
|
||||
|
||||
outputs = self.model(X)
|
||||
probs = F.softmax(outputs, dim=1)
|
||||
pred_class = torch.argmax(probs, dim=1).numpy()
|
||||
pred_proba = probs.numpy()
|
||||
|
||||
# Ensure proper tensor conversion to avoid scalar conversion errors
|
||||
pred_class = torch.argmax(probs, dim=1).detach().cpu().numpy()
|
||||
pred_proba = probs.detach().cpu().numpy()
|
||||
|
||||
# Handle single batch case - ensure scalars are properly extracted
|
||||
if pred_class.ndim > 0 and pred_class.size == 1:
|
||||
pred_class = pred_class.item() # Convert to Python scalar
|
||||
if pred_proba.ndim > 1 and pred_proba.shape[0] == 1:
|
||||
pred_proba = pred_proba[0] # Remove batch dimension
|
||||
|
||||
logger.debug(f"Fallback CNN prediction: class={pred_class}, proba_shape={pred_proba.shape}")
|
||||
return pred_class, pred_proba
|
||||
|
Reference in New Issue
Block a user