device fix , TZ fix
This commit is contained in:
@ -5851,20 +5851,76 @@ class CleanTradingDashboard:
|
||||
|
||||
logger.debug(f"Base data input created successfully for {symbol}")
|
||||
|
||||
# Make prediction using CNN adapter
|
||||
model_output = self.cnn_adapter.predict(base_data_input)
|
||||
|
||||
# Convert to dictionary for dashboard use
|
||||
prediction = {
|
||||
'action': model_output.predictions.get('action', 'HOLD'),
|
||||
'confidence': model_output.confidence,
|
||||
'buy_probability': model_output.predictions.get('buy_probability', 0.0),
|
||||
'sell_probability': model_output.predictions.get('sell_probability', 0.0),
|
||||
'hold_probability': model_output.predictions.get('hold_probability', 0.0),
|
||||
'timestamp': model_output.timestamp,
|
||||
'hidden_states': model_output.hidden_states,
|
||||
'metadata': model_output.metadata
|
||||
}
|
||||
# Make prediction using CNN model directly (EnhancedCNN uses act method)
|
||||
if hasattr(self.cnn_adapter, 'act'):
|
||||
# Use the act method for EnhancedCNN
|
||||
features = base_data_input.get_feature_vector()
|
||||
|
||||
# Convert to tensor and ensure proper device placement
|
||||
import torch
|
||||
device = next(self.cnn_adapter.parameters()).device
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
|
||||
|
||||
# Ensure batch dimension
|
||||
if features_tensor.dim() == 1:
|
||||
features_tensor = features_tensor.unsqueeze(0)
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.cnn_adapter.eval()
|
||||
|
||||
# Get prediction from CNN model
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_adapter(features_tensor)
|
||||
|
||||
# Convert to probabilities using softmax
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Map action index to action string
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Create probabilities dictionary
|
||||
probabilities = {
|
||||
'BUY': float(action_probs[0, 0].item()),
|
||||
'SELL': float(action_probs[0, 1].item()),
|
||||
'HOLD': float(action_probs[0, 2].item())
|
||||
}
|
||||
|
||||
# Extract price predictions if available
|
||||
price_prediction = None
|
||||
if price_pred is not None:
|
||||
price_prediction = price_pred.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
prediction = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'buy_probability': probabilities['BUY'],
|
||||
'sell_probability': probabilities['SELL'],
|
||||
'hold_probability': probabilities['HOLD'],
|
||||
'timestamp': datetime.now(),
|
||||
'hidden_states': features_refined.squeeze(0).cpu().numpy().tolist() if features_refined is not None else None,
|
||||
'metadata': {
|
||||
'price_prediction': price_prediction,
|
||||
'extrema_prediction': extrema_pred.squeeze(0).cpu().numpy().tolist() if extrema_pred is not None else None
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Fallback for other CNN models that might have predict method
|
||||
model_output = self.cnn_adapter.predict(base_data_input)
|
||||
|
||||
# Convert to dictionary for dashboard use
|
||||
prediction = {
|
||||
'action': model_output.predictions.get('action', 'HOLD'),
|
||||
'confidence': model_output.confidence,
|
||||
'buy_probability': model_output.predictions.get('buy_probability', 0.0),
|
||||
'sell_probability': model_output.predictions.get('sell_probability', 0.0),
|
||||
'hold_probability': model_output.predictions.get('hold_probability', 0.0),
|
||||
'timestamp': model_output.timestamp,
|
||||
'hidden_states': model_output.hidden_states,
|
||||
'metadata': model_output.metadata
|
||||
}
|
||||
|
||||
logger.debug(f"CNN prediction for {symbol}: {prediction['action']} ({prediction['confidence']:.3f})")
|
||||
return prediction
|
||||
|
Reference in New Issue
Block a user