From d349a1bac0a5d1e413155452fbfe6bc5cf14bae3 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Sat, 22 Nov 2025 01:14:32 +0200 Subject: [PATCH] ghost T predictions plotting on chart live chart updates --- ANNOTATE/core/data_loader.py | 27 ++- ANNOTATE/core/real_training_adapter.py | 160 ++++++++++++++++-- ANNOTATE/web/static/js/chart_manager.js | 96 ++++++++++- .../templates/components/training_panel.html | 19 ++- 4 files changed, 272 insertions(+), 30 deletions(-) diff --git a/ANNOTATE/core/data_loader.py b/ANNOTATE/core/data_loader.py index b983a8d..826d0d0 100644 --- a/ANNOTATE/core/data_loader.py +++ b/ANNOTATE/core/data_loader.py @@ -65,14 +65,26 @@ class HistoricalDataLoader: # Check memory cache first (exclude direction from cache key for infinite scroll) cache_key = f"{symbol}_{timeframe}_{start_time}_{end_time}_{limit}" + + # Determine TTL based on timeframe + current_ttl = self.cache_ttl + if timeframe == '1s': + current_ttl = timedelta(seconds=1) + elif timeframe == '1m': + current_ttl = timedelta(seconds=5) + if cache_key in self.memory_cache and direction == 'latest': cached_data, cached_time = self.memory_cache[cache_key] - if datetime.now() - cached_time < self.cache_ttl: - elapsed_ms = (time.time() - start_time_ms) * 1000 - logger.debug(f"⚡ Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)") + if datetime.now() - cached_time < current_ttl: + # For 1s/1m, we want to return immediately if valid + if timeframe not in ['1s', '1m']: + elapsed_ms = (time.time() - start_time_ms) * 1000 + logger.debug(f"⚡ Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)") return cached_data try: + # FORCE refresh for 1s/1m if requesting latest data + force_refresh = (timeframe in ['1s', '1m'] and not start_time and not end_time) # Try to get data from DataProvider's cached data first (most efficient) if hasattr(self.data_provider, 'cached_data'): with self.data_provider.data_lock: @@ -215,7 +227,7 @@ class HistoricalDataLoader: return None # Fallback: Use DataProvider for latest data (startup mode or no time range) - if self.startup_mode and not (start_time or end_time): + if self.startup_mode and not (start_time or end_time) and not force_refresh: logger.info(f"Loading data for {symbol} {timeframe} (startup mode: allow stale cache)") df = self.data_provider.get_historical_data( symbol=symbol, @@ -225,7 +237,12 @@ class HistoricalDataLoader: ) else: # Fetch from API and store in DuckDB (no time range specified) - logger.info(f"Fetching latest data from API for {symbol} {timeframe}") + # For 1s/1m, logging every request is too verbose, use debug + if timeframe in ['1s', '1m']: + logger.debug(f"Fetching latest data from API for {symbol} {timeframe}") + else: + logger.info(f"Fetching latest data from API for {symbol} {timeframe}") + df = self.data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index a28f1ed..015464c 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -16,12 +16,13 @@ import uuid import time import threading import os -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path import torch +import numpy as np try: import pytz @@ -2488,7 +2489,7 @@ class RealTrainingAdapter: trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) if trainer and trainer.model: # Get recent market data - market_data = self._get_realtime_market_data(symbol, data_provider) + market_data, norm_params = self._get_realtime_market_data(symbol, data_provider) if not market_data: return None @@ -2507,41 +2508,162 @@ class RealTrainingAdapter: actions = ['BUY', 'SELL', 'HOLD'] action = actions[action_idx] if action_idx < len(actions) else 'HOLD' + # Handle predicted candles - DENORMALIZE them + predicted_candles_raw = {} + if 'next_candles' in outputs: + for tf, tensor in outputs['next_candles'].items(): + predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist() + + # Denormalize if we have params + predicted_candles_denorm = {} + if predicted_candles_raw and norm_params: + for tf, raw_candle in predicted_candles_raw.items(): + # raw_candle is [1, 5] list + if tf in norm_params: + params = norm_params[tf] + price_min = params['price_min'] + price_max = params['price_max'] + vol_min = params['volume_min'] + vol_max = params['volume_max'] + + # Denormalize [Open, High, Low, Close, Volume] + # Note: raw_candle[0] is the list of 5 values + candle_values = raw_candle[0] + + denorm_candle = [ + candle_values[0] * (price_max - price_min) + price_min, # Open + candle_values[1] * (price_max - price_min) + price_min, # High + candle_values[2] * (price_max - price_min) + price_min, # Low + candle_values[3] * (price_max - price_min) + price_min, # Close + candle_values[4] * (vol_max - vol_min) + vol_min # Volume + ] + predicted_candles_denorm[tf] = denorm_candle + + # Calculate predicted price from candle close + predicted_price = None + if '1m' in predicted_candles_denorm: + predicted_price = predicted_candles_denorm['1m'][3] # Close price + elif '1s' in predicted_candles_denorm: + predicted_price = predicted_candles_denorm['1s'][3] + elif outputs.get('price_prediction') is not None: + # Fallback to price_prediction head if available (normalized) + # This would need separate denormalization based on reference price + pass + return { 'action': action, - 'confidence': confidence + 'confidence': confidence, + 'predicted_price': predicted_price, + 'predicted_candle': predicted_candles_denorm } return None except Exception as e: logger.debug(f"Error making realtime prediction: {e}") + import traceback + logger.debug(traceback.format_exc()) return None - def _get_realtime_market_data(self, symbol: str, data_provider) -> Dict: - """Get current market data for prediction""" + def _get_realtime_market_data(self, symbol: str, data_provider) -> Tuple[Dict, Dict]: + """ + Get current market data for prediction AND normalization parameters + + Returns: + Tuple of (model_inputs_dict, normalization_params_dict) + """ try: # Get recent candles for all timeframes data = {} + norm_params = {} + for tf in ['1s', '1m', '1h', '1d']: - df = data_provider.get_historical_data(symbol, tf, limit=200) + # Get historical data (raw) + # Force refresh for 1s/1m to ensure we have the very latest candle for prediction + refresh = tf in ['1s', '1m'] + df = data_provider.get_historical_data(symbol, tf, limit=600, refresh=refresh) if df is not None and not df.empty: - # Convert to tensor format (simplified) - import torch - import numpy as np + # Extract raw arrays + opens = df['open'].values.astype(np.float32) + highs = df['high'].values.astype(np.float32) + lows = df['low'].values.astype(np.float32) + closes = df['close'].values.astype(np.float32) + volumes = df['volume'].values.astype(np.float32) - candles = df[['open', 'high', 'low', 'close', 'volume']].values - candles_tensor = torch.tensor(candles, dtype=torch.float32).unsqueeze(0) + # Need at least 1 candle + if len(closes) == 0: + continue + + # Prepare OHLCV for normalization logic + # Padding if needed (though limit=600 usually suffices) + if len(closes) < 600: + pad_len = 600 - len(closes) + # Pad with first value + opens = np.pad(opens, (pad_len, 0), mode='edge') + highs = np.pad(highs, (pad_len, 0), mode='edge') + lows = np.pad(lows, (pad_len, 0), mode='edge') + closes = np.pad(closes, (pad_len, 0), mode='edge') + volumes = np.pad(volumes, (pad_len, 0), mode='edge') + else: + # Take last 600 + opens = opens[-600:] + highs = highs[-600:] + lows = lows[-600:] + closes = closes[-600:] + volumes = volumes[-600:] + + # Stack OHLCV [seq_len, 5] + ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1) + + # Calculate min/max for normalization + price_min = np.min(ohlcv[:, :4]) + price_max = np.max(ohlcv[:, :4]) + volume_min = np.min(ohlcv[:, 4]) + volume_max = np.max(ohlcv[:, 4]) + + # Avoid division by zero + if price_max == price_min: price_max += 1.0 + if volume_max == volume_min: volume_max += 1.0 + + # Store params for denormalization later + norm_params[tf] = { + 'price_min': float(price_min), + 'price_max': float(price_max), + 'volume_min': float(volume_min), + 'volume_max': float(volume_max) + } + + # Normalize in-place + ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min) + ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min) + + # Convert to tensor [1, seq_len, 5] + import torch + candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0) data[f'price_data_{tf}'] = candles_tensor - # Add placeholder data for other inputs - import torch - data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32) - data['market_data'] = torch.zeros(1, 30, dtype=torch.float32) + # Add placeholder data for other inputs if we have at least one timeframe + if data: + import torch + # Correct shapes based on model expectation + # tech_data: [1, 40] + # market_data: [1, 30] + # cob_data: [1, 600, 100] + data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32) + data['market_data'] = torch.zeros(1, 30, dtype=torch.float32) + data['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32) + + # Move to device if available + if hasattr(self.orchestrator, 'device'): + device = self.orchestrator.device + for k, v in data.items(): + data[k] = v.to(device) - return data if data else None + return data, norm_params if data else (None, None) except Exception as e: logger.debug(f"Error getting realtime market data: {e}") - return None + import traceback + logger.debug(traceback.format_exc()) + return None, None def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider): """Train model on new candle when it closes""" @@ -2734,7 +2856,9 @@ class RealTrainingAdapter: 'model': model_name, 'action': prediction['action'], 'confidence': prediction['confidence'], - 'price': current_price + 'price': current_price, + 'predicted_price': prediction.get('predicted_price'), + 'predicted_candle': prediction.get('predicted_candle') } session['signals'].append(signal) diff --git a/ANNOTATE/web/static/js/chart_manager.js b/ANNOTATE/web/static/js/chart_manager.js index 9cc99f5..b1203e4 100644 --- a/ANNOTATE/web/static/js/chart_manager.js +++ b/ANNOTATE/web/static/js/chart_manager.js @@ -27,14 +27,14 @@ class ChartManager { this.autoUpdateEnabled = true; console.log('Starting chart auto-update...'); - // Update 1s chart every 20 seconds + // Update 1s chart every 2 seconds (was 20s) if (this.timeframes.includes('1s')) { this.updateTimers['1s'] = setInterval(() => { this.updateChart('1s'); - }, 20000); // 20 seconds + }, 2000); // 2 seconds } - // Update 1m chart - sync to whole minutes + every 20s + // Update 1m chart - sync to whole minutes + every 5s (was 20s) if (this.timeframes.includes('1m')) { // Calculate ms until next whole minute const now = new Date(); @@ -44,10 +44,10 @@ class ChartManager { setTimeout(() => { this.updateChart('1m'); - // Then update every 20s + // Then update every 5s this.updateTimers['1m'] = setInterval(() => { this.updateChart('1m'); - }, 20000); // 20 seconds + }, 5000); // 5 seconds }, msUntilNextMinute); } @@ -1758,6 +1758,7 @@ class ChartManager { // Prepare prediction markers const predictionShapes = []; const predictionAnnotations = []; + const predictionTraces = []; // New traces for ghost candles // Add DQN predictions (arrows) if (predictions.dqn) { @@ -1769,9 +1770,18 @@ class ChartManager { this._addCNNPrediction(predictions.cnn, predictionShapes, predictionAnnotations); } - // Add Transformer predictions (star markers with trend lines) + // Add Transformer predictions (star markers with trend lines + ghost candles) if (predictions.transformer) { this._addTransformerPrediction(predictions.transformer, predictionShapes, predictionAnnotations); + + // Add ghost candle if available + if (predictions.transformer.predicted_candle) { + // Check if we have prediction for this timeframe + const candleData = predictions.transformer.predicted_candle[timeframe]; + if (candleData) { + this._addGhostCandlePrediction(candleData, timeframe, predictionTraces); + } + } } // Update chart layout with predictions @@ -1782,11 +1792,85 @@ class ChartManager { }); } + // Add prediction traces (ghost candles) + if (predictionTraces.length > 0) { + // Remove existing ghost traces first (heuristic: traces with name 'Ghost Prediction') + const currentTraces = plotElement.data.length; + const indicesToRemove = []; + for (let i = 0; i < currentTraces; i++) { + if (plotElement.data[i].name === 'Ghost Prediction') { + indicesToRemove.push(i); + } + } + if (indicesToRemove.length > 0) { + Plotly.deleteTraces(plotId, indicesToRemove); + } + + // Add new traces + Plotly.addTraces(plotId, predictionTraces); + } + } catch (error) { console.debug('Error updating predictions:', error); } } + _addGhostCandlePrediction(candleData, timeframe, traces) { + // candleData is [Open, High, Low, Close, Volume] + // We need to determine the timestamp for this ghost candle + // It should be the NEXT candle after the last one on chart + + const chart = this.charts[timeframe]; + if (!chart || !chart.data) return; + + const lastTimestamp = new Date(chart.data.timestamps[chart.data.timestamps.length - 1]); + let nextTimestamp; + + // Calculate next timestamp based on timeframe + if (timeframe === '1s') { + nextTimestamp = new Date(lastTimestamp.getTime() + 1000); + } else if (timeframe === '1m') { + nextTimestamp = new Date(lastTimestamp.getTime() + 60000); + } else if (timeframe === '1h') { + nextTimestamp = new Date(lastTimestamp.getTime() + 3600000); + } else { + nextTimestamp = new Date(lastTimestamp.getTime() + 60000); // Default 1m + } + + const open = candleData[0]; + const high = candleData[1]; + const low = candleData[2]; + const close = candleData[3]; + + // Determine color + const color = close >= open ? '#10b981' : '#ef4444'; + + // Create ghost candle trace + const ghostTrace = { + x: [nextTimestamp], + open: [open], + high: [high], + low: [low], + close: [close], + type: 'candlestick', + name: 'Ghost Prediction', + increasing: { + line: { color: color, width: 1 }, + fillcolor: color + }, + decreasing: { + line: { color: color, width: 1 }, + fillcolor: color + }, + opacity: 0.3, // 30% transparent + hoverinfo: 'x+y+text', + text: ['Predicted Next Candle'] + }; + + traces.push(ghostTrace); + console.log('Added ghost candle prediction:', ghostTrace); + } + _addDQNPrediction(prediction, shapes, annotations) { const timestamp = new Date(prediction.timestamp || Date.now()); const price = prediction.current_price || 0; diff --git a/ANNOTATE/web/templates/components/training_panel.html b/ANNOTATE/web/templates/components/training_panel.html index 6e4267b..25aac1b 100644 --- a/ANNOTATE/web/templates/components/training_panel.html +++ b/ANNOTATE/web/templates/components/training_panel.html @@ -576,6 +576,11 @@ // Start polling for signals startSignalPolling(); + // Start chart auto-update + if (window.appState && window.appState.chartManager) { + window.appState.chartManager.startAutoUpdate(); + } + const trainingMode = data.training_mode || 'inference-only'; const modeText = trainingMode === 'per-candle' ? ' with per-candle training' : (trainingMode === 'pivot-based' ? ' with pivot training' : ''); @@ -630,6 +635,11 @@ // Stop polling stopSignalPolling(); + + // Stop chart auto-update + if (window.appState && window.appState.chartManager) { + window.appState.chartManager.stopAutoUpdate(); + } currentInferenceId = null; showSuccess('Real-time inference stopped'); @@ -932,9 +942,16 @@ } updatePredictionHistory(); - // Update chart with signal markers + // Update chart with signal markers and predictions if (window.appState && window.appState.chartManager) { displaySignalOnChart(latest); + + // Update ghost candles and other predictions + const predictions = {}; + const modelKey = latest.model ? latest.model.toLowerCase() : 'transformer'; + predictions[modelKey] = latest; + + window.appState.chartManager.updatePredictions(predictions); } } })