fixed other CNN references

This commit is contained in:
Dobromir Popov
2025-06-24 21:13:06 +03:00
parent 8685319989
commit e7ea17b626
3 changed files with 49 additions and 17 deletions

View File

@ -431,11 +431,11 @@ class EnhancedCNNModel(nn.Module):
return { return {
'action': action, 'action': action,
'action_name': 'BUY' if action == 0 else 'SELL', 'action_name': 'BUY' if action == 0 else 'SELL',
'confidence': float(confidence), 'confidence': confidence, # Already converted to float above
'action_confidence': action_confidence, 'action_confidence': action_confidence,
'probabilities': probs.tolist(), 'probabilities': probs.tolist(),
'regime_probabilities': regime.tolist(), 'regime_probabilities': regime.tolist(),
'volatility_prediction': float(volatility), 'volatility_prediction': volatility, # Already converted to float above
'raw_logits': outputs['logits'].cpu().numpy()[0].tolist() 'raw_logits': outputs['logits'].cpu().numpy()[0].tolist()
} }

View File

@ -501,7 +501,16 @@ class EnhancedCNNWithOrderBook(nn.Module):
# Get probabilities # Get probabilities
q_values = outputs['q_values'] q_values = outputs['q_values']
probs = F.softmax(q_values, dim=1) probs = F.softmax(q_values, dim=1)
confidence = outputs['confidence'].item()
# Handle confidence shape properly to avoid scalar conversion errors
confidence_tensor = outputs['confidence']
if isinstance(confidence_tensor, torch.Tensor):
if confidence_tensor.numel() == 1:
confidence = confidence_tensor.item()
else:
confidence = confidence_tensor.flatten()[0].item()
else:
confidence = float(confidence_tensor)
# Action selection with confidence thresholding # Action selection with confidence thresholding
if confidence >= self.confidence_threshold: if confidence >= self.confidence_threshold:

View File

@ -1404,19 +1404,8 @@ class EnhancedTradingOrchestrator(TradingOrchestrator):
predictions_array = np.array([float(predictions)], dtype=np.float32) predictions_array = np.array([float(predictions)], dtype=np.float32)
# Create final predictions array with confidence # Create final predictions array with confidence
# Ensure confidence is a scalar value - handle all array shapes safely # Use safe tensor conversion to avoid scalar conversion errors
if isinstance(confidence, np.ndarray): confidence_scalar = self._safe_tensor_to_scalar(confidence, default_value=0.7)
if confidence.ndim == 0:
# 0-dimensional array (scalar)
confidence_scalar = float(confidence.item())
elif confidence.size == 1:
# 1-element array
confidence_scalar = float(confidence.item())
else:
# Multi-element array - take first element or mean
confidence_scalar = float(confidence.flat[0]) # Use flat[0] to safely get first element
else:
confidence_scalar = float(confidence)
# Combine predictions and confidence as separate elements # Combine predictions and confidence as separate elements
predictions = np.concatenate([ predictions = np.concatenate([
@ -4736,4 +4725,38 @@ class EnhancedTradingOrchestrator(TradingOrchestrator):
'price_to_pivot_ratio': 1.0, 'price_to_pivot_ratio': 1.0,
'volume_strength': 1.0, 'volume_strength': 1.0,
'pivot_strength': 0.5 'pivot_strength': 0.5
} }
# Helper function to safely extract scalar values from tensors
def _safe_tensor_to_scalar(self, tensor_value, default_value: float = 0.7) -> float:
"""
Safely convert tensor/array values to Python scalar floats
Args:
tensor_value: Input tensor, array, or scalar value
default_value: Default value to return if conversion fails
Returns:
Python float scalar value
"""
try:
if hasattr(tensor_value, 'item'):
# PyTorch tensor - handle different shapes
if tensor_value.numel() == 1:
return float(tensor_value.item())
else:
return float(tensor_value.flatten()[0].item())
elif isinstance(tensor_value, np.ndarray):
# NumPy array - handle different shapes
if tensor_value.ndim == 0:
return float(tensor_value.item())
elif tensor_value.size == 1:
return float(tensor_value.flatten()[0])
else:
return float(tensor_value.flat[0])
else:
# Already a scalar value
return float(tensor_value)
except Exception as e:
logger.warning(f"Error converting tensor to scalar, using default {default_value}: {e}")
return default_value