adding model predictions to dash (wip)
This commit is contained in:
@ -339,9 +339,15 @@ class CleanTradingDashboard:
|
||||
def _create_price_chart(self, symbol: str) -> go.Figure:
|
||||
"""Create 1-minute main chart with 1-second mini chart - Updated every second"""
|
||||
try:
|
||||
# FIXED: Merge historical + live data instead of replacing
|
||||
# 1. Get historical 1-minute data as base (180 candles = 3 hours)
|
||||
df_historical = self.data_provider.get_historical_data(symbol, '1m', limit=180)
|
||||
# FIXED: Always get fresh data on startup to avoid gaps
|
||||
# 1. Get historical 1-minute data as base (180 candles = 3 hours) - FORCE REFRESH on first load
|
||||
is_startup = not hasattr(self, '_chart_initialized') or not self._chart_initialized
|
||||
df_historical = self.data_provider.get_historical_data(symbol, '1m', limit=180, refresh=is_startup)
|
||||
|
||||
# Mark chart as initialized to use cache on subsequent loads
|
||||
if is_startup:
|
||||
self._chart_initialized = True
|
||||
logger.info(f"[STARTUP] Fetched fresh {symbol} 1m data to avoid gaps")
|
||||
|
||||
# 2. Get WebSocket 1s data and convert to 1m bars
|
||||
ws_data_raw = self._get_websocket_chart_data(symbol, 'raw')
|
||||
@ -433,6 +439,12 @@ class CleanTradingDashboard:
|
||||
row=1, col=1
|
||||
)
|
||||
|
||||
# ADD MODEL PREDICTIONS TO MAIN CHART
|
||||
self._add_model_predictions_to_chart(fig, symbol, df_main, row=1)
|
||||
|
||||
# ADD TRADES TO MAIN CHART
|
||||
self._add_trades_to_chart(fig, symbol, df_main, row=1)
|
||||
|
||||
# Mini 1-second chart (if available)
|
||||
if has_mini_chart:
|
||||
fig.add_trace(
|
||||
@ -465,7 +477,7 @@ class CleanTradingDashboard:
|
||||
fig.update_layout(
|
||||
title=f'{symbol} Live Chart - {main_source} (Updated Every Second)',
|
||||
template='plotly_dark',
|
||||
showlegend=False,
|
||||
showlegend=True, # Show legend for model predictions
|
||||
height=chart_height,
|
||||
margin=dict(l=50, r=50, t=60, b=50),
|
||||
xaxis_rangeslider_visible=False
|
||||
@ -502,6 +514,246 @@ class CleanTradingDashboard:
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, showarrow=False)
|
||||
|
||||
def _add_model_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add model predictions to the chart"""
|
||||
try:
|
||||
# Get CNN predictions from orchestrator
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'get_recent_predictions'):
|
||||
try:
|
||||
cnn_predictions = self.orchestrator.get_recent_predictions(symbol)
|
||||
if cnn_predictions:
|
||||
# Separate by prediction type
|
||||
buy_predictions = []
|
||||
sell_predictions = []
|
||||
|
||||
for pred in cnn_predictions[-20:]: # Last 20 predictions
|
||||
pred_time = pred.get('timestamp')
|
||||
pred_price = pred.get('price', 0)
|
||||
pred_action = pred.get('action', 'HOLD')
|
||||
pred_confidence = pred.get('confidence', 0)
|
||||
|
||||
if pred_time and pred_price and pred_confidence > 0.5: # Only confident predictions
|
||||
if pred_action == 'BUY':
|
||||
buy_predictions.append({'x': pred_time, 'y': pred_price, 'confidence': pred_confidence})
|
||||
elif pred_action == 'SELL':
|
||||
sell_predictions.append({'x': pred_time, 'y': pred_price, 'confidence': pred_confidence})
|
||||
|
||||
# Add BUY predictions (green triangles)
|
||||
if buy_predictions:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in buy_predictions],
|
||||
y=[p['y'] for p in buy_predictions],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='triangle-up',
|
||||
size=12,
|
||||
color='rgba(0, 255, 100, 0.8)',
|
||||
line=dict(width=2, color='green')
|
||||
),
|
||||
name='CNN BUY Predictions',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>CNN BUY Prediction</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"Confidence: %{customdata:.1%}<extra></extra>",
|
||||
customdata=[p['confidence'] for p in buy_predictions]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add SELL predictions (red triangles)
|
||||
if sell_predictions:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in sell_predictions],
|
||||
y=[p['y'] for p in sell_predictions],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='triangle-down',
|
||||
size=12,
|
||||
color='rgba(255, 100, 100, 0.8)',
|
||||
line=dict(width=2, color='red')
|
||||
),
|
||||
name='CNN SELL Predictions',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>CNN SELL Prediction</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"Confidence: %{customdata:.1%}<extra></extra>",
|
||||
customdata=[p['confidence'] for p in sell_predictions]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get CNN predictions: {e}")
|
||||
|
||||
# Get COB RL predictions
|
||||
if hasattr(self, 'cob_predictions') and symbol in self.cob_predictions:
|
||||
try:
|
||||
cob_preds = self.cob_predictions[symbol][-10:] # Last 10 COB predictions
|
||||
|
||||
up_predictions = []
|
||||
down_predictions = []
|
||||
|
||||
for pred in cob_preds:
|
||||
pred_time = pred.get('timestamp')
|
||||
pred_direction = pred.get('direction', 1) # 0=DOWN, 1=SIDEWAYS, 2=UP
|
||||
pred_confidence = pred.get('confidence', 0)
|
||||
|
||||
if pred_time and pred_confidence > 0.7: # Only high confidence COB predictions
|
||||
# Get price from main chart at that time
|
||||
pred_price = self._get_price_at_time(df_main, pred_time)
|
||||
if pred_price:
|
||||
if pred_direction == 2: # UP
|
||||
up_predictions.append({'x': pred_time, 'y': pred_price, 'confidence': pred_confidence})
|
||||
elif pred_direction == 0: # DOWN
|
||||
down_predictions.append({'x': pred_time, 'y': pred_price, 'confidence': pred_confidence})
|
||||
|
||||
# Add COB UP predictions (cyan diamonds)
|
||||
if up_predictions:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in up_predictions],
|
||||
y=[p['y'] for p in up_predictions],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='diamond',
|
||||
size=10,
|
||||
color='rgba(0, 255, 255, 0.9)',
|
||||
line=dict(width=2, color='cyan')
|
||||
),
|
||||
name='COB RL UP (1B)',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>COB RL UP Prediction</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"Confidence: %{customdata:.1%}<br>" +
|
||||
"Model: 1B Parameters<extra></extra>",
|
||||
customdata=[p['confidence'] for p in up_predictions]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add COB DOWN predictions (magenta diamonds)
|
||||
if down_predictions:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[p['x'] for p in down_predictions],
|
||||
y=[p['y'] for p in down_predictions],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='diamond',
|
||||
size=10,
|
||||
color='rgba(255, 0, 255, 0.9)',
|
||||
line=dict(width=2, color='magenta')
|
||||
),
|
||||
name='COB RL DOWN (1B)',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>COB RL DOWN Prediction</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"Confidence: %{customdata:.1%}<br>" +
|
||||
"Model: 1B Parameters<extra></extra>",
|
||||
customdata=[p['confidence'] for p in down_predictions]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get COB predictions: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding model predictions to chart: {e}")
|
||||
|
||||
def _add_trades_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add executed trades to the chart"""
|
||||
try:
|
||||
if not self.closed_trades:
|
||||
return
|
||||
|
||||
buy_trades = []
|
||||
sell_trades = []
|
||||
|
||||
for trade in self.closed_trades[-20:]: # Last 20 trades
|
||||
entry_time = trade.get('entry_time')
|
||||
side = trade.get('side', 'UNKNOWN')
|
||||
entry_price = trade.get('entry_price', 0)
|
||||
pnl = trade.get('pnl', 0)
|
||||
|
||||
if entry_time and entry_price:
|
||||
trade_data = {'x': entry_time, 'y': entry_price, 'pnl': pnl}
|
||||
|
||||
if side == 'BUY':
|
||||
buy_trades.append(trade_data)
|
||||
elif side == 'SELL':
|
||||
sell_trades.append(trade_data)
|
||||
|
||||
# Add BUY trades (green circles)
|
||||
if buy_trades:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[t['x'] for t in buy_trades],
|
||||
y=[t['y'] for t in buy_trades],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='circle',
|
||||
size=8,
|
||||
color='rgba(0, 255, 0, 0.7)',
|
||||
line=dict(width=2, color='green')
|
||||
),
|
||||
name='BUY Trades',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>BUY Trade Executed</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"P&L: $%{customdata:.2f}<extra></extra>",
|
||||
customdata=[t['pnl'] for t in buy_trades]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add SELL trades (red circles)
|
||||
if sell_trades:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[t['x'] for t in sell_trades],
|
||||
y=[t['y'] for t in sell_trades],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
symbol='circle',
|
||||
size=8,
|
||||
color='rgba(255, 0, 0, 0.7)',
|
||||
line=dict(width=2, color='red')
|
||||
),
|
||||
name='SELL Trades',
|
||||
showlegend=True,
|
||||
hovertemplate="<b>SELL Trade Executed</b><br>" +
|
||||
"Price: $%{y:.2f}<br>" +
|
||||
"Time: %{x}<br>" +
|
||||
"P&L: $%{customdata:.2f}<extra></extra>",
|
||||
customdata=[t['pnl'] for t in sell_trades]
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding trades to chart: {e}")
|
||||
|
||||
def _get_price_at_time(self, df: pd.DataFrame, timestamp) -> Optional[float]:
|
||||
"""Get price from dataframe at specific timestamp"""
|
||||
try:
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = pd.to_datetime(timestamp)
|
||||
|
||||
# Find closest timestamp in dataframe
|
||||
closest_idx = df.index.get_indexer([timestamp], method='nearest')[0]
|
||||
if closest_idx >= 0 and closest_idx < len(df):
|
||||
return float(df.iloc[closest_idx]['close'])
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_websocket_chart_data(self, symbol: str, timeframe: str = '1m') -> Optional[pd.DataFrame]:
|
||||
"""Get WebSocket chart data - supports both 1m and 1s timeframes"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user