pivot points option in UI
This commit is contained in:
@@ -584,7 +584,6 @@ class TradingOrchestrator:
|
|||||||
return alias_to_canonical.get(name, name)
|
return alias_to_canonical.get(name, name)
|
||||||
except Exception:
|
except Exception:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def _initialize_ml_models(self):
|
def _initialize_ml_models(self):
|
||||||
"""Initialize ML models for enhanced trading"""
|
"""Initialize ML models for enhanced trading"""
|
||||||
try:
|
try:
|
||||||
@@ -738,45 +737,42 @@ class TradingOrchestrator:
|
|||||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
||||||
"enhanced_cnn"
|
"enhanced_cnn"
|
||||||
)
|
)
|
||||||
if checkpoint_metadata:
|
if checkpoint_metadata and os.path.exists(checkpoint_metadata.file_path):
|
||||||
self.model_states["cnn"]["initial_loss"] = 0.412
|
try:
|
||||||
self.model_states["cnn"]["current_loss"] = (
|
saved = torch.load(checkpoint_metadata.file_path, map_location=self.device)
|
||||||
checkpoint_metadata.performance_metrics.get("loss", 0.0187)
|
if saved and saved.get("model_state_dict"):
|
||||||
)
|
self.cnn_model.load_state_dict(saved["model_state_dict"], strict=False)
|
||||||
self.model_states["cnn"]["best_loss"] = (
|
checkpoint_loaded = True
|
||||||
checkpoint_metadata.performance_metrics.get("loss", 0.0134)
|
except Exception as load_ex:
|
||||||
)
|
logger.warning(f"CNN checkpoint load_state_dict failed: {load_ex}")
|
||||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
if not checkpoint_loaded:
|
||||||
self.model_states["cnn"][
|
# Filesystem fallback
|
||||||
"checkpoint_filename"
|
from utils.checkpoint_manager import load_best_checkpoint as _load_best_ckpt
|
||||||
] = checkpoint_metadata.checkpoint_id
|
result = _load_best_ckpt("enhanced_cnn")
|
||||||
checkpoint_loaded = True
|
if result:
|
||||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
ckpt_path, meta = result
|
||||||
logger.info(
|
try:
|
||||||
f"CNN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
saved = torch.load(ckpt_path, map_location=self.device)
|
||||||
)
|
if saved and saved.get("model_state_dict"):
|
||||||
|
self.cnn_model.load_state_dict(saved["model_state_dict"], strict=False)
|
||||||
|
checkpoint_loaded = True
|
||||||
|
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, "checkpoint_id", os.path.basename(ckpt_path))
|
||||||
|
except Exception as e_load:
|
||||||
|
logger.warning(f"Failed loading CNN weights from {ckpt_path}: {e_load}")
|
||||||
|
# Update model_states flags after attempts
|
||||||
|
self.model_states["cnn"]["checkpoint_loaded"] = checkpoint_loaded
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||||
# Filesystem fallback
|
checkpoint_loaded = False
|
||||||
try:
|
|
||||||
from utils.checkpoint_manager import get_checkpoint_manager
|
|
||||||
cm = get_checkpoint_manager()
|
|
||||||
result = cm.load_best_checkpoint("enhanced_cnn")
|
|
||||||
if result:
|
|
||||||
model_path, meta = result
|
|
||||||
self.model_states["cnn"]["checkpoint_loaded"] = True
|
|
||||||
self.model_states["cnn"]["checkpoint_filename"] = getattr(meta, 'checkpoint_id', None)
|
|
||||||
checkpoint_loaded = True
|
|
||||||
logger.info(f"CNN checkpoint (fs) detected: {getattr(meta, 'checkpoint_id', 'unknown')}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not checkpoint_loaded:
|
if not checkpoint_loaded:
|
||||||
# New model - no synthetic data
|
# New model - no synthetic data
|
||||||
self.model_states["cnn"]["initial_loss"] = None
|
self.model_states["cnn"]["initial_loss"] = None
|
||||||
self.model_states["cnn"]["current_loss"] = None
|
self.model_states["cnn"]["current_loss"] = None
|
||||||
self.model_states["cnn"]["best_loss"] = None
|
self.model_states["cnn"]["best_loss"] = None
|
||||||
logger.info("CNN starting fresh - no checkpoint found")
|
self.model_states["cnn"]["checkpoint_loaded"] = False
|
||||||
|
logger.info("CNN starting fresh - no checkpoint found or failed to load")
|
||||||
|
else:
|
||||||
|
logger.info("CNN weights loaded from checkpoint successfully")
|
||||||
|
|
||||||
logger.info("Enhanced CNN model initialized directly")
|
logger.info("Enhanced CNN model initialized directly")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -1339,7 +1335,6 @@ class TradingOrchestrator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def clear_session_data(self):
|
def clear_session_data(self):
|
||||||
"""Clear all session-related data for fresh start"""
|
"""Clear all session-related data for fresh start"""
|
||||||
try:
|
try:
|
||||||
@@ -2122,7 +2117,6 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error registering model {model.name}: {e}")
|
logger.error(f"Error registering model {model.name}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def unregister_model(self, model_name: str) -> bool:
|
def unregister_model(self, model_name: str) -> bool:
|
||||||
"""Unregister a model"""
|
"""Unregister a model"""
|
||||||
try:
|
try:
|
||||||
@@ -3540,7 +3534,6 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in immediate training for {model_name}: {e}")
|
logger.error(f"Error in immediate training for {model_name}: {e}")
|
||||||
|
|
||||||
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
async def _evaluate_and_train_on_record(self, record: Dict, current_price: float):
|
||||||
"""Evaluate prediction outcome and train model"""
|
"""Evaluate prediction outcome and train model"""
|
||||||
try:
|
try:
|
||||||
@@ -5779,7 +5772,6 @@ class TradingOrchestrator:
|
|||||||
if symbol in self.recent_decisions:
|
if symbol in self.recent_decisions:
|
||||||
return self.recent_decisions[symbol][-limit:]
|
return self.recent_decisions[symbol][-limit:]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||||
"""Get performance metrics for the orchestrator"""
|
"""Get performance metrics for the orchestrator"""
|
||||||
return {
|
return {
|
||||||
@@ -6579,7 +6571,6 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding decision fusion training sample: {e}")
|
logger.error(f"Error adding decision fusion training sample: {e}")
|
||||||
|
|
||||||
def _train_decision_fusion_network(self):
|
def _train_decision_fusion_network(self):
|
||||||
"""Train the decision fusion network on collected data"""
|
"""Train the decision fusion network on collected data"""
|
||||||
try:
|
try:
|
||||||
@@ -8133,7 +8124,6 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||||
self.checkpoint_manager = None
|
self.checkpoint_manager = None
|
||||||
|
|
||||||
def autosave_models(self):
|
def autosave_models(self):
|
||||||
"""Attempt to autosave best model checkpoints periodically."""
|
"""Attempt to autosave best model checkpoints periodically."""
|
||||||
try:
|
try:
|
||||||
@@ -8990,4 +8980,4 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in fallback data strategy: {e}")
|
logger.error(f"Error in fallback data strategy: {e}")
|
||||||
return False
|
return False
|
@@ -1280,13 +1280,15 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('price-chart', 'figure'),
|
Output('price-chart', 'figure'),
|
||||||
[Input('interval-component', 'n_intervals')],
|
[Input('interval-component', 'n_intervals'),
|
||||||
|
Input('show-pivots-switch', 'value')],
|
||||||
[State('price-chart', 'relayoutData')]
|
[State('price-chart', 'relayoutData')]
|
||||||
)
|
)
|
||||||
def update_price_chart(n, relayout_data):
|
def update_price_chart(n, pivots_value, relayout_data):
|
||||||
"""Update price chart every second, persisting user zoom/pan"""
|
"""Update price chart every second, persisting user zoom/pan"""
|
||||||
try:
|
try:
|
||||||
fig = self._create_price_chart('ETH/USDT')
|
show_pivots = bool(pivots_value and 'enabled' in pivots_value)
|
||||||
|
fig = self._create_price_chart('ETH/USDT', show_pivots=show_pivots)
|
||||||
|
|
||||||
if relayout_data:
|
if relayout_data:
|
||||||
if 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
|
if 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
|
||||||
@@ -1300,6 +1302,15 @@ class CleanTradingDashboard:
|
|||||||
return go.Figure().add_annotation(text=f"Chart Error: {str(e)}",
|
return go.Figure().add_annotation(text=f"Chart Error: {str(e)}",
|
||||||
xref="paper", yref="paper",
|
xref="paper", yref="paper",
|
||||||
x=0.5, y=0.5, showarrow=False)
|
x=0.5, y=0.5, showarrow=False)
|
||||||
|
|
||||||
|
# Display state label for pivots toggle
|
||||||
|
@self.app.callback(
|
||||||
|
Output('pivots-display', 'children'),
|
||||||
|
[Input('show-pivots-switch', 'value')]
|
||||||
|
)
|
||||||
|
def update_pivots_display(value):
|
||||||
|
enabled = bool(value and 'enabled' in value)
|
||||||
|
return "ON" if enabled else "OFF"
|
||||||
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
Output('closed-trades-table', 'children'),
|
Output('closed-trades-table', 'children'),
|
||||||
@@ -1651,7 +1662,7 @@ class CleanTradingDashboard:
|
|||||||
elif hasattr(panel, 'render'):
|
elif hasattr(panel, 'render'):
|
||||||
panel_content = panel.render()
|
panel_content = panel.render()
|
||||||
else:
|
else:
|
||||||
panel_content = html.Div([html.Div("Training panel not available", className="text-muted small")])
|
panel_content = [html.Div("Training panel not available", className="text-muted small")]
|
||||||
|
|
||||||
logger.info("Successfully created training metrics panel")
|
logger.info("Successfully created training metrics panel")
|
||||||
return panel_content
|
return panel_content
|
||||||
@@ -1663,10 +1674,10 @@ class CleanTradingDashboard:
|
|||||||
logger.error(f"Error updating training metrics with new panel: {e}")
|
logger.error(f"Error updating training metrics with new panel: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return html.Div([
|
return [
|
||||||
html.P("Error loading training panel", className="text-danger small"),
|
html.P("Error loading training panel", className="text-danger small"),
|
||||||
html.P(f"Details: {str(e)}", className="text-muted small")
|
html.P(f"Details: {str(e)}", className="text-muted small")
|
||||||
], id="training-metrics")
|
]
|
||||||
|
|
||||||
# Universal model toggle callback using pattern matching
|
# Universal model toggle callback using pattern matching
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
@@ -2234,7 +2245,7 @@ class CleanTradingDashboard:
|
|||||||
# Return None if absolutely nothing available
|
# Return None if absolutely nothing available
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_price_chart(self, symbol: str) -> go.Figure:
|
def _create_price_chart(self, symbol: str, show_pivots: bool = True) -> go.Figure:
|
||||||
"""Create 1-minute main chart with 1-second mini chart - Updated every second"""
|
"""Create 1-minute main chart with 1-second mini chart - Updated every second"""
|
||||||
try:
|
try:
|
||||||
# FIXED: Always get fresh data on startup to avoid gaps
|
# FIXED: Always get fresh data on startup to avoid gaps
|
||||||
@@ -2404,63 +2415,64 @@ class CleanTradingDashboard:
|
|||||||
self._add_trades_to_chart(fig, symbol, df_main, row=1)
|
self._add_trades_to_chart(fig, symbol, df_main, row=1)
|
||||||
|
|
||||||
# ADD PIVOT POINTS TO MAIN CHART (overlay on 1m)
|
# ADD PIVOT POINTS TO MAIN CHART (overlay on 1m)
|
||||||
try:
|
if show_pivots:
|
||||||
pivots_input = None
|
try:
|
||||||
if hasattr(self.data_provider, 'get_base_data_input'):
|
pivots_input = None
|
||||||
bdi = self.data_provider.get_base_data_input(symbol)
|
if hasattr(self.data_provider, 'get_base_data_input'):
|
||||||
if bdi and getattr(bdi, 'pivot_points', None):
|
bdi = self.data_provider.get_base_data_input(symbol)
|
||||||
pivots_input = bdi.pivot_points
|
if bdi and getattr(bdi, 'pivot_points', None):
|
||||||
if pivots_input:
|
pivots_input = bdi.pivot_points
|
||||||
# Filter pivots within the visible time range of df_main
|
if pivots_input:
|
||||||
start_ts = df_main.index.min()
|
# Filter pivots within the visible time range of df_main
|
||||||
end_ts = df_main.index.max()
|
start_ts = df_main.index.min()
|
||||||
xs_high = []
|
end_ts = df_main.index.max()
|
||||||
ys_high = []
|
xs_high = []
|
||||||
xs_low = []
|
ys_high = []
|
||||||
ys_low = []
|
xs_low = []
|
||||||
for p in pivots_input:
|
ys_low = []
|
||||||
ts = getattr(p, 'timestamp', None)
|
for p in pivots_input:
|
||||||
price = getattr(p, 'price', None)
|
ts = getattr(p, 'timestamp', None)
|
||||||
ptype = getattr(p, 'type', 'low')
|
price = getattr(p, 'price', None)
|
||||||
if ts is None or price is None:
|
ptype = getattr(p, 'type', 'low')
|
||||||
continue
|
if ts is None or price is None:
|
||||||
# Convert pivot timestamp to local tz and make tz-naive to match chart axes
|
continue
|
||||||
try:
|
# Convert pivot timestamp to local tz and make tz-naive to match chart axes
|
||||||
if hasattr(ts, 'tzinfo') and ts.tzinfo is not None:
|
|
||||||
pt = ts.astimezone(_local_tz) if _local_tz else ts
|
|
||||||
else:
|
|
||||||
# Assume UTC then convert
|
|
||||||
pt = ts.replace(tzinfo=timezone.utc)
|
|
||||||
pt = pt.astimezone(_local_tz) if _local_tz else pt
|
|
||||||
# Drop tzinfo for plotting
|
|
||||||
try:
|
try:
|
||||||
pt = pt.replace(tzinfo=None)
|
if hasattr(ts, 'tzinfo') and ts.tzinfo is not None:
|
||||||
|
pt = ts.astimezone(_local_tz) if _local_tz else ts
|
||||||
|
else:
|
||||||
|
# Assume UTC then convert
|
||||||
|
pt = ts.replace(tzinfo=timezone.utc)
|
||||||
|
pt = pt.astimezone(_local_tz) if _local_tz else pt
|
||||||
|
# Drop tzinfo for plotting
|
||||||
|
try:
|
||||||
|
pt = pt.replace(tzinfo=None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pt = ts
|
||||||
except Exception:
|
if start_ts <= pt <= end_ts:
|
||||||
pt = ts
|
if str(ptype).lower() == 'high':
|
||||||
if start_ts <= pt <= end_ts:
|
xs_high.append(pt)
|
||||||
if str(ptype).lower() == 'high':
|
ys_high.append(price)
|
||||||
xs_high.append(pt)
|
else:
|
||||||
ys_high.append(price)
|
xs_low.append(pt)
|
||||||
else:
|
ys_low.append(price)
|
||||||
xs_low.append(pt)
|
if xs_high or xs_low:
|
||||||
ys_low.append(price)
|
fig.add_trace(
|
||||||
if xs_high or xs_low:
|
go.Scatter(x=xs_high, y=ys_high, mode='markers', name='Pivot High',
|
||||||
fig.add_trace(
|
marker=dict(color='#ff7043', size=7, symbol='triangle-up'),
|
||||||
go.Scatter(x=xs_high, y=ys_high, mode='markers', name='Pivot High',
|
hoverinfo='skip'),
|
||||||
marker=dict(color='#ff7043', size=7, symbol='triangle-up'),
|
row=1, col=1
|
||||||
hoverinfo='skip'),
|
)
|
||||||
row=1, col=1
|
fig.add_trace(
|
||||||
)
|
go.Scatter(x=xs_low, y=ys_low, mode='markers', name='Pivot Low',
|
||||||
fig.add_trace(
|
marker=dict(color='#42a5f5', size=7, symbol='triangle-down'),
|
||||||
go.Scatter(x=xs_low, y=ys_low, mode='markers', name='Pivot Low',
|
hoverinfo='skip'),
|
||||||
marker=dict(color='#42a5f5', size=7, symbol='triangle-down'),
|
row=1, col=1
|
||||||
hoverinfo='skip'),
|
)
|
||||||
row=1, col=1
|
except Exception as e:
|
||||||
)
|
logger.debug(f"Error overlaying pivot points: {e}")
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error overlaying pivot points: {e}")
|
|
||||||
|
|
||||||
# Mini 1-second chart (if available)
|
# Mini 1-second chart (if available)
|
||||||
if has_mini_chart and ws_data_1s is not None:
|
if has_mini_chart and ws_data_1s is not None:
|
||||||
|
@@ -226,6 +226,21 @@ class DashboardLayoutManager:
|
|||||||
html.Hr(className="my-2"),
|
html.Hr(className="my-2"),
|
||||||
|
|
||||||
# Leverage Control
|
# Leverage Control
|
||||||
|
html.Div([
|
||||||
|
html.Label([
|
||||||
|
html.I(className="fas fa-compass me-1"),
|
||||||
|
"Show Pivot Points: ",
|
||||||
|
html.Span(id="pivots-display", children="ON", className="fw-bold text-success")
|
||||||
|
], className="form-label small mb-1"),
|
||||||
|
dcc.Checklist(
|
||||||
|
id='show-pivots-switch',
|
||||||
|
options=[{'label': '', 'value': 'enabled'}],
|
||||||
|
value=['enabled'],
|
||||||
|
className="form-check-input"
|
||||||
|
),
|
||||||
|
html.Small("Toggle pivot overlays on the chart", className="text-muted d-block")
|
||||||
|
], className="mb-2"),
|
||||||
|
|
||||||
html.Div([
|
html.Div([
|
||||||
html.Label([
|
html.Label([
|
||||||
html.I(className="fas fa-sliders-h me-1"),
|
html.I(className="fas fa-sliders-h me-1"),
|
||||||
|
@@ -21,7 +21,7 @@ class ModelsTrainingPanel:
|
|||||||
def __init__(self, orchestrator=None):
|
def __init__(self, orchestrator=None):
|
||||||
self.orchestrator = orchestrator
|
self.orchestrator = orchestrator
|
||||||
|
|
||||||
def create_panel(self) -> html.Div:
|
def create_panel(self) -> Any:
|
||||||
try:
|
try:
|
||||||
data = self._gather_data()
|
data = self._gather_data()
|
||||||
|
|
||||||
@@ -34,12 +34,13 @@ class ModelsTrainingPanel:
|
|||||||
if data.get("system_metrics"):
|
if data.get("system_metrics"):
|
||||||
content.append(self._create_system_metrics_section(data["system_metrics"]))
|
content.append(self._create_system_metrics_section(data["system_metrics"]))
|
||||||
|
|
||||||
return html.Div(content, id="training-metrics")
|
# Return children (to be assigned to 'training-metrics' container)
|
||||||
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating models training panel: {e}")
|
logger.error(f"Error creating models training panel: {e}")
|
||||||
return html.Div([
|
return [
|
||||||
html.P(f"Error loading training panel: {str(e)}", className="text-danger small")
|
html.P(f"Error loading training panel: {str(e)}", className="text-danger small")
|
||||||
], id="training-metrics")
|
]
|
||||||
|
|
||||||
def _gather_data(self) -> Dict[str, Any]:
|
def _gather_data(self) -> Dict[str, Any]:
|
||||||
result: Dict[str, Any] = {"models": {}, "training_status": {}, "system_metrics": {}}
|
result: Dict[str, Any] = {"models": {}, "training_status": {}, "system_metrics": {}}
|
||||||
|
Reference in New Issue
Block a user