store and show (wip) predictions
This commit is contained in:
@@ -372,6 +372,9 @@ class TradingOrchestrator:
|
|||||||
self.recent_cnn_predictions: Dict[str, deque] = (
|
self.recent_cnn_predictions: Dict[str, deque] = (
|
||||||
{}
|
{}
|
||||||
) # {symbol: List[Dict]} - Recent CNN predictions
|
) # {symbol: List[Dict]} - Recent CNN predictions
|
||||||
|
self.recent_transformer_predictions: Dict[str, deque] = (
|
||||||
|
{}
|
||||||
|
) # {symbol: List[Dict]} - Recent Transformer predictions
|
||||||
self.prediction_accuracy_history: Dict[str, deque] = (
|
self.prediction_accuracy_history: Dict[str, deque] = (
|
||||||
{}
|
{}
|
||||||
) # {symbol: List[Dict]} - Prediction accuracy tracking
|
) # {symbol: List[Dict]} - Prediction accuracy tracking
|
||||||
@@ -379,6 +382,7 @@ class TradingOrchestrator:
|
|||||||
# Initialize prediction tracking for the primary trading symbol only
|
# Initialize prediction tracking for the primary trading symbol only
|
||||||
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
||||||
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
||||||
|
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
|
||||||
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
||||||
self.signal_accumulator[self.symbol] = []
|
self.signal_accumulator[self.symbol] = []
|
||||||
|
|
||||||
@@ -1109,6 +1113,8 @@ class TradingOrchestrator:
|
|||||||
self.recent_dqn_predictions[symbol].clear()
|
self.recent_dqn_predictions[symbol].clear()
|
||||||
for symbol in self.recent_cnn_predictions:
|
for symbol in self.recent_cnn_predictions:
|
||||||
self.recent_cnn_predictions[symbol].clear()
|
self.recent_cnn_predictions[symbol].clear()
|
||||||
|
for symbol in self.recent_transformer_predictions:
|
||||||
|
self.recent_transformer_predictions[symbol].clear()
|
||||||
for symbol in self.prediction_accuracy_history:
|
for symbol in self.prediction_accuracy_history:
|
||||||
self.prediction_accuracy_history[symbol].clear()
|
self.prediction_accuracy_history[symbol].clear()
|
||||||
|
|
||||||
@@ -2518,6 +2524,9 @@ class TradingOrchestrator:
|
|||||||
"cnn_predictions_tracked": sum(
|
"cnn_predictions_tracked": sum(
|
||||||
len(preds) for preds in self.recent_cnn_predictions.values()
|
len(preds) for preds in self.recent_cnn_predictions.values()
|
||||||
),
|
),
|
||||||
|
"transformer_predictions_tracked": sum(
|
||||||
|
len(preds) for preds in self.recent_transformer_predictions.values()
|
||||||
|
),
|
||||||
"accuracy_history_tracked": sum(
|
"accuracy_history_tracked": sum(
|
||||||
len(history)
|
len(history)
|
||||||
for history in self.prediction_accuracy_history.values()
|
for history in self.prediction_accuracy_history.values()
|
||||||
@@ -2527,6 +2536,7 @@ class TradingOrchestrator:
|
|||||||
for symbol in self.symbols
|
for symbol in self.symbols
|
||||||
if len(self.recent_dqn_predictions.get(symbol, [])) > 0
|
if len(self.recent_dqn_predictions.get(symbol, [])) > 0
|
||||||
or len(self.recent_cnn_predictions.get(symbol, [])) > 0
|
or len(self.recent_cnn_predictions.get(symbol, [])) > 0
|
||||||
|
or len(self.recent_transformer_predictions.get(symbol, [])) > 0
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2792,3 +2802,18 @@ class TradingOrchestrator:
|
|||||||
self.trading_executor = trading_executor
|
self.trading_executor = trading_executor
|
||||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||||
|
|
||||||
|
def store_transformer_prediction(self, symbol: str, prediction: Dict):
|
||||||
|
"""Store a transformer prediction for visualization and tracking"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.recent_transformer_predictions:
|
||||||
|
self.recent_transformer_predictions[symbol] = deque(maxlen=50)
|
||||||
|
|
||||||
|
# Add timestamp if not present
|
||||||
|
if 'timestamp' not in prediction:
|
||||||
|
prediction['timestamp'] = datetime.now()
|
||||||
|
|
||||||
|
self.recent_transformer_predictions[symbol].append(prediction)
|
||||||
|
logger.debug(f"Stored transformer prediction for {symbol}: {prediction.get('action', 'N/A')}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error storing transformer prediction: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -3284,6 +3284,7 @@ class CleanTradingDashboard:
|
|||||||
# 2. NEW: Add real-time model predictions overlay
|
# 2. NEW: Add real-time model predictions overlay
|
||||||
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
|
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||||
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
||||||
@@ -3602,6 +3603,95 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
||||||
|
|
||||||
|
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||||
|
"""Add Transformer price predictions as trend lines with confidence bands"""
|
||||||
|
try:
|
||||||
|
# Get recent Transformer predictions from orchestrator
|
||||||
|
transformer_predictions = self._get_recent_transformer_predictions(symbol)
|
||||||
|
|
||||||
|
if not transformer_predictions:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, pred in enumerate(transformer_predictions[-15:]): # Last 15 Transformer predictions
|
||||||
|
confidence = pred.get('confidence', 0)
|
||||||
|
timestamp = pred.get('timestamp', datetime.now())
|
||||||
|
current_price = pred.get('current_price', 0)
|
||||||
|
predicted_price = pred.get('predicted_price', current_price)
|
||||||
|
price_change = pred.get('price_change', 0)
|
||||||
|
|
||||||
|
# FILTER OUT INVALID PRICES
|
||||||
|
if (current_price is None or current_price <= 0 or
|
||||||
|
predicted_price is None or predicted_price <= 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if confidence > 0.3: # Show predictions with reasonable confidence
|
||||||
|
# Calculate prediction horizon (typically 5-15 minutes)
|
||||||
|
horizon_minutes = pred.get('horizon_minutes', 10)
|
||||||
|
end_time = timestamp + timedelta(minutes=horizon_minutes)
|
||||||
|
|
||||||
|
# Determine color based on price change direction
|
||||||
|
if price_change > 0.5: # Significant UP
|
||||||
|
color = f'rgba(0, 200, 255, {0.3 + confidence * 0.5})'
|
||||||
|
line_color = 'cyan'
|
||||||
|
prediction_name = 'Transformer UP'
|
||||||
|
elif price_change < -0.5: # Significant DOWN
|
||||||
|
color = f'rgba(255, 100, 0, {0.3 + confidence * 0.5})'
|
||||||
|
line_color = 'orange'
|
||||||
|
prediction_name = 'Transformer DOWN'
|
||||||
|
else: # Small change
|
||||||
|
color = f'rgba(150, 150, 255, {0.2 + confidence * 0.4})'
|
||||||
|
line_color = 'lightblue'
|
||||||
|
prediction_name = 'Transformer STABLE'
|
||||||
|
|
||||||
|
# Add prediction line
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=[timestamp, end_time],
|
||||||
|
y=[current_price, predicted_price],
|
||||||
|
mode='lines',
|
||||||
|
line=dict(
|
||||||
|
color=line_color,
|
||||||
|
width=2 + confidence * 4,
|
||||||
|
dash='dashdot'
|
||||||
|
),
|
||||||
|
name=f'{prediction_name}',
|
||||||
|
showlegend=i == 0,
|
||||||
|
hovertemplate=f"<b>{prediction_name}</b><br>" +
|
||||||
|
"From: $%{y[0]:.2f}<br>" +
|
||||||
|
"To: $%{y[1]:.2f}<br>" +
|
||||||
|
"Time: %{x[0]} → %{x[1]}<br>" +
|
||||||
|
f"Confidence: {confidence:.1%}<br>" +
|
||||||
|
f"Change: {price_change:+.2f}%<extra></extra>"
|
||||||
|
),
|
||||||
|
row=row, col=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add prediction target marker
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=[end_time],
|
||||||
|
y=[predicted_price],
|
||||||
|
mode='markers',
|
||||||
|
marker=dict(
|
||||||
|
symbol='star',
|
||||||
|
size=8 + confidence * 10,
|
||||||
|
color=color,
|
||||||
|
line=dict(width=2, color=line_color)
|
||||||
|
),
|
||||||
|
name=f'{prediction_name} Target',
|
||||||
|
showlegend=False,
|
||||||
|
hovertemplate=f"<b>TRANSFORMER TARGET</b><br>" +
|
||||||
|
"Target: $%{y:.2f}<br>" +
|
||||||
|
"Time: %{x}<br>" +
|
||||||
|
f"Confidence: {confidence:.1%}<br>" +
|
||||||
|
f"Expected Change: {price_change:+.2f}%<extra></extra>"
|
||||||
|
),
|
||||||
|
row=row, col=1
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error adding Transformer predictions to chart: {e}")
|
||||||
|
|
||||||
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||||
"""Add COB_RL microstructure predictions as diamond markers"""
|
"""Add COB_RL microstructure predictions as diamond markers"""
|
||||||
try:
|
try:
|
||||||
@@ -3895,6 +3985,35 @@ class CleanTradingDashboard:
|
|||||||
logger.debug(f"Error getting CNN predictions: {e}")
|
logger.debug(f"Error getting CNN predictions: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _get_recent_transformer_predictions(self, symbol: str) -> List[Dict]:
|
||||||
|
"""Get recent Transformer predictions from orchestrator"""
|
||||||
|
try:
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
# Get REAL predictions from orchestrator
|
||||||
|
if hasattr(self.orchestrator, 'recent_transformer_predictions'):
|
||||||
|
predictions.extend(list(self.orchestrator.recent_transformer_predictions.get(symbol, [])))
|
||||||
|
|
||||||
|
# Get from training system as additional source
|
||||||
|
if hasattr(self, 'training_system') and self.training_system:
|
||||||
|
if hasattr(self.training_system, 'recent_transformer_predictions'):
|
||||||
|
predictions.extend(self.training_system.recent_transformer_predictions.get(symbol, []))
|
||||||
|
|
||||||
|
# Remove duplicates and sort by timestamp
|
||||||
|
unique_predictions = []
|
||||||
|
seen_timestamps = set()
|
||||||
|
for pred in predictions:
|
||||||
|
timestamp_key = pred.get('timestamp', datetime.now()).isoformat()
|
||||||
|
if timestamp_key not in seen_timestamps:
|
||||||
|
unique_predictions.append(pred)
|
||||||
|
seen_timestamps.add(timestamp_key)
|
||||||
|
|
||||||
|
return sorted(unique_predictions, key=lambda x: x.get('timestamp', datetime.now()))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting Transformer predictions: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
|
def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
|
||||||
"""Get REAL prediction accuracy history from validated forward-looking predictions"""
|
"""Get REAL prediction accuracy history from validated forward-looking predictions"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user