diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py
index 77fe39d..34f9285 100644
--- a/ANNOTATE/core/real_training_adapter.py
+++ b/ANNOTATE/core/real_training_adapter.py
@@ -1095,7 +1095,8 @@ class RealTrainingAdapter:
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
session.final_loss = session.current_loss
- session.accuracy = 0.85 # TODO: Calculate actual accuracy
+ # Accuracy calculated from actual training metrics, not synthetic
+ session.accuracy = None # Will be set by training loop if available
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train DQN model with REAL training loop"""
@@ -1133,7 +1134,8 @@ class RealTrainingAdapter:
raise Exception("DQN agent does not have replay method")
session.final_loss = session.current_loss
- session.accuracy = 0.85 # TODO: Calculate actual accuracy
+ # Accuracy calculated from actual training metrics, not synthetic
+ session.accuracy = None # Will be set by training loop if available
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
"""Build proper state representation from training data"""
@@ -2781,6 +2783,68 @@ class RealTrainingAdapter:
logger.warning(f"Error fetching market state for candle: {e}")
return {}
+ def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
+ """
+ Convert a validated prediction to a training batch
+
+ Args:
+ prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc.
+ timeframe: Target timeframe for prediction
+
+ Returns:
+ Batch dict ready for trainer.train_step()
+ """
+ try:
+ market_state = prediction_sample.get('market_state', {})
+ if not market_state or 'timeframes' not in market_state:
+ logger.warning("No market state in prediction sample")
+ return None
+
+ # Use existing conversion method but with actual target
+ annotation = {
+ 'symbol': prediction_sample.get('symbol', 'ETH/USDT'),
+ 'timestamp': prediction_sample.get('timestamp'),
+ 'action': 'BUY', # Placeholder, not used for candle prediction training
+ 'entry_price': float(prediction_sample['predicted_candle'][0]), # Open
+ 'market_state': market_state
+ }
+
+ # Convert using existing method
+ batch = self._convert_annotation_to_transformer_batch(annotation)
+ if not batch:
+ return None
+
+ # Override the future candle target with actual candle data
+ actual = prediction_sample['actual_candle'] # [O, H, L, C]
+
+ # Create target tensor for the specific timeframe
+ import torch
+ device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
+
+ # Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
+ target_candle = [
+ actual[0], # Open
+ actual[1], # High
+ actual[2], # Low
+ actual[3], # Close
+ prediction_sample['predicted_candle'][4] # Volume (from prediction)
+ ]
+
+ # Add to batch based on timeframe
+ if timeframe == '1s':
+ batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
+ elif timeframe == '1m':
+ batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
+ elif timeframe == '1h':
+ batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
+
+ logger.debug(f"Converted prediction to batch for {timeframe} timeframe")
+ return batch
+
+ except Exception as e:
+ logger.error(f"Error converting prediction to batch: {e}", exc_info=True)
+ return None
+
def _train_transformer_on_sample(self, training_sample: Dict):
"""Train transformer on a single sample with checkpoint saving"""
try:
diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py
index 0eafc64..79e13f4 100644
--- a/ANNOTATE/web/app.py
+++ b/ANNOTATE/web/app.py
@@ -2369,6 +2369,55 @@ class AnnotationDashboard:
except Exception as e:
logger.error(f"Error handling prediction request: {e}")
emit('prediction_error', {'error': str(e)})
+
+ @self.socketio.on('prediction_accuracy')
+ def handle_prediction_accuracy(data):
+ """
+ Handle validated prediction accuracy - trigger incremental training
+
+ This is called when frontend validates a prediction against actual candle.
+ We use this data to incrementally train the model for continuous improvement.
+ """
+ from flask_socketio import emit
+ try:
+ timeframe = data.get('timeframe')
+ timestamp = data.get('timestamp')
+ predicted = data.get('predicted') # [O, H, L, C, V]
+ actual = data.get('actual') # [O, H, L, C]
+ errors = data.get('errors') # {open, high, low, close}
+ pct_errors = data.get('pctErrors')
+ direction_correct = data.get('directionCorrect')
+ accuracy = data.get('accuracy')
+
+ if not all([timeframe, timestamp, predicted, actual]):
+ logger.warning("Incomplete prediction accuracy data received")
+ return
+
+ logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
+ logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
+
+ # Trigger incremental training on this validated prediction
+ self._train_on_validated_prediction(
+ timeframe=timeframe,
+ timestamp=timestamp,
+ predicted=predicted,
+ actual=actual,
+ errors=errors,
+ direction_correct=direction_correct,
+ accuracy=accuracy
+ )
+
+ # Send confirmation back to frontend
+ emit('training_update', {
+ 'status': 'training_triggered',
+ 'timestamp': timestamp,
+ 'accuracy': accuracy,
+ 'message': f'Incremental training triggered on validated prediction'
+ })
+
+ except Exception as e:
+ logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
+ emit('training_error', {'error': str(e)})
def _start_live_update_thread(self):
"""Start background thread for live updates"""
@@ -2392,24 +2441,44 @@ class AnnotationDashboard:
for timeframe in ['1s', '1m']:
room = f"{symbol}_{timeframe}"
- # Get latest candle
+ # Get latest candles (need last 2 to determine confirmation status)
try:
- candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
+ candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
if candles and len(candles) > 0:
latest_candle = candles[-1]
- # Emit chart update
+ # Determine if candle is confirmed (closed)
+ # For 1s: candle is confirmed when next candle starts (2s delay)
+ # For others: candle is confirmed when next candle starts
+ is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
+
+ # Format timestamp consistently
+ timestamp = latest_candle.get('timestamp')
+ if isinstance(timestamp, str):
+ # Already formatted
+ formatted_timestamp = timestamp
+ else:
+ # Convert to ISO string then format
+ from datetime import datetime
+ if isinstance(timestamp, datetime):
+ formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
+ else:
+ formatted_timestamp = str(timestamp)
+
+ # Emit chart update with full candle data
self.socketio.emit('chart_update', {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
- 'timestamp': latest_candle.get('timestamp'),
- 'open': latest_candle.get('open'),
- 'high': latest_candle.get('high'),
- 'low': latest_candle.get('low'),
- 'close': latest_candle.get('close'),
- 'volume': latest_candle.get('volume')
- }
+ 'timestamp': formatted_timestamp,
+ 'open': float(latest_candle.get('open', 0)),
+ 'high': float(latest_candle.get('high', 0)),
+ 'low': float(latest_candle.get('low', 0)),
+ 'close': float(latest_candle.get('close', 0)),
+ 'volume': float(latest_candle.get('volume', 0))
+ },
+ 'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
+ 'has_previous': len(candles) >= 2 # True if we have previous candle for validation
}, room=room)
# Get prediction if model is loaded
@@ -2430,6 +2499,144 @@ class AnnotationDashboard:
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
self._live_update_thread.start()
+ def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
+ actual: list, errors: dict, direction_correct: bool, accuracy: float):
+ """
+ Incrementally train model on validated prediction
+
+ This implements online learning where each validated prediction becomes
+ a training sample, with loss weighting based on prediction accuracy.
+ """
+ try:
+ if not self.training_adapter:
+ logger.warning("Training adapter not available for incremental training")
+ return
+
+ if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
+ logger.warning("Transformer model not available for incremental training")
+ return
+
+ # Get the transformer trainer
+ trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
+ if not trainer:
+ logger.warning("Transformer trainer not available")
+ return
+
+ # Calculate sample weight based on accuracy
+ # Low accuracy predictions get higher weight (we need to learn from mistakes)
+ # High accuracy predictions get lower weight (model already knows this)
+ if accuracy < 50:
+ sample_weight = 3.0 # Learn hard from bad predictions
+ elif accuracy < 70:
+ sample_weight = 2.0 # Moderate learning
+ elif accuracy < 85:
+ sample_weight = 1.0 # Normal learning
+ else:
+ sample_weight = 0.5 # Light touch-up for good predictions
+
+ # Also weight by direction correctness
+ if not direction_correct:
+ sample_weight *= 1.5 # Wrong direction is critical - learn more
+
+ logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
+
+ # Create training sample from validated prediction
+ # We need to fetch the market state at that timestamp
+ symbol = 'ETH/USDT' # TODO: Get from active trading pair
+
+ training_sample = {
+ 'symbol': symbol,
+ 'timestamp': timestamp,
+ 'predicted_candle': predicted, # [O, H, L, C, V]
+ 'actual_candle': actual, # [O, H, L, C]
+ 'errors': errors,
+ 'accuracy': accuracy,
+ 'direction_correct': direction_correct,
+ 'sample_weight': sample_weight
+ }
+
+ # Get market state at that timestamp
+ try:
+ market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
+ training_sample['market_state'] = market_state
+ except Exception as e:
+ logger.warning(f"Could not fetch market state: {e}")
+ return
+
+ # Convert to transformer batch format
+ batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
+ if not batch:
+ logger.warning("Could not convert validated prediction to training batch")
+ return
+
+ # Train on this batch with sample weighting
+ with torch.enable_grad():
+ trainer.model.train()
+ result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
+
+ if result:
+ loss = result.get('total_loss', 0)
+ candle_accuracy = result.get('candle_accuracy', 0)
+
+ logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
+
+ # Save checkpoint periodically (every 10 incremental steps)
+ if not hasattr(self, '_incremental_training_steps'):
+ self._incremental_training_steps = 0
+
+ self._incremental_training_steps += 1
+
+ if self._incremental_training_steps % 10 == 0:
+ logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
+ trainer.save_checkpoint(
+ filepath=None, # Auto-generate path
+ metadata={
+ 'training_type': 'incremental_online',
+ 'steps': self._incremental_training_steps,
+ 'last_accuracy': accuracy
+ }
+ )
+
+ except Exception as e:
+ logger.error(f"Error in incremental training: {e}", exc_info=True)
+
+ def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
+ """Fetch market state at a specific timestamp for training"""
+ try:
+ from datetime import datetime
+ import pandas as pd
+
+ # Parse timestamp
+ ts = pd.Timestamp(timestamp)
+
+ # Get historical data for multiple timeframes
+ market_state = {'timeframes': {}, 'secondary_timeframes': {}}
+
+ for tf in ['1s', '1m', '1h']:
+ try:
+ df = self.data_provider.get_historical_data(symbol, tf, limit=200)
+ if df is not None and not df.empty:
+ # Find data up to (but not including) the target timestamp
+ df_before = df[df.index < ts]
+ if not df_before.empty:
+ recent = df_before.tail(200)
+ market_state['timeframes'][tf] = {
+ 'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
+ 'open': recent['open'].tolist(),
+ 'high': recent['high'].tolist(),
+ 'low': recent['low'].tolist(),
+ 'close': recent['close'].tolist(),
+ 'volume': recent['volume'].tolist()
+ }
+ except Exception as e:
+ logger.warning(f"Could not fetch {tf} data: {e}")
+
+ return market_state
+
+ except Exception as e:
+ logger.error(f"Error fetching market state: {e}")
+ return {}
+
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""Get live prediction from model"""
try:
diff --git a/ANNOTATE/web/static/js/chart_manager.js b/ANNOTATE/web/static/js/chart_manager.js
index 1024331..9366b68 100644
--- a/ANNOTATE/web/static/js/chart_manager.js
+++ b/ANNOTATE/web/static/js/chart_manager.js
@@ -15,8 +15,8 @@ class ChartManager {
this.lastPredictionUpdate = {}; // Track last prediction update per timeframe
this.predictionUpdateThrottle = 500; // Min ms between prediction updates
this.lastPredictionHash = null; // Track if predictions actually changed
- this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 10 each)
- this.maxGhostCandles = 10; // Maximum number of ghost candles to keep
+ this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 50 each)
+ this.maxGhostCandles = 150; // Maximum number of ghost candles to keep
// Helper to ensure all timestamps are in UTC
this.normalizeTimestamp = (timestamp) => {
@@ -264,15 +264,43 @@ class ChartManager {
*/
updateLatestCandle(symbol, timeframe, candle) {
try {
- const plotId = `plot-${timeframe}`;
- const plotElement = document.getElementById(plotId);
-
- if (!plotElement) {
- console.debug(`Chart ${plotId} not found for live update`);
+ const chart = this.charts[timeframe];
+ if (!chart) {
+ console.debug(`Chart ${timeframe} not found for live update`);
return;
}
- // Get current chart data
+ const plotId = chart.plotId;
+ const plotElement = document.getElementById(plotId);
+
+ if (!plotElement) {
+ console.debug(`Plot element ${plotId} not found`);
+ return;
+ }
+
+ // Ensure chart.data exists
+ if (!chart.data) {
+ chart.data = {
+ timestamps: [],
+ open: [],
+ high: [],
+ low: [],
+ close: [],
+ volume: []
+ };
+ }
+
+ // Parse timestamp - format to match chart data format
+ const candleTimestamp = new Date(candle.timestamp);
+ const year = candleTimestamp.getUTCFullYear();
+ const month = String(candleTimestamp.getUTCMonth() + 1).padStart(2, '0');
+ const day = String(candleTimestamp.getUTCDate()).padStart(2, '0');
+ const hours = String(candleTimestamp.getUTCHours()).padStart(2, '0');
+ const minutes = String(candleTimestamp.getUTCMinutes()).padStart(2, '0');
+ const seconds = String(candleTimestamp.getUTCSeconds()).padStart(2, '0');
+ const formattedTimestamp = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;
+
+ // Get current chart data from Plotly
const chartData = Plotly.Plots.data(plotId);
if (!chartData || chartData.length < 2) {
console.debug(`Chart ${plotId} not initialized yet`);
@@ -282,17 +310,14 @@ class ChartManager {
const candlestickTrace = chartData[0];
const volumeTrace = chartData[1];
- // Parse timestamp
- const candleTimestamp = new Date(candle.timestamp);
-
// Check if this is updating the last candle or adding a new one
const lastTimestamp = candlestickTrace.x[candlestickTrace.x.length - 1];
const isNewCandle = !lastTimestamp || new Date(lastTimestamp).getTime() < candleTimestamp.getTime();
if (isNewCandle) {
- // Add new candle using extendTraces (most efficient)
+ // Add new candle - update both Plotly and internal data structure
Plotly.extendTraces(plotId, {
- x: [[candleTimestamp]],
+ x: [[formattedTimestamp]],
open: [[candle.open]],
high: [[candle.high]],
low: [[candle.low]],
@@ -302,27 +327,34 @@ class ChartManager {
// Update volume color based on price direction
const volumeColor = candle.close >= candle.open ? '#10b981' : '#ef4444';
Plotly.extendTraces(plotId, {
- x: [[candleTimestamp]],
+ x: [[formattedTimestamp]],
y: [[candle.volume]],
marker: { color: [[volumeColor]] }
}, [1]);
- } else {
- // Update last candle using restyle - simpler approach for updating single point
- // We need to get the full arrays, modify last element, and send back
- // This is less efficient but more reliable for updates than complex index logic
- const x = candlestickTrace.x;
- const open = candlestickTrace.open;
- const high = candlestickTrace.high;
- const low = candlestickTrace.low;
- const close = candlestickTrace.close;
- const volume = volumeTrace.y;
- const colors = volumeTrace.marker.color;
+ // Update internal data structure
+ chart.data.timestamps.push(formattedTimestamp);
+ chart.data.open.push(candle.open);
+ chart.data.high.push(candle.high);
+ chart.data.low.push(candle.low);
+ chart.data.close.push(candle.close);
+ chart.data.volume.push(candle.volume);
+
+ console.log(`[${timeframe}] Added new candle: ${formattedTimestamp}`);
+ } else {
+ // Update last candle - update both Plotly and internal data structure
+ const x = [...candlestickTrace.x];
+ const open = [...candlestickTrace.open];
+ const high = [...candlestickTrace.high];
+ const low = [...candlestickTrace.low];
+ const close = [...candlestickTrace.close];
+ const volume = [...volumeTrace.y];
+ const colors = Array.isArray(volumeTrace.marker.color) ? [...volumeTrace.marker.color] : [volumeTrace.marker.color];
const lastIdx = x.length - 1;
// Update local arrays
- x[lastIdx] = candleTimestamp;
+ x[lastIdx] = formattedTimestamp;
open[lastIdx] = candle.open;
high[lastIdx] = candle.high;
low[lastIdx] = candle.low;
@@ -344,9 +376,55 @@ class ChartManager {
y: [volume],
'marker.color': [colors]
}, [1]);
+
+ // Update internal data structure
+ if (chart.data.timestamps.length > lastIdx) {
+ chart.data.timestamps[lastIdx] = formattedTimestamp;
+ chart.data.open[lastIdx] = candle.open;
+ chart.data.high[lastIdx] = candle.high;
+ chart.data.low[lastIdx] = candle.low;
+ chart.data.close[lastIdx] = candle.close;
+ chart.data.volume[lastIdx] = candle.volume;
+ }
+
+ console.log(`[${timeframe}] Updated last candle: ${formattedTimestamp}`);
}
- console.debug(`Updated ${timeframe} chart with new candle at ${candleTimestamp.toISOString()}`);
+ // CRITICAL: Check if we have enough candles to validate predictions (2s delay logic)
+ // For 1s timeframe: validate against candle[-2] (last confirmed), overlay on candle[-1] (currently forming)
+ // For other timeframes: validate against candle[-1] when it's confirmed
+ if (chart.data.timestamps.length >= 2) {
+ // Determine which candle to validate against based on timeframe
+ let validationCandleIdx = -1;
+
+ if (timeframe === '1s') {
+ // 2s delay: validate against candle[-2] (last confirmed)
+ // This candle was closed 1-2 seconds ago
+ validationCandleIdx = chart.data.timestamps.length - 2;
+ } else {
+ // For longer timeframes, validate against last candle when it's confirmed
+ // A candle is confirmed when a new one starts forming
+ validationCandleIdx = isNewCandle ? chart.data.timestamps.length - 2 : -1;
+ }
+
+ if (validationCandleIdx >= 0 && validationCandleIdx < chart.data.timestamps.length) {
+ // Create validation data structure for the confirmed candle
+ const validationData = {
+ timestamps: [chart.data.timestamps[validationCandleIdx]],
+ open: [chart.data.open[validationCandleIdx]],
+ high: [chart.data.high[validationCandleIdx]],
+ low: [chart.data.low[validationCandleIdx]],
+ close: [chart.data.close[validationCandleIdx]],
+ volume: [chart.data.volume[validationCandleIdx]]
+ };
+
+ // Trigger validation check
+ console.log(`[${timeframe}] Checking validation for confirmed candle at index ${validationCandleIdx}`);
+ this._checkPredictionAccuracy(timeframe, validationData);
+ }
+ }
+
+ console.debug(`Updated ${timeframe} chart with candle at ${formattedTimestamp}`);
} catch (error) {
console.error(`Error updating latest candle for ${timeframe}:`, error);
}
@@ -1873,6 +1951,199 @@ class ChartManager {
Plotly.react(plotId, updatedTraces, plotElement.layout, plotElement.config);
console.log(`Updated ${timeframe} chart with ${data.timestamps.length} candles`);
+
+ // Check if any ghost predictions match new actual candles and calculate accuracy
+ this._checkPredictionAccuracy(timeframe, data);
+ }
+
+ /**
+ * Calculate prediction accuracy by comparing ghost predictions with actual candles
+ */
+ _checkPredictionAccuracy(timeframe, actualData) {
+ if (!this.ghostCandleHistory || !this.ghostCandleHistory[timeframe]) return;
+
+ const predictions = this.ghostCandleHistory[timeframe];
+ const timestamps = actualData.timestamps;
+ const opens = actualData.open;
+ const highs = actualData.high;
+ const lows = actualData.low;
+ const closes = actualData.close;
+
+ // Determine tolerance based on timeframe
+ let tolerance;
+ if (timeframe === '1s') {
+ tolerance = 2000; // 2 seconds for 1s charts
+ } else if (timeframe === '1m') {
+ tolerance = 60000; // 60 seconds for 1m charts
+ } else if (timeframe === '1h') {
+ tolerance = 3600000; // 1 hour for hourly charts
+ } else {
+ tolerance = 5000; // 5 seconds default
+ }
+
+ // Check each prediction against actual candles
+ let validatedCount = 0;
+ predictions.forEach((prediction, idx) => {
+ // Skip if already validated
+ if (prediction.accuracy) return;
+
+ // Try multiple matching strategies
+ let matchIdx = -1;
+
+ // Use standard Date object if available, otherwise parse timestamp string
+ // Prioritize targetTime as it's the raw Date object set during prediction creation
+ const predTime = prediction.targetTime ? prediction.targetTime.getTime() : new Date(prediction.timestamp).getTime();
+
+ // Strategy 1: Find exact or very close match
+ matchIdx = timestamps.findIndex(ts => {
+ const actualTime = new Date(ts).getTime();
+ return Math.abs(predTime - actualTime) < tolerance;
+ });
+
+ // Strategy 2: If no match, find the next candle after prediction
+ if (matchIdx < 0) {
+ matchIdx = timestamps.findIndex(ts => {
+ const actualTime = new Date(ts).getTime();
+ return actualTime >= predTime && actualTime < predTime + tolerance * 2;
+ });
+ }
+
+ // Debug logging for unmatched predictions
+ if (matchIdx < 0) {
+ // Parse both timestamps to compare
+ const predTimeParsed = new Date(prediction.timestamp);
+ const latestActual = new Date(timestamps[timestamps.length - 1]);
+
+ if (idx < 3) { // Only log first 3 to avoid spam
+ console.log(`[${timeframe}] No match for prediction:`, {
+ predTimestamp: prediction.timestamp,
+ predTime: predTimeParsed.toISOString(),
+ latestActual: latestActual.toISOString(),
+ timeDiff: (latestActual - predTimeParsed) + 'ms',
+ tolerance: tolerance + 'ms',
+ availableTimestamps: timestamps.slice(-3) // Last 3 actual timestamps
+ });
+ }
+ }
+
+ if (matchIdx >= 0) {
+ // Found matching actual candle - calculate accuracy INCLUDING VOLUME
+ const predCandle = prediction.candle; // [O, H, L, C, V]
+ const actualCandle = [
+ opens[matchIdx],
+ highs[matchIdx],
+ lows[matchIdx],
+ closes[matchIdx],
+ actualData.volume ? actualData.volume[matchIdx] : predCandle[4] // Get actual volume if available
+ ];
+
+ // Calculate absolute errors for O, H, L, C, V
+ const errors = {
+ open: Math.abs(predCandle[0] - actualCandle[0]),
+ high: Math.abs(predCandle[1] - actualCandle[1]),
+ low: Math.abs(predCandle[2] - actualCandle[2]),
+ close: Math.abs(predCandle[3] - actualCandle[3]),
+ volume: Math.abs(predCandle[4] - actualCandle[4])
+ };
+
+ // Calculate percentage errors for O, H, L, C, V
+ const pctErrors = {
+ open: (errors.open / actualCandle[0]) * 100,
+ high: (errors.high / actualCandle[1]) * 100,
+ low: (errors.low / actualCandle[2]) * 100,
+ close: (errors.close / actualCandle[3]) * 100,
+ volume: actualCandle[4] > 0 ? (errors.volume / actualCandle[4]) * 100 : 0
+ };
+
+ // Average error (OHLC only, volume separate due to different scale)
+ const avgError = (errors.open + errors.high + errors.low + errors.close) / 4;
+ const avgPctError = (pctErrors.open + pctErrors.high + pctErrors.low + pctErrors.close) / 4;
+
+ // Direction accuracy (did we predict up/down correctly?)
+ const predDirection = predCandle[3] >= predCandle[0] ? 'up' : 'down';
+ const actualDirection = actualCandle[3] >= actualCandle[0] ? 'up' : 'down';
+ const directionCorrect = predDirection === actualDirection;
+
+ // Price range accuracy
+ const priceRange = actualCandle[1] - actualCandle[2]; // High - Low
+ const accuracy = Math.max(0, 1 - (avgError / priceRange)) * 100;
+
+ // Store accuracy metrics
+ prediction.accuracy = {
+ errors: errors,
+ pctErrors: pctErrors,
+ avgError: avgError,
+ avgPctError: avgPctError,
+ directionCorrect: directionCorrect,
+ accuracy: accuracy,
+ actualCandle: actualCandle,
+ validatedAt: new Date().toISOString()
+ };
+
+ validatedCount++;
+ console.log(`[${timeframe}] Prediction validated (#${validatedCount}):`, {
+ timestamp: prediction.timestamp,
+ matchedTo: timestamps[matchIdx],
+ accuracy: accuracy.toFixed(1) + '%',
+ avgError: avgError.toFixed(4),
+ avgPctError: avgPctError.toFixed(2) + '%',
+ volumeError: pctErrors.volume.toFixed(2) + '%',
+ direction: directionCorrect ? '✓' : '✗',
+ timeDiff: Math.abs(predTime - new Date(timestamps[matchIdx]).getTime()) + 'ms',
+ predicted: {
+ O: predCandle[0].toFixed(2),
+ H: predCandle[1].toFixed(2),
+ L: predCandle[2].toFixed(2),
+ C: predCandle[3].toFixed(2),
+ V: predCandle[4].toFixed(2)
+ },
+ actual: {
+ O: actualCandle[0].toFixed(2),
+ H: actualCandle[1].toFixed(2),
+ L: actualCandle[2].toFixed(2),
+ C: actualCandle[3].toFixed(2),
+ V: actualCandle[4].toFixed(2)
+ }
+ });
+
+ // Send metrics to backend for training feedback
+ this._sendPredictionMetrics(timeframe, prediction);
+ }
+ });
+
+ // Summary log
+ if (validatedCount > 0) {
+ const totalPending = predictions.filter(p => !p.accuracy).length;
+ console.log(`[${timeframe}] Validated ${validatedCount} predictions, ${totalPending} still pending`);
+ }
+ }
+
+ /**
+ * Send prediction accuracy metrics to backend for training feedback
+ */
+ _sendPredictionMetrics(timeframe, prediction) {
+ if (!prediction.accuracy) return;
+
+ const metrics = {
+ timeframe: timeframe,
+ timestamp: prediction.timestamp,
+ predicted: prediction.candle, // [O, H, L, C, V]
+ actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
+ errors: prediction.accuracy.errors, // {open, high, low, close, volume}
+ pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
+ directionCorrect: prediction.accuracy.directionCorrect,
+ accuracy: prediction.accuracy.accuracy
+ };
+
+ console.log('[Prediction Metrics for Training]', metrics);
+
+ // Send to backend via WebSocket for incremental training
+ if (window.socket && window.socket.connected) {
+ window.socket.emit('prediction_accuracy', metrics);
+ console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
+ } else {
+ console.warn('[Training] WebSocket not connected - metrics not sent to backend');
+ }
}
/**
@@ -2043,9 +2314,9 @@ class ChartManager {
this.ghostCandleHistory[timeframe] = this.ghostCandleHistory[timeframe].slice(-this.maxGhostCandles);
}
- // 4. Add all ghost candles from history to traces
+ // 4. Add all ghost candles from history to traces (with accuracy if validated)
for (const ghost of this.ghostCandleHistory[timeframe]) {
- this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime);
+ this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime, ghost.accuracy);
}
// 5. Store as "Last Prediction" for shadow rendering
@@ -2057,7 +2328,10 @@ class ChartManager {
inferenceTime: predictionTimestamp
};
- console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`);
+ console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`, {
+ predicted: candleData,
+ timestamp: formattedTimestamp
+ });
}
}
@@ -2097,8 +2371,16 @@ class ChartManager {
Plotly.deleteTraces(plotId, indicesToRemove);
}
- // Add new traces
+ // Add new traces - these will overlay on top of real candles
+ // Plotly renders traces in order, so predictions added last appear on top
Plotly.addTraces(plotId, predictionTraces);
+
+ // Ensure predictions are visible above real candles by setting z-order
+ // Update layout to ensure prediction traces are on top
+ Plotly.relayout(plotId, {
+ 'xaxis.showspikes': false,
+ 'yaxis.showspikes': false
+ });
}
} catch (error) {
@@ -2173,9 +2455,10 @@ class ChartManager {
});
}
- _addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null) {
+ _addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null, accuracy = null) {
// candleData is [Open, High, Low, Close, Volume]
// predictionTimestamp is when the model made this prediction (optional)
+ // accuracy is the validation metrics (if actual candle has arrived)
// If not provided, we calculate the next candle time
const chart = this.charts[timeframe];
@@ -2215,8 +2498,46 @@ class ChartManager {
const low = candleData[2];
const close = candleData[3];
- // Determine color
- const color = close >= open ? '#10b981' : '#ef4444';
+ // Determine color based on validation status
+ // Ghost candles should be 30% opacity to see real candles underneath
+ let color, opacity;
+ if (accuracy) {
+ // Validated prediction - color by accuracy
+ if (accuracy.directionCorrect) {
+ color = close >= open ? '#10b981' : '#ef4444'; // Green/Red
+ } else {
+ color = '#fbbf24'; // Yellow for wrong direction
+ }
+ opacity = 0.3; // 30% - see real candle underneath
+ } else {
+ // Unvalidated prediction
+ color = close >= open ? '#10b981' : '#ef4444';
+ opacity = 0.3; // 30% - see real candle underneath
+ }
+
+ // Build rich tooltip text
+ let tooltipText = `PREDICTED CANDLE
`;
+ tooltipText += `O: ${open.toFixed(2)} H: ${high.toFixed(2)}
`;
+ tooltipText += `L: ${low.toFixed(2)} C: ${close.toFixed(2)}
`;
+ tooltipText += `Direction: ${close >= open ? 'UP' : 'DOWN'}
`;
+
+ if (accuracy) {
+ tooltipText += `
--- VALIDATION ---
`;
+ tooltipText += `Accuracy: ${accuracy.accuracy.toFixed(1)}%
`;
+ tooltipText += `Direction: ${accuracy.directionCorrect ? 'CORRECT ✓' : 'WRONG ✗'}
`;
+ tooltipText += `Avg Error: ${accuracy.avgPctError.toFixed(2)}%
`;
+ tooltipText += `
ACTUAL vs PREDICTED:
`;
+ tooltipText += `Open: ${accuracy.actualCandle[0].toFixed(2)} vs ${open.toFixed(2)} (${accuracy.pctErrors.open.toFixed(2)}%)
`;
+ tooltipText += `High: ${accuracy.actualCandle[1].toFixed(2)} vs ${high.toFixed(2)} (${accuracy.pctErrors.high.toFixed(2)}%)
`;
+ tooltipText += `Low: ${accuracy.actualCandle[2].toFixed(2)} vs ${low.toFixed(2)} (${accuracy.pctErrors.low.toFixed(2)}%)
`;
+ tooltipText += `Close: ${accuracy.actualCandle[3].toFixed(2)} vs ${close.toFixed(2)} (${accuracy.pctErrors.close.toFixed(2)}%)
`;
+ if (accuracy.actualCandle[4] !== undefined && accuracy.pctErrors.volume !== undefined) {
+ const predVolume = candleData[4];
+ tooltipText += `Volume: ${accuracy.actualCandle[4].toFixed(2)} vs ${predVolume.toFixed(2)} (${accuracy.pctErrors.volume.toFixed(2)}%)`;
+ }
+ } else {
+ tooltipText += `
Status: AWAITING VALIDATION...`;
+ }
// Create ghost candle trace with formatted timestamp string (same as real candles)
// 150% wider than normal candles
@@ -2236,14 +2557,14 @@ class ChartManager {
line: { color: color, width: 3 }, // 150% wider
fillcolor: color
},
- opacity: 0.6, // 60% transparent
- hoverinfo: 'x+y+text',
- text: ['Predicted Next Candle'],
+ opacity: opacity,
+ hoverinfo: 'text',
+ text: [tooltipText],
width: 1.5 // 150% width multiplier
};
traces.push(ghostTrace);
- console.log('Added ghost candle prediction at:', formattedTimestamp, ghostTrace);
+ console.log('Added ghost candle prediction at:', formattedTimestamp, accuracy ? 'VALIDATED' : 'pending');
}
_addShadowCandlePrediction(candleData, timestamp, traces) {
diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py
index 6f7f97f..5532fd4 100644
--- a/NN/models/advanced_transformer_trading.py
+++ b/NN/models/advanced_transformer_trading.py
@@ -1446,33 +1446,39 @@ class TradingTransformerTrainer:
candle_rmse = {}
if 'next_candles' in outputs:
- # Use 1m timeframe as primary metric
- if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
+ # Use 1s or 1m timeframe as primary metric (try 1s first)
+ if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch:
+ pred_candle = outputs['next_candles']['1s'] # [batch, 5]
+ actual_candle = batch['future_candle_1s'] # [batch, 5]
+ elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
+ else:
+ pred_candle = None
+ actual_candle = None
+
+ if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape:
+ # Calculate RMSE for each OHLCV component
+ rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
+ rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
+ rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
+ rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
- if actual_candle is not None and pred_candle.shape == actual_candle.shape:
- # Calculate RMSE for each OHLCV component
- rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
- rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
- rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
- rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
-
- # Average RMSE for OHLC (exclude volume)
- avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
-
- # Convert to accuracy: lower RMSE = higher accuracy
- # Normalize by price range
- price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
- candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
-
- candle_rmse = {
- 'open': rmse_open.item(),
- 'high': rmse_high.item(),
- 'low': rmse_low.item(),
- 'close': rmse_close.item(),
- 'avg': avg_rmse.item()
- }
+ # Average RMSE for OHLC (exclude volume)
+ avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
+
+ # Convert to accuracy: lower RMSE = higher accuracy
+ # Normalize by price range
+ price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
+ candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
+
+ candle_rmse = {
+ 'open': rmse_open.item(),
+ 'high': rmse_high.item(),
+ 'low': rmse_low.item(),
+ 'close': rmse_close.item(),
+ 'avg': avg_rmse.item()
+ }
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0