IMPLEMENTED: WIP; realtime candle predictions training

This commit is contained in:
Dobromir Popov
2025-11-22 17:57:58 +02:00
parent 423132dc8f
commit 26cbfd771b
4 changed files with 672 additions and 74 deletions

View File

@@ -1095,7 +1095,8 @@ class RealTrainingAdapter:
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy
# Accuracy calculated from actual training metrics, not synthetic
session.accuracy = None # Will be set by training loop if available
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train DQN model with REAL training loop"""
@@ -1133,7 +1134,8 @@ class RealTrainingAdapter:
raise Exception("DQN agent does not have replay method")
session.final_loss = session.current_loss
session.accuracy = 0.85 # TODO: Calculate actual accuracy
# Accuracy calculated from actual training metrics, not synthetic
session.accuracy = None # Will be set by training loop if available
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
"""Build proper state representation from training data"""
@@ -2781,6 +2783,68 @@ class RealTrainingAdapter:
logger.warning(f"Error fetching market state for candle: {e}")
return {}
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
"""
Convert a validated prediction to a training batch
Args:
prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc.
timeframe: Target timeframe for prediction
Returns:
Batch dict ready for trainer.train_step()
"""
try:
market_state = prediction_sample.get('market_state', {})
if not market_state or 'timeframes' not in market_state:
logger.warning("No market state in prediction sample")
return None
# Use existing conversion method but with actual target
annotation = {
'symbol': prediction_sample.get('symbol', 'ETH/USDT'),
'timestamp': prediction_sample.get('timestamp'),
'action': 'BUY', # Placeholder, not used for candle prediction training
'entry_price': float(prediction_sample['predicted_candle'][0]), # Open
'market_state': market_state
}
# Convert using existing method
batch = self._convert_annotation_to_transformer_batch(annotation)
if not batch:
return None
# Override the future candle target with actual candle data
actual = prediction_sample['actual_candle'] # [O, H, L, C]
# Create target tensor for the specific timeframe
import torch
device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
# Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
target_candle = [
actual[0], # Open
actual[1], # High
actual[2], # Low
actual[3], # Close
prediction_sample['predicted_candle'][4] # Volume (from prediction)
]
# Add to batch based on timeframe
if timeframe == '1s':
batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
elif timeframe == '1m':
batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
elif timeframe == '1h':
batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
logger.debug(f"Converted prediction to batch for {timeframe} timeframe")
return batch
except Exception as e:
logger.error(f"Error converting prediction to batch: {e}", exc_info=True)
return None
def _train_transformer_on_sample(self, training_sample: Dict):
"""Train transformer on a single sample with checkpoint saving"""
try:

View File

@@ -2369,6 +2369,55 @@ class AnnotationDashboard:
except Exception as e:
logger.error(f"Error handling prediction request: {e}")
emit('prediction_error', {'error': str(e)})
@self.socketio.on('prediction_accuracy')
def handle_prediction_accuracy(data):
"""
Handle validated prediction accuracy - trigger incremental training
This is called when frontend validates a prediction against actual candle.
We use this data to incrementally train the model for continuous improvement.
"""
from flask_socketio import emit
try:
timeframe = data.get('timeframe')
timestamp = data.get('timestamp')
predicted = data.get('predicted') # [O, H, L, C, V]
actual = data.get('actual') # [O, H, L, C]
errors = data.get('errors') # {open, high, low, close}
pct_errors = data.get('pctErrors')
direction_correct = data.get('directionCorrect')
accuracy = data.get('accuracy')
if not all([timeframe, timestamp, predicted, actual]):
logger.warning("Incomplete prediction accuracy data received")
return
logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
# Trigger incremental training on this validated prediction
self._train_on_validated_prediction(
timeframe=timeframe,
timestamp=timestamp,
predicted=predicted,
actual=actual,
errors=errors,
direction_correct=direction_correct,
accuracy=accuracy
)
# Send confirmation back to frontend
emit('training_update', {
'status': 'training_triggered',
'timestamp': timestamp,
'accuracy': accuracy,
'message': f'Incremental training triggered on validated prediction'
})
except Exception as e:
logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
emit('training_error', {'error': str(e)})
def _start_live_update_thread(self):
"""Start background thread for live updates"""
@@ -2392,24 +2441,44 @@ class AnnotationDashboard:
for timeframe in ['1s', '1m']:
room = f"{symbol}_{timeframe}"
# Get latest candle
# Get latest candles (need last 2 to determine confirmation status)
try:
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
if candles and len(candles) > 0:
latest_candle = candles[-1]
# Emit chart update
# Determine if candle is confirmed (closed)
# For 1s: candle is confirmed when next candle starts (2s delay)
# For others: candle is confirmed when next candle starts
is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
# Format timestamp consistently
timestamp = latest_candle.get('timestamp')
if isinstance(timestamp, str):
# Already formatted
formatted_timestamp = timestamp
else:
# Convert to ISO string then format
from datetime import datetime
if isinstance(timestamp, datetime):
formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
else:
formatted_timestamp = str(timestamp)
# Emit chart update with full candle data
self.socketio.emit('chart_update', {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': latest_candle.get('timestamp'),
'open': latest_candle.get('open'),
'high': latest_candle.get('high'),
'low': latest_candle.get('low'),
'close': latest_candle.get('close'),
'volume': latest_candle.get('volume')
}
'timestamp': formatted_timestamp,
'open': float(latest_candle.get('open', 0)),
'high': float(latest_candle.get('high', 0)),
'low': float(latest_candle.get('low', 0)),
'close': float(latest_candle.get('close', 0)),
'volume': float(latest_candle.get('volume', 0))
},
'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
'has_previous': len(candles) >= 2 # True if we have previous candle for validation
}, room=room)
# Get prediction if model is loaded
@@ -2430,6 +2499,144 @@ class AnnotationDashboard:
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
self._live_update_thread.start()
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
actual: list, errors: dict, direction_correct: bool, accuracy: float):
"""
Incrementally train model on validated prediction
This implements online learning where each validated prediction becomes
a training sample, with loss weighting based on prediction accuracy.
"""
try:
if not self.training_adapter:
logger.warning("Training adapter not available for incremental training")
return
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
logger.warning("Transformer model not available for incremental training")
return
# Get the transformer trainer
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
logger.warning("Transformer trainer not available")
return
# Calculate sample weight based on accuracy
# Low accuracy predictions get higher weight (we need to learn from mistakes)
# High accuracy predictions get lower weight (model already knows this)
if accuracy < 50:
sample_weight = 3.0 # Learn hard from bad predictions
elif accuracy < 70:
sample_weight = 2.0 # Moderate learning
elif accuracy < 85:
sample_weight = 1.0 # Normal learning
else:
sample_weight = 0.5 # Light touch-up for good predictions
# Also weight by direction correctness
if not direction_correct:
sample_weight *= 1.5 # Wrong direction is critical - learn more
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
# Create training sample from validated prediction
# We need to fetch the market state at that timestamp
symbol = 'ETH/USDT' # TODO: Get from active trading pair
training_sample = {
'symbol': symbol,
'timestamp': timestamp,
'predicted_candle': predicted, # [O, H, L, C, V]
'actual_candle': actual, # [O, H, L, C]
'errors': errors,
'accuracy': accuracy,
'direction_correct': direction_correct,
'sample_weight': sample_weight
}
# Get market state at that timestamp
try:
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
training_sample['market_state'] = market_state
except Exception as e:
logger.warning(f"Could not fetch market state: {e}")
return
# Convert to transformer batch format
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
if not batch:
logger.warning("Could not convert validated prediction to training batch")
return
# Train on this batch with sample weighting
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
if result:
loss = result.get('total_loss', 0)
candle_accuracy = result.get('candle_accuracy', 0)
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
# Save checkpoint periodically (every 10 incremental steps)
if not hasattr(self, '_incremental_training_steps'):
self._incremental_training_steps = 0
self._incremental_training_steps += 1
if self._incremental_training_steps % 10 == 0:
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
trainer.save_checkpoint(
filepath=None, # Auto-generate path
metadata={
'training_type': 'incremental_online',
'steps': self._incremental_training_steps,
'last_accuracy': accuracy
}
)
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""
try:
from datetime import datetime
import pandas as pd
# Parse timestamp
ts = pd.Timestamp(timestamp)
# Get historical data for multiple timeframes
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h']:
try:
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
if df is not None and not df.empty:
# Find data up to (but not including) the target timestamp
df_before = df[df.index < ts]
if not df_before.empty:
recent = df_before.tail(200)
market_state['timeframes'][tf] = {
'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': recent['open'].tolist(),
'high': recent['high'].tolist(),
'low': recent['low'].tolist(),
'close': recent['close'].tolist(),
'volume': recent['volume'].tolist()
}
except Exception as e:
logger.warning(f"Could not fetch {tf} data: {e}")
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""Get live prediction from model"""
try:

View File

@@ -15,8 +15,8 @@ class ChartManager {
this.lastPredictionUpdate = {}; // Track last prediction update per timeframe
this.predictionUpdateThrottle = 500; // Min ms between prediction updates
this.lastPredictionHash = null; // Track if predictions actually changed
this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 10 each)
this.maxGhostCandles = 10; // Maximum number of ghost candles to keep
this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 50 each)
this.maxGhostCandles = 150; // Maximum number of ghost candles to keep
// Helper to ensure all timestamps are in UTC
this.normalizeTimestamp = (timestamp) => {
@@ -264,15 +264,43 @@ class ChartManager {
*/
updateLatestCandle(symbol, timeframe, candle) {
try {
const plotId = `plot-${timeframe}`;
const plotElement = document.getElementById(plotId);
if (!plotElement) {
console.debug(`Chart ${plotId} not found for live update`);
const chart = this.charts[timeframe];
if (!chart) {
console.debug(`Chart ${timeframe} not found for live update`);
return;
}
// Get current chart data
const plotId = chart.plotId;
const plotElement = document.getElementById(plotId);
if (!plotElement) {
console.debug(`Plot element ${plotId} not found`);
return;
}
// Ensure chart.data exists
if (!chart.data) {
chart.data = {
timestamps: [],
open: [],
high: [],
low: [],
close: [],
volume: []
};
}
// Parse timestamp - format to match chart data format
const candleTimestamp = new Date(candle.timestamp);
const year = candleTimestamp.getUTCFullYear();
const month = String(candleTimestamp.getUTCMonth() + 1).padStart(2, '0');
const day = String(candleTimestamp.getUTCDate()).padStart(2, '0');
const hours = String(candleTimestamp.getUTCHours()).padStart(2, '0');
const minutes = String(candleTimestamp.getUTCMinutes()).padStart(2, '0');
const seconds = String(candleTimestamp.getUTCSeconds()).padStart(2, '0');
const formattedTimestamp = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;
// Get current chart data from Plotly
const chartData = Plotly.Plots.data(plotId);
if (!chartData || chartData.length < 2) {
console.debug(`Chart ${plotId} not initialized yet`);
@@ -282,17 +310,14 @@ class ChartManager {
const candlestickTrace = chartData[0];
const volumeTrace = chartData[1];
// Parse timestamp
const candleTimestamp = new Date(candle.timestamp);
// Check if this is updating the last candle or adding a new one
const lastTimestamp = candlestickTrace.x[candlestickTrace.x.length - 1];
const isNewCandle = !lastTimestamp || new Date(lastTimestamp).getTime() < candleTimestamp.getTime();
if (isNewCandle) {
// Add new candle using extendTraces (most efficient)
// Add new candle - update both Plotly and internal data structure
Plotly.extendTraces(plotId, {
x: [[candleTimestamp]],
x: [[formattedTimestamp]],
open: [[candle.open]],
high: [[candle.high]],
low: [[candle.low]],
@@ -302,27 +327,34 @@ class ChartManager {
// Update volume color based on price direction
const volumeColor = candle.close >= candle.open ? '#10b981' : '#ef4444';
Plotly.extendTraces(plotId, {
x: [[candleTimestamp]],
x: [[formattedTimestamp]],
y: [[candle.volume]],
marker: { color: [[volumeColor]] }
}, [1]);
} else {
// Update last candle using restyle - simpler approach for updating single point
// We need to get the full arrays, modify last element, and send back
// This is less efficient but more reliable for updates than complex index logic
const x = candlestickTrace.x;
const open = candlestickTrace.open;
const high = candlestickTrace.high;
const low = candlestickTrace.low;
const close = candlestickTrace.close;
const volume = volumeTrace.y;
const colors = volumeTrace.marker.color;
// Update internal data structure
chart.data.timestamps.push(formattedTimestamp);
chart.data.open.push(candle.open);
chart.data.high.push(candle.high);
chart.data.low.push(candle.low);
chart.data.close.push(candle.close);
chart.data.volume.push(candle.volume);
console.log(`[${timeframe}] Added new candle: ${formattedTimestamp}`);
} else {
// Update last candle - update both Plotly and internal data structure
const x = [...candlestickTrace.x];
const open = [...candlestickTrace.open];
const high = [...candlestickTrace.high];
const low = [...candlestickTrace.low];
const close = [...candlestickTrace.close];
const volume = [...volumeTrace.y];
const colors = Array.isArray(volumeTrace.marker.color) ? [...volumeTrace.marker.color] : [volumeTrace.marker.color];
const lastIdx = x.length - 1;
// Update local arrays
x[lastIdx] = candleTimestamp;
x[lastIdx] = formattedTimestamp;
open[lastIdx] = candle.open;
high[lastIdx] = candle.high;
low[lastIdx] = candle.low;
@@ -344,9 +376,55 @@ class ChartManager {
y: [volume],
'marker.color': [colors]
}, [1]);
// Update internal data structure
if (chart.data.timestamps.length > lastIdx) {
chart.data.timestamps[lastIdx] = formattedTimestamp;
chart.data.open[lastIdx] = candle.open;
chart.data.high[lastIdx] = candle.high;
chart.data.low[lastIdx] = candle.low;
chart.data.close[lastIdx] = candle.close;
chart.data.volume[lastIdx] = candle.volume;
}
console.log(`[${timeframe}] Updated last candle: ${formattedTimestamp}`);
}
console.debug(`Updated ${timeframe} chart with new candle at ${candleTimestamp.toISOString()}`);
// CRITICAL: Check if we have enough candles to validate predictions (2s delay logic)
// For 1s timeframe: validate against candle[-2] (last confirmed), overlay on candle[-1] (currently forming)
// For other timeframes: validate against candle[-1] when it's confirmed
if (chart.data.timestamps.length >= 2) {
// Determine which candle to validate against based on timeframe
let validationCandleIdx = -1;
if (timeframe === '1s') {
// 2s delay: validate against candle[-2] (last confirmed)
// This candle was closed 1-2 seconds ago
validationCandleIdx = chart.data.timestamps.length - 2;
} else {
// For longer timeframes, validate against last candle when it's confirmed
// A candle is confirmed when a new one starts forming
validationCandleIdx = isNewCandle ? chart.data.timestamps.length - 2 : -1;
}
if (validationCandleIdx >= 0 && validationCandleIdx < chart.data.timestamps.length) {
// Create validation data structure for the confirmed candle
const validationData = {
timestamps: [chart.data.timestamps[validationCandleIdx]],
open: [chart.data.open[validationCandleIdx]],
high: [chart.data.high[validationCandleIdx]],
low: [chart.data.low[validationCandleIdx]],
close: [chart.data.close[validationCandleIdx]],
volume: [chart.data.volume[validationCandleIdx]]
};
// Trigger validation check
console.log(`[${timeframe}] Checking validation for confirmed candle at index ${validationCandleIdx}`);
this._checkPredictionAccuracy(timeframe, validationData);
}
}
console.debug(`Updated ${timeframe} chart with candle at ${formattedTimestamp}`);
} catch (error) {
console.error(`Error updating latest candle for ${timeframe}:`, error);
}
@@ -1873,6 +1951,199 @@ class ChartManager {
Plotly.react(plotId, updatedTraces, plotElement.layout, plotElement.config);
console.log(`Updated ${timeframe} chart with ${data.timestamps.length} candles`);
// Check if any ghost predictions match new actual candles and calculate accuracy
this._checkPredictionAccuracy(timeframe, data);
}
/**
* Calculate prediction accuracy by comparing ghost predictions with actual candles
*/
_checkPredictionAccuracy(timeframe, actualData) {
if (!this.ghostCandleHistory || !this.ghostCandleHistory[timeframe]) return;
const predictions = this.ghostCandleHistory[timeframe];
const timestamps = actualData.timestamps;
const opens = actualData.open;
const highs = actualData.high;
const lows = actualData.low;
const closes = actualData.close;
// Determine tolerance based on timeframe
let tolerance;
if (timeframe === '1s') {
tolerance = 2000; // 2 seconds for 1s charts
} else if (timeframe === '1m') {
tolerance = 60000; // 60 seconds for 1m charts
} else if (timeframe === '1h') {
tolerance = 3600000; // 1 hour for hourly charts
} else {
tolerance = 5000; // 5 seconds default
}
// Check each prediction against actual candles
let validatedCount = 0;
predictions.forEach((prediction, idx) => {
// Skip if already validated
if (prediction.accuracy) return;
// Try multiple matching strategies
let matchIdx = -1;
// Use standard Date object if available, otherwise parse timestamp string
// Prioritize targetTime as it's the raw Date object set during prediction creation
const predTime = prediction.targetTime ? prediction.targetTime.getTime() : new Date(prediction.timestamp).getTime();
// Strategy 1: Find exact or very close match
matchIdx = timestamps.findIndex(ts => {
const actualTime = new Date(ts).getTime();
return Math.abs(predTime - actualTime) < tolerance;
});
// Strategy 2: If no match, find the next candle after prediction
if (matchIdx < 0) {
matchIdx = timestamps.findIndex(ts => {
const actualTime = new Date(ts).getTime();
return actualTime >= predTime && actualTime < predTime + tolerance * 2;
});
}
// Debug logging for unmatched predictions
if (matchIdx < 0) {
// Parse both timestamps to compare
const predTimeParsed = new Date(prediction.timestamp);
const latestActual = new Date(timestamps[timestamps.length - 1]);
if (idx < 3) { // Only log first 3 to avoid spam
console.log(`[${timeframe}] No match for prediction:`, {
predTimestamp: prediction.timestamp,
predTime: predTimeParsed.toISOString(),
latestActual: latestActual.toISOString(),
timeDiff: (latestActual - predTimeParsed) + 'ms',
tolerance: tolerance + 'ms',
availableTimestamps: timestamps.slice(-3) // Last 3 actual timestamps
});
}
}
if (matchIdx >= 0) {
// Found matching actual candle - calculate accuracy INCLUDING VOLUME
const predCandle = prediction.candle; // [O, H, L, C, V]
const actualCandle = [
opens[matchIdx],
highs[matchIdx],
lows[matchIdx],
closes[matchIdx],
actualData.volume ? actualData.volume[matchIdx] : predCandle[4] // Get actual volume if available
];
// Calculate absolute errors for O, H, L, C, V
const errors = {
open: Math.abs(predCandle[0] - actualCandle[0]),
high: Math.abs(predCandle[1] - actualCandle[1]),
low: Math.abs(predCandle[2] - actualCandle[2]),
close: Math.abs(predCandle[3] - actualCandle[3]),
volume: Math.abs(predCandle[4] - actualCandle[4])
};
// Calculate percentage errors for O, H, L, C, V
const pctErrors = {
open: (errors.open / actualCandle[0]) * 100,
high: (errors.high / actualCandle[1]) * 100,
low: (errors.low / actualCandle[2]) * 100,
close: (errors.close / actualCandle[3]) * 100,
volume: actualCandle[4] > 0 ? (errors.volume / actualCandle[4]) * 100 : 0
};
// Average error (OHLC only, volume separate due to different scale)
const avgError = (errors.open + errors.high + errors.low + errors.close) / 4;
const avgPctError = (pctErrors.open + pctErrors.high + pctErrors.low + pctErrors.close) / 4;
// Direction accuracy (did we predict up/down correctly?)
const predDirection = predCandle[3] >= predCandle[0] ? 'up' : 'down';
const actualDirection = actualCandle[3] >= actualCandle[0] ? 'up' : 'down';
const directionCorrect = predDirection === actualDirection;
// Price range accuracy
const priceRange = actualCandle[1] - actualCandle[2]; // High - Low
const accuracy = Math.max(0, 1 - (avgError / priceRange)) * 100;
// Store accuracy metrics
prediction.accuracy = {
errors: errors,
pctErrors: pctErrors,
avgError: avgError,
avgPctError: avgPctError,
directionCorrect: directionCorrect,
accuracy: accuracy,
actualCandle: actualCandle,
validatedAt: new Date().toISOString()
};
validatedCount++;
console.log(`[${timeframe}] Prediction validated (#${validatedCount}):`, {
timestamp: prediction.timestamp,
matchedTo: timestamps[matchIdx],
accuracy: accuracy.toFixed(1) + '%',
avgError: avgError.toFixed(4),
avgPctError: avgPctError.toFixed(2) + '%',
volumeError: pctErrors.volume.toFixed(2) + '%',
direction: directionCorrect ? '✓' : '✗',
timeDiff: Math.abs(predTime - new Date(timestamps[matchIdx]).getTime()) + 'ms',
predicted: {
O: predCandle[0].toFixed(2),
H: predCandle[1].toFixed(2),
L: predCandle[2].toFixed(2),
C: predCandle[3].toFixed(2),
V: predCandle[4].toFixed(2)
},
actual: {
O: actualCandle[0].toFixed(2),
H: actualCandle[1].toFixed(2),
L: actualCandle[2].toFixed(2),
C: actualCandle[3].toFixed(2),
V: actualCandle[4].toFixed(2)
}
});
// Send metrics to backend for training feedback
this._sendPredictionMetrics(timeframe, prediction);
}
});
// Summary log
if (validatedCount > 0) {
const totalPending = predictions.filter(p => !p.accuracy).length;
console.log(`[${timeframe}] Validated ${validatedCount} predictions, ${totalPending} still pending`);
}
}
/**
* Send prediction accuracy metrics to backend for training feedback
*/
_sendPredictionMetrics(timeframe, prediction) {
if (!prediction.accuracy) return;
const metrics = {
timeframe: timeframe,
timestamp: prediction.timestamp,
predicted: prediction.candle, // [O, H, L, C, V]
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
directionCorrect: prediction.accuracy.directionCorrect,
accuracy: prediction.accuracy.accuracy
};
console.log('[Prediction Metrics for Training]', metrics);
// Send to backend via WebSocket for incremental training
if (window.socket && window.socket.connected) {
window.socket.emit('prediction_accuracy', metrics);
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
} else {
console.warn('[Training] WebSocket not connected - metrics not sent to backend');
}
}
/**
@@ -2043,9 +2314,9 @@ class ChartManager {
this.ghostCandleHistory[timeframe] = this.ghostCandleHistory[timeframe].slice(-this.maxGhostCandles);
}
// 4. Add all ghost candles from history to traces
// 4. Add all ghost candles from history to traces (with accuracy if validated)
for (const ghost of this.ghostCandleHistory[timeframe]) {
this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime);
this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime, ghost.accuracy);
}
// 5. Store as "Last Prediction" for shadow rendering
@@ -2057,7 +2328,10 @@ class ChartManager {
inferenceTime: predictionTimestamp
};
console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`);
console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`, {
predicted: candleData,
timestamp: formattedTimestamp
});
}
}
@@ -2097,8 +2371,16 @@ class ChartManager {
Plotly.deleteTraces(plotId, indicesToRemove);
}
// Add new traces
// Add new traces - these will overlay on top of real candles
// Plotly renders traces in order, so predictions added last appear on top
Plotly.addTraces(plotId, predictionTraces);
// Ensure predictions are visible above real candles by setting z-order
// Update layout to ensure prediction traces are on top
Plotly.relayout(plotId, {
'xaxis.showspikes': false,
'yaxis.showspikes': false
});
}
} catch (error) {
@@ -2173,9 +2455,10 @@ class ChartManager {
});
}
_addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null) {
_addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null, accuracy = null) {
// candleData is [Open, High, Low, Close, Volume]
// predictionTimestamp is when the model made this prediction (optional)
// accuracy is the validation metrics (if actual candle has arrived)
// If not provided, we calculate the next candle time
const chart = this.charts[timeframe];
@@ -2215,8 +2498,46 @@ class ChartManager {
const low = candleData[2];
const close = candleData[3];
// Determine color
const color = close >= open ? '#10b981' : '#ef4444';
// Determine color based on validation status
// Ghost candles should be 30% opacity to see real candles underneath
let color, opacity;
if (accuracy) {
// Validated prediction - color by accuracy
if (accuracy.directionCorrect) {
color = close >= open ? '#10b981' : '#ef4444'; // Green/Red
} else {
color = '#fbbf24'; // Yellow for wrong direction
}
opacity = 0.3; // 30% - see real candle underneath
} else {
// Unvalidated prediction
color = close >= open ? '#10b981' : '#ef4444';
opacity = 0.3; // 30% - see real candle underneath
}
// Build rich tooltip text
let tooltipText = `PREDICTED CANDLE<br>`;
tooltipText += `O: ${open.toFixed(2)} H: ${high.toFixed(2)}<br>`;
tooltipText += `L: ${low.toFixed(2)} C: ${close.toFixed(2)}<br>`;
tooltipText += `Direction: ${close >= open ? 'UP' : 'DOWN'}<br>`;
if (accuracy) {
tooltipText += `<br>--- VALIDATION ---<br>`;
tooltipText += `Accuracy: ${accuracy.accuracy.toFixed(1)}%<br>`;
tooltipText += `Direction: ${accuracy.directionCorrect ? 'CORRECT ✓' : 'WRONG ✗'}<br>`;
tooltipText += `Avg Error: ${accuracy.avgPctError.toFixed(2)}%<br>`;
tooltipText += `<br>ACTUAL vs PREDICTED:<br>`;
tooltipText += `Open: ${accuracy.actualCandle[0].toFixed(2)} vs ${open.toFixed(2)} (${accuracy.pctErrors.open.toFixed(2)}%)<br>`;
tooltipText += `High: ${accuracy.actualCandle[1].toFixed(2)} vs ${high.toFixed(2)} (${accuracy.pctErrors.high.toFixed(2)}%)<br>`;
tooltipText += `Low: ${accuracy.actualCandle[2].toFixed(2)} vs ${low.toFixed(2)} (${accuracy.pctErrors.low.toFixed(2)}%)<br>`;
tooltipText += `Close: ${accuracy.actualCandle[3].toFixed(2)} vs ${close.toFixed(2)} (${accuracy.pctErrors.close.toFixed(2)}%)<br>`;
if (accuracy.actualCandle[4] !== undefined && accuracy.pctErrors.volume !== undefined) {
const predVolume = candleData[4];
tooltipText += `Volume: ${accuracy.actualCandle[4].toFixed(2)} vs ${predVolume.toFixed(2)} (${accuracy.pctErrors.volume.toFixed(2)}%)`;
}
} else {
tooltipText += `<br>Status: AWAITING VALIDATION...`;
}
// Create ghost candle trace with formatted timestamp string (same as real candles)
// 150% wider than normal candles
@@ -2236,14 +2557,14 @@ class ChartManager {
line: { color: color, width: 3 }, // 150% wider
fillcolor: color
},
opacity: 0.6, // 60% transparent
hoverinfo: 'x+y+text',
text: ['Predicted Next Candle'],
opacity: opacity,
hoverinfo: 'text',
text: [tooltipText],
width: 1.5 // 150% width multiplier
};
traces.push(ghostTrace);
console.log('Added ghost candle prediction at:', formattedTimestamp, ghostTrace);
console.log('Added ghost candle prediction at:', formattedTimestamp, accuracy ? 'VALIDATED' : 'pending');
}
_addShadowCandlePrediction(candleData, timestamp, traces) {

View File

@@ -1446,33 +1446,39 @@ class TradingTransformerTrainer:
candle_rmse = {}
if 'next_candles' in outputs:
# Use 1m timeframe as primary metric
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
# Use 1s or 1m timeframe as primary metric (try 1s first)
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch:
pred_candle = outputs['next_candles']['1s'] # [batch, 5]
actual_candle = batch['future_candle_1s'] # [batch, 5]
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
else:
pred_candle = None
actual_candle = None
if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
# Calculate RMSE for each OHLCV component
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# Average RMSE for OHLC (exclude volume)
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
# Convert to accuracy: lower RMSE = higher accuracy
# Normalize by price range
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
candle_rmse = {
'open': rmse_open.item(),
'high': rmse_high.item(),
'low': rmse_low.item(),
'close': rmse_close.item(),
'avg': avg_rmse.item()
}
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0