diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index cd8aff9..a316499 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -2791,6 +2791,41 @@ class AnnotationDashboard: } }) + @self.server.route('/api/train-validated-prediction', methods=['POST']) + def train_validated_prediction(): + """Train model on a validated prediction (online learning)""" + try: + data = request.get_json() + + timeframe = data.get('timeframe') + timestamp = data.get('timestamp') + predicted = data.get('predicted') + actual = data.get('actual') + errors = data.get('errors') + direction_correct = data.get('direction_correct') + accuracy = data.get('accuracy') + + logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'✓' if direction_correct else '✗'}") + + # Trigger training and get metrics + metrics = self._train_on_validated_prediction( + timeframe, timestamp, predicted, actual, + errors, direction_correct, accuracy + ) + + return jsonify({ + 'success': True, + 'message': 'Training triggered', + 'metrics': metrics or {} + }) + + except Exception as e: + logger.error(f"Error training on validated prediction: {e}", exc_info=True) + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + @self.server.route('/api/realtime-inference/train-manual', methods=['POST']) def train_manual(): """Manually trigger training on current candle with specified action""" @@ -2998,21 +3033,24 @@ class AnnotationDashboard: This implements online learning where each validated prediction becomes a training sample, with loss weighting based on prediction accuracy. + + Returns: + Dict with training metrics (loss, accuracy, steps) """ try: if not self.training_adapter: logger.warning("Training adapter not available for incremental training") - return + return None if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'): logger.warning("Transformer model not available for incremental training") - return + return None # Get the transformer trainer trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) if not trainer: logger.warning("Transformer trainer not available") - return + return None # Calculate sample weight based on accuracy # Low accuracy predictions get higher weight (we need to learn from mistakes) @@ -3062,6 +3100,7 @@ class AnnotationDashboard: return # Train on this batch with sample weighting + import torch with torch.enable_grad(): trainer.model.train() result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) @@ -3070,7 +3109,7 @@ class AnnotationDashboard: loss = result.get('total_loss', 0) candle_accuracy = result.get('candle_accuracy', 0) - logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}") + logger.info(f"[{timeframe}] ✓ Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}") # Save checkpoint periodically (every 10 incremental steps) if not hasattr(self, '_incremental_training_steps'): @@ -3078,6 +3117,22 @@ class AnnotationDashboard: self._incremental_training_steps += 1 + # Track metrics for display + if not hasattr(self, '_training_metrics_history'): + self._training_metrics_history = [] + + self._training_metrics_history.append({ + 'step': self._incremental_training_steps, + 'loss': loss, + 'accuracy': candle_accuracy, + 'timeframe': timeframe, + 'timestamp': timestamp + }) + + # Keep only last 100 metrics + if len(self._training_metrics_history) > 100: + self._training_metrics_history.pop(0) + if self._incremental_training_steps % 10 == 0: logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps") trainer.save_checkpoint( @@ -3088,9 +3143,18 @@ class AnnotationDashboard: 'last_accuracy': accuracy } ) + + # Return metrics for display + return { + 'loss': loss, + 'accuracy': candle_accuracy, + 'steps': self._incremental_training_steps, + 'sample_weight': sample_weight + } except Exception as e: logger.error(f"Error in incremental training: {e}", exc_info=True) + return None def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict: """Fetch market state at a specific timestamp for training""" diff --git a/ANNOTATE/web/static/js/chart_manager.js b/ANNOTATE/web/static/js/chart_manager.js index 86ac474..efdc89d 100644 --- a/ANNOTATE/web/static/js/chart_manager.js +++ b/ANNOTATE/web/static/js/chart_manager.js @@ -2800,20 +2800,34 @@ class ChartManager { predicted: prediction.candle, // [O, H, L, C, V] actual: prediction.accuracy.actualCandle, // [O, H, L, C, V] errors: prediction.accuracy.errors, // {open, high, low, close, volume} - pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume} - directionCorrect: prediction.accuracy.directionCorrect, + direction_correct: prediction.accuracy.directionCorrect, accuracy: prediction.accuracy.accuracy }; - console.log('[Prediction Metrics for Training]', metrics); + console.log('[Prediction Metrics] Triggering online learning:', metrics); - // Send to backend via WebSocket for incremental training - if (window.socket && window.socket.connected) { - window.socket.emit('prediction_accuracy', metrics); - console.log(`[${timeframe}] Sent prediction accuracy to backend for training`); - } else { - console.warn('[Training] WebSocket not connected - metrics not sent to backend'); - } + // Send to backend for incremental training (online learning) + fetch('/api/train-validated-prediction', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(metrics) + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + console.log(`[${timeframe}] ✓ Online learning triggered - model updated from validated prediction`); + + // Update metrics display if available + if (window.updateMetricsDisplay) { + window.updateMetricsDisplay(data.metrics); + } + } else { + console.warn(`[${timeframe}] Training failed:`, data.error); + } + }) + .catch(error => { + console.warn('[Training] Error sending metrics:', error); + }); } /**