refactoring, predictions WIP

This commit is contained in:
Dobromir Popov
2025-06-13 11:48:03 +03:00
parent 543b53883e
commit 5bce17a21a
10 changed files with 14187 additions and 174 deletions

View File

@ -2568,6 +2568,241 @@
# except Exception as e:
# logger.error(f"Error handling unified stream data: {e}")
# def _get_cnn_pivot_predictions(self, symbol: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
# """Get CNN model predictions for next pivot points"""
# try:
# predictions = []
# if not hasattr(self, 'orchestrator') or not self.orchestrator:
# return predictions
#
# # Check if orchestrator has CNN capabilities
# if hasattr(self.orchestrator, 'pivot_rl_trainer') and self.orchestrator.pivot_rl_trainer:
# if hasattr(self.orchestrator.pivot_rl_trainer, 'williams') and self.orchestrator.pivot_rl_trainer.williams:
# williams = self.orchestrator.pivot_rl_trainer.williams
#
# if hasattr(williams, 'cnn_model') and williams.cnn_model:
# # Get latest market data for CNN input
# if not df.empty and len(df) >= 900: # CNN needs at least 900 timesteps
# try:
# # Prepare multi-timeframe input for CNN
# current_time = datetime.now()
#
# # Create dummy pivot point for CNN input preparation
# dummy_pivot = type('SwingPoint', (), {
# 'timestamp': current_time,
# 'price': df['close'].iloc[-1],
# 'index': len(df) - 1,
# 'swing_type': 'prediction_point',
# 'strength': 1
# })()
#
# # Prepare CNN input using Williams structure
# cnn_input = williams._prepare_cnn_input(
# dummy_pivot,
# df.values, # OHLCV data context
# None # No previous pivot details
# )
#
# if cnn_input is not None and cnn_input.size > 0:
# # Reshape for batch prediction
# if len(cnn_input.shape) == 2:
# cnn_input = np.expand_dims(cnn_input, axis=0)
#
# # Get CNN prediction
# pred_output = williams.cnn_model.model.predict(cnn_input, verbose=0)
#
# if pred_output is not None and len(pred_output) > 0:
# # Parse CNN output (10 outputs for 5 Williams levels)
# # Each level has [type_probability, predicted_price]
# current_price = df['close'].iloc[-1]
#
# for level_idx in range(min(5, len(pred_output[0]) // 2)):
# type_prob = pred_output[0][level_idx * 2]
# price_offset = pred_output[0][level_idx * 2 + 1]
#
# # Determine prediction type
# is_high = type_prob > 0.5
# confidence = abs(type_prob - 0.5) * 2 # Convert to 0-1 range
#
# # Calculate predicted price
# predicted_price = current_price + (price_offset * current_price * 0.01) # Assume price_offset is percentage
#
# # Only include predictions with reasonable confidence
# if confidence > 0.3:
# prediction = {
# 'level': level_idx + 1,
# 'type': 'HIGH' if is_high else 'LOW',
# 'predicted_price': predicted_price,
# 'confidence': confidence,
# 'timestamp': current_time,
# 'current_price': current_price,
# 'price_offset_pct': price_offset * 100,
# 'model_output': {
# 'type_prob': float(type_prob),
# 'price_offset': float(price_offset)
# }
# }
# predictions.append(prediction)
#
# logger.debug(f"[CNN] Generated {len(predictions)} pivot predictions for {symbol}")
#
# except Exception as e:
# logger.warning(f"Error generating CNN predictions: {e}")
#
# return predictions
#
# except Exception as e:
# logger.error(f"Error getting CNN pivot predictions: {e}")
# return []
# def _add_cnn_predictions_to_chart(self, fig: go.Figure, predictions: List[Dict[str, Any]], row: int = 1):
# """Add CNN predictions as hollow circles to the chart"""
# try:
# if not predictions:
# return
#
# # Separate HIGH and LOW predictions
# high_predictions = [p for p in predictions if p['type'] == 'HIGH']
# low_predictions = [p for p in predictions if p['type'] == 'LOW']
#
# # Add HIGH prediction markers (hollow red circles)
# if high_predictions:
# # Create future timestamps for display (predictions are for future points)
# base_time = high_predictions[0]['timestamp']
#
# fig.add_trace(
# go.Scatter(
# x=[base_time + timedelta(minutes=i*5) for i in range(len(high_predictions))],
# y=[p['predicted_price'] for p in high_predictions],
# mode='markers',
# marker=dict(
# color='rgba(255, 107, 107, 0)', # Transparent fill
# size=[max(8, min(20, p['confidence'] * 20)) for p in high_predictions],
# symbol='circle',
# line=dict(
# color='#ff6b6b', # Red border
# width=2
# )
# ),
# name='CNN HIGH Predictions',
# showlegend=True,
# hovertemplate='<b>CNN HIGH Prediction</b><br>' +
# 'Price: $%{y:.2f}<br>' +
# 'Confidence: %{customdata:.1%}<br>' +
# 'Level: %{text}<extra></extra>',
# customdata=[p['confidence'] for p in high_predictions],
# text=[f"Level {p['level']}" for p in high_predictions]
# ),
# row=row, col=1
# )
#
# # Add LOW prediction markers (hollow green circles)
# if low_predictions:
# base_time = low_predictions[0]['timestamp']
#
# fig.add_trace(
# go.Scatter(
# x=[base_time + timedelta(minutes=i*5) for i in range(len(low_predictions))],
# y=[p['predicted_price'] for p in low_predictions],
# mode='markers',
# marker=dict(
# color='rgba(0, 255, 136, 0)', # Transparent fill
# size=[max(8, min(20, p['confidence'] * 20)) for p in low_predictions],
# symbol='circle',
# line=dict(
# color='#00ff88', # Green border
# width=2
# )
# ),
# name='CNN LOW Predictions',
# showlegend=True,
# hovertemplate='<b>CNN LOW Prediction</b><br>' +
# 'Price: $%{y:.2f}<br>' +
# 'Confidence: %{customdata:.1%}<br>' +
# 'Level: %{text}<extra></extra>',
# customdata=[p['confidence'] for p in low_predictions],
# text=[f"Level {p['level']}" for p in low_predictions]
# ),
# row=row, col=1
# )
#
# logger.debug(f"[CHART] Added {len(high_predictions)} HIGH and {len(low_predictions)} LOW CNN predictions to chart")
#
# except Exception as e:
# logger.error(f"Error adding CNN predictions to chart: {e}")
# def _capture_actual_pivot_data(self, actual_pivot: Dict[str, Any]) -> None:
# """Capture actual pivot data when it occurs for training comparison"""
# try:
# if not hasattr(self, '_pivot_training_data'):
# self._pivot_training_data = []
#
# # Store actual pivot with timestamp for later comparison with predictions
# pivot_data = {
# 'actual_pivot': actual_pivot,
# 'timestamp': datetime.now(),
# 'captured_at': datetime.now().isoformat()
# }
#
# self._pivot_training_data.append(pivot_data)
#
# # Keep only last 1000 actual pivots
# if len(self._pivot_training_data) > 1000:
# self._pivot_training_data = self._pivot_training_data[-1000:]
#
# logger.info(f"[TRAINING] Captured actual pivot: {actual_pivot['type']} at ${actual_pivot['price']:.2f}")
#
# # Save to persistent storage periodically
# if len(self._pivot_training_data) % 10 == 0:
# self._save_pivot_training_data()
#
# except Exception as e:
# logger.error(f"Error capturing actual pivot data: {e}")
# def _save_pivot_training_data(self) -> None:
# """Save pivot training data to JSON file for model improvement"""
# try:
# if not hasattr(self, '_pivot_training_data') or not self._pivot_training_data:
# return
#
# # Create data directory if it doesn't exist
# import os
# os.makedirs('data/cnn_training', exist_ok=True)
#
# # Save to timestamped file
# timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# filename = f'data/cnn_training/pivot_predictions_vs_actual_{timestamp}.json'
#
# # Prepare data for JSON serialization
# save_data = {
# 'metadata': {
# 'created_at': datetime.now().isoformat(),
# 'total_samples': len(self._pivot_training_data),
# 'description': 'CNN pivot predictions compared with actual market pivots'
# },
# 'training_samples': []
# }
#
# for sample in self._pivot_training_data:
# # Convert datetime objects to ISO strings for JSON
# json_sample = {
# 'actual_pivot': sample['actual_pivot'],
# 'timestamp': sample['timestamp'].isoformat() if isinstance(sample['timestamp'], datetime) else sample['timestamp'],
# 'captured_at': sample['captured_at']
# }
# save_data['training_samples'].append(json_sample)
#
# # Write to file
# import json
# with open(filename, 'w') as f:
# json.dump(save_data, f, indent=2, default=str)
#
# logger.info(f"[TRAINING] Saved {len(self._pivot_training_data)} pivot training samples to {filename}")
#
# except Exception as e:
# logger.error(f"Error saving pivot training data: {e}")
# def create_scalping_dashboard(data_provider=None, orchestrator=None, trading_executor=None):
# """Create real-time dashboard instance with MEXC integration"""
# return RealTimeScalpingDashboard(data_provider, orchestrator, trading_executor)