store and show (wip) predictions
This commit is contained in:
@@ -372,6 +372,9 @@ class TradingOrchestrator:
|
||||
self.recent_cnn_predictions: Dict[str, deque] = (
|
||||
{}
|
||||
) # {symbol: List[Dict]} - Recent CNN predictions
|
||||
self.recent_transformer_predictions: Dict[str, deque] = (
|
||||
{}
|
||||
) # {symbol: List[Dict]} - Recent Transformer predictions
|
||||
self.prediction_accuracy_history: Dict[str, deque] = (
|
||||
{}
|
||||
) # {symbol: List[Dict]} - Prediction accuracy tracking
|
||||
@@ -379,6 +382,7 @@ class TradingOrchestrator:
|
||||
# Initialize prediction tracking for the primary trading symbol only
|
||||
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
||||
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
||||
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
|
||||
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
||||
self.signal_accumulator[self.symbol] = []
|
||||
|
||||
@@ -1109,6 +1113,8 @@ class TradingOrchestrator:
|
||||
self.recent_dqn_predictions[symbol].clear()
|
||||
for symbol in self.recent_cnn_predictions:
|
||||
self.recent_cnn_predictions[symbol].clear()
|
||||
for symbol in self.recent_transformer_predictions:
|
||||
self.recent_transformer_predictions[symbol].clear()
|
||||
for symbol in self.prediction_accuracy_history:
|
||||
self.prediction_accuracy_history[symbol].clear()
|
||||
|
||||
@@ -2518,6 +2524,9 @@ class TradingOrchestrator:
|
||||
"cnn_predictions_tracked": sum(
|
||||
len(preds) for preds in self.recent_cnn_predictions.values()
|
||||
),
|
||||
"transformer_predictions_tracked": sum(
|
||||
len(preds) for preds in self.recent_transformer_predictions.values()
|
||||
),
|
||||
"accuracy_history_tracked": sum(
|
||||
len(history)
|
||||
for history in self.prediction_accuracy_history.values()
|
||||
@@ -2527,6 +2536,7 @@ class TradingOrchestrator:
|
||||
for symbol in self.symbols
|
||||
if len(self.recent_dqn_predictions.get(symbol, [])) > 0
|
||||
or len(self.recent_cnn_predictions.get(symbol, [])) > 0
|
||||
or len(self.recent_transformer_predictions.get(symbol, [])) > 0
|
||||
],
|
||||
}
|
||||
|
||||
@@ -2791,4 +2801,19 @@ class TradingOrchestrator:
|
||||
"""Set the trading executor for position tracking"""
|
||||
self.trading_executor = trading_executor
|
||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||
|
||||
def store_transformer_prediction(self, symbol: str, prediction: Dict):
|
||||
"""Store a transformer prediction for visualization and tracking"""
|
||||
try:
|
||||
if symbol not in self.recent_transformer_predictions:
|
||||
self.recent_transformer_predictions[symbol] = deque(maxlen=50)
|
||||
|
||||
# Add timestamp if not present
|
||||
if 'timestamp' not in prediction:
|
||||
prediction['timestamp'] = datetime.now()
|
||||
|
||||
self.recent_transformer_predictions[symbol].append(prediction)
|
||||
logger.debug(f"Stored transformer prediction for {symbol}: {prediction.get('action', 'N/A')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing transformer prediction: {e}")
|
||||
|
||||
|
||||
@@ -3284,6 +3284,7 @@ class CleanTradingDashboard:
|
||||
# 2. NEW: Add real-time model predictions overlay
|
||||
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
||||
@@ -3602,6 +3603,95 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
||||
|
||||
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add Transformer price predictions as trend lines with confidence bands"""
|
||||
try:
|
||||
# Get recent Transformer predictions from orchestrator
|
||||
transformer_predictions = self._get_recent_transformer_predictions(symbol)
|
||||
|
||||
if not transformer_predictions:
|
||||
return
|
||||
|
||||
for i, pred in enumerate(transformer_predictions[-15:]): # Last 15 Transformer predictions
|
||||
confidence = pred.get('confidence', 0)
|
||||
timestamp = pred.get('timestamp', datetime.now())
|
||||
current_price = pred.get('current_price', 0)
|
||||
predicted_price = pred.get('predicted_price', current_price)
|
||||
price_change = pred.get('price_change', 0)
|
||||
|
||||
# FILTER OUT INVALID PRICES
|
||||
if (current_price is None or current_price <= 0 or
|
||||
predicted_price is None or predicted_price <= 0):
|
||||
continue
|
||||
|
||||
if confidence > 0.3: # Show predictions with reasonable confidence
|
||||
# Calculate prediction horizon (typically 5-15 minutes)
|
||||
horizon_minutes = pred.get('horizon_minutes', 10)
|
||||
end_time = timestamp + timedelta(minutes=horizon_minutes)
|
||||
|
||||
# Determine color based on price change direction
|
||||
if price_change > 0.5: # Significant UP
|
||||
color = f'rgba(0, 200, 255, {0.3 + confidence * 0.5})'
|
||||
line_color = 'cyan'
|
||||
prediction_name = 'Transformer UP'
|
||||
elif price_change < -0.5: # Significant DOWN
|
||||
color = f'rgba(255, 100, 0, {0.3 + confidence * 0.5})'
|
||||
line_color = 'orange'
|
||||
prediction_name = 'Transformer DOWN'
|
||||
else: # Small change
|
||||
color = f'rgba(150, 150, 255, {0.2 + confidence * 0.4})'
|
||||
line_color = 'lightblue'
|
||||
prediction_name = 'Transformer STABLE'
|
||||
|
||||
# Add prediction line
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[timestamp, end_time],
|
||||
y=[current_price, predicted_price],
|
||||
mode='lines',
|
||||
line=dict(
|
||||
color=line_color,
|
||||
width=2 + confidence * 4,
|
||||
dash='dashdot'
|
||||
),
|
||||
name=f'{prediction_name}',
|
||||
showlegend=i == 0,
|
||||
hovertemplate=f"<b>{prediction_name}</b><br>" +
|
||||
"From: $%{y[0]:.2f}<br>" +
|
||||
"To: $%{y[1]:.2f}<br>" +
|
||||
"Time: %{x[0]} → %{x[1]}<br>" +
|
||||
f"Confidence: {confidence:.1%}<br>" +
|
||||
f"Change: {price_change:+.2f}%<extra></extra>"
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add prediction target marker
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[end_time],
|
||||
y=[predicted_price],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='star',
|
||||
size=8 + confidence * 10,
|
||||
color=color,
|
||||
line=dict(width=2, color=line_color)
|
||||
),
|
||||
name=f'{prediction_name} Target',
|
||||
showlegend=False,
|
||||
hovertemplate=f"<b>TRANSFORMER TARGET</b><br>" +
|
||||
"Target: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
f"Confidence: {confidence:.1%}<br>" +
|
||||
f"Expected Change: {price_change:+.2f}%<extra></extra>"
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding Transformer predictions to chart: {e}")
|
||||
|
||||
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add COB_RL microstructure predictions as diamond markers"""
|
||||
try:
|
||||
@@ -3895,6 +3985,35 @@ class CleanTradingDashboard:
|
||||
logger.debug(f"Error getting CNN predictions: {e}")
|
||||
return []
|
||||
|
||||
def _get_recent_transformer_predictions(self, symbol: str) -> List[Dict]:
|
||||
"""Get recent Transformer predictions from orchestrator"""
|
||||
try:
|
||||
predictions = []
|
||||
|
||||
# Get REAL predictions from orchestrator
|
||||
if hasattr(self.orchestrator, 'recent_transformer_predictions'):
|
||||
predictions.extend(list(self.orchestrator.recent_transformer_predictions.get(symbol, [])))
|
||||
|
||||
# Get from training system as additional source
|
||||
if hasattr(self, 'training_system') and self.training_system:
|
||||
if hasattr(self.training_system, 'recent_transformer_predictions'):
|
||||
predictions.extend(self.training_system.recent_transformer_predictions.get(symbol, []))
|
||||
|
||||
# Remove duplicates and sort by timestamp
|
||||
unique_predictions = []
|
||||
seen_timestamps = set()
|
||||
for pred in predictions:
|
||||
timestamp_key = pred.get('timestamp', datetime.now()).isoformat()
|
||||
if timestamp_key not in seen_timestamps:
|
||||
unique_predictions.append(pred)
|
||||
seen_timestamps.add(timestamp_key)
|
||||
|
||||
return sorted(unique_predictions, key=lambda x: x.get('timestamp', datetime.now()))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting Transformer predictions: {e}")
|
||||
return []
|
||||
|
||||
def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
|
||||
"""Get REAL prediction accuracy history from validated forward-looking predictions"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user