fix model mappings,dash updates, trading
This commit is contained in:
@ -111,6 +111,9 @@ class SpatialAttentionBlock(nn.Module):
|
||||
# Avoid in-place operation by creating new tensor
|
||||
return torch.mul(x, attention)
|
||||
|
||||
#Todo:
|
||||
#1. Add pivot points array as input
|
||||
#2. change output to be next pivot point (we'll need to adjust training as well)
|
||||
class EnhancedCNNModel(nn.Module):
|
||||
"""
|
||||
Much larger and more sophisticated CNN architecture for trading
|
||||
@ -125,7 +128,7 @@ class EnhancedCNNModel(nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 2, # BUY/SELL for 2-action system
|
||||
output_size: int = 3, # BUY/SELL/HOLD for 3-action system
|
||||
base_channels: int = 256, # Increased from 128 to 256
|
||||
num_blocks: int = 12, # Increased from 6 to 12
|
||||
num_attention_heads: int = 16, # Increased from 8 to 16
|
||||
@ -479,9 +482,13 @@ class EnhancedCNNModel(nn.Module):
|
||||
action = int(np.argmax(probs))
|
||||
action_confidence = float(probs[action])
|
||||
|
||||
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
action_name = action_names[action] if action < len(action_names) else 'HOLD'
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'action_name': 'BUY' if action == 0 else 'SELL',
|
||||
'action_name': action_name,
|
||||
'confidence': float(confidence),
|
||||
'action_confidence': action_confidence,
|
||||
'probabilities': probs.tolist(),
|
||||
@ -965,21 +972,21 @@ class CNNModel:
|
||||
if len(trend_data) > 1:
|
||||
trend = (trend_data[-1] - trend_data[0]) / trend_data[0] if trend_data[0] != 0 else 0
|
||||
|
||||
# Map trend to action
|
||||
# Map trend to action - FIXED ACTION MAPPING: 0=BUY, 1=SELL
|
||||
if trend > 0.001: # Upward trend > 0.1%
|
||||
action = 1 # BUY
|
||||
action = 0 # BUY (action 0)
|
||||
confidence = min(0.9, 0.5 + abs(trend) * 10)
|
||||
elif trend < -0.001: # Downward trend < -0.1%
|
||||
action = 0 # SELL
|
||||
action = 1 # SELL (action 1)
|
||||
confidence = min(0.9, 0.5 + abs(trend) * 10)
|
||||
else:
|
||||
action = 0 # Default to SELL for unclear trend
|
||||
action = 2 # Default to HOLD for unclear trend
|
||||
confidence = 0.3
|
||||
else:
|
||||
action = 0
|
||||
action = 2 # HOLD for unknown trend
|
||||
confidence = 0.3
|
||||
else:
|
||||
action = 0
|
||||
action = 2 # HOLD for insufficient data
|
||||
confidence = 0.3
|
||||
|
||||
# Create probabilities
|
||||
@ -1000,7 +1007,7 @@ class CNNModel:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback prediction: {e}")
|
||||
# Final fallback - conservative prediction
|
||||
pred_class = np.array([0]) # SELL
|
||||
pred_class = np.array([2]) # HOLD (safe default)
|
||||
proba = np.ones(self.output_size) / self.output_size # Equal probabilities
|
||||
pred_proba = np.array([proba])
|
||||
return pred_class, pred_proba
|
||||
|
Reference in New Issue
Block a user