show dummy references
This commit is contained in:
@@ -143,7 +143,7 @@ class EnhancedCNNModel(nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 2, # BUY/SELL for 2-action system
|
||||
output_size: int = 5, # OHLCV prediction (Open, High, Low, Close, Volume)
|
||||
base_channels: int = 256, # Increased from 128 to 256
|
||||
num_blocks: int = 12, # Increased from 6 to 12
|
||||
num_attention_heads: int = 16, # Increased from 8 to 16
|
||||
@@ -416,39 +416,40 @@ class EnhancedCNNModel(nn.Module):
|
||||
volatility_pred = self._memory_barrier(self.volatility_predictor(processed_features))
|
||||
confidence = self._memory_barrier(self.confidence_head(processed_features))
|
||||
|
||||
# Combine all features for final decision (8 regime classes + 1 volatility)
|
||||
# Combine all features for OHLCV prediction
|
||||
# Create completely independent tensors for concatenation
|
||||
vol_pred_flat = self._memory_barrier(volatility_pred.reshape(volatility_pred.shape[0], -1)) # Flatten instead of squeeze
|
||||
combined_features = torch.cat([processed_features, regime_probs, vol_pred_flat], dim=1)
|
||||
combined_features = self._memory_barrier(combined_features)
|
||||
|
||||
trading_logits = self._memory_barrier(self.decision_head(combined_features))
|
||||
|
||||
# Apply temperature scaling for better calibration - create new tensor
|
||||
temperature = 1.5
|
||||
scaled_logits = trading_logits / temperature
|
||||
trading_probs = self._memory_barrier(F.softmax(scaled_logits, dim=1))
|
||||
|
||||
# Flatten confidence to ensure consistent shape
|
||||
|
||||
# OHLCV prediction (Open, High, Low, Close, Volume)
|
||||
ohlcv_pred = self._memory_barrier(self.decision_head(combined_features))
|
||||
|
||||
# Generate confidence based on prediction stability
|
||||
confidence_flat = self._memory_barrier(confidence.reshape(confidence.shape[0], -1))
|
||||
volatility_flat = self._memory_barrier(volatility_pred.reshape(volatility_pred.shape[0], -1))
|
||||
|
||||
|
||||
# Calculate prediction confidence based on volatility and regime stability
|
||||
regime_stability = torch.std(regime_probs, dim=1, keepdim=True)
|
||||
prediction_confidence = 1.0 / (1.0 + regime_stability + volatility_flat * 0.1)
|
||||
prediction_confidence = self._memory_barrier(prediction_confidence.squeeze(-1))
|
||||
|
||||
return {
|
||||
'logits': self._memory_barrier(trading_logits),
|
||||
'probabilities': self._memory_barrier(trading_probs),
|
||||
'confidence': confidence_flat[:, 0] if confidence_flat.shape[1] > 0 else confidence_flat.reshape(-1)[0],
|
||||
'ohlcv': self._memory_barrier(ohlcv_pred), # [batch_size, 5] - OHLCV predictions
|
||||
'confidence': prediction_confidence,
|
||||
'regime': self._memory_barrier(regime_probs),
|
||||
'volatility': volatility_flat[:, 0] if volatility_flat.shape[1] > 0 else volatility_flat.reshape(-1)[0],
|
||||
'features': self._memory_barrier(processed_features)
|
||||
'features': self._memory_barrier(processed_features),
|
||||
'regime_stability': self._memory_barrier(regime_stability.squeeze(-1))
|
||||
}
|
||||
|
||||
def predict(self, feature_matrix) -> Dict[str, Any]:
|
||||
"""
|
||||
Make predictions on feature matrix
|
||||
Make OHLCV predictions on feature matrix
|
||||
Args:
|
||||
feature_matrix: tensor or numpy array of shape [sequence_length, features]
|
||||
Returns:
|
||||
Dictionary with prediction results
|
||||
Dictionary with OHLCV prediction results and trading signals
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
@@ -468,17 +469,13 @@ class EnhancedCNNModel(nn.Module):
|
||||
# Forward pass
|
||||
outputs = self.forward(x)
|
||||
|
||||
# Extract results with proper shape handling
|
||||
if HAS_NUMPY:
|
||||
probs = outputs['probabilities'].cpu().numpy()[0]
|
||||
confidence_tensor = outputs['confidence'].cpu().numpy()
|
||||
regime = outputs['regime'].cpu().numpy()[0]
|
||||
volatility = outputs['volatility'].cpu().numpy()
|
||||
else:
|
||||
probs = outputs['probabilities'].cpu().tolist()[0]
|
||||
confidence_tensor = outputs['confidence'].cpu().tolist()
|
||||
regime = outputs['regime'].cpu().tolist()[0]
|
||||
volatility = outputs['volatility'].cpu().tolist()
|
||||
# Extract OHLCV predictions
|
||||
ohlcv_pred = outputs['ohlcv'].cpu().numpy()[0] if HAS_NUMPY else outputs['ohlcv'].cpu().tolist()[0]
|
||||
|
||||
# Extract other outputs
|
||||
confidence_tensor = outputs['confidence'].cpu().numpy() if HAS_NUMPY else outputs['confidence'].cpu().tolist()
|
||||
regime = outputs['regime'].cpu().numpy()[0] if HAS_NUMPY else outputs['regime'].cpu().tolist()[0]
|
||||
volatility = outputs['volatility'].cpu().numpy() if HAS_NUMPY else outputs['volatility'].cpu().tolist()
|
||||
|
||||
# Handle confidence shape properly
|
||||
if HAS_NUMPY and isinstance(confidence_tensor, np.ndarray):
|
||||
@@ -490,7 +487,7 @@ class EnhancedCNNModel(nn.Module):
|
||||
confidence = float(confidence_tensor[0] if len(confidence_tensor) > 0 else 0.7)
|
||||
else:
|
||||
confidence = float(confidence_tensor)
|
||||
|
||||
|
||||
# Handle volatility shape properly
|
||||
if HAS_NUMPY and isinstance(volatility, np.ndarray):
|
||||
if volatility.ndim == 0:
|
||||
@@ -502,28 +499,68 @@ class EnhancedCNNModel(nn.Module):
|
||||
else:
|
||||
volatility = float(volatility)
|
||||
|
||||
# Determine action (0=BUY, 1=SELL for 2-action system)
|
||||
if HAS_NUMPY:
|
||||
action = int(np.argmax(probs))
|
||||
else:
|
||||
action = int(torch.argmax(torch.tensor(probs)).item())
|
||||
action_confidence = float(probs[action])
|
||||
# Extract OHLCV values
|
||||
open_price, high_price, low_price, close_price, volume = ohlcv_pred
|
||||
|
||||
# Convert logits to list
|
||||
if HAS_NUMPY:
|
||||
raw_logits = outputs['logits'].cpu().numpy()[0].tolist()
|
||||
else:
|
||||
raw_logits = outputs['logits'].cpu().tolist()[0]
|
||||
# Calculate price movement and direction
|
||||
price_change = close_price - open_price
|
||||
price_change_pct = (price_change / open_price) * 100 if open_price != 0 else 0
|
||||
|
||||
# Calculate candle characteristics
|
||||
body_size = abs(close_price - open_price)
|
||||
upper_wick = high_price - max(open_price, close_price)
|
||||
lower_wick = min(open_price, close_price) - low_price
|
||||
total_range = high_price - low_price
|
||||
|
||||
# Determine trading action based on predicted candle
|
||||
if price_change_pct > 0.1: # Bullish candle (>0.1% gain)
|
||||
action = 0 # BUY
|
||||
action_name = 'BUY'
|
||||
action_confidence = min(0.95, confidence * (1 + abs(price_change_pct) * 10))
|
||||
elif price_change_pct < -0.1: # Bearish candle (<-0.1% loss)
|
||||
action = 1 # SELL
|
||||
action_name = 'SELL'
|
||||
action_confidence = min(0.95, confidence * (1 + abs(price_change_pct) * 10))
|
||||
else: # Sideways/neutral candle
|
||||
# Use body vs wick analysis for weak signals
|
||||
if body_size / total_range > 0.7: # Strong directional body
|
||||
action = 0 if price_change > 0 else 1
|
||||
action_name = 'BUY' if action == 0 else 'SELL'
|
||||
action_confidence = confidence * 0.6 # Reduce confidence for weak signals
|
||||
else:
|
||||
action = 2 # HOLD
|
||||
action_name = 'HOLD'
|
||||
action_confidence = confidence * 0.3 # Very low confidence
|
||||
|
||||
# Adjust confidence based on volatility
|
||||
if volatility > 0.5: # High volatility
|
||||
action_confidence *= 0.8 # Reduce confidence in volatile conditions
|
||||
elif volatility < 0.2: # Low volatility
|
||||
action_confidence *= 1.2 # Increase confidence in stable conditions
|
||||
action_confidence = min(0.95, action_confidence) # Cap at 95%
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'action_name': 'BUY' if action == 0 else 'SELL',
|
||||
'action_name': action_name,
|
||||
'confidence': float(confidence),
|
||||
'action_confidence': action_confidence,
|
||||
'probabilities': probs if isinstance(probs, list) else probs.tolist(),
|
||||
'ohlcv_prediction': {
|
||||
'open': float(open_price),
|
||||
'high': float(high_price),
|
||||
'low': float(low_price),
|
||||
'close': float(close_price),
|
||||
'volume': float(volume)
|
||||
},
|
||||
'price_change_pct': price_change_pct,
|
||||
'candle_characteristics': {
|
||||
'body_size': body_size,
|
||||
'upper_wick': upper_wick,
|
||||
'lower_wick': lower_wick,
|
||||
'total_range': total_range
|
||||
},
|
||||
'regime_probabilities': regime if isinstance(regime, list) else regime.tolist(),
|
||||
'volatility_prediction': float(volatility),
|
||||
'raw_logits': raw_logits
|
||||
'prediction_quality': 'high' if action_confidence > 0.8 else 'medium' if action_confidence > 0.6 else 'low'
|
||||
}
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
|
Reference in New Issue
Block a user