IMPLEMENTED: WIP; realtime candle predictions training
This commit is contained in:
@@ -1095,7 +1095,8 @@ class RealTrainingAdapter:
|
||||
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
# Accuracy calculated from actual training metrics, not synthetic
|
||||
session.accuracy = None # Will be set by training loop if available
|
||||
|
||||
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""Train DQN model with REAL training loop"""
|
||||
@@ -1133,7 +1134,8 @@ class RealTrainingAdapter:
|
||||
raise Exception("DQN agent does not have replay method")
|
||||
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = 0.85 # TODO: Calculate actual accuracy
|
||||
# Accuracy calculated from actual training metrics, not synthetic
|
||||
session.accuracy = None # Will be set by training loop if available
|
||||
|
||||
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
|
||||
"""Build proper state representation from training data"""
|
||||
@@ -2781,6 +2783,68 @@ class RealTrainingAdapter:
|
||||
logger.warning(f"Error fetching market state for candle: {e}")
|
||||
return {}
|
||||
|
||||
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
|
||||
"""
|
||||
Convert a validated prediction to a training batch
|
||||
|
||||
Args:
|
||||
prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc.
|
||||
timeframe: Target timeframe for prediction
|
||||
|
||||
Returns:
|
||||
Batch dict ready for trainer.train_step()
|
||||
"""
|
||||
try:
|
||||
market_state = prediction_sample.get('market_state', {})
|
||||
if not market_state or 'timeframes' not in market_state:
|
||||
logger.warning("No market state in prediction sample")
|
||||
return None
|
||||
|
||||
# Use existing conversion method but with actual target
|
||||
annotation = {
|
||||
'symbol': prediction_sample.get('symbol', 'ETH/USDT'),
|
||||
'timestamp': prediction_sample.get('timestamp'),
|
||||
'action': 'BUY', # Placeholder, not used for candle prediction training
|
||||
'entry_price': float(prediction_sample['predicted_candle'][0]), # Open
|
||||
'market_state': market_state
|
||||
}
|
||||
|
||||
# Convert using existing method
|
||||
batch = self._convert_annotation_to_transformer_batch(annotation)
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# Override the future candle target with actual candle data
|
||||
actual = prediction_sample['actual_candle'] # [O, H, L, C]
|
||||
|
||||
# Create target tensor for the specific timeframe
|
||||
import torch
|
||||
device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
|
||||
|
||||
# Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
|
||||
target_candle = [
|
||||
actual[0], # Open
|
||||
actual[1], # High
|
||||
actual[2], # Low
|
||||
actual[3], # Close
|
||||
prediction_sample['predicted_candle'][4] # Volume (from prediction)
|
||||
]
|
||||
|
||||
# Add to batch based on timeframe
|
||||
if timeframe == '1s':
|
||||
batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
elif timeframe == '1m':
|
||||
batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
elif timeframe == '1h':
|
||||
batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
|
||||
|
||||
logger.debug(f"Converted prediction to batch for {timeframe} timeframe")
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting prediction to batch: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _train_transformer_on_sample(self, training_sample: Dict):
|
||||
"""Train transformer on a single sample with checkpoint saving"""
|
||||
try:
|
||||
|
||||
@@ -2370,6 +2370,55 @@ class AnnotationDashboard:
|
||||
logger.error(f"Error handling prediction request: {e}")
|
||||
emit('prediction_error', {'error': str(e)})
|
||||
|
||||
@self.socketio.on('prediction_accuracy')
|
||||
def handle_prediction_accuracy(data):
|
||||
"""
|
||||
Handle validated prediction accuracy - trigger incremental training
|
||||
|
||||
This is called when frontend validates a prediction against actual candle.
|
||||
We use this data to incrementally train the model for continuous improvement.
|
||||
"""
|
||||
from flask_socketio import emit
|
||||
try:
|
||||
timeframe = data.get('timeframe')
|
||||
timestamp = data.get('timestamp')
|
||||
predicted = data.get('predicted') # [O, H, L, C, V]
|
||||
actual = data.get('actual') # [O, H, L, C]
|
||||
errors = data.get('errors') # {open, high, low, close}
|
||||
pct_errors = data.get('pctErrors')
|
||||
direction_correct = data.get('directionCorrect')
|
||||
accuracy = data.get('accuracy')
|
||||
|
||||
if not all([timeframe, timestamp, predicted, actual]):
|
||||
logger.warning("Incomplete prediction accuracy data received")
|
||||
return
|
||||
|
||||
logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
|
||||
logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
|
||||
|
||||
# Trigger incremental training on this validated prediction
|
||||
self._train_on_validated_prediction(
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
predicted=predicted,
|
||||
actual=actual,
|
||||
errors=errors,
|
||||
direction_correct=direction_correct,
|
||||
accuracy=accuracy
|
||||
)
|
||||
|
||||
# Send confirmation back to frontend
|
||||
emit('training_update', {
|
||||
'status': 'training_triggered',
|
||||
'timestamp': timestamp,
|
||||
'accuracy': accuracy,
|
||||
'message': f'Incremental training triggered on validated prediction'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
|
||||
emit('training_error', {'error': str(e)})
|
||||
|
||||
def _start_live_update_thread(self):
|
||||
"""Start background thread for live updates"""
|
||||
import threading
|
||||
@@ -2392,24 +2441,44 @@ class AnnotationDashboard:
|
||||
for timeframe in ['1s', '1m']:
|
||||
room = f"{symbol}_{timeframe}"
|
||||
|
||||
# Get latest candle
|
||||
# Get latest candles (need last 2 to determine confirmation status)
|
||||
try:
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
|
||||
if candles and len(candles) > 0:
|
||||
latest_candle = candles[-1]
|
||||
|
||||
# Emit chart update
|
||||
# Determine if candle is confirmed (closed)
|
||||
# For 1s: candle is confirmed when next candle starts (2s delay)
|
||||
# For others: candle is confirmed when next candle starts
|
||||
is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
|
||||
|
||||
# Format timestamp consistently
|
||||
timestamp = latest_candle.get('timestamp')
|
||||
if isinstance(timestamp, str):
|
||||
# Already formatted
|
||||
formatted_timestamp = timestamp
|
||||
else:
|
||||
# Convert to ISO string then format
|
||||
from datetime import datetime
|
||||
if isinstance(timestamp, datetime):
|
||||
formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
formatted_timestamp = str(timestamp)
|
||||
|
||||
# Emit chart update with full candle data
|
||||
self.socketio.emit('chart_update', {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'candle': {
|
||||
'timestamp': latest_candle.get('timestamp'),
|
||||
'open': latest_candle.get('open'),
|
||||
'high': latest_candle.get('high'),
|
||||
'low': latest_candle.get('low'),
|
||||
'close': latest_candle.get('close'),
|
||||
'volume': latest_candle.get('volume')
|
||||
}
|
||||
'timestamp': formatted_timestamp,
|
||||
'open': float(latest_candle.get('open', 0)),
|
||||
'high': float(latest_candle.get('high', 0)),
|
||||
'low': float(latest_candle.get('low', 0)),
|
||||
'close': float(latest_candle.get('close', 0)),
|
||||
'volume': float(latest_candle.get('volume', 0))
|
||||
},
|
||||
'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
|
||||
'has_previous': len(candles) >= 2 # True if we have previous candle for validation
|
||||
}, room=room)
|
||||
|
||||
# Get prediction if model is loaded
|
||||
@@ -2430,6 +2499,144 @@ class AnnotationDashboard:
|
||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||
self._live_update_thread.start()
|
||||
|
||||
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
||||
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
||||
"""
|
||||
Incrementally train model on validated prediction
|
||||
|
||||
This implements online learning where each validated prediction becomes
|
||||
a training sample, with loss weighting based on prediction accuracy.
|
||||
"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
logger.warning("Training adapter not available for incremental training")
|
||||
return
|
||||
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.warning("Transformer model not available for incremental training")
|
||||
return
|
||||
|
||||
# Get the transformer trainer
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer:
|
||||
logger.warning("Transformer trainer not available")
|
||||
return
|
||||
|
||||
# Calculate sample weight based on accuracy
|
||||
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
||||
# High accuracy predictions get lower weight (model already knows this)
|
||||
if accuracy < 50:
|
||||
sample_weight = 3.0 # Learn hard from bad predictions
|
||||
elif accuracy < 70:
|
||||
sample_weight = 2.0 # Moderate learning
|
||||
elif accuracy < 85:
|
||||
sample_weight = 1.0 # Normal learning
|
||||
else:
|
||||
sample_weight = 0.5 # Light touch-up for good predictions
|
||||
|
||||
# Also weight by direction correctness
|
||||
if not direction_correct:
|
||||
sample_weight *= 1.5 # Wrong direction is critical - learn more
|
||||
|
||||
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
|
||||
|
||||
# Create training sample from validated prediction
|
||||
# We need to fetch the market state at that timestamp
|
||||
symbol = 'ETH/USDT' # TODO: Get from active trading pair
|
||||
|
||||
training_sample = {
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'predicted_candle': predicted, # [O, H, L, C, V]
|
||||
'actual_candle': actual, # [O, H, L, C]
|
||||
'errors': errors,
|
||||
'accuracy': accuracy,
|
||||
'direction_correct': direction_correct,
|
||||
'sample_weight': sample_weight
|
||||
}
|
||||
|
||||
# Get market state at that timestamp
|
||||
try:
|
||||
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
|
||||
training_sample['market_state'] = market_state
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch market state: {e}")
|
||||
return
|
||||
|
||||
# Convert to transformer batch format
|
||||
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
|
||||
if not batch:
|
||||
logger.warning("Could not convert validated prediction to training batch")
|
||||
return
|
||||
|
||||
# Train on this batch with sample weighting
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
candle_accuracy = result.get('candle_accuracy', 0)
|
||||
|
||||
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
|
||||
# Save checkpoint periodically (every 10 incremental steps)
|
||||
if not hasattr(self, '_incremental_training_steps'):
|
||||
self._incremental_training_steps = 0
|
||||
|
||||
self._incremental_training_steps += 1
|
||||
|
||||
if self._incremental_training_steps % 10 == 0:
|
||||
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
||||
trainer.save_checkpoint(
|
||||
filepath=None, # Auto-generate path
|
||||
metadata={
|
||||
'training_type': 'incremental_online',
|
||||
'steps': self._incremental_training_steps,
|
||||
'last_accuracy': accuracy
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||
|
||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||
"""Fetch market state at a specific timestamp for training"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
# Parse timestamp
|
||||
ts = pd.Timestamp(timestamp)
|
||||
|
||||
# Get historical data for multiple timeframes
|
||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||
|
||||
for tf in ['1s', '1m', '1h']:
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
|
||||
if df is not None and not df.empty:
|
||||
# Find data up to (but not including) the target timestamp
|
||||
df_before = df[df.index < ts]
|
||||
if not df_before.empty:
|
||||
recent = df_before.tail(200)
|
||||
market_state['timeframes'][tf] = {
|
||||
'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': recent['open'].tolist(),
|
||||
'high': recent['high'].tolist(),
|
||||
'low': recent['low'].tolist(),
|
||||
'close': recent['close'].tolist(),
|
||||
'volume': recent['volume'].tolist()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch {tf} data: {e}")
|
||||
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
return {}
|
||||
|
||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||
"""Get live prediction from model"""
|
||||
try:
|
||||
|
||||
@@ -15,8 +15,8 @@ class ChartManager {
|
||||
this.lastPredictionUpdate = {}; // Track last prediction update per timeframe
|
||||
this.predictionUpdateThrottle = 500; // Min ms between prediction updates
|
||||
this.lastPredictionHash = null; // Track if predictions actually changed
|
||||
this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 10 each)
|
||||
this.maxGhostCandles = 10; // Maximum number of ghost candles to keep
|
||||
this.ghostCandleHistory = {}; // Store ghost candles per timeframe (max 50 each)
|
||||
this.maxGhostCandles = 150; // Maximum number of ghost candles to keep
|
||||
|
||||
// Helper to ensure all timestamps are in UTC
|
||||
this.normalizeTimestamp = (timestamp) => {
|
||||
@@ -264,15 +264,43 @@ class ChartManager {
|
||||
*/
|
||||
updateLatestCandle(symbol, timeframe, candle) {
|
||||
try {
|
||||
const plotId = `plot-${timeframe}`;
|
||||
const plotElement = document.getElementById(plotId);
|
||||
|
||||
if (!plotElement) {
|
||||
console.debug(`Chart ${plotId} not found for live update`);
|
||||
const chart = this.charts[timeframe];
|
||||
if (!chart) {
|
||||
console.debug(`Chart ${timeframe} not found for live update`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get current chart data
|
||||
const plotId = chart.plotId;
|
||||
const plotElement = document.getElementById(plotId);
|
||||
|
||||
if (!plotElement) {
|
||||
console.debug(`Plot element ${plotId} not found`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure chart.data exists
|
||||
if (!chart.data) {
|
||||
chart.data = {
|
||||
timestamps: [],
|
||||
open: [],
|
||||
high: [],
|
||||
low: [],
|
||||
close: [],
|
||||
volume: []
|
||||
};
|
||||
}
|
||||
|
||||
// Parse timestamp - format to match chart data format
|
||||
const candleTimestamp = new Date(candle.timestamp);
|
||||
const year = candleTimestamp.getUTCFullYear();
|
||||
const month = String(candleTimestamp.getUTCMonth() + 1).padStart(2, '0');
|
||||
const day = String(candleTimestamp.getUTCDate()).padStart(2, '0');
|
||||
const hours = String(candleTimestamp.getUTCHours()).padStart(2, '0');
|
||||
const minutes = String(candleTimestamp.getUTCMinutes()).padStart(2, '0');
|
||||
const seconds = String(candleTimestamp.getUTCSeconds()).padStart(2, '0');
|
||||
const formattedTimestamp = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;
|
||||
|
||||
// Get current chart data from Plotly
|
||||
const chartData = Plotly.Plots.data(plotId);
|
||||
if (!chartData || chartData.length < 2) {
|
||||
console.debug(`Chart ${plotId} not initialized yet`);
|
||||
@@ -282,17 +310,14 @@ class ChartManager {
|
||||
const candlestickTrace = chartData[0];
|
||||
const volumeTrace = chartData[1];
|
||||
|
||||
// Parse timestamp
|
||||
const candleTimestamp = new Date(candle.timestamp);
|
||||
|
||||
// Check if this is updating the last candle or adding a new one
|
||||
const lastTimestamp = candlestickTrace.x[candlestickTrace.x.length - 1];
|
||||
const isNewCandle = !lastTimestamp || new Date(lastTimestamp).getTime() < candleTimestamp.getTime();
|
||||
|
||||
if (isNewCandle) {
|
||||
// Add new candle using extendTraces (most efficient)
|
||||
// Add new candle - update both Plotly and internal data structure
|
||||
Plotly.extendTraces(plotId, {
|
||||
x: [[candleTimestamp]],
|
||||
x: [[formattedTimestamp]],
|
||||
open: [[candle.open]],
|
||||
high: [[candle.high]],
|
||||
low: [[candle.low]],
|
||||
@@ -302,27 +327,34 @@ class ChartManager {
|
||||
// Update volume color based on price direction
|
||||
const volumeColor = candle.close >= candle.open ? '#10b981' : '#ef4444';
|
||||
Plotly.extendTraces(plotId, {
|
||||
x: [[candleTimestamp]],
|
||||
x: [[formattedTimestamp]],
|
||||
y: [[candle.volume]],
|
||||
marker: { color: [[volumeColor]] }
|
||||
}, [1]);
|
||||
} else {
|
||||
// Update last candle using restyle - simpler approach for updating single point
|
||||
// We need to get the full arrays, modify last element, and send back
|
||||
// This is less efficient but more reliable for updates than complex index logic
|
||||
|
||||
const x = candlestickTrace.x;
|
||||
const open = candlestickTrace.open;
|
||||
const high = candlestickTrace.high;
|
||||
const low = candlestickTrace.low;
|
||||
const close = candlestickTrace.close;
|
||||
const volume = volumeTrace.y;
|
||||
const colors = volumeTrace.marker.color;
|
||||
// Update internal data structure
|
||||
chart.data.timestamps.push(formattedTimestamp);
|
||||
chart.data.open.push(candle.open);
|
||||
chart.data.high.push(candle.high);
|
||||
chart.data.low.push(candle.low);
|
||||
chart.data.close.push(candle.close);
|
||||
chart.data.volume.push(candle.volume);
|
||||
|
||||
console.log(`[${timeframe}] Added new candle: ${formattedTimestamp}`);
|
||||
} else {
|
||||
// Update last candle - update both Plotly and internal data structure
|
||||
const x = [...candlestickTrace.x];
|
||||
const open = [...candlestickTrace.open];
|
||||
const high = [...candlestickTrace.high];
|
||||
const low = [...candlestickTrace.low];
|
||||
const close = [...candlestickTrace.close];
|
||||
const volume = [...volumeTrace.y];
|
||||
const colors = Array.isArray(volumeTrace.marker.color) ? [...volumeTrace.marker.color] : [volumeTrace.marker.color];
|
||||
|
||||
const lastIdx = x.length - 1;
|
||||
|
||||
// Update local arrays
|
||||
x[lastIdx] = candleTimestamp;
|
||||
x[lastIdx] = formattedTimestamp;
|
||||
open[lastIdx] = candle.open;
|
||||
high[lastIdx] = candle.high;
|
||||
low[lastIdx] = candle.low;
|
||||
@@ -344,9 +376,55 @@ class ChartManager {
|
||||
y: [volume],
|
||||
'marker.color': [colors]
|
||||
}, [1]);
|
||||
|
||||
// Update internal data structure
|
||||
if (chart.data.timestamps.length > lastIdx) {
|
||||
chart.data.timestamps[lastIdx] = formattedTimestamp;
|
||||
chart.data.open[lastIdx] = candle.open;
|
||||
chart.data.high[lastIdx] = candle.high;
|
||||
chart.data.low[lastIdx] = candle.low;
|
||||
chart.data.close[lastIdx] = candle.close;
|
||||
chart.data.volume[lastIdx] = candle.volume;
|
||||
}
|
||||
|
||||
console.log(`[${timeframe}] Updated last candle: ${formattedTimestamp}`);
|
||||
}
|
||||
|
||||
console.debug(`Updated ${timeframe} chart with new candle at ${candleTimestamp.toISOString()}`);
|
||||
// CRITICAL: Check if we have enough candles to validate predictions (2s delay logic)
|
||||
// For 1s timeframe: validate against candle[-2] (last confirmed), overlay on candle[-1] (currently forming)
|
||||
// For other timeframes: validate against candle[-1] when it's confirmed
|
||||
if (chart.data.timestamps.length >= 2) {
|
||||
// Determine which candle to validate against based on timeframe
|
||||
let validationCandleIdx = -1;
|
||||
|
||||
if (timeframe === '1s') {
|
||||
// 2s delay: validate against candle[-2] (last confirmed)
|
||||
// This candle was closed 1-2 seconds ago
|
||||
validationCandleIdx = chart.data.timestamps.length - 2;
|
||||
} else {
|
||||
// For longer timeframes, validate against last candle when it's confirmed
|
||||
// A candle is confirmed when a new one starts forming
|
||||
validationCandleIdx = isNewCandle ? chart.data.timestamps.length - 2 : -1;
|
||||
}
|
||||
|
||||
if (validationCandleIdx >= 0 && validationCandleIdx < chart.data.timestamps.length) {
|
||||
// Create validation data structure for the confirmed candle
|
||||
const validationData = {
|
||||
timestamps: [chart.data.timestamps[validationCandleIdx]],
|
||||
open: [chart.data.open[validationCandleIdx]],
|
||||
high: [chart.data.high[validationCandleIdx]],
|
||||
low: [chart.data.low[validationCandleIdx]],
|
||||
close: [chart.data.close[validationCandleIdx]],
|
||||
volume: [chart.data.volume[validationCandleIdx]]
|
||||
};
|
||||
|
||||
// Trigger validation check
|
||||
console.log(`[${timeframe}] Checking validation for confirmed candle at index ${validationCandleIdx}`);
|
||||
this._checkPredictionAccuracy(timeframe, validationData);
|
||||
}
|
||||
}
|
||||
|
||||
console.debug(`Updated ${timeframe} chart with candle at ${formattedTimestamp}`);
|
||||
} catch (error) {
|
||||
console.error(`Error updating latest candle for ${timeframe}:`, error);
|
||||
}
|
||||
@@ -1873,6 +1951,199 @@ class ChartManager {
|
||||
Plotly.react(plotId, updatedTraces, plotElement.layout, plotElement.config);
|
||||
|
||||
console.log(`Updated ${timeframe} chart with ${data.timestamps.length} candles`);
|
||||
|
||||
// Check if any ghost predictions match new actual candles and calculate accuracy
|
||||
this._checkPredictionAccuracy(timeframe, data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate prediction accuracy by comparing ghost predictions with actual candles
|
||||
*/
|
||||
_checkPredictionAccuracy(timeframe, actualData) {
|
||||
if (!this.ghostCandleHistory || !this.ghostCandleHistory[timeframe]) return;
|
||||
|
||||
const predictions = this.ghostCandleHistory[timeframe];
|
||||
const timestamps = actualData.timestamps;
|
||||
const opens = actualData.open;
|
||||
const highs = actualData.high;
|
||||
const lows = actualData.low;
|
||||
const closes = actualData.close;
|
||||
|
||||
// Determine tolerance based on timeframe
|
||||
let tolerance;
|
||||
if (timeframe === '1s') {
|
||||
tolerance = 2000; // 2 seconds for 1s charts
|
||||
} else if (timeframe === '1m') {
|
||||
tolerance = 60000; // 60 seconds for 1m charts
|
||||
} else if (timeframe === '1h') {
|
||||
tolerance = 3600000; // 1 hour for hourly charts
|
||||
} else {
|
||||
tolerance = 5000; // 5 seconds default
|
||||
}
|
||||
|
||||
// Check each prediction against actual candles
|
||||
let validatedCount = 0;
|
||||
predictions.forEach((prediction, idx) => {
|
||||
// Skip if already validated
|
||||
if (prediction.accuracy) return;
|
||||
|
||||
// Try multiple matching strategies
|
||||
let matchIdx = -1;
|
||||
|
||||
// Use standard Date object if available, otherwise parse timestamp string
|
||||
// Prioritize targetTime as it's the raw Date object set during prediction creation
|
||||
const predTime = prediction.targetTime ? prediction.targetTime.getTime() : new Date(prediction.timestamp).getTime();
|
||||
|
||||
// Strategy 1: Find exact or very close match
|
||||
matchIdx = timestamps.findIndex(ts => {
|
||||
const actualTime = new Date(ts).getTime();
|
||||
return Math.abs(predTime - actualTime) < tolerance;
|
||||
});
|
||||
|
||||
// Strategy 2: If no match, find the next candle after prediction
|
||||
if (matchIdx < 0) {
|
||||
matchIdx = timestamps.findIndex(ts => {
|
||||
const actualTime = new Date(ts).getTime();
|
||||
return actualTime >= predTime && actualTime < predTime + tolerance * 2;
|
||||
});
|
||||
}
|
||||
|
||||
// Debug logging for unmatched predictions
|
||||
if (matchIdx < 0) {
|
||||
// Parse both timestamps to compare
|
||||
const predTimeParsed = new Date(prediction.timestamp);
|
||||
const latestActual = new Date(timestamps[timestamps.length - 1]);
|
||||
|
||||
if (idx < 3) { // Only log first 3 to avoid spam
|
||||
console.log(`[${timeframe}] No match for prediction:`, {
|
||||
predTimestamp: prediction.timestamp,
|
||||
predTime: predTimeParsed.toISOString(),
|
||||
latestActual: latestActual.toISOString(),
|
||||
timeDiff: (latestActual - predTimeParsed) + 'ms',
|
||||
tolerance: tolerance + 'ms',
|
||||
availableTimestamps: timestamps.slice(-3) // Last 3 actual timestamps
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (matchIdx >= 0) {
|
||||
// Found matching actual candle - calculate accuracy INCLUDING VOLUME
|
||||
const predCandle = prediction.candle; // [O, H, L, C, V]
|
||||
const actualCandle = [
|
||||
opens[matchIdx],
|
||||
highs[matchIdx],
|
||||
lows[matchIdx],
|
||||
closes[matchIdx],
|
||||
actualData.volume ? actualData.volume[matchIdx] : predCandle[4] // Get actual volume if available
|
||||
];
|
||||
|
||||
// Calculate absolute errors for O, H, L, C, V
|
||||
const errors = {
|
||||
open: Math.abs(predCandle[0] - actualCandle[0]),
|
||||
high: Math.abs(predCandle[1] - actualCandle[1]),
|
||||
low: Math.abs(predCandle[2] - actualCandle[2]),
|
||||
close: Math.abs(predCandle[3] - actualCandle[3]),
|
||||
volume: Math.abs(predCandle[4] - actualCandle[4])
|
||||
};
|
||||
|
||||
// Calculate percentage errors for O, H, L, C, V
|
||||
const pctErrors = {
|
||||
open: (errors.open / actualCandle[0]) * 100,
|
||||
high: (errors.high / actualCandle[1]) * 100,
|
||||
low: (errors.low / actualCandle[2]) * 100,
|
||||
close: (errors.close / actualCandle[3]) * 100,
|
||||
volume: actualCandle[4] > 0 ? (errors.volume / actualCandle[4]) * 100 : 0
|
||||
};
|
||||
|
||||
// Average error (OHLC only, volume separate due to different scale)
|
||||
const avgError = (errors.open + errors.high + errors.low + errors.close) / 4;
|
||||
const avgPctError = (pctErrors.open + pctErrors.high + pctErrors.low + pctErrors.close) / 4;
|
||||
|
||||
// Direction accuracy (did we predict up/down correctly?)
|
||||
const predDirection = predCandle[3] >= predCandle[0] ? 'up' : 'down';
|
||||
const actualDirection = actualCandle[3] >= actualCandle[0] ? 'up' : 'down';
|
||||
const directionCorrect = predDirection === actualDirection;
|
||||
|
||||
// Price range accuracy
|
||||
const priceRange = actualCandle[1] - actualCandle[2]; // High - Low
|
||||
const accuracy = Math.max(0, 1 - (avgError / priceRange)) * 100;
|
||||
|
||||
// Store accuracy metrics
|
||||
prediction.accuracy = {
|
||||
errors: errors,
|
||||
pctErrors: pctErrors,
|
||||
avgError: avgError,
|
||||
avgPctError: avgPctError,
|
||||
directionCorrect: directionCorrect,
|
||||
accuracy: accuracy,
|
||||
actualCandle: actualCandle,
|
||||
validatedAt: new Date().toISOString()
|
||||
};
|
||||
|
||||
validatedCount++;
|
||||
console.log(`[${timeframe}] Prediction validated (#${validatedCount}):`, {
|
||||
timestamp: prediction.timestamp,
|
||||
matchedTo: timestamps[matchIdx],
|
||||
accuracy: accuracy.toFixed(1) + '%',
|
||||
avgError: avgError.toFixed(4),
|
||||
avgPctError: avgPctError.toFixed(2) + '%',
|
||||
volumeError: pctErrors.volume.toFixed(2) + '%',
|
||||
direction: directionCorrect ? '✓' : '✗',
|
||||
timeDiff: Math.abs(predTime - new Date(timestamps[matchIdx]).getTime()) + 'ms',
|
||||
predicted: {
|
||||
O: predCandle[0].toFixed(2),
|
||||
H: predCandle[1].toFixed(2),
|
||||
L: predCandle[2].toFixed(2),
|
||||
C: predCandle[3].toFixed(2),
|
||||
V: predCandle[4].toFixed(2)
|
||||
},
|
||||
actual: {
|
||||
O: actualCandle[0].toFixed(2),
|
||||
H: actualCandle[1].toFixed(2),
|
||||
L: actualCandle[2].toFixed(2),
|
||||
C: actualCandle[3].toFixed(2),
|
||||
V: actualCandle[4].toFixed(2)
|
||||
}
|
||||
});
|
||||
|
||||
// Send metrics to backend for training feedback
|
||||
this._sendPredictionMetrics(timeframe, prediction);
|
||||
}
|
||||
});
|
||||
|
||||
// Summary log
|
||||
if (validatedCount > 0) {
|
||||
const totalPending = predictions.filter(p => !p.accuracy).length;
|
||||
console.log(`[${timeframe}] Validated ${validatedCount} predictions, ${totalPending} still pending`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send prediction accuracy metrics to backend for training feedback
|
||||
*/
|
||||
_sendPredictionMetrics(timeframe, prediction) {
|
||||
if (!prediction.accuracy) return;
|
||||
|
||||
const metrics = {
|
||||
timeframe: timeframe,
|
||||
timestamp: prediction.timestamp,
|
||||
predicted: prediction.candle, // [O, H, L, C, V]
|
||||
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
|
||||
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
|
||||
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
|
||||
directionCorrect: prediction.accuracy.directionCorrect,
|
||||
accuracy: prediction.accuracy.accuracy
|
||||
};
|
||||
|
||||
console.log('[Prediction Metrics for Training]', metrics);
|
||||
|
||||
// Send to backend via WebSocket for incremental training
|
||||
if (window.socket && window.socket.connected) {
|
||||
window.socket.emit('prediction_accuracy', metrics);
|
||||
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
|
||||
} else {
|
||||
console.warn('[Training] WebSocket not connected - metrics not sent to backend');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2043,9 +2314,9 @@ class ChartManager {
|
||||
this.ghostCandleHistory[timeframe] = this.ghostCandleHistory[timeframe].slice(-this.maxGhostCandles);
|
||||
}
|
||||
|
||||
// 4. Add all ghost candles from history to traces
|
||||
// 4. Add all ghost candles from history to traces (with accuracy if validated)
|
||||
for (const ghost of this.ghostCandleHistory[timeframe]) {
|
||||
this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime);
|
||||
this._addGhostCandlePrediction(ghost.candle, timeframe, predictionTraces, ghost.targetTime, ghost.accuracy);
|
||||
}
|
||||
|
||||
// 5. Store as "Last Prediction" for shadow rendering
|
||||
@@ -2057,7 +2328,10 @@ class ChartManager {
|
||||
inferenceTime: predictionTimestamp
|
||||
};
|
||||
|
||||
console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`);
|
||||
console.log(`[${timeframe}] Ghost candle added (${this.ghostCandleHistory[timeframe].length}/${this.maxGhostCandles}) at ${targetTimestamp.toISOString()}`, {
|
||||
predicted: candleData,
|
||||
timestamp: formattedTimestamp
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2097,8 +2371,16 @@ class ChartManager {
|
||||
Plotly.deleteTraces(plotId, indicesToRemove);
|
||||
}
|
||||
|
||||
// Add new traces
|
||||
// Add new traces - these will overlay on top of real candles
|
||||
// Plotly renders traces in order, so predictions added last appear on top
|
||||
Plotly.addTraces(plotId, predictionTraces);
|
||||
|
||||
// Ensure predictions are visible above real candles by setting z-order
|
||||
// Update layout to ensure prediction traces are on top
|
||||
Plotly.relayout(plotId, {
|
||||
'xaxis.showspikes': false,
|
||||
'yaxis.showspikes': false
|
||||
});
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
@@ -2173,9 +2455,10 @@ class ChartManager {
|
||||
});
|
||||
}
|
||||
|
||||
_addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null) {
|
||||
_addGhostCandlePrediction(candleData, timeframe, traces, predictionTimestamp = null, accuracy = null) {
|
||||
// candleData is [Open, High, Low, Close, Volume]
|
||||
// predictionTimestamp is when the model made this prediction (optional)
|
||||
// accuracy is the validation metrics (if actual candle has arrived)
|
||||
// If not provided, we calculate the next candle time
|
||||
|
||||
const chart = this.charts[timeframe];
|
||||
@@ -2215,8 +2498,46 @@ class ChartManager {
|
||||
const low = candleData[2];
|
||||
const close = candleData[3];
|
||||
|
||||
// Determine color
|
||||
const color = close >= open ? '#10b981' : '#ef4444';
|
||||
// Determine color based on validation status
|
||||
// Ghost candles should be 30% opacity to see real candles underneath
|
||||
let color, opacity;
|
||||
if (accuracy) {
|
||||
// Validated prediction - color by accuracy
|
||||
if (accuracy.directionCorrect) {
|
||||
color = close >= open ? '#10b981' : '#ef4444'; // Green/Red
|
||||
} else {
|
||||
color = '#fbbf24'; // Yellow for wrong direction
|
||||
}
|
||||
opacity = 0.3; // 30% - see real candle underneath
|
||||
} else {
|
||||
// Unvalidated prediction
|
||||
color = close >= open ? '#10b981' : '#ef4444';
|
||||
opacity = 0.3; // 30% - see real candle underneath
|
||||
}
|
||||
|
||||
// Build rich tooltip text
|
||||
let tooltipText = `PREDICTED CANDLE<br>`;
|
||||
tooltipText += `O: ${open.toFixed(2)} H: ${high.toFixed(2)}<br>`;
|
||||
tooltipText += `L: ${low.toFixed(2)} C: ${close.toFixed(2)}<br>`;
|
||||
tooltipText += `Direction: ${close >= open ? 'UP' : 'DOWN'}<br>`;
|
||||
|
||||
if (accuracy) {
|
||||
tooltipText += `<br>--- VALIDATION ---<br>`;
|
||||
tooltipText += `Accuracy: ${accuracy.accuracy.toFixed(1)}%<br>`;
|
||||
tooltipText += `Direction: ${accuracy.directionCorrect ? 'CORRECT ✓' : 'WRONG ✗'}<br>`;
|
||||
tooltipText += `Avg Error: ${accuracy.avgPctError.toFixed(2)}%<br>`;
|
||||
tooltipText += `<br>ACTUAL vs PREDICTED:<br>`;
|
||||
tooltipText += `Open: ${accuracy.actualCandle[0].toFixed(2)} vs ${open.toFixed(2)} (${accuracy.pctErrors.open.toFixed(2)}%)<br>`;
|
||||
tooltipText += `High: ${accuracy.actualCandle[1].toFixed(2)} vs ${high.toFixed(2)} (${accuracy.pctErrors.high.toFixed(2)}%)<br>`;
|
||||
tooltipText += `Low: ${accuracy.actualCandle[2].toFixed(2)} vs ${low.toFixed(2)} (${accuracy.pctErrors.low.toFixed(2)}%)<br>`;
|
||||
tooltipText += `Close: ${accuracy.actualCandle[3].toFixed(2)} vs ${close.toFixed(2)} (${accuracy.pctErrors.close.toFixed(2)}%)<br>`;
|
||||
if (accuracy.actualCandle[4] !== undefined && accuracy.pctErrors.volume !== undefined) {
|
||||
const predVolume = candleData[4];
|
||||
tooltipText += `Volume: ${accuracy.actualCandle[4].toFixed(2)} vs ${predVolume.toFixed(2)} (${accuracy.pctErrors.volume.toFixed(2)}%)`;
|
||||
}
|
||||
} else {
|
||||
tooltipText += `<br>Status: AWAITING VALIDATION...`;
|
||||
}
|
||||
|
||||
// Create ghost candle trace with formatted timestamp string (same as real candles)
|
||||
// 150% wider than normal candles
|
||||
@@ -2236,14 +2557,14 @@ class ChartManager {
|
||||
line: { color: color, width: 3 }, // 150% wider
|
||||
fillcolor: color
|
||||
},
|
||||
opacity: 0.6, // 60% transparent
|
||||
hoverinfo: 'x+y+text',
|
||||
text: ['Predicted Next Candle'],
|
||||
opacity: opacity,
|
||||
hoverinfo: 'text',
|
||||
text: [tooltipText],
|
||||
width: 1.5 // 150% width multiplier
|
||||
};
|
||||
|
||||
traces.push(ghostTrace);
|
||||
console.log('Added ghost candle prediction at:', formattedTimestamp, ghostTrace);
|
||||
console.log('Added ghost candle prediction at:', formattedTimestamp, accuracy ? 'VALIDATED' : 'pending');
|
||||
}
|
||||
|
||||
_addShadowCandlePrediction(candleData, timestamp, traces) {
|
||||
|
||||
@@ -1446,33 +1446,39 @@ class TradingTransformerTrainer:
|
||||
candle_rmse = {}
|
||||
|
||||
if 'next_candles' in outputs:
|
||||
# Use 1m timeframe as primary metric
|
||||
if '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
||||
# Use 1s or 1m timeframe as primary metric (try 1s first)
|
||||
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch:
|
||||
pred_candle = outputs['next_candles']['1s'] # [batch, 5]
|
||||
actual_candle = batch['future_candle_1s'] # [batch, 5]
|
||||
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
|
||||
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
actual_candle = batch['future_candle_1m'] # [batch, 5]
|
||||
else:
|
||||
pred_candle = None
|
||||
actual_candle = None
|
||||
|
||||
if actual_candle is not None and pred_candle.shape == actual_candle.shape:
|
||||
# Calculate RMSE for each OHLCV component
|
||||
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
||||
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
||||
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
||||
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
||||
if actual_candle is not None and pred_candle is not None and pred_candle.shape == actual_candle.shape:
|
||||
# Calculate RMSE for each OHLCV component
|
||||
rmse_open = torch.sqrt(torch.mean((pred_candle[:, 0] - actual_candle[:, 0])**2) + 1e-8)
|
||||
rmse_high = torch.sqrt(torch.mean((pred_candle[:, 1] - actual_candle[:, 1])**2) + 1e-8)
|
||||
rmse_low = torch.sqrt(torch.mean((pred_candle[:, 2] - actual_candle[:, 2])**2) + 1e-8)
|
||||
rmse_close = torch.sqrt(torch.mean((pred_candle[:, 3] - actual_candle[:, 3])**2) + 1e-8)
|
||||
|
||||
# Average RMSE for OHLC (exclude volume)
|
||||
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
||||
# Average RMSE for OHLC (exclude volume)
|
||||
avg_rmse = (rmse_open + rmse_high + rmse_low + rmse_close) / 4
|
||||
|
||||
# Convert to accuracy: lower RMSE = higher accuracy
|
||||
# Normalize by price range
|
||||
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
||||
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
||||
# Convert to accuracy: lower RMSE = higher accuracy
|
||||
# Normalize by price range
|
||||
price_range = torch.clamp(actual_candle[:, 1].max() - actual_candle[:, 2].min(), min=1e-8)
|
||||
candle_accuracy = (1.0 - torch.clamp(avg_rmse / price_range, 0, 1)).item()
|
||||
|
||||
candle_rmse = {
|
||||
'open': rmse_open.item(),
|
||||
'high': rmse_high.item(),
|
||||
'low': rmse_low.item(),
|
||||
'close': rmse_close.item(),
|
||||
'avg': avg_rmse.item()
|
||||
}
|
||||
candle_rmse = {
|
||||
'open': rmse_open.item(),
|
||||
'high': rmse_high.item(),
|
||||
'low': rmse_low.item(),
|
||||
'close': rmse_close.item(),
|
||||
'avg': avg_rmse.item()
|
||||
}
|
||||
|
||||
# SECONDARY: Trend vector prediction accuracy
|
||||
trend_accuracy = 0.0
|
||||
|
||||
Reference in New Issue
Block a user