ghost T predictions plotting on chart live chart updates
This commit is contained in:
@@ -65,14 +65,26 @@ class HistoricalDataLoader:
|
||||
|
||||
# Check memory cache first (exclude direction from cache key for infinite scroll)
|
||||
cache_key = f"{symbol}_{timeframe}_{start_time}_{end_time}_{limit}"
|
||||
|
||||
# Determine TTL based on timeframe
|
||||
current_ttl = self.cache_ttl
|
||||
if timeframe == '1s':
|
||||
current_ttl = timedelta(seconds=1)
|
||||
elif timeframe == '1m':
|
||||
current_ttl = timedelta(seconds=5)
|
||||
|
||||
if cache_key in self.memory_cache and direction == 'latest':
|
||||
cached_data, cached_time = self.memory_cache[cache_key]
|
||||
if datetime.now() - cached_time < self.cache_ttl:
|
||||
if datetime.now() - cached_time < current_ttl:
|
||||
# For 1s/1m, we want to return immediately if valid
|
||||
if timeframe not in ['1s', '1m']:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.debug(f"⚡ Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
# FORCE refresh for 1s/1m if requesting latest data
|
||||
force_refresh = (timeframe in ['1s', '1m'] and not start_time and not end_time)
|
||||
# Try to get data from DataProvider's cached data first (most efficient)
|
||||
if hasattr(self.data_provider, 'cached_data'):
|
||||
with self.data_provider.data_lock:
|
||||
@@ -215,7 +227,7 @@ class HistoricalDataLoader:
|
||||
return None
|
||||
|
||||
# Fallback: Use DataProvider for latest data (startup mode or no time range)
|
||||
if self.startup_mode and not (start_time or end_time):
|
||||
if self.startup_mode and not (start_time or end_time) and not force_refresh:
|
||||
logger.info(f"Loading data for {symbol} {timeframe} (startup mode: allow stale cache)")
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
@@ -225,7 +237,12 @@ class HistoricalDataLoader:
|
||||
)
|
||||
else:
|
||||
# Fetch from API and store in DuckDB (no time range specified)
|
||||
# For 1s/1m, logging every request is too verbose, use debug
|
||||
if timeframe in ['1s', '1m']:
|
||||
logger.debug(f"Fetching latest data from API for {symbol} {timeframe}")
|
||||
else:
|
||||
logger.info(f"Fetching latest data from API for {symbol} {timeframe}")
|
||||
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
|
||||
@@ -16,12 +16,13 @@ import uuid
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import pytz
|
||||
@@ -2488,7 +2489,7 @@ class RealTrainingAdapter:
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if trainer and trainer.model:
|
||||
# Get recent market data
|
||||
market_data = self._get_realtime_market_data(symbol, data_provider)
|
||||
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
if not market_data:
|
||||
return None
|
||||
|
||||
@@ -2507,41 +2508,162 @@ class RealTrainingAdapter:
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||
|
||||
# Handle predicted candles - DENORMALIZE them
|
||||
predicted_candles_raw = {}
|
||||
if 'next_candles' in outputs:
|
||||
for tf, tensor in outputs['next_candles'].items():
|
||||
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
|
||||
|
||||
# Denormalize if we have params
|
||||
predicted_candles_denorm = {}
|
||||
if predicted_candles_raw and norm_params:
|
||||
for tf, raw_candle in predicted_candles_raw.items():
|
||||
# raw_candle is [1, 5] list
|
||||
if tf in norm_params:
|
||||
params = norm_params[tf]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
# Denormalize [Open, High, Low, Close, Volume]
|
||||
# Note: raw_candle[0] is the list of 5 values
|
||||
candle_values = raw_candle[0]
|
||||
|
||||
denorm_candle = [
|
||||
candle_values[0] * (price_max - price_min) + price_min, # Open
|
||||
candle_values[1] * (price_max - price_min) + price_min, # High
|
||||
candle_values[2] * (price_max - price_min) + price_min, # Low
|
||||
candle_values[3] * (price_max - price_min) + price_min, # Close
|
||||
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
||||
]
|
||||
predicted_candles_denorm[tf] = denorm_candle
|
||||
|
||||
# Calculate predicted price from candle close
|
||||
predicted_price = None
|
||||
if '1m' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1m'][3] # Close price
|
||||
elif '1s' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1s'][3]
|
||||
elif outputs.get('price_prediction') is not None:
|
||||
# Fallback to price_prediction head if available (normalized)
|
||||
# This would need separate denormalization based on reference price
|
||||
pass
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'confidence': confidence
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_candle': predicted_candles_denorm
|
||||
}
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error making realtime prediction: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _get_realtime_market_data(self, symbol: str, data_provider) -> Dict:
|
||||
"""Get current market data for prediction"""
|
||||
def _get_realtime_market_data(self, symbol: str, data_provider) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Get current market data for prediction AND normalization parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (model_inputs_dict, normalization_params_dict)
|
||||
"""
|
||||
try:
|
||||
# Get recent candles for all timeframes
|
||||
data = {}
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
df = data_provider.get_historical_data(symbol, tf, limit=200)
|
||||
if df is not None and not df.empty:
|
||||
# Convert to tensor format (simplified)
|
||||
import torch
|
||||
import numpy as np
|
||||
norm_params = {}
|
||||
|
||||
candles = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
candles_tensor = torch.tensor(candles, dtype=torch.float32).unsqueeze(0)
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
# Get historical data (raw)
|
||||
# Force refresh for 1s/1m to ensure we have the very latest candle for prediction
|
||||
refresh = tf in ['1s', '1m']
|
||||
df = data_provider.get_historical_data(symbol, tf, limit=600, refresh=refresh)
|
||||
if df is not None and not df.empty:
|
||||
# Extract raw arrays
|
||||
opens = df['open'].values.astype(np.float32)
|
||||
highs = df['high'].values.astype(np.float32)
|
||||
lows = df['low'].values.astype(np.float32)
|
||||
closes = df['close'].values.astype(np.float32)
|
||||
volumes = df['volume'].values.astype(np.float32)
|
||||
|
||||
# Need at least 1 candle
|
||||
if len(closes) == 0:
|
||||
continue
|
||||
|
||||
# Prepare OHLCV for normalization logic
|
||||
# Padding if needed (though limit=600 usually suffices)
|
||||
if len(closes) < 600:
|
||||
pad_len = 600 - len(closes)
|
||||
# Pad with first value
|
||||
opens = np.pad(opens, (pad_len, 0), mode='edge')
|
||||
highs = np.pad(highs, (pad_len, 0), mode='edge')
|
||||
lows = np.pad(lows, (pad_len, 0), mode='edge')
|
||||
closes = np.pad(closes, (pad_len, 0), mode='edge')
|
||||
volumes = np.pad(volumes, (pad_len, 0), mode='edge')
|
||||
else:
|
||||
# Take last 600
|
||||
opens = opens[-600:]
|
||||
highs = highs[-600:]
|
||||
lows = lows[-600:]
|
||||
closes = closes[-600:]
|
||||
volumes = volumes[-600:]
|
||||
|
||||
# Stack OHLCV [seq_len, 5]
|
||||
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
||||
|
||||
# Calculate min/max for normalization
|
||||
price_min = np.min(ohlcv[:, :4])
|
||||
price_max = np.max(ohlcv[:, :4])
|
||||
volume_min = np.min(ohlcv[:, 4])
|
||||
volume_max = np.max(ohlcv[:, 4])
|
||||
|
||||
# Avoid division by zero
|
||||
if price_max == price_min: price_max += 1.0
|
||||
if volume_max == volume_min: volume_max += 1.0
|
||||
|
||||
# Store params for denormalization later
|
||||
norm_params[tf] = {
|
||||
'price_min': float(price_min),
|
||||
'price_max': float(price_max),
|
||||
'volume_min': float(volume_min),
|
||||
'volume_max': float(volume_max)
|
||||
}
|
||||
|
||||
# Normalize in-place
|
||||
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
||||
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
||||
|
||||
# Convert to tensor [1, seq_len, 5]
|
||||
import torch
|
||||
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
||||
data[f'price_data_{tf}'] = candles_tensor
|
||||
|
||||
# Add placeholder data for other inputs
|
||||
# Add placeholder data for other inputs if we have at least one timeframe
|
||||
if data:
|
||||
import torch
|
||||
# Correct shapes based on model expectation
|
||||
# tech_data: [1, 40]
|
||||
# market_data: [1, 30]
|
||||
# cob_data: [1, 600, 100]
|
||||
data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32)
|
||||
data['market_data'] = torch.zeros(1, 30, dtype=torch.float32)
|
||||
data['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32)
|
||||
|
||||
return data if data else None
|
||||
# Move to device if available
|
||||
if hasattr(self.orchestrator, 'device'):
|
||||
device = self.orchestrator.device
|
||||
for k, v in data.items():
|
||||
data[k] = v.to(device)
|
||||
|
||||
return data, norm_params if data else (None, None)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting realtime market data: {e}")
|
||||
return None
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None, None
|
||||
|
||||
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider):
|
||||
"""Train model on new candle when it closes"""
|
||||
@@ -2734,7 +2856,9 @@ class RealTrainingAdapter:
|
||||
'model': model_name,
|
||||
'action': prediction['action'],
|
||||
'confidence': prediction['confidence'],
|
||||
'price': current_price
|
||||
'price': current_price,
|
||||
'predicted_price': prediction.get('predicted_price'),
|
||||
'predicted_candle': prediction.get('predicted_candle')
|
||||
}
|
||||
|
||||
session['signals'].append(signal)
|
||||
|
||||
@@ -27,14 +27,14 @@ class ChartManager {
|
||||
this.autoUpdateEnabled = true;
|
||||
console.log('Starting chart auto-update...');
|
||||
|
||||
// Update 1s chart every 20 seconds
|
||||
// Update 1s chart every 2 seconds (was 20s)
|
||||
if (this.timeframes.includes('1s')) {
|
||||
this.updateTimers['1s'] = setInterval(() => {
|
||||
this.updateChart('1s');
|
||||
}, 20000); // 20 seconds
|
||||
}, 2000); // 2 seconds
|
||||
}
|
||||
|
||||
// Update 1m chart - sync to whole minutes + every 20s
|
||||
// Update 1m chart - sync to whole minutes + every 5s (was 20s)
|
||||
if (this.timeframes.includes('1m')) {
|
||||
// Calculate ms until next whole minute
|
||||
const now = new Date();
|
||||
@@ -44,10 +44,10 @@ class ChartManager {
|
||||
setTimeout(() => {
|
||||
this.updateChart('1m');
|
||||
|
||||
// Then update every 20s
|
||||
// Then update every 5s
|
||||
this.updateTimers['1m'] = setInterval(() => {
|
||||
this.updateChart('1m');
|
||||
}, 20000); // 20 seconds
|
||||
}, 5000); // 5 seconds
|
||||
}, msUntilNextMinute);
|
||||
}
|
||||
|
||||
@@ -1758,6 +1758,7 @@ class ChartManager {
|
||||
// Prepare prediction markers
|
||||
const predictionShapes = [];
|
||||
const predictionAnnotations = [];
|
||||
const predictionTraces = []; // New traces for ghost candles
|
||||
|
||||
// Add DQN predictions (arrows)
|
||||
if (predictions.dqn) {
|
||||
@@ -1769,9 +1770,18 @@ class ChartManager {
|
||||
this._addCNNPrediction(predictions.cnn, predictionShapes, predictionAnnotations);
|
||||
}
|
||||
|
||||
// Add Transformer predictions (star markers with trend lines)
|
||||
// Add Transformer predictions (star markers with trend lines + ghost candles)
|
||||
if (predictions.transformer) {
|
||||
this._addTransformerPrediction(predictions.transformer, predictionShapes, predictionAnnotations);
|
||||
|
||||
// Add ghost candle if available
|
||||
if (predictions.transformer.predicted_candle) {
|
||||
// Check if we have prediction for this timeframe
|
||||
const candleData = predictions.transformer.predicted_candle[timeframe];
|
||||
if (candleData) {
|
||||
this._addGhostCandlePrediction(candleData, timeframe, predictionTraces);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update chart layout with predictions
|
||||
@@ -1782,11 +1792,85 @@ class ChartManager {
|
||||
});
|
||||
}
|
||||
|
||||
// Add prediction traces (ghost candles)
|
||||
if (predictionTraces.length > 0) {
|
||||
// Remove existing ghost traces first (heuristic: traces with name 'Ghost Prediction')
|
||||
const currentTraces = plotElement.data.length;
|
||||
const indicesToRemove = [];
|
||||
for (let i = 0; i < currentTraces; i++) {
|
||||
if (plotElement.data[i].name === 'Ghost Prediction') {
|
||||
indicesToRemove.push(i);
|
||||
}
|
||||
}
|
||||
if (indicesToRemove.length > 0) {
|
||||
Plotly.deleteTraces(plotId, indicesToRemove);
|
||||
}
|
||||
|
||||
// Add new traces
|
||||
Plotly.addTraces(plotId, predictionTraces);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.debug('Error updating predictions:', error);
|
||||
}
|
||||
}
|
||||
|
||||
_addGhostCandlePrediction(candleData, timeframe, traces) {
|
||||
// candleData is [Open, High, Low, Close, Volume]
|
||||
// We need to determine the timestamp for this ghost candle
|
||||
// It should be the NEXT candle after the last one on chart
|
||||
|
||||
const chart = this.charts[timeframe];
|
||||
if (!chart || !chart.data) return;
|
||||
|
||||
const lastTimestamp = new Date(chart.data.timestamps[chart.data.timestamps.length - 1]);
|
||||
let nextTimestamp;
|
||||
|
||||
// Calculate next timestamp based on timeframe
|
||||
if (timeframe === '1s') {
|
||||
nextTimestamp = new Date(lastTimestamp.getTime() + 1000);
|
||||
} else if (timeframe === '1m') {
|
||||
nextTimestamp = new Date(lastTimestamp.getTime() + 60000);
|
||||
} else if (timeframe === '1h') {
|
||||
nextTimestamp = new Date(lastTimestamp.getTime() + 3600000);
|
||||
} else {
|
||||
nextTimestamp = new Date(lastTimestamp.getTime() + 60000); // Default 1m
|
||||
}
|
||||
|
||||
const open = candleData[0];
|
||||
const high = candleData[1];
|
||||
const low = candleData[2];
|
||||
const close = candleData[3];
|
||||
|
||||
// Determine color
|
||||
const color = close >= open ? '#10b981' : '#ef4444';
|
||||
|
||||
// Create ghost candle trace
|
||||
const ghostTrace = {
|
||||
x: [nextTimestamp],
|
||||
open: [open],
|
||||
high: [high],
|
||||
low: [low],
|
||||
close: [close],
|
||||
type: 'candlestick',
|
||||
name: 'Ghost Prediction',
|
||||
increasing: {
|
||||
line: { color: color, width: 1 },
|
||||
fillcolor: color
|
||||
},
|
||||
decreasing: {
|
||||
line: { color: color, width: 1 },
|
||||
fillcolor: color
|
||||
},
|
||||
opacity: 0.3, // 30% transparent
|
||||
hoverinfo: 'x+y+text',
|
||||
text: ['Predicted Next Candle']
|
||||
};
|
||||
|
||||
traces.push(ghostTrace);
|
||||
console.log('Added ghost candle prediction:', ghostTrace);
|
||||
}
|
||||
|
||||
_addDQNPrediction(prediction, shapes, annotations) {
|
||||
const timestamp = new Date(prediction.timestamp || Date.now());
|
||||
const price = prediction.current_price || 0;
|
||||
|
||||
@@ -576,6 +576,11 @@
|
||||
// Start polling for signals
|
||||
startSignalPolling();
|
||||
|
||||
// Start chart auto-update
|
||||
if (window.appState && window.appState.chartManager) {
|
||||
window.appState.chartManager.startAutoUpdate();
|
||||
}
|
||||
|
||||
const trainingMode = data.training_mode || 'inference-only';
|
||||
const modeText = trainingMode === 'per-candle' ? ' with per-candle training' :
|
||||
(trainingMode === 'pivot-based' ? ' with pivot training' : '');
|
||||
@@ -631,6 +636,11 @@
|
||||
// Stop polling
|
||||
stopSignalPolling();
|
||||
|
||||
// Stop chart auto-update
|
||||
if (window.appState && window.appState.chartManager) {
|
||||
window.appState.chartManager.stopAutoUpdate();
|
||||
}
|
||||
|
||||
currentInferenceId = null;
|
||||
showSuccess('Real-time inference stopped');
|
||||
}
|
||||
@@ -932,9 +942,16 @@
|
||||
}
|
||||
updatePredictionHistory();
|
||||
|
||||
// Update chart with signal markers
|
||||
// Update chart with signal markers and predictions
|
||||
if (window.appState && window.appState.chartManager) {
|
||||
displaySignalOnChart(latest);
|
||||
|
||||
// Update ghost candles and other predictions
|
||||
const predictions = {};
|
||||
const modelKey = latest.model ? latest.model.toLowerCase() : 'transformer';
|
||||
predictions[modelKey] = latest;
|
||||
|
||||
window.appState.chartManager.updatePredictions(predictions);
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user