diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 77fe39d..34f9285 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -1095,7 +1095,8 @@ class RealTrainingAdapter: raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method") session.final_loss = session.current_loss - session.accuracy = 0.85 # TODO: Calculate actual accuracy + # Accuracy calculated from actual training metrics, not synthetic + session.accuracy = None # Will be set by training loop if available def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]): """Train DQN model with REAL training loop""" @@ -1133,7 +1134,8 @@ class RealTrainingAdapter: raise Exception("DQN agent does not have replay method") session.final_loss = session.current_loss - session.accuracy = 0.85 # TODO: Calculate actual accuracy + # Accuracy calculated from actual training metrics, not synthetic + session.accuracy = None # Will be set by training loop if available def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]: """Build proper state representation from training data""" @@ -2781,6 +2783,68 @@ class RealTrainingAdapter: logger.warning(f"Error fetching market state for candle: {e}") return {} + def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str): + """ + Convert a validated prediction to a training batch + + Args: + prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc. + timeframe: Target timeframe for prediction + + Returns: + Batch dict ready for trainer.train_step() + """ + try: + market_state = prediction_sample.get('market_state', {}) + if not market_state or 'timeframes' not in market_state: + logger.warning("No market state in prediction sample") + return None + + # Use existing conversion method but with actual target + annotation = { + 'symbol': prediction_sample.get('symbol', 'ETH/USDT'), + 'timestamp': prediction_sample.get('timestamp'), + 'action': 'BUY', # Placeholder, not used for candle prediction training + 'entry_price': float(prediction_sample['predicted_candle'][0]), # Open + 'market_state': market_state + } + + # Convert using existing method + batch = self._convert_annotation_to_transformer_batch(annotation) + if not batch: + return None + + # Override the future candle target with actual candle data + actual = prediction_sample['actual_candle'] # [O, H, L, C] + + # Create target tensor for the specific timeframe + import torch + device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu') + + # Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted + target_candle = [ + actual[0], # Open + actual[1], # High + actual[2], # Low + actual[3], # Close + prediction_sample['predicted_candle'][4] # Volume (from prediction) + ] + + # Add to batch based on timeframe + if timeframe == '1s': + batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device) + elif timeframe == '1m': + batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device) + elif timeframe == '1h': + batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device) + + logger.debug(f"Converted prediction to batch for {timeframe} timeframe") + return batch + + except Exception as e: + logger.error(f"Error converting prediction to batch: {e}", exc_info=True) + return None + def _train_transformer_on_sample(self, training_sample: Dict): """Train transformer on a single sample with checkpoint saving""" try: diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index 0eafc64..79e13f4 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -2369,6 +2369,55 @@ class AnnotationDashboard: except Exception as e: logger.error(f"Error handling prediction request: {e}") emit('prediction_error', {'error': str(e)}) + + @self.socketio.on('prediction_accuracy') + def handle_prediction_accuracy(data): + """ + Handle validated prediction accuracy - trigger incremental training + + This is called when frontend validates a prediction against actual candle. + We use this data to incrementally train the model for continuous improvement. + """ + from flask_socketio import emit + try: + timeframe = data.get('timeframe') + timestamp = data.get('timestamp') + predicted = data.get('predicted') # [O, H, L, C, V] + actual = data.get('actual') # [O, H, L, C] + errors = data.get('errors') # {open, high, low, close} + pct_errors = data.get('pctErrors') + direction_correct = data.get('directionCorrect') + accuracy = data.get('accuracy') + + if not all([timeframe, timestamp, predicted, actual]): + logger.warning("Incomplete prediction accuracy data received") + return + + logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}") + logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%") + + # Trigger incremental training on this validated prediction + self._train_on_validated_prediction( + timeframe=timeframe, + timestamp=timestamp, + predicted=predicted, + actual=actual, + errors=errors, + direction_correct=direction_correct, + accuracy=accuracy + ) + + # Send confirmation back to frontend + emit('training_update', { + 'status': 'training_triggered', + 'timestamp': timestamp, + 'accuracy': accuracy, + 'message': f'Incremental training triggered on validated prediction' + }) + + except Exception as e: + logger.error(f"Error handling prediction accuracy: {e}", exc_info=True) + emit('training_error', {'error': str(e)}) def _start_live_update_thread(self): """Start background thread for live updates""" @@ -2392,24 +2441,44 @@ class AnnotationDashboard: for timeframe in ['1s', '1m']: room = f"{symbol}_{timeframe}" - # Get latest candle + # Get latest candles (need last 2 to determine confirmation status) try: - candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1) + candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2) if candles and len(candles) > 0: latest_candle = candles[-1] - # Emit chart update + # Determine if candle is confirmed (closed) + # For 1s: candle is confirmed when next candle starts (2s delay) + # For others: candle is confirmed when next candle starts + is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed + + # Format timestamp consistently + timestamp = latest_candle.get('timestamp') + if isinstance(timestamp, str): + # Already formatted + formatted_timestamp = timestamp + else: + # Convert to ISO string then format + from datetime import datetime + if isinstance(timestamp, datetime): + formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S') + else: + formatted_timestamp = str(timestamp) + + # Emit chart update with full candle data self.socketio.emit('chart_update', { 'symbol': symbol, 'timeframe': timeframe, 'candle': { - 'timestamp': latest_candle.get('timestamp'), - 'open': latest_candle.get('open'), - 'high': latest_candle.get('high'), - 'low': latest_candle.get('low'), - 'close': latest_candle.get('close'), - 'volume': latest_candle.get('volume') - } + 'timestamp': formatted_timestamp, + 'open': float(latest_candle.get('open', 0)), + 'high': float(latest_candle.get('high', 0)), + 'low': float(latest_candle.get('low', 0)), + 'close': float(latest_candle.get('close', 0)), + 'volume': float(latest_candle.get('volume', 0)) + }, + 'is_confirmed': is_confirmed, # True if this candle is closed/confirmed + 'has_previous': len(candles) >= 2 # True if we have previous candle for validation }, room=room) # Get prediction if model is loaded @@ -2430,6 +2499,144 @@ class AnnotationDashboard: self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True) self._live_update_thread.start() + def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list, + actual: list, errors: dict, direction_correct: bool, accuracy: float): + """ + Incrementally train model on validated prediction + + This implements online learning where each validated prediction becomes + a training sample, with loss weighting based on prediction accuracy. + """ + try: + if not self.training_adapter: + logger.warning("Training adapter not available for incremental training") + return + + if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'): + logger.warning("Transformer model not available for incremental training") + return + + # Get the transformer trainer + trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) + if not trainer: + logger.warning("Transformer trainer not available") + return + + # Calculate sample weight based on accuracy + # Low accuracy predictions get higher weight (we need to learn from mistakes) + # High accuracy predictions get lower weight (model already knows this) + if accuracy < 50: + sample_weight = 3.0 # Learn hard from bad predictions + elif accuracy < 70: + sample_weight = 2.0 # Moderate learning + elif accuracy < 85: + sample_weight = 1.0 # Normal learning + else: + sample_weight = 0.5 # Light touch-up for good predictions + + # Also weight by direction correctness + if not direction_correct: + sample_weight *= 1.5 # Wrong direction is critical - learn more + + logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x") + + # Create training sample from validated prediction + # We need to fetch the market state at that timestamp + symbol = 'ETH/USDT' # TODO: Get from active trading pair + + training_sample = { + 'symbol': symbol, + 'timestamp': timestamp, + 'predicted_candle': predicted, # [O, H, L, C, V] + 'actual_candle': actual, # [O, H, L, C] + 'errors': errors, + 'accuracy': accuracy, + 'direction_correct': direction_correct, + 'sample_weight': sample_weight + } + + # Get market state at that timestamp + try: + market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe) + training_sample['market_state'] = market_state + except Exception as e: + logger.warning(f"Could not fetch market state: {e}") + return + + # Convert to transformer batch format + batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe) + if not batch: + logger.warning("Could not convert validated prediction to training batch") + return + + # Train on this batch with sample weighting + with torch.enable_grad(): + trainer.model.train() + result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) + + if result: + loss = result.get('total_loss', 0) + candle_accuracy = result.get('candle_accuracy', 0) + + logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}") + + # Save checkpoint periodically (every 10 incremental steps) + if not hasattr(self, '_incremental_training_steps'): + self._incremental_training_steps = 0 + + self._incremental_training_steps += 1 + + if self._incremental_training_steps % 10 == 0: + logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps") + trainer.save_checkpoint( + filepath=None, # Auto-generate path + metadata={ + 'training_type': 'incremental_online', + 'steps': self._incremental_training_steps, + 'last_accuracy': accuracy + } + ) + + except Exception as e: + logger.error(f"Error in incremental training: {e}", exc_info=True) + + def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict: + """Fetch market state at a specific timestamp for training""" + try: + from datetime import datetime + import pandas as pd + + # Parse timestamp + ts = pd.Timestamp(timestamp) + + # Get historical data for multiple timeframes + market_state = {'timeframes': {}, 'secondary_timeframes': {}} + + for tf in ['1s', '1m', '1h']: + try: + df = self.data_provider.get_historical_data(symbol, tf, limit=200) + if df is not None and not df.empty: + # Find data up to (but not including) the target timestamp + df_before = df[df.index < ts] + if not df_before.empty: + recent = df_before.tail(200) + market_state['timeframes'][tf] = { + 'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), + 'open': recent['open'].tolist(), + 'high': recent['high'].tolist(), + 'low': recent['low'].tolist(), + 'close': recent['close'].tolist(), + 'volume': recent['volume'].tolist() + } + except Exception as e: + logger.warning(f"Could not fetch {tf} data: {e}") + + return market_state + + except Exception as e: + logger.error(f"Error fetching market state: {e}") + return {} + def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1): """Get live prediction from model""" try: diff --git a/ANNOTATE/web/static/js/chart_manager.js b/ANNOTATE/web/static/js/chart_manager.js index 1024331..9366b68 100644 --- a/ANNOTATE/web/static/js/chart_manager.js +++ b/ANNOTATE/web/static/js/chart_manager.js @@ -15,8 +15,8 @@ class ChartManager { this.lastPredictionUpdate = {}; // Track last prediction update per timeframe this.predictionUpdateThrottle = 500; // Min ms between prediction updates this.lastPredictionHash = null; // Track if predictions actually changed - this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 10 each) - this.maxGhostCandles = 10; // Maximum number of ghost candles to keep + this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 50 each) + this.maxGhostCandles = 150; // Maximum number of ghost candles to keep // Helper to ensure all timestamps are in UTC this.normalizeTimestamp = (timestamp) => { @@ -264,15 +264,43 @@ class ChartManager { */ updateLatestCandle(symbol, timeframe, candle) { try { - const plotId = `plot-${timeframe}`; - const plotElement = document.getElementById(plotId); - - if (!plotElement) { - console.debug(`Chart ${plotId} not found for live update`); + const chart = this.charts[timeframe]; + if (!chart) { + console.debug(`Chart ${timeframe} not found for live update`); return; } - // Get current chart data + const plotId = chart.plotId; + const plotElement = document.getElementById(plotId); + + if (!plotElement) { + console.debug(`Plot element ${plotId} not found`); + return; + } + + // Ensure chart.data exists + if (!chart.data) { + chart.data = { + timestamps: [], + open: [], + high: [], + low: [], + close: [], + volume: [] + }; + } + + // Parse timestamp - format to match chart data format + const candleTimestamp = new Date(candle.timestamp); + const year = candleTimestamp.getUTCFullYear(); + const month = String(candleTimestamp.getUTCMonth() + 1).padStart(2, '0'); + const day = String(candleTimestamp.getUTCDate()).padStart(2, '0'); + const hours = String(candleTimestamp.getUTCHours()).padStart(2, '0'); + const minutes = String(candleTimestamp.getUTCMinutes()).padStart(2, '0'); + const seconds = String(candleTimestamp.getUTCSeconds()).padStart(2, '0'); + const formattedTimestamp = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; + + // Get current chart data from Plotly const chartData = Plotly.Plots.data(plotId); if (!chartData || chartData.length < 2) { console.debug(`Chart ${plotId} not initialized yet`); @@ -282,17 +310,14 @@ class ChartManager { const candlestickTrace = chartData[0]; const volumeTrace = chartData[1]; - // Parse timestamp - const candleTimestamp = new Date(candle.timestamp); - // Check if this is updating the last candle or adding a new one const lastTimestamp = candlestickTrace.x[candlestickTrace.x.length - 1]; const isNewCandle = !lastTimestamp || new Date(lastTimestamp).getTime() < candleTimestamp.getTime(); if (isNewCandle) { - // Add new candle using extendTraces (most efficient) + // Add new candle - update both Plotly and internal data structure Plotly.extendTraces(plotId, { - x: [[candleTimestamp]], + x: [[formattedTimestamp]], open: [[candle.open]], high: [[candle.high]], low: [[candle.low]], @@ -302,27 +327,34 @@ class ChartManager { // Update volume color based on price direction const volumeColor = candle.close >= candle.open ? '#10b981' : '#ef4444'; Plotly.extendTraces(plotId, { - x: [[candleTimestamp]], + x: [[formattedTimestamp]], y: [[candle.volume]], marker: { color: [[volumeColor]] } }, [1]); - } else { - // Update last candle using restyle - simpler approach for updating single point - // We need to get the full arrays, modify last element, and send back - // This is less efficient but more reliable for updates than complex index logic - const x = candlestickTrace.x; - const open = candlestickTrace.open; - const high = candlestickTrace.high; - const low = candlestickTrace.low; - const close = candlestickTrace.close; - const volume = volumeTrace.y; - const colors = volumeTrace.marker.color; + // Update internal data structure + chart.data.timestamps.push(formattedTimestamp); + chart.data.open.push(candle.open); + chart.data.high.push(candle.high); + chart.data.low.push(candle.low); + chart.data.close.push(candle.close); + chart.data.volume.push(candle.volume); + + console.log(`[${timeframe}] Added new candle: ${formattedTimestamp}`); + } else { + // Update last candle - update both Plotly and internal data structure + const x = [...candlestickTrace.x]; + const open = [...candlestickTrace.open]; + const high = [...candlestickTrace.high]; + const low = [...candlestickTrace.low]; + const close = [...candlestickTrace.close]; + const volume = [...volumeTrace.y]; + const colors = Array.isArray(volumeTrace.marker.color) ? [...volumeTrace.marker.color] : [volumeTrace.marker.color]; const lastIdx = x.length - 1; // Update local arrays - x[lastIdx] = candleTimestamp; + x[lastIdx] = formattedTimestamp; open[lastIdx] = candle.open; high[lastIdx] = candle.high; low[lastIdx] = candle.low; @@ -344,9 +376,55 @@ class ChartManager { y: [volume], 'marker.color': [colors] }, [1]); + + // Update internal data structure + if (chart.data.timestamps.length > lastIdx) { + chart.data.timestamps[lastIdx] = formattedTimestamp; + chart.data.open[lastIdx] = candle.open; + chart.data.high[lastIdx] = candle.high; + chart.data.low[lastIdx] = candle.low; + chart.data.close[lastIdx] = candle.close; + chart.data.volume[lastIdx] = candle.volume; + } + + console.log(`[${timeframe}] Updated last candle: ${formattedTimestamp}`); } - console.debug(`Updated ${timeframe} chart with new candle at ${candleTimestamp.toISOString()}`); + // CRITICAL: Check if we have enough candles to validate predictions (2s delay logic) + // For 1s timeframe: validate against candle[-2] (last confirmed), overlay on candle[-1] (currently forming) + // For other timeframes: validate against candle[-1] when it's confirmed + if (chart.data.timestamps.length >= 2) { + // Determine which candle to validate against based on timeframe + let validationCandleIdx = -1; + + if (timeframe === '1s') { + // 2s delay: validate against candle[-2] (last confirmed) + // This candle was closed 1-2 seconds ago + validationCandleIdx = chart.data.timestamps.length - 2; + } else { + // For longer timeframes, validate against last candle when it's confirmed + // A candle is confirmed when a new one starts forming + validationCandleIdx = isNewCandle ? chart.data.timestamps.length - 2 : -1; + } + + if (validationCandleIdx >= 0 && validationCandleIdx < chart.data.timestamps.length) { + // Create validation data structure for the confirmed candle + const validationData = { + timestamps: [chart.data.timestamps[validationCandleIdx]], + open: [chart.data.open[validationCandleIdx]], + high: [chart.data.high[validationCandleIdx]], + low: [chart.data.low[validationCandleIdx]], + close: [chart.data.close[validationCandleIdx]], + volume: [chart.data.volume[validationCandleIdx]] + }; + + // Trigger validation check + console.log(`[${timeframe}] Checking validation for confirmed candle at index ${validationCandleIdx}`); + this._checkPredictionAccuracy(timeframe, validationData); + } + } + + console.debug(`Updated ${timeframe} chart with candle at ${formattedTimestamp}`); } catch (error) { console.error(`Error updating latest candle for ${timeframe}:`, error); } @@ -1873,6 +1951,199 @@ class ChartManager { Plotly.react(plotId, updatedTraces, plotElement.layout, plotElement.config); console.log(`Updated ${timeframe} chart with ${data.timestamps.length} candles`); + + // Check if any ghost predictions match new actual candles and calculate accuracy + this._checkPredictionAccuracy(timeframe, data); + } + + /** + * Calculate prediction accuracy by comparing ghost predictions with actual candles + */ + _checkPredictionAccuracy(timeframe, actualData) { + if (!this.ghostCandleHistory || !this.ghostCandleHistory[timeframe]) return; + + const predictions = this.ghostCandleHistory[timeframe]; + const timestamps = actualData.timestamps; + const opens = actualData.open; + const highs = actualData.high; + const lows = actualData.low; + const closes = actualData.close; + + // Determine tolerance based on timeframe + let tolerance; + if (timeframe === '1s') { + tolerance = 2000; // 2 seconds for 1s charts + } else if (timeframe === '1m') { + tolerance = 60000; // 60 seconds for 1m charts + } else if (timeframe === '1h') { + tolerance = 3600000; // 1 hour for hourly charts + } else { + tolerance = 5000; // 5 seconds default + } + + // Check each prediction against actual candles + let validatedCount = 0; + predictions.forEach((prediction, idx) => { + // Skip if already validated + if (prediction.accuracy) return; + + // Try multiple matching strategies + let matchIdx = -1; + + // Use standard Date object if available, otherwise parse timestamp string + // Prioritize targetTime as it's the raw Date object set during prediction creation + const predTime = prediction.targetTime ? prediction.targetTime.getTime() : new Date(prediction.timestamp).getTime(); + + // Strategy 1: Find exact or very close match + matchIdx = timestamps.findIndex(ts => { + const actualTime = new Date(ts).getTime(); + return Math.abs(predTime - actualTime) < tolerance; + }); + + // Strategy 2: If no match, find the next candle after prediction + if (matchIdx < 0) { + matchIdx = timestamps.findIndex(ts => { + const actualTime = new Date(ts).getTime(); + return actualTime >= predTime && actualTime < predTime + tolerance * 2; + }); + } + + // Debug logging for unmatched predictions + if (matchIdx < 0) { + // Parse both timestamps to compare + const predTimeParsed = new Date(prediction.timestamp); + const latestActual = new Date(timestamps[timestamps.length - 1]); + + if (idx < 3) { // Only log first 3 to avoid spam + console.log(`[${timeframe}] No match for prediction:`, { + predTimestamp: prediction.timestamp, + predTime: predTimeParsed.toISOString(), + latestActual: latestActual.toISOString(), + timeDiff: (latestActual - predTimeParsed) + 'ms', + tolerance: tolerance + 'ms', + availableTimestamps: timestamps.slice(-3) // Last 3 actual timestamps + }); + } + } + + if (matchIdx >= 0) { + // Found matching actual candle - calculate accuracy INCLUDING VOLUME + const predCandle = prediction.candle; // [O, H, L, C, V] + const actualCandle = [ + opens[matchIdx], + highs[matchIdx], + lows[matchIdx], + closes[matchIdx], + actualData.volume ? actualData.volume[matchIdx] : predCandle[4] // Get actual volume if available + ]; + + // Calculate absolute errors for O, H, L, C, V + const errors = { + open: Math.abs(predCandle[0] - actualCandle[0]), + high: Math.abs(predCandle[1] - actualCandle[1]), + low: Math.abs(predCandle[2] - actualCandle[2]), + close: Math.abs(predCandle[3] - actualCandle[3]), + volume: Math.abs(predCandle[4] - actualCandle[4]) + }; + + // Calculate percentage errors for O, H, L, C, V + const pctErrors = { + open: (errors.open / actualCandle[0]) * 100, + high: (errors.high / actualCandle[1]) * 100, + low: (errors.low / actualCandle[2]) * 100, + close: (errors.close / actualCandle[3]) * 100, + volume: actualCandle[4] > 0 ? (errors.volume / actualCandle[4]) * 100 : 0 + }; + + // Average error (OHLC only, volume separate due to different scale) + const avgError = (errors.open + errors.high + errors.low + errors.close) / 4; + const avgPctError = (pctErrors.open + pctErrors.high + pctErrors.low + pctErrors.close) / 4; + + // Direction accuracy (did we predict up/down correctly?) + const predDirection = predCandle[3] >= predCandle[0] ? 'up' : 'down'; + const actualDirection = actualCandle[3] >= actualCandle[0] ? 'up' : 'down'; + const directionCorrect = predDirection === actualDirection; + + // Price range accuracy + const priceRange = actualCandle[1] - actualCandle[2]; // High - Low + const accuracy = Math.max(0, 1 - (avgError / priceRange)) * 100; + + // Store accuracy metrics + prediction.accuracy = { + errors: errors, + pctErrors: pctErrors, + avgError: avgError, + avgPctError: avgPctError, + directionCorrect: directionCorrect, + accuracy: accuracy, + actualCandle: actualCandle, + validatedAt: new Date().toISOString() + }; + + validatedCount++; + console.log(`[${timeframe}] Prediction validated (#${validatedCount}):`, { + timestamp: prediction.timestamp, + matchedTo: timestamps[matchIdx], + accuracy: accuracy.toFixed(1) + '%', + avgError: avgError.toFixed(4), + avgPctError: avgPctError.toFixed(2) + '%', + volumeError: pctErrors.volume.toFixed(2) + '%', + direction: directionCorrect ? '✓' : '✗', + timeDiff: Math.abs(predTime - new Date(timestamps[matchIdx]).getTime()) + 'ms', + predicted: { + O: predCandle[0].toFixed(2), + H: predCandle[1].toFixed(2), + L: predCandle[2].toFixed(2), + C: predCandle[3].toFixed(2), + V: predCandle[4].toFixed(2) + }, + actual: { + O: actualCandle[0].toFixed(2), + H: actualCandle[1].toFixed(2), + L: actualCandle[2].toFixed(2), + C: actualCandle[3].toFixed(2), + V: actualCandle[4].toFixed(2) + } + }); + + // Send metrics to backend for training feedback + this._sendPredictionMetrics(timeframe, prediction); + } + }); + + // Summary log + if (validatedCount > 0) { + const totalPending = predictions.filter(p => !p.accuracy).length; + console.log(`[${timeframe}] Validated ${validatedCount} predictions, ${totalPending} still pending`); + } + } + + /** + * Send prediction accuracy metrics to backend for training feedback + */ + _sendPredictionMetrics(timeframe, prediction) { + if (!prediction.accuracy) return; + + const metrics = { + timeframe: timeframe, + timestamp: prediction.timestamp, + predicted: prediction.candle, // [O, H, L, C, V] + actual: prediction.accuracy.actualCandle, // [O, H, L, C, V] + errors: prediction.accuracy.errors, // {open, high, low, close, volume} + pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume} + directionCorrect: prediction.accuracy.directionCorrect, + accuracy: prediction.accuracy.accuracy + }; + + console.log('[Prediction Metrics for Training]', metrics); + + // Send to backend via WebSocket for incremental training + if (window.socket && window.socket.connected) { + window.socket.emit('prediction_accuracy', metrics); + console.log(`[${timeframe}] Sent prediction accuracy to backend for training`); + } else { + console.warn('[Training] WebSocket not connected - metrics not sent to backend'); + } } /** @@ -2043,9 +2314,9 @@ class ChartManager { this.ghostCandleHistory[timeframe] = this.ghostCandleHistory[timeframe].slice(-this.maxGhostCandles); } - // 4. Add all ghost candles from history to traces + // 4. Add all ghost candles from history to traces (with accuracy if validated) for (const ghost of this.ghostCandleHistory[timeframe]) { - this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime); + this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime, ghost.accuracy); } // 5. Store as "Last Prediction" for shadow rendering @@ -2057,7 +2328,10 @@ class ChartManager { inferenceTime: predictionTimestamp }; - console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`); + console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`, { + predicted: candleData, + timestamp: formattedTimestamp + }); } } @@ -2097,8 +2371,16 @@ class ChartManager { Plotly.deleteTraces(plotId, indicesToRemove); } - // Add new traces + // Add new traces - these will overlay on top of real candles + // Plotly renders traces in order, so predictions added last appear on top Plotly.addTraces(plotId, predictionTraces); + + // Ensure predictions are visible above real candles by setting z-order + // Update layout to ensure prediction traces are on top + Plotly.relayout(plotId, { + 'xaxis.showspikes': false, + 'yaxis.showspikes': false + }); } } catch (error) { @@ -2173,9 +2455,10 @@ class ChartManager { }); } - _addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null) { + _addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null, accuracy = null) { // candleData is [Open, High, Low, Close, Volume] // predictionTimestamp is when the model made this prediction (optional) + // accuracy is the validation metrics (if actual candle has arrived) // If not provided, we calculate the next candle time const chart = this.charts[timeframe]; @@ -2215,8 +2498,46 @@ class ChartManager { const low = candleData[2]; const close = candleData[3]; - // Determine color - const color = close >= open ? '#10b981' : '#ef4444'; + // Determine color based on validation status + // Ghost candles should be 30% opacity to see real candles underneath + let color, opacity; + if (accuracy) { + // Validated prediction - color by accuracy + if (accuracy.directionCorrect) { + color = close >= open ? '#10b981' : '#ef4444'; // Green/Red + } else { + color = '#fbbf24'; // Yellow for wrong direction + } + opacity = 0.3; // 30% - see real candle underneath + } else { + // Unvalidated prediction + color = close >= open ? '#10b981' : '#ef4444'; + opacity = 0.3; // 30% - see real candle underneath + } + + // Build rich tooltip text + let tooltipText = `PREDICTED CANDLE
`; + tooltipText += `O: ${open.toFixed(2)} H: ${high.toFixed(2)}
`; + tooltipText += `L: ${low.toFixed(2)} C: ${close.toFixed(2)}
`; + tooltipText += `Direction: ${close >= open ? 'UP' : 'DOWN'}
`; + + if (accuracy) { + tooltipText += `
--- VALIDATION ---
`; + tooltipText += `Accuracy: ${accuracy.accuracy.toFixed(1)}%
`; + tooltipText += `Direction: ${accuracy.directionCorrect ? 'CORRECT ✓' : 'WRONG ✗'}
`; + tooltipText += `Avg Error: ${accuracy.avgPctError.toFixed(2)}%
`; + tooltipText += `
ACTUAL vs PREDICTED:
`; + tooltipText += `Open: ${accuracy.actualCandle[0].toFixed(2)} vs ${open.toFixed(2)} (${accuracy.pctErrors.open.toFixed(2)}%)
`; + tooltipText += `High: ${accuracy.actualCandle[1].toFixed(2)} vs ${high.toFixed(2)} (${accuracy.pctErrors.high.toFixed(2)}%)
`; + tooltipText += `Low: ${accuracy.actualCandle[2].toFixed(2)} vs ${low.toFixed(2)} (${accuracy.pctErrors.low.toFixed(2)}%)
`; + tooltipText += `Close: ${accuracy.actualCandle[3].toFixed(2)} vs ${close.toFixed(2)} (${accuracy.pctErrors.close.toFixed(2)}%)
`; + if (accuracy.actualCandle[4] !== undefined && accuracy.pctErrors.volume !== undefined) { + const predVolume = candleData[4]; + tooltipText += `Volume: ${accuracy.actualCandle[4].toFixed(2)} vs ${predVolume.toFixed(2)} (${accuracy.pctErrors.volume.toFixed(2)}%)`; + } + } else { + tooltipText += `
Status: AWAITING VALIDATION...`; + } // Create ghost candle trace with formatted timestamp string (same as real candles) // 150% wider than normal candles @@ -2236,14 +2557,14 @@ class ChartManager { line: { color: color, width: 3 }, // 150% wider fillcolor: color }, - opacity: 0.6, // 60% transparent - hoverinfo: 'x+y+text', - text: ['Predicted Next Candle'], + opacity: opacity, + hoverinfo: 'text', + text: [tooltipText], width: 1.5 // 150% width multiplier }; traces.push(ghostTrace); - console.log('Added ghost candle prediction at:', formattedTimestamp, ghostTrace); + console.log('Added ghost candle prediction at:', formattedTimestamp, accuracy ? 'VALIDATED' : 'pending'); } _addShadowCandlePrediction(candleData, timestamp, traces) { diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index 6f7f97f..5532fd4 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1446,33 +1446,39 @@ class TradingTransformerTrainer: candle_rmse = {} if 'next_candles' in outputs: - # Use 1m timeframe as primary metric - if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch: + # Use 1s or 1m timeframe as primary metric (try 1s first) + if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch: + pred_candle = outputs['next_candles']['1s'] # [batch, 5] + actual_candle = batch['future_candle_1s'] # [batch, 5] + elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch: pred_candle = outputs['next_candles']['1m'] # [batch, 5] actual_candle = batch['future_candle_1m'] # [batch, 5] + else: + pred_candle = None + actual_candle = None + + if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape: + # Calculate RMSE for each OHLCV component + rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8) + rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8) + rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8) + rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8) - if actual_candle is not None and pred_candle.shape == actual_candle.shape: - # Calculate RMSE for each OHLCV component - rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8) - rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8) - rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8) - rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8) - - # Average RMSE for OHLC (exclude volume) - avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4 - - # Convert to accuracy: lower RMSE = higher accuracy - # Normalize by price range - price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8) - candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item() - - candle_rmse = { - 'open': rmse_open.item(), - 'high': rmse_high.item(), - 'low': rmse_low.item(), - 'close': rmse_close.item(), - 'avg': avg_rmse.item() - } + # Average RMSE for OHLC (exclude volume) + avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4 + + # Convert to accuracy: lower RMSE = higher accuracy + # Normalize by price range + price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8) + candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item() + + candle_rmse = { + 'open': rmse_open.item(), + 'high': rmse_high.item(), + 'low': rmse_low.item(), + 'close': rmse_close.item(), + 'avg': avg_rmse.item() + } # SECONDARY: Trend vector prediction accuracy trend_accuracy = 0.0