fixes around pivot points and BOM matrix

This commit is contained in:
Dobromir Popov
2025-06-24 21:09:35 +03:00
parent 6a4a73ff0b
commit 8685319989
7 changed files with 136 additions and 23 deletions

View File

@ -125,8 +125,16 @@ except ImportError:
outputs = self.model(X)
probs = F.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).numpy()
pred_proba = probs.numpy()
# Ensure proper tensor conversion to avoid scalar conversion errors
pred_class = torch.argmax(probs, dim=1).detach().cpu().numpy()
pred_proba = probs.detach().cpu().numpy()
# Handle single batch case - ensure scalars are properly extracted
if pred_class.ndim > 0 and pred_class.size == 1:
pred_class = pred_class.item() # Convert to Python scalar
if pred_proba.ndim > 1 and pred_proba.shape[0] == 1:
pred_proba = pred_proba[0] # Remove batch dimension
logger.debug(f"Fallback CNN prediction: class={pred_class}, proba_shape={pred_proba.shape}")
return pred_class, pred_proba