From b2faa9b6caf775111ef9df399bf4a370381ca338 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Sun, 10 Aug 2025 03:20:13 +0300 Subject: [PATCH] extrema trainer WIP --- core/extrema_trainer.py | 177 +++++++++++++++++++++++++++++++++++++++ core/orchestrator.py | 24 ++++-- data/ui_state.json | 25 +++--- web/clean_dashboard.py | 112 ++++++++++++++++++++++++- web/component_manager.py | 12 +-- 5 files changed, 315 insertions(+), 35 deletions(-) diff --git a/core/extrema_trainer.py b/core/extrema_trainer.py index f68777e..07926fc 100644 --- a/core/extrema_trainer.py +++ b/core/extrema_trainer.py @@ -43,6 +43,22 @@ class ExtremaPoint: market_context: Dict[str, Any] outcome: Optional[float] = None # Price change after extrema +@dataclass +class PredictedPivot: + """Represents a prediction of the next pivot point within a capped horizon""" + symbol: str + created_at: datetime + current_price: float + predicted_time: datetime + predicted_price: float + horizon_seconds: int + target_type: str # 'top' or 'bottom' + confidence: float + evaluated: bool = False + success: Optional[bool] = None + error_abs: Optional[float] = None # absolute price error at eval + time_error_s: Optional[int] = None # time offset at eval + @dataclass class ContextData: """200-candle 1m context data for enhanced model performance""" @@ -103,6 +119,10 @@ class ExtremaTrainer: 'successful_predictions': 0, 'failed_predictions': 0, 'detection_accuracy': 0.0, + 'prediction_evaluations': 0, + 'prediction_successes': 0, + 'prediction_mae': 0.0, # mean absolute error for price + 'prediction_mte': 0.0, # mean time error seconds 'last_training_time': None } @@ -114,6 +134,140 @@ class ExtremaTrainer: logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s") logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}") + # Next pivot prediction management + self.prediction_window_seconds = 300 # cap at 5 minutes + self.pending_predictions: Dict[str, deque] = {symbol: deque(maxlen=200) for symbol in symbols} + self.last_prediction: Dict[str, Optional[PredictedPivot]] = {symbol: None for symbol in symbols} + + # === Prediction API === + def predict_next_pivot(self, symbol: str, now: Optional[datetime] = None, current_price: Optional[float] = None) -> Optional[PredictedPivot]: + """Predict next pivot point (time, price) within 5 minutes using real context. + + Strategy (baseline, fully real-data-driven): + - Determine last detected extrema type from recent detections; target the opposite type next. + - Estimate horizon by median time gap between recent extrema (capped to 300s, floored at 30s). + - Estimate amplitude by median absolute price change between recent extrema; project from current_price in the direction implied by target type. + - Confidence derived from recent detection confidence averages (bounded). + """ + try: + if symbol not in self.detected_extrema: + return None + now = now or datetime.now() + + # Use current price from provider if not passed + if current_price is None: + try: + if hasattr(self.data_provider, 'get_current_price'): + current_price = self.data_provider.get_current_price(symbol) or 0.0 + except Exception: + current_price = 0.0 + if not current_price or current_price <= 0: + return None + + recent = list(self.detected_extrema[symbol])[-10:] + if not recent: + return None + + # Determine last extrema + last_ext = recent[-1] + target_type = 'top' if last_ext.extrema_type == 'bottom' else 'bottom' + + # Estimate horizon as median delta between last extrema timestamps + gaps = [] + for i in range(1, len(recent)): + gaps.append((recent[i].timestamp - recent[i-1].timestamp).total_seconds()) + median_gap = int(np.median(gaps)) if gaps else 60 + horizon_s = max(30, min(self.prediction_window_seconds, median_gap)) + + # Estimate amplitude as median absolute change between extrema + price_changes = [] + for i in range(1, len(recent)): + price_changes.append(abs(recent[i].price - recent[i-1].price)) + median_amp = float(np.median(price_changes)) if price_changes else current_price * 0.002 # ~0.2% + + predicted_price = current_price + (median_amp if target_type == 'top' else -median_amp) + predicted_time = now + timedelta(seconds=horizon_s) + + # Confidence from average of recent detection confidences + conf_vals = [e.confidence for e in recent] + confidence = float(np.mean(conf_vals)) if conf_vals else 0.5 + confidence = max(0.1, min(0.95, confidence)) + + pred = PredictedPivot( + symbol=symbol, + created_at=now, + current_price=current_price, + predicted_time=predicted_time, + predicted_price=predicted_price, + horizon_seconds=horizon_s, + target_type=target_type, + confidence=confidence + ) + self.pending_predictions[symbol].append(pred) + self.last_prediction[symbol] = pred + return pred + except Exception as e: + logger.error(f"Error predicting next pivot for {symbol}: {e}") + return None + + def get_latest_prediction(self, symbol: str) -> Optional[PredictedPivot]: + return self.last_prediction.get(symbol) + + def evaluate_pending_predictions(self, symbol: str) -> int: + """Evaluate pending predictions within the 5-minute window using detected extrema. + Returns number of evaluations performed. + """ + try: + if symbol not in self.pending_predictions: + return 0 + now = datetime.now() + evaluated = 0 + # Build a quick index of detected extrema within last 10 minutes + recent_extrema = [e for e in self.detected_extrema[symbol] if (now - e.timestamp).total_seconds() <= 600] + for pred in list(self.pending_predictions[symbol]): + if pred.evaluated: + continue + # If evaluation horizon passed, evaluate against nearest extrema in time + if (now - pred.created_at).total_seconds() >= min(self.prediction_window_seconds, pred.horizon_seconds): + # Find extrema closest in time after creation + candidate = None + min_dt = None + for e in recent_extrema: + if e.timestamp >= pred.created_at and e.extrema_type == pred.target_type: + dt = abs((e.timestamp - pred.predicted_time).total_seconds()) + if min_dt is None or dt < min_dt: + min_dt = dt + candidate = e + if candidate is not None: + price_err = abs(candidate.price - pred.predicted_price) + time_err = int(abs((candidate.timestamp - pred.predicted_time).total_seconds())) + # Decide success with simple thresholds + price_tol = max(0.001 * pred.current_price, 0.5) # 0.1% or $0.5 + time_tol = 90 # 1.5 minutes + success = (price_err <= price_tol) and (time_err <= time_tol) + pred.evaluated = True + pred.success = success + pred.error_abs = price_err + pred.time_error_s = time_err + + self.training_stats['prediction_evaluations'] += 1 + if success: + self.training_stats['prediction_successes'] += 1 + # Update running means + n = self.training_stats['prediction_evaluations'] + prev_mae = self.training_stats['prediction_mae'] + prev_mte = self.training_stats['prediction_mte'] + self.training_stats['prediction_mae'] = ((prev_mae * (n - 1)) + price_err) / n + self.training_stats['prediction_mte'] = ((prev_mte * (n - 1)) + time_err) / n + evaluated += 1 + # Optionally checkpoint on batch + if evaluated > 0: + self.save_checkpoint(force_save=False) + return evaluated + except Exception as e: + logger.error(f"Error evaluating predictions for {symbol}: {e}") + return 0 + def load_best_checkpoint(self): """Load the best checkpoint for this extrema trainer""" try: @@ -182,6 +336,10 @@ class ExtremaTrainer: symbol: list(extrema_deque) for symbol, extrema_deque in self.detected_extrema.items() }, + 'last_prediction': { + symbol: (self._serialize_prediction(pred) if pred else None) + for symbol, pred in self.last_prediction.items() + }, 'window_size': self.window_size, 'symbols': self.symbols } @@ -216,6 +374,25 @@ class ExtremaTrainer: except Exception as e: logger.error(f"Error saving ExtremaTrainer checkpoint: {e}") return False + + def _serialize_prediction(self, pred: PredictedPivot) -> Dict[str, Any]: + try: + return { + 'symbol': pred.symbol, + 'created_at': pred.created_at.isoformat(), + 'current_price': pred.current_price, + 'predicted_time': pred.predicted_time.isoformat(), + 'predicted_price': pred.predicted_price, + 'horizon_seconds': pred.horizon_seconds, + 'target_type': pred.target_type, + 'confidence': pred.confidence, + 'evaluated': pred.evaluated, + 'success': pred.success, + 'error_abs': pred.error_abs, + 'time_error_s': pred.time_error_s, + } + except Exception: + return {} def initialize_context_data(self) -> Dict[str, bool]: """Initialize 200-candle 1m context data for all symbols""" diff --git a/core/orchestrator.py b/core/orchestrator.py index 2b2083e..5a34610 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -976,15 +976,21 @@ class TradingOrchestrator: # The presence of features indicates a signal. We'll return a generic HOLD # with a neutral confidence. This can be refined if ExtremaTrainer provides # more specific BUY/SELL signals directly. - return { - "action": "HOLD", - "confidence": 0.5, - "probabilities": { - "BUY": 0.33, - "SELL": 0.33, - "HOLD": 0.34, - }, - } + # Provide next-pivot prediction vector capped at 5 min + pred = self.model.predict_next_pivot(symbol=symbol) + if pred: + return { + "action": "HOLD", + "confidence": pred.confidence, + "prediction": { + "target_type": pred.target_type, + "predicted_time": pred.predicted_time, + "predicted_price": pred.predicted_price, + "horizon_seconds": pred.horizon_seconds, + }, + } + # Fallback neutral + return {"action": "HOLD", "confidence": 0.5} return None except Exception as e: logger.error( diff --git a/data/ui_state.json b/data/ui_state.json index 70e145f..f21412b 100644 --- a/data/ui_state.json +++ b/data/ui_state.json @@ -1,29 +1,30 @@ { "model_toggle_states": { "dqn": { - "inference_enabled": false, - "training_enabled": true + "inference_enabled": true, + "training_enabled": true, + "routing_enabled": true }, "cnn": { "inference_enabled": true, - "training_enabled": true + "training_enabled": true, + "routing_enabled": true }, "cob_rl": { - "inference_enabled": false, - "training_enabled": true + "inference_enabled": true, + "training_enabled": true, + "routing_enabled": true }, "decision_fusion": { "inference_enabled": false, - "training_enabled": false + "training_enabled": false, + "routing_enabled": true }, "transformer": { "inference_enabled": false, - "training_enabled": true - }, - "dqn_agent": { - "inference_enabled": "inference_enabled", - "training_enabled": false + "training_enabled": true, + "routing_enabled": true } }, - "timestamp": "2025-08-09T00:59:11.537013" + "timestamp": "2025-08-10T03:17:33.694986" } \ No newline at end of file diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index 7996298..dc1b5e0 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -2625,10 +2625,114 @@ class CleanTradingDashboard: self._add_cnn_predictions_to_chart(fig, symbol, df_main, row) self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row) self._add_prediction_accuracy_feedback(fig, symbol, df_main, row) + self._add_williams_pivots_to_chart(fig, symbol, df_main, row) + + # 2b. Overlay extrema next-pivot prediction vector (within 5 minutes) + try: + latest_predictions = self._get_latest_model_predictions() + extrema_pred = latest_predictions.get('extrema_trainer', {}) + pred_obj = extrema_pred.get('prediction') if isinstance(extrema_pred, dict) else None + if pred_obj and 'predicted_time' in pred_obj and 'predicted_price' in pred_obj: + # Build a short vector from now/current price to predicted pivot + now_ts = df_main.index[-1] if not df_main.empty else datetime.now() + current_price = float(df_main['close'].iloc[-1]) if not df_main.empty else self._get_current_price(symbol) + pt = pred_obj['predicted_time'] + # Ensure datetime + if not isinstance(pt, datetime): + try: + pt = pd.to_datetime(pt) + except Exception: + pt = now_ts + pp = float(pred_obj['predicted_price']) + fig.add_trace( + go.Scatter( + x=[now_ts, pt], + y=[current_price, pp], + mode='lines+markers', + line=dict(color='#cddc39', width=2, dash='dot'), + marker=dict(size=6, color='#cddc39'), + name='Next Pivot (<=5m)' + ), + row=row, col=1 + ) + except Exception as e: + logger.debug(f"Error overlaying extrema pivot vector: {e}") except Exception as e: logger.warning(f"Error adding model predictions to chart: {e}") + def _add_williams_pivots_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): + """Overlay first two levels of Williams pivot points on the 1m chart.""" + try: + if not hasattr(self.data_provider, 'get_recent_pivot_points'): + return + pivots_lvl1 = self.data_provider.get_recent_pivot_points(symbol, level=1, count=50) or [] + pivots_lvl2 = self.data_provider.get_recent_pivot_points(symbol, level=2, count=50) or [] + + # Normalize timestamps to match chart timezone (avoid 3h offset) + try: + from datetime import timezone as _dt_tz + local_tz = datetime.now().astimezone().tzinfo + + def normalize_ts(ts: datetime) -> datetime: + try: + if ts is None: + return ts + if getattr(df_main.index, 'tz', None) is not None: + # Match DataFrame index timezone + if ts.tzinfo is None: + ts = ts.replace(tzinfo=_dt_tz.utc) + return ts.astimezone(df_main.index.tz) + # Chart is tz-naive: convert to local and drop tz + if ts.tzinfo is None: + # Treat naive timestamp as UTC coming from server calculations + ts = ts.replace(tzinfo=_dt_tz.utc) + ts_local = ts.astimezone(local_tz) + return ts_local.replace(tzinfo=None) + except Exception: + return ts + except Exception: + def normalize_ts(ts: datetime) -> datetime: + return ts + + def to_xy(pivots): + xs_h, ys_h, xs_l, ys_l = [], [], [], [] + for p in pivots: + ts = getattr(p, 'timestamp', None) + price = getattr(p, 'price', None) + ptype = getattr(p, 'type', getattr(p, 'pivot_type', 'low')) + if ts and price: + ts = normalize_ts(ts) + if str(ptype).lower() == 'high': + xs_h.append(ts); ys_h.append(price) + else: + xs_l.append(ts); ys_l.append(price) + return xs_h, ys_h, xs_l, ys_l + + l1xh, l1yh, l1xl, l1yl = to_xy(pivots_lvl1) + l2xh, l2yh, l2xl, l2yl = to_xy(pivots_lvl2) + + if l1xh or l1xl: + if l1xh: + fig.add_trace(go.Scatter(x=l1xh, y=l1yh, mode='markers', name='L1 High', + marker=dict(color='#ff7043', size=7, symbol='triangle-up'), hoverinfo='skip'), + row=row, col=1) + if l1xl: + fig.add_trace(go.Scatter(x=l1xl, y=l1yl, mode='markers', name='L1 Low', + marker=dict(color='#42a5f5', size=7, symbol='triangle-down'), hoverinfo='skip'), + row=row, col=1) + if l2xh or l2xl: + if l2xh: + fig.add_trace(go.Scatter(x=l2xh, y=l2yh, mode='markers', name='L2 High', + marker=dict(color='#ef6c00', size=6, symbol='triangle-up-open'), hoverinfo='skip'), + row=row, col=1) + if l2xl: + fig.add_trace(go.Scatter(x=l2xl, y=l2yl, mode='markers', name='L2 Low', + marker=dict(color='#1e88e5', size=6, symbol='triangle-down-open'), hoverinfo='skip'), + row=row, col=1) + except Exception as e: + logger.debug(f"Error overlaying Williams pivots: {e}") + def _add_dqn_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1): """Add DQN action predictions as directional arrows""" try: @@ -5900,7 +6004,7 @@ class CleanTradingDashboard: state_features = self._get_dqn_state_features(symbol, current_price) if hasattr(self.orchestrator.rl_agent, 'predict'): return self.orchestrator.rl_agent.predict(state_features) - return 0.5 # Default neutral prediction + return 0.0 except: return 0.5 @@ -5914,7 +6018,7 @@ class CleanTradingDashboard: if features: import numpy as np return self.orchestrator.cnn_model.predict(np.array([features])) - return 0.5 # Default neutral prediction + return 0.0 except: return 0.5 @@ -5928,7 +6032,7 @@ class CleanTradingDashboard: if features: import numpy as np return self.orchestrator.primary_transformer.predict(np.array([features])) - return 0.5 # Default neutral prediction + return 0.0 except: return 0.5 @@ -5940,7 +6044,7 @@ class CleanTradingDashboard: if cob_features and hasattr(self.orchestrator.cob_rl_agent, 'predict'): import numpy as np return self.orchestrator.cob_rl_agent.predict(np.array([cob_features])) - return 0.5 # Default neutral prediction + return 0.0 except: return 0.5 diff --git a/web/component_manager.py b/web/component_manager.py index d7f1e68..fa54908 100644 --- a/web/component_manager.py +++ b/web/component_manager.py @@ -442,20 +442,12 @@ class DashboardComponentManager: extras.append(html.Small(f"Recent ticks: {len(recent)}", className="text-muted ms-2")) extras_div = html.Div(extras, className="mb-1") if extras else None - # Insert mini heatmap inside the COB panel (right side) + # Heatmap is rendered in dedicated tiles (avoid duplicate component IDs) heatmap_graph = None - try: - # The dashboard's data provider is accessible through a global reference on the dashboard instance. - # We embed a placeholder Graph here; actual figure is provided by the dashboard callback tied to this id. - graph_id = 'cob-heatmap-eth' if 'ETH' in symbol else 'cob-heatmap-btc' - heatmap_graph = dcc.Graph(id=graph_id, config={'displayModeBar': False}, style={"height": "220px"}) - except Exception: - heatmap_graph = None children = [dbc.Col(overview_panel, width=5, className="pe-1")] right_children = [] - if heatmap_graph: - right_children.append(heatmap_graph) + # Do not append inline heatmap here to prevent duplicate IDs right_children.append(ladder_panel) if extras_div: right_children.insert(0, extras_div)