diff --git a/core/orchestrator.py b/core/orchestrator.py index d177307..b752383 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -372,6 +372,9 @@ class TradingOrchestrator: self.recent_cnn_predictions: Dict[str, deque] = ( {} ) # {symbol: List[Dict]} - Recent CNN predictions + self.recent_transformer_predictions: Dict[str, deque] = ( + {} + ) # {symbol: List[Dict]} - Recent Transformer predictions self.prediction_accuracy_history: Dict[str, deque] = ( {} ) # {symbol: List[Dict]} - Prediction accuracy tracking @@ -379,6 +382,7 @@ class TradingOrchestrator: # Initialize prediction tracking for the primary trading symbol only self.recent_dqn_predictions[self.symbol] = deque(maxlen=100) self.recent_cnn_predictions[self.symbol] = deque(maxlen=50) + self.recent_transformer_predictions[self.symbol] = deque(maxlen=50) self.prediction_accuracy_history[self.symbol] = deque(maxlen=200) self.signal_accumulator[self.symbol] = [] @@ -1109,6 +1113,8 @@ class TradingOrchestrator: self.recent_dqn_predictions[symbol].clear() for symbol in self.recent_cnn_predictions: self.recent_cnn_predictions[symbol].clear() + for symbol in self.recent_transformer_predictions: + self.recent_transformer_predictions[symbol].clear() for symbol in self.prediction_accuracy_history: self.prediction_accuracy_history[symbol].clear() @@ -2518,6 +2524,9 @@ class TradingOrchestrator: "cnn_predictions_tracked": sum( len(preds) for preds in self.recent_cnn_predictions.values() ), + "transformer_predictions_tracked": sum( + len(preds) for preds in self.recent_transformer_predictions.values() + ), "accuracy_history_tracked": sum( len(history) for history in self.prediction_accuracy_history.values() @@ -2527,6 +2536,7 @@ class TradingOrchestrator: for symbol in self.symbols if len(self.recent_dqn_predictions.get(symbol, [])) > 0 or len(self.recent_cnn_predictions.get(symbol, [])) > 0 + or len(self.recent_transformer_predictions.get(symbol, [])) > 0 ], } @@ -2791,4 +2801,19 @@ class TradingOrchestrator: """Set the trading executor for position tracking""" self.trading_executor = trading_executor logger.info("Trading executor set for position tracking and P&L feedback") + + def store_transformer_prediction(self, symbol: str, prediction: Dict): + """Store a transformer prediction for visualization and tracking""" + try: + if symbol not in self.recent_transformer_predictions: + self.recent_transformer_predictions[symbol] = deque(maxlen=50) + + # Add timestamp if not present + if 'timestamp' not in prediction: + prediction['timestamp'] = datetime.now() + + self.recent_transformer_predictions[symbol].append(prediction) + logger.debug(f"Stored transformer prediction for {symbol}: {prediction.get('action', 'N/A')}") + except Exception as e: + logger.error(f"Error storing transformer prediction: {e}") diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 71a74bc..a6ee9f9 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -3284,6 +3284,7 @@ class CleanTradingDashboard: # 2. NEW: Add real-time model predictions overlay self._add_dqn_predictions_to_chart(fig, symbol, df_main, row) self._add_cnn_predictions_to_chart(fig, symbol, df_main, row) + self._add_transformer_predictions_to_chart(fig, symbol, df_main, row) self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row) self._add_prediction_accuracy_feedback(fig, symbol, df_main, row) self._add_williams_pivots_to_chart(fig, symbol, df_main, row) @@ -3602,6 +3603,95 @@ class CleanTradingDashboard: except Exception as e: logger.debug(f"Error adding CNN predictions to chart: {e}") + def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): + """Add Transformer price predictions as trend lines with confidence bands""" + try: + # Get recent Transformer predictions from orchestrator + transformer_predictions = self._get_recent_transformer_predictions(symbol) + + if not transformer_predictions: + return + + for i, pred in enumerate(transformer_predictions[-15:]): # Last 15 Transformer predictions + confidence = pred.get('confidence', 0) + timestamp = pred.get('timestamp', datetime.now()) + current_price = pred.get('current_price', 0) + predicted_price = pred.get('predicted_price', current_price) + price_change = pred.get('price_change', 0) + + # FILTER OUT INVALID PRICES + if (current_price is None or current_price <= 0 or + predicted_price is None or predicted_price <= 0): + continue + + if confidence > 0.3: # Show predictions with reasonable confidence + # Calculate prediction horizon (typically 5-15 minutes) + horizon_minutes = pred.get('horizon_minutes', 10) + end_time = timestamp + timedelta(minutes=horizon_minutes) + + # Determine color based on price change direction + if price_change > 0.5: # Significant UP + color = f'rgba(0, 200, 255, {0.3 + confidence * 0.5})' + line_color = 'cyan' + prediction_name = 'Transformer UP' + elif price_change < -0.5: # Significant DOWN + color = f'rgba(255, 100, 0, {0.3 + confidence * 0.5})' + line_color = 'orange' + prediction_name = 'Transformer DOWN' + else: # Small change + color = f'rgba(150, 150, 255, {0.2 + confidence * 0.4})' + line_color = 'lightblue' + prediction_name = 'Transformer STABLE' + + # Add prediction line + fig.add_trace( + go.Scatter( + x=[timestamp, end_time], + y=[current_price, predicted_price], + mode='lines', + line=dict( + color=line_color, + width=2 + confidence * 4, + dash='dashdot' + ), + name=f'{prediction_name}', + showlegend=i == 0, + hovertemplate=f"{prediction_name}
" + + "From: $%{y[0]:.2f}
" + + "To: $%{y[1]:.2f}
" + + "Time: %{x[0]} → %{x[1]}
" + + f"Confidence: {confidence:.1%}
" + + f"Change: {price_change:+.2f}%" + ), + row=row, col=1 + ) + + # Add prediction target marker + fig.add_trace( + go.Scatter( + x=[end_time], + y=[predicted_price], + mode='markers', + marker=dict( + symbol='star', + size=8 + confidence * 10, + color=color, + line=dict(width=2, color=line_color) + ), + name=f'{prediction_name} Target', + showlegend=False, + hovertemplate=f"TRANSFORMER TARGET
" + + "Target: $%{y:.2f}
" + + "Time: %{x}
" + + f"Confidence: {confidence:.1%}
" + + f"Expected Change: {price_change:+.2f}%" + ), + row=row, col=1 + ) + + except Exception as e: + logger.debug(f"Error adding Transformer predictions to chart: {e}") + def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): """Add COB_RL microstructure predictions as diamond markers""" try: @@ -3895,6 +3985,35 @@ class CleanTradingDashboard: logger.debug(f"Error getting CNN predictions: {e}") return [] + def _get_recent_transformer_predictions(self, symbol: str) -> List[Dict]: + """Get recent Transformer predictions from orchestrator""" + try: + predictions = [] + + # Get REAL predictions from orchestrator + if hasattr(self.orchestrator, 'recent_transformer_predictions'): + predictions.extend(list(self.orchestrator.recent_transformer_predictions.get(symbol, []))) + + # Get from training system as additional source + if hasattr(self, 'training_system') and self.training_system: + if hasattr(self.training_system, 'recent_transformer_predictions'): + predictions.extend(self.training_system.recent_transformer_predictions.get(symbol, [])) + + # Remove duplicates and sort by timestamp + unique_predictions = [] + seen_timestamps = set() + for pred in predictions: + timestamp_key = pred.get('timestamp', datetime.now()).isoformat() + if timestamp_key not in seen_timestamps: + unique_predictions.append(pred) + seen_timestamps.add(timestamp_key) + + return sorted(unique_predictions, key=lambda x: x.get('timestamp', datetime.now())) + + except Exception as e: + logger.debug(f"Error getting Transformer predictions: {e}") + return [] + def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]: """Get REAL prediction accuracy history from validated forward-looking predictions""" try: