extrema trainer WIP
This commit is contained in:
@ -2625,10 +2625,114 @@ class CleanTradingDashboard:
|
||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
||||
|
||||
# 2b. Overlay extrema next-pivot prediction vector (within 5 minutes)
|
||||
try:
|
||||
latest_predictions = self._get_latest_model_predictions()
|
||||
extrema_pred = latest_predictions.get('extrema_trainer', {})
|
||||
pred_obj = extrema_pred.get('prediction') if isinstance(extrema_pred, dict) else None
|
||||
if pred_obj and 'predicted_time' in pred_obj and 'predicted_price' in pred_obj:
|
||||
# Build a short vector from now/current price to predicted pivot
|
||||
now_ts = df_main.index[-1] if not df_main.empty else datetime.now()
|
||||
current_price = float(df_main['close'].iloc[-1]) if not df_main.empty else self._get_current_price(symbol)
|
||||
pt = pred_obj['predicted_time']
|
||||
# Ensure datetime
|
||||
if not isinstance(pt, datetime):
|
||||
try:
|
||||
pt = pd.to_datetime(pt)
|
||||
except Exception:
|
||||
pt = now_ts
|
||||
pp = float(pred_obj['predicted_price'])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[now_ts, pt],
|
||||
y=[current_price, pp],
|
||||
mode='lines+markers',
|
||||
line=dict(color='#cddc39', width=2, dash='dot'),
|
||||
marker=dict(size=6, color='#cddc39'),
|
||||
name='Next Pivot (<=5m)'
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error overlaying extrema pivot vector: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding model predictions to chart: {e}")
|
||||
|
||||
def _add_williams_pivots_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Overlay first two levels of Williams pivot points on the 1m chart."""
|
||||
try:
|
||||
if not hasattr(self.data_provider, 'get_recent_pivot_points'):
|
||||
return
|
||||
pivots_lvl1 = self.data_provider.get_recent_pivot_points(symbol, level=1, count=50) or []
|
||||
pivots_lvl2 = self.data_provider.get_recent_pivot_points(symbol, level=2, count=50) or []
|
||||
|
||||
# Normalize timestamps to match chart timezone (avoid 3h offset)
|
||||
try:
|
||||
from datetime import timezone as _dt_tz
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
|
||||
def normalize_ts(ts: datetime) -> datetime:
|
||||
try:
|
||||
if ts is None:
|
||||
return ts
|
||||
if getattr(df_main.index, 'tz', None) is not None:
|
||||
# Match DataFrame index timezone
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=_dt_tz.utc)
|
||||
return ts.astimezone(df_main.index.tz)
|
||||
# Chart is tz-naive: convert to local and drop tz
|
||||
if ts.tzinfo is None:
|
||||
# Treat naive timestamp as UTC coming from server calculations
|
||||
ts = ts.replace(tzinfo=_dt_tz.utc)
|
||||
ts_local = ts.astimezone(local_tz)
|
||||
return ts_local.replace(tzinfo=None)
|
||||
except Exception:
|
||||
return ts
|
||||
except Exception:
|
||||
def normalize_ts(ts: datetime) -> datetime:
|
||||
return ts
|
||||
|
||||
def to_xy(pivots):
|
||||
xs_h, ys_h, xs_l, ys_l = [], [], [], []
|
||||
for p in pivots:
|
||||
ts = getattr(p, 'timestamp', None)
|
||||
price = getattr(p, 'price', None)
|
||||
ptype = getattr(p, 'type', getattr(p, 'pivot_type', 'low'))
|
||||
if ts and price:
|
||||
ts = normalize_ts(ts)
|
||||
if str(ptype).lower() == 'high':
|
||||
xs_h.append(ts); ys_h.append(price)
|
||||
else:
|
||||
xs_l.append(ts); ys_l.append(price)
|
||||
return xs_h, ys_h, xs_l, ys_l
|
||||
|
||||
l1xh, l1yh, l1xl, l1yl = to_xy(pivots_lvl1)
|
||||
l2xh, l2yh, l2xl, l2yl = to_xy(pivots_lvl2)
|
||||
|
||||
if l1xh or l1xl:
|
||||
if l1xh:
|
||||
fig.add_trace(go.Scatter(x=l1xh, y=l1yh, mode='markers', name='L1 High',
|
||||
marker=dict(color='#ff7043', size=7, symbol='triangle-up'), hoverinfo='skip'),
|
||||
row=row, col=1)
|
||||
if l1xl:
|
||||
fig.add_trace(go.Scatter(x=l1xl, y=l1yl, mode='markers', name='L1 Low',
|
||||
marker=dict(color='#42a5f5', size=7, symbol='triangle-down'), hoverinfo='skip'),
|
||||
row=row, col=1)
|
||||
if l2xh or l2xl:
|
||||
if l2xh:
|
||||
fig.add_trace(go.Scatter(x=l2xh, y=l2yh, mode='markers', name='L2 High',
|
||||
marker=dict(color='#ef6c00', size=6, symbol='triangle-up-open'), hoverinfo='skip'),
|
||||
row=row, col=1)
|
||||
if l2xl:
|
||||
fig.add_trace(go.Scatter(x=l2xl, y=l2yl, mode='markers', name='L2 Low',
|
||||
marker=dict(color='#1e88e5', size=6, symbol='triangle-down-open'), hoverinfo='skip'),
|
||||
row=row, col=1)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error overlaying Williams pivots: {e}")
|
||||
|
||||
def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add DQN action predictions as directional arrows"""
|
||||
try:
|
||||
@ -5900,7 +6004,7 @@ class CleanTradingDashboard:
|
||||
state_features = self._get_dqn_state_features(symbol, current_price)
|
||||
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
||||
return self.orchestrator.rl_agent.predict(state_features)
|
||||
return 0.5 # Default neutral prediction
|
||||
return 0.0
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
@ -5914,7 +6018,7 @@ class CleanTradingDashboard:
|
||||
if features:
|
||||
import numpy as np
|
||||
return self.orchestrator.cnn_model.predict(np.array([features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
return 0.0
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
@ -5928,7 +6032,7 @@ class CleanTradingDashboard:
|
||||
if features:
|
||||
import numpy as np
|
||||
return self.orchestrator.primary_transformer.predict(np.array([features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
return 0.0
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
@ -5940,7 +6044,7 @@ class CleanTradingDashboard:
|
||||
if cob_features and hasattr(self.orchestrator.cob_rl_agent, 'predict'):
|
||||
import numpy as np
|
||||
return self.orchestrator.cob_rl_agent.predict(np.array([cob_features]))
|
||||
return 0.5 # Default neutral prediction
|
||||
return 0.0
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
|
@ -442,20 +442,12 @@ class DashboardComponentManager:
|
||||
extras.append(html.Small(f"Recent ticks: {len(recent)}", className="text-muted ms-2"))
|
||||
extras_div = html.Div(extras, className="mb-1") if extras else None
|
||||
|
||||
# Insert mini heatmap inside the COB panel (right side)
|
||||
# Heatmap is rendered in dedicated tiles (avoid duplicate component IDs)
|
||||
heatmap_graph = None
|
||||
try:
|
||||
# The dashboard's data provider is accessible through a global reference on the dashboard instance.
|
||||
# We embed a placeholder Graph here; actual figure is provided by the dashboard callback tied to this id.
|
||||
graph_id = 'cob-heatmap-eth' if 'ETH' in symbol else 'cob-heatmap-btc'
|
||||
heatmap_graph = dcc.Graph(id=graph_id, config={'displayModeBar': False}, style={"height": "220px"})
|
||||
except Exception:
|
||||
heatmap_graph = None
|
||||
|
||||
children = [dbc.Col(overview_panel, width=5, className="pe-1")]
|
||||
right_children = []
|
||||
if heatmap_graph:
|
||||
right_children.append(heatmap_graph)
|
||||
# Do not append inline heatmap here to prevent duplicate IDs
|
||||
right_children.append(ladder_panel)
|
||||
if extras_div:
|
||||
right_children.insert(0, extras_div)
|
||||
|
Reference in New Issue
Block a user