store and show (wip) predictions

This commit is contained in:
Dobromir Popov
2025-11-19 10:38:16 +02:00
parent 2d1d036c07
commit 8ee8558829
2 changed files with 144 additions and 0 deletions

View File

@@ -372,6 +372,9 @@ class TradingOrchestrator:
self.recent_cnn_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent CNN predictions
self.recent_transformer_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent Transformer predictions
self.prediction_accuracy_history: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Prediction accuracy tracking
@@ -379,6 +382,7 @@ class TradingOrchestrator:
# Initialize prediction tracking for the primary trading symbol only
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
self.signal_accumulator[self.symbol] = []
@@ -1109,6 +1113,8 @@ class TradingOrchestrator:
self.recent_dqn_predictions[symbol].clear()
for symbol in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol].clear()
for symbol in self.recent_transformer_predictions:
self.recent_transformer_predictions[symbol].clear()
for symbol in self.prediction_accuracy_history:
self.prediction_accuracy_history[symbol].clear()
@@ -2518,6 +2524,9 @@ class TradingOrchestrator:
"cnn_predictions_tracked": sum(
len(preds) for preds in self.recent_cnn_predictions.values()
),
"transformer_predictions_tracked": sum(
len(preds) for preds in self.recent_transformer_predictions.values()
),
"accuracy_history_tracked": sum(
len(history)
for history in self.prediction_accuracy_history.values()
@@ -2527,6 +2536,7 @@ class TradingOrchestrator:
for symbol in self.symbols
if len(self.recent_dqn_predictions.get(symbol, [])) > 0
or len(self.recent_cnn_predictions.get(symbol, [])) > 0
or len(self.recent_transformer_predictions.get(symbol, [])) > 0
],
}
@@ -2792,3 +2802,18 @@ class TradingOrchestrator:
self.trading_executor = trading_executor
logger.info("Trading executor set for position tracking and P&L feedback")
def store_transformer_prediction(self, symbol: str, prediction: Dict):
"""Store a transformer prediction for visualization and tracking"""
try:
if symbol not in self.recent_transformer_predictions:
self.recent_transformer_predictions[symbol] = deque(maxlen=50)
# Add timestamp if not present
if 'timestamp' not in prediction:
prediction['timestamp'] = datetime.now()
self.recent_transformer_predictions[symbol].append(prediction)
logger.debug(f"Stored transformer prediction for {symbol}: {prediction.get('action', 'N/A')}")
except Exception as e:
logger.error(f"Error storing transformer prediction: {e}")

View File

@@ -3284,6 +3284,7 @@ class CleanTradingDashboard:
# 2. NEW: Add real-time model predictions overlay
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
@@ -3602,6 +3603,95 @@ class CleanTradingDashboard:
except Exception as e:
logger.debug(f"Error adding CNN predictions to chart: {e}")
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add Transformer price predictions as trend lines with confidence bands"""
try:
# Get recent Transformer predictions from orchestrator
transformer_predictions = self._get_recent_transformer_predictions(symbol)
if not transformer_predictions:
return
for i, pred in enumerate(transformer_predictions[-15:]): # Last 15 Transformer predictions
confidence = pred.get('confidence', 0)
timestamp = pred.get('timestamp', datetime.now())
current_price = pred.get('current_price', 0)
predicted_price = pred.get('predicted_price', current_price)
price_change = pred.get('price_change', 0)
# FILTER OUT INVALID PRICES
if (current_price is None or current_price <= 0 or
predicted_price is None or predicted_price <= 0):
continue
if confidence > 0.3: # Show predictions with reasonable confidence
# Calculate prediction horizon (typically 5-15 minutes)
horizon_minutes = pred.get('horizon_minutes', 10)
end_time = timestamp + timedelta(minutes=horizon_minutes)
# Determine color based on price change direction
if price_change > 0.5: # Significant UP
color = f'rgba(0, 200, 255, {0.3 + confidence * 0.5})'
line_color = 'cyan'
prediction_name = 'Transformer UP'
elif price_change < -0.5: # Significant DOWN
color = f'rgba(255, 100, 0, {0.3 + confidence * 0.5})'
line_color = 'orange'
prediction_name = 'Transformer DOWN'
else: # Small change
color = f'rgba(150, 150, 255, {0.2 + confidence * 0.4})'
line_color = 'lightblue'
prediction_name = 'Transformer STABLE'
# Add prediction line
fig.add_trace(
go.Scatter(
x=[timestamp, end_time],
y=[current_price, predicted_price],
mode='lines',
line=dict(
color=line_color,
width=2 + confidence * 4,
dash='dashdot'
),
name=f'{prediction_name}',
showlegend=i == 0,
hovertemplate=f"<b>{prediction_name}</b><br>" +
"From: $%{y[0]:.2f}<br>" +
"To: $%{y[1]:.2f}<br>" +
"Time: %{x[0]}%{x[1]}<br>" +
f"Confidence: {confidence:.1%}<br>" +
f"Change: {price_change:+.2f}%<extra></extra>"
),
row=row, col=1
)
# Add prediction target marker
fig.add_trace(
go.Scatter(
x=[end_time],
y=[predicted_price],
mode='markers',
marker=dict(
symbol='star',
size=8 + confidence * 10,
color=color,
line=dict(width=2, color=line_color)
),
name=f'{prediction_name} Target',
showlegend=False,
hovertemplate=f"<b>TRANSFORMER TARGET</b><br>" +
"Target: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
f"Confidence: {confidence:.1%}<br>" +
f"Expected Change: {price_change:+.2f}%<extra></extra>"
),
row=row, col=1
)
except Exception as e:
logger.debug(f"Error adding Transformer predictions to chart: {e}")
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add COB_RL microstructure predictions as diamond markers"""
try:
@@ -3895,6 +3985,35 @@ class CleanTradingDashboard:
logger.debug(f"Error getting CNN predictions: {e}")
return []
def _get_recent_transformer_predictions(self, symbol: str) -> List[Dict]:
"""Get recent Transformer predictions from orchestrator"""
try:
predictions = []
# Get REAL predictions from orchestrator
if hasattr(self.orchestrator, 'recent_transformer_predictions'):
predictions.extend(list(self.orchestrator.recent_transformer_predictions.get(symbol, [])))
# Get from training system as additional source
if hasattr(self, 'training_system') and self.training_system:
if hasattr(self.training_system, 'recent_transformer_predictions'):
predictions.extend(self.training_system.recent_transformer_predictions.get(symbol, []))
# Remove duplicates and sort by timestamp
unique_predictions = []
seen_timestamps = set()
for pred in predictions:
timestamp_key = pred.get('timestamp', datetime.now()).isoformat()
if timestamp_key not in seen_timestamps:
unique_predictions.append(pred)
seen_timestamps.add(timestamp_key)
return sorted(unique_predictions, key=lambda x: x.get('timestamp', datetime.now()))
except Exception as e:
logger.debug(f"Error getting Transformer predictions: {e}")
return []
def _get_prediction_accuracy_history(self, symbol: str) -> List[Dict]:
"""Get REAL prediction accuracy history from validated forward-looking predictions"""
try: