display predictions
This commit is contained in:
@ -71,6 +71,34 @@ class EnhancedRealtimeTrainingSystem:
|
||||
'validation': 0.0
|
||||
}
|
||||
|
||||
# Model prediction tracking - NEW for dashboard visualization
|
||||
self.recent_dqn_predictions = {
|
||||
'ETH/USDT': deque(maxlen=100),
|
||||
'BTC/USDT': deque(maxlen=100)
|
||||
}
|
||||
self.recent_cnn_predictions = {
|
||||
'ETH/USDT': deque(maxlen=50),
|
||||
'BTC/USDT': deque(maxlen=50)
|
||||
}
|
||||
self.prediction_accuracy_history = {
|
||||
'ETH/USDT': deque(maxlen=200),
|
||||
'BTC/USDT': deque(maxlen=200)
|
||||
}
|
||||
|
||||
# FIXED: Forward-looking prediction system
|
||||
self.pending_predictions = {
|
||||
'ETH/USDT': deque(maxlen=100), # Predictions waiting for validation
|
||||
'BTC/USDT': deque(maxlen=100)
|
||||
}
|
||||
self.last_prediction_time = {
|
||||
'ETH/USDT': 0,
|
||||
'BTC/USDT': 0
|
||||
}
|
||||
self.prediction_intervals = {
|
||||
'dqn': 30, # Make DQN prediction every 30 seconds
|
||||
'cnn': 60 # Make CNN prediction every 60 seconds
|
||||
}
|
||||
|
||||
# Real-time data streams
|
||||
self.real_time_data = {
|
||||
'ticks': deque(maxlen=1000),
|
||||
@ -146,24 +174,27 @@ class EnhancedRealtimeTrainingSystem:
|
||||
current_time = time.time()
|
||||
self.training_iteration += 1
|
||||
|
||||
# 1. DQN Training (every 5 seconds with enough data)
|
||||
# 1. FORWARD-LOOKING PREDICTIONS - Generate real predictions for future validation
|
||||
self.generate_forward_looking_predictions()
|
||||
|
||||
# 2. DQN Training (every 5 seconds with enough data)
|
||||
if (current_time - self.last_training_times['dqn'] > self.training_config['dqn_training_interval']
|
||||
and len(self.experience_buffer) >= self.training_config['min_training_samples']):
|
||||
self._perform_enhanced_dqn_training()
|
||||
self.last_training_times['dqn'] = current_time
|
||||
|
||||
# 2. CNN Training (every 10 seconds)
|
||||
# 3. CNN Training (every 10 seconds)
|
||||
if (current_time - self.last_training_times['cnn'] > self.training_config['cnn_training_interval']
|
||||
and len(self.real_time_data['ohlcv_1m']) >= 20):
|
||||
self._perform_enhanced_cnn_training()
|
||||
self.last_training_times['cnn'] = current_time
|
||||
|
||||
# 3. Validation (every minute)
|
||||
# 4. Validation (every minute)
|
||||
if current_time - self.last_training_times['validation'] > self.training_config['validation_interval']:
|
||||
self._perform_validation()
|
||||
self.last_training_times['validation'] = current_time
|
||||
|
||||
# 4. Adaptive learning rate adjustment
|
||||
# 5. Adaptive learning rate adjustment
|
||||
if self.training_iteration % 100 == 0:
|
||||
self._adapt_learning_parameters()
|
||||
|
||||
@ -911,6 +942,11 @@ class EnhancedRealtimeTrainingSystem:
|
||||
'dqn_loss_count': len(self.performance_history['dqn_losses']),
|
||||
'cnn_loss_count': len(self.performance_history['cnn_losses']),
|
||||
'validation_count': len(self.performance_history['validation_scores'])
|
||||
},
|
||||
'prediction_stats': {
|
||||
'dqn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_dqn_predictions.items()},
|
||||
'cnn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_cnn_predictions.items()},
|
||||
'accuracy_history': {symbol: len(history) for symbol, history in self.prediction_accuracy_history.items()}
|
||||
}
|
||||
}
|
||||
|
||||
@ -927,4 +963,492 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
return {'error': str(e)}
|
||||
|
||||
def capture_dqn_prediction(self, symbol: str, state: np.ndarray, q_values: List[float], action: int, confidence: float, price: float):
|
||||
"""Capture DQN prediction for dashboard visualization"""
|
||||
try:
|
||||
prediction = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'state': state.tolist() if hasattr(state, 'tolist') else state,
|
||||
'q_values': q_values,
|
||||
'action': action, # 0=BUY, 1=SELL, 2=HOLD
|
||||
'confidence': confidence,
|
||||
'price': price
|
||||
}
|
||||
|
||||
if symbol in self.recent_dqn_predictions:
|
||||
self.recent_dqn_predictions[symbol].append(prediction)
|
||||
|
||||
logger.debug(f"DQN prediction captured: {symbol} action={action} confidence={confidence:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||
|
||||
def capture_cnn_prediction(self, symbol: str, current_price: float, predicted_price: float, direction: int, confidence: float, features: Optional[np.ndarray] = None):
|
||||
"""Capture CNN prediction for dashboard visualization"""
|
||||
try:
|
||||
prediction = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'current_price': current_price,
|
||||
'predicted_price': predicted_price,
|
||||
'direction': direction, # 0=DOWN, 1=SAME, 2=UP
|
||||
'confidence': confidence,
|
||||
'features': features.tolist() if features is not None and hasattr(features, 'tolist') else None
|
||||
}
|
||||
|
||||
if symbol in self.recent_cnn_predictions:
|
||||
self.recent_cnn_predictions[symbol].append(prediction)
|
||||
|
||||
logger.debug(f"CNN prediction captured: {symbol} direction={direction} confidence={confidence:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error capturing CNN prediction: {e}")
|
||||
|
||||
def validate_prediction_accuracy(self, symbol: str, prediction_type: str, predicted_action: int, actual_price_change: float, confidence: float):
|
||||
"""Validate prediction accuracy and store results"""
|
||||
try:
|
||||
# Determine if prediction was correct
|
||||
was_correct = False
|
||||
|
||||
if prediction_type == 'DQN':
|
||||
# For DQN: BUY (0) should be followed by price increase, SELL (1) by decrease
|
||||
if predicted_action == 0 and actual_price_change > 0.001: # BUY + price up
|
||||
was_correct = True
|
||||
elif predicted_action == 1 and actual_price_change < -0.001: # SELL + price down
|
||||
was_correct = True
|
||||
elif predicted_action == 2 and abs(actual_price_change) <= 0.001: # HOLD + no change
|
||||
was_correct = True
|
||||
|
||||
elif prediction_type == 'CNN':
|
||||
# For CNN: direction prediction accuracy
|
||||
if predicted_action == 2 and actual_price_change > 0.001: # UP + price up
|
||||
was_correct = True
|
||||
elif predicted_action == 0 and actual_price_change < -0.001: # DOWN + price down
|
||||
was_correct = True
|
||||
elif predicted_action == 1 and abs(actual_price_change) <= 0.001: # SAME + no change
|
||||
was_correct = True
|
||||
|
||||
# Calculate accuracy score based on confidence and correctness
|
||||
accuracy_score = confidence if was_correct else (1.0 - confidence)
|
||||
|
||||
accuracy_data = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol,
|
||||
'prediction_type': prediction_type,
|
||||
'correct': was_correct,
|
||||
'accuracy_score': accuracy_score,
|
||||
'confidence': confidence,
|
||||
'actual_price_change': actual_price_change,
|
||||
'predicted_action': predicted_action
|
||||
}
|
||||
|
||||
if symbol in self.prediction_accuracy_history:
|
||||
self.prediction_accuracy_history[symbol].append(accuracy_data)
|
||||
|
||||
logger.debug(f"Prediction accuracy validated: {symbol} {prediction_type} correct={was_correct} score={accuracy_score:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error validating prediction accuracy: {e}")
|
||||
|
||||
def get_prediction_summary(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get prediction summary for a symbol"""
|
||||
try:
|
||||
summary = {
|
||||
'symbol': symbol,
|
||||
'dqn_predictions': len(self.recent_dqn_predictions.get(symbol, [])),
|
||||
'cnn_predictions': len(self.recent_cnn_predictions.get(symbol, [])),
|
||||
'accuracy_history': len(self.prediction_accuracy_history.get(symbol, [])),
|
||||
'pending_predictions': len(self.pending_predictions.get(symbol, []))
|
||||
}
|
||||
|
||||
# Calculate accuracy statistics
|
||||
if symbol in self.prediction_accuracy_history and self.prediction_accuracy_history[symbol]:
|
||||
accuracy_data = list(self.prediction_accuracy_history[symbol])
|
||||
|
||||
total_predictions = len(accuracy_data)
|
||||
correct_predictions = sum(1 for acc in accuracy_data if acc['correct'])
|
||||
|
||||
summary['total_predictions'] = total_predictions
|
||||
summary['correct_predictions'] = correct_predictions
|
||||
summary['accuracy_rate'] = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# Calculate accuracy by prediction type
|
||||
dqn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'DQN']
|
||||
cnn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'CNN']
|
||||
|
||||
if dqn_accuracy_data:
|
||||
dqn_correct = sum(1 for acc in dqn_accuracy_data if acc['correct'])
|
||||
summary['dqn_accuracy_rate'] = dqn_correct / len(dqn_accuracy_data)
|
||||
else:
|
||||
summary['dqn_accuracy_rate'] = 0.0
|
||||
|
||||
if cnn_accuracy_data:
|
||||
cnn_correct = sum(1 for acc in cnn_accuracy_data if acc['correct'])
|
||||
summary['cnn_accuracy_rate'] = cnn_correct / len(cnn_accuracy_data)
|
||||
else:
|
||||
summary['cnn_accuracy_rate'] = 0.0
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction summary: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def generate_forward_looking_predictions(self):
|
||||
"""Generate forward-looking predictions based on current market data"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
# Check if it's time to make new predictions
|
||||
time_since_last = current_time - self.last_prediction_time.get(symbol, 0)
|
||||
|
||||
# Generate DQN prediction every 30 seconds
|
||||
if time_since_last >= self.prediction_intervals['dqn']:
|
||||
self._generate_forward_dqn_prediction(symbol, current_time)
|
||||
|
||||
# Generate CNN prediction every 60 seconds
|
||||
if time_since_last >= self.prediction_intervals['cnn']:
|
||||
self._generate_forward_cnn_prediction(symbol, current_time)
|
||||
|
||||
# Validate pending predictions
|
||||
self._validate_pending_predictions(symbol, current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward-looking predictions: {e}")
|
||||
|
||||
def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
|
||||
"""Generate a DQN prediction for future price movement"""
|
||||
try:
|
||||
# Get current market state (only historical data)
|
||||
current_state = self._build_comprehensive_state()
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Use DQN model to predict action (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
# Get Q-values from model
|
||||
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
|
||||
if isinstance(q_values, tuple):
|
||||
action, q_vals = q_values
|
||||
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
|
||||
else:
|
||||
action = q_values
|
||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||
|
||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||
|
||||
else:
|
||||
# Fallback to technical analysis-based prediction
|
||||
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
||||
|
||||
# Create forward-looking prediction
|
||||
prediction_time = datetime.now()
|
||||
target_time = prediction_time + timedelta(minutes=5) # Predict 5 minutes ahead
|
||||
|
||||
prediction = {
|
||||
'id': f"dqn_{symbol}_{int(current_time)}",
|
||||
'type': 'DQN',
|
||||
'symbol': symbol,
|
||||
'prediction_time': prediction_time,
|
||||
'target_time': target_time,
|
||||
'current_price': current_price,
|
||||
'predicted_action': action,
|
||||
'q_values': q_values,
|
||||
'confidence': confidence,
|
||||
'state': current_state.tolist() if hasattr(current_state, 'tolist') else current_state,
|
||||
'validated': False
|
||||
}
|
||||
|
||||
# Add to pending predictions for future validation
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.4:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'price': current_price,
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'q_values': q_values
|
||||
}
|
||||
if symbol in self.recent_dqn_predictions:
|
||||
self.recent_dqn_predictions[symbol].append(display_prediction)
|
||||
|
||||
self.last_prediction_time[symbol] = current_time
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||
|
||||
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
||||
"""Generate a CNN prediction for future price direction"""
|
||||
try:
|
||||
# Get current price and historical sequence (only past data)
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
||||
|
||||
if current_price is None or len(price_sequence) < 15:
|
||||
return
|
||||
|
||||
# Use CNN model to predict direction (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
|
||||
and self.orchestrator.cnn_model):
|
||||
|
||||
# Prepare features for CNN
|
||||
features = self._prepare_cnn_features(price_sequence)
|
||||
|
||||
try:
|
||||
# Get prediction from CNN model
|
||||
prediction_output = self.orchestrator.cnn_model.predict(features)
|
||||
if hasattr(prediction_output, 'tolist'):
|
||||
pred_probs = prediction_output.tolist()
|
||||
else:
|
||||
pred_probs = [0.33, 0.33, 0.34] # Default
|
||||
|
||||
direction = int(np.argmax(pred_probs)) # 0=DOWN, 1=SAME, 2=UP
|
||||
confidence = max(pred_probs)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"CNN model prediction failed: {e}")
|
||||
direction, confidence = self._technical_direction_prediction(symbol)
|
||||
|
||||
else:
|
||||
# Fallback to technical analysis
|
||||
direction, confidence = self._technical_direction_prediction(symbol)
|
||||
|
||||
# Calculate predicted price based on direction
|
||||
price_change_percent = self._estimate_price_change(direction, confidence)
|
||||
predicted_price = current_price * (1 + price_change_percent)
|
||||
|
||||
# Create forward-looking prediction
|
||||
prediction_time = datetime.now()
|
||||
target_time = prediction_time + timedelta(minutes=10) # Predict 10 minutes ahead
|
||||
|
||||
prediction = {
|
||||
'id': f"cnn_{symbol}_{int(current_time)}",
|
||||
'type': 'CNN',
|
||||
'symbol': symbol,
|
||||
'prediction_time': prediction_time,
|
||||
'target_time': target_time,
|
||||
'current_price': current_price,
|
||||
'predicted_price': predicted_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'features': features.tolist() if hasattr(features, 'tolist') else None,
|
||||
'validated': False
|
||||
}
|
||||
|
||||
# Add to pending predictions for future validation
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.5:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'current_price': current_price,
|
||||
'predicted_price': predicted_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence
|
||||
}
|
||||
if symbol in self.recent_cnn_predictions:
|
||||
self.recent_cnn_predictions[symbol].append(display_prediction)
|
||||
|
||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward CNN prediction: {e}")
|
||||
|
||||
def _validate_pending_predictions(self, symbol: str, current_time: float):
|
||||
"""Validate pending predictions when their target time arrives"""
|
||||
try:
|
||||
if symbol not in self.pending_predictions:
|
||||
return
|
||||
|
||||
current_datetime = datetime.now()
|
||||
validated_predictions = []
|
||||
|
||||
# Check each pending prediction
|
||||
for prediction in list(self.pending_predictions[symbol]):
|
||||
target_time = prediction['target_time']
|
||||
|
||||
# If target time has passed, validate the prediction
|
||||
if current_datetime >= target_time:
|
||||
actual_price = self._get_current_price_from_data(symbol)
|
||||
|
||||
if actual_price is not None:
|
||||
# Calculate actual price change
|
||||
predicted_price = prediction.get('predicted_price', prediction['current_price'])
|
||||
actual_change = (actual_price - prediction['current_price']) / prediction['current_price']
|
||||
predicted_change = (predicted_price - prediction['current_price']) / prediction['current_price']
|
||||
|
||||
# Validate based on prediction type
|
||||
if prediction['type'] == 'DQN':
|
||||
was_correct = self._validate_dqn_prediction(prediction, actual_change)
|
||||
else: # CNN
|
||||
was_correct = self._validate_cnn_prediction(prediction, actual_change)
|
||||
|
||||
# Store accuracy result
|
||||
accuracy_data = {
|
||||
'timestamp': current_datetime,
|
||||
'symbol': symbol,
|
||||
'prediction_type': prediction['type'],
|
||||
'correct': was_correct,
|
||||
'accuracy_score': prediction['confidence'] if was_correct else (1.0 - prediction['confidence']),
|
||||
'confidence': prediction['confidence'],
|
||||
'actual_price_change': actual_change,
|
||||
'predicted_action': prediction.get('predicted_action', prediction.get('direction', 0)),
|
||||
'actual_price': actual_price
|
||||
}
|
||||
|
||||
if symbol in self.prediction_accuracy_history:
|
||||
self.prediction_accuracy_history[symbol].append(accuracy_data)
|
||||
|
||||
validated_predictions.append(prediction['id'])
|
||||
|
||||
logger.info(f"Validated {prediction['type']} prediction: {symbol} correct={was_correct} confidence={prediction['confidence']:.2f}")
|
||||
|
||||
# Remove validated predictions from pending list
|
||||
if validated_predictions:
|
||||
self.pending_predictions[symbol] = deque([
|
||||
p for p in self.pending_predictions[symbol]
|
||||
if p['id'] not in validated_predictions
|
||||
], maxlen=100)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating pending predictions: {e}")
|
||||
|
||||
def _validate_dqn_prediction(self, prediction: Dict, actual_change: float) -> bool:
|
||||
"""Validate DQN action prediction"""
|
||||
predicted_action = prediction['predicted_action']
|
||||
threshold = 0.005 # 0.5% threshold for significant movement
|
||||
|
||||
if predicted_action == 0: # BUY prediction
|
||||
return actual_change > threshold
|
||||
elif predicted_action == 1: # SELL prediction
|
||||
return actual_change < -threshold
|
||||
else: # HOLD prediction
|
||||
return abs(actual_change) <= threshold
|
||||
|
||||
def _validate_cnn_prediction(self, prediction: Dict, actual_change: float) -> bool:
|
||||
"""Validate CNN direction prediction"""
|
||||
predicted_direction = prediction['direction']
|
||||
threshold = 0.002 # 0.2% threshold for direction
|
||||
|
||||
if predicted_direction == 2: # UP prediction
|
||||
return actual_change > threshold
|
||||
elif predicted_direction == 0: # DOWN prediction
|
||||
return actual_change < -threshold
|
||||
else: # SAME prediction
|
||||
return abs(actual_change) <= threshold
|
||||
|
||||
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price from real-time data streams"""
|
||||
try:
|
||||
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||
return self.real_time_data['ohlcv_1m'][-1]['close']
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price: {e}")
|
||||
return None
|
||||
|
||||
def _get_historical_price_sequence(self, symbol: str, periods: int = 15) -> List[float]:
|
||||
"""Get historical price sequence for CNN features"""
|
||||
try:
|
||||
if len(self.real_time_data['ohlcv_1m']) >= periods:
|
||||
return [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-periods:]]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting price sequence: {e}")
|
||||
return []
|
||||
|
||||
def _technical_analysis_prediction(self, symbol: str) -> Tuple[int, List[float], float]:
|
||||
"""Fallback technical analysis prediction for DQN"""
|
||||
try:
|
||||
# Simple momentum-based prediction
|
||||
if len(self.real_time_data['ohlcv_1m']) >= 5:
|
||||
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]]
|
||||
momentum = (recent_prices[-1] - recent_prices[0]) / recent_prices[0]
|
||||
|
||||
if momentum > 0.01: # 1% upward momentum
|
||||
return 0, [0.6, 0.2, 0.2], 0.6 # BUY
|
||||
elif momentum < -0.01: # 1% downward momentum
|
||||
return 1, [0.2, 0.6, 0.2], 0.6 # SELL
|
||||
else:
|
||||
return 2, [0.2, 0.2, 0.6], 0.6 # HOLD
|
||||
|
||||
return 2, [0.33, 0.33, 0.34], 0.33 # Default HOLD
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in technical analysis prediction: {e}")
|
||||
return 2, [0.33, 0.33, 0.34], 0.33
|
||||
|
||||
def _technical_direction_prediction(self, symbol: str) -> Tuple[int, float]:
|
||||
"""Fallback technical analysis for CNN direction"""
|
||||
try:
|
||||
if len(self.real_time_data['ohlcv_1m']) >= 3:
|
||||
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-3:]]
|
||||
short_momentum = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2]
|
||||
|
||||
if short_momentum > 0.005: # 0.5% short-term up
|
||||
return 2, 0.65 # UP
|
||||
elif short_momentum < -0.005: # 0.5% short-term down
|
||||
return 0, 0.65 # DOWN
|
||||
else:
|
||||
return 1, 0.55 # SAME
|
||||
|
||||
return 1, 0.5 # Default SAME
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in technical direction prediction: {e}")
|
||||
return 1, 0.5
|
||||
|
||||
def _prepare_cnn_features(self, price_sequence: List[float]) -> np.ndarray:
|
||||
"""Prepare features for CNN model"""
|
||||
try:
|
||||
# Normalize prices relative to first price
|
||||
if len(price_sequence) >= 15:
|
||||
base_price = price_sequence[0]
|
||||
normalized = [(p - base_price) / base_price for p in price_sequence]
|
||||
|
||||
# Create feature matrix (15 x 20, flattened)
|
||||
features = np.zeros((15, 20))
|
||||
for i, norm_price in enumerate(normalized):
|
||||
features[i, 0] = norm_price # Normalized price
|
||||
if i > 0:
|
||||
features[i, 1] = normalized[i] - normalized[i-1] # Price change
|
||||
|
||||
return features.flatten()
|
||||
|
||||
return np.zeros(300) # Default feature vector
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error preparing CNN features: {e}")
|
||||
return np.zeros(300)
|
||||
|
||||
def _estimate_price_change(self, direction: int, confidence: float) -> float:
|
||||
"""Estimate price change percentage based on direction and confidence"""
|
||||
try:
|
||||
# Base change scaled by confidence
|
||||
base_change = 0.01 * confidence # Up to 1% change
|
||||
|
||||
if direction == 2: # UP
|
||||
return base_change
|
||||
elif direction == 0: # DOWN
|
||||
return -base_change
|
||||
else: # SAME
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error estimating price change: {e}")
|
||||
return 0.0
|
Reference in New Issue
Block a user