wip
This commit is contained in:
@@ -3264,18 +3264,25 @@ class RealTrainingAdapter:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
||||||
|
|
||||||
# Store prediction for visualization
|
# Store prediction for visualization WITH predicted_candle data for ghost candles
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||||
self.orchestrator.store_transformer_prediction(symbol, {
|
stored_prediction = {
|
||||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||||
'current_price': current_price,
|
'current_price': current_price,
|
||||||
'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
|
'predicted_price': prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99)),
|
||||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||||
'confidence': prediction['confidence'],
|
'confidence': prediction['confidence'],
|
||||||
'action': prediction['action'],
|
'action': prediction['action'],
|
||||||
'horizon_minutes': 10,
|
'horizon_minutes': 10,
|
||||||
'source': 'live_inference'
|
'source': 'live_inference'
|
||||||
})
|
}
|
||||||
|
# Include predicted_candle for ghost candle visualization
|
||||||
|
if 'predicted_candle' in prediction and prediction['predicted_candle']:
|
||||||
|
stored_prediction['predicted_candle'] = prediction['predicted_candle']
|
||||||
|
stored_prediction['next_candles'] = prediction['predicted_candle'] # Alias for compatibility
|
||||||
|
logger.debug(f"Stored prediction with {len(prediction['predicted_candle'])} timeframe candles")
|
||||||
|
|
||||||
|
self.orchestrator.store_transformer_prediction(symbol, stored_prediction)
|
||||||
|
|
||||||
# Per-candle training mode
|
# Per-candle training mode
|
||||||
if train_every_candle:
|
if train_every_candle:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ sys.path.insert(0, str(parent_dir))
|
|||||||
from flask import Flask, render_template, request, jsonify, send_file
|
from flask import Flask, render_template, request, jsonify, send_file
|
||||||
from dash import Dash, html
|
from dash import Dash, html
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Optional, Dict, List, Any
|
from typing import Optional, Dict, List, Any
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -2254,11 +2254,19 @@ class AnnotationDashboard:
|
|||||||
if cnn_preds:
|
if cnn_preds:
|
||||||
predictions['cnn'] = cnn_preds[-1]
|
predictions['cnn'] = cnn_preds[-1]
|
||||||
|
|
||||||
# Transformer predictions
|
# Transformer predictions with next_candles for ghost candles
|
||||||
|
# First check if there are stored predictions from the inference loop
|
||||||
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
||||||
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
||||||
if transformer_preds:
|
if transformer_preds:
|
||||||
|
# Use the most recent stored prediction (from inference loop)
|
||||||
predictions['transformer'] = transformer_preds[-1]
|
predictions['transformer'] = transformer_preds[-1]
|
||||||
|
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
|
||||||
|
else:
|
||||||
|
# Fallback: generate new prediction if no stored predictions
|
||||||
|
transformer_pred = self._get_live_transformer_prediction(symbol)
|
||||||
|
if transformer_pred:
|
||||||
|
predictions['transformer'] = transformer_pred
|
||||||
|
|
||||||
if predictions:
|
if predictions:
|
||||||
response['prediction'] = predictions
|
response['prediction'] = predictions
|
||||||
@@ -2497,6 +2505,110 @@ class AnnotationDashboard:
|
|||||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||||
self._live_update_thread.start()
|
self._live_update_thread.start()
|
||||||
|
|
||||||
|
def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'):
|
||||||
|
"""
|
||||||
|
Generate live transformer prediction with next_candles for ghost candle display
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator:
|
||||||
|
logger.debug("No orchestrator - cannot generate predictions")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not hasattr(self.orchestrator, 'primary_transformer'):
|
||||||
|
logger.debug("Orchestrator has no primary_transformer - enable training first")
|
||||||
|
return None
|
||||||
|
|
||||||
|
transformer = self.orchestrator.primary_transformer
|
||||||
|
if not transformer:
|
||||||
|
logger.debug("primary_transformer is None - model not loaded yet")
|
||||||
|
return None
|
||||||
|
|
||||||
|
transformer.eval()
|
||||||
|
|
||||||
|
# Get recent market data
|
||||||
|
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None
|
||||||
|
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None
|
||||||
|
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None
|
||||||
|
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None
|
||||||
|
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None
|
||||||
|
|
||||||
|
if not price_data_1m or len(price_data_1m) < 10:
|
||||||
|
return None
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
device = next(transformer.parameters()).device
|
||||||
|
|
||||||
|
def ohlcv_to_tensor(data, limit=None):
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
data = data[-limit:] if limit and len(data) > limit else data
|
||||||
|
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
||||||
|
return torch.from_numpy(arr).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
||||||
|
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
|
||||||
|
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
||||||
|
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
||||||
|
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = transformer(**inputs)
|
||||||
|
|
||||||
|
# Extract next_candles
|
||||||
|
next_candles = outputs.get('next_candles', {})
|
||||||
|
if not next_candles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to JSON-serializable format
|
||||||
|
predicted_candle = {}
|
||||||
|
for tf, candle_tensor in next_candles.items():
|
||||||
|
if candle_tensor is not None:
|
||||||
|
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
|
||||||
|
predicted_candle[tf] = candle_values
|
||||||
|
|
||||||
|
current_price = price_data_1m[-1]['close']
|
||||||
|
predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3]
|
||||||
|
price_change = (predicted_1m_close - current_price) / current_price
|
||||||
|
|
||||||
|
if price_change > 0.001:
|
||||||
|
action = 'BUY'
|
||||||
|
elif price_change < -0.001:
|
||||||
|
action = 'SELL'
|
||||||
|
else:
|
||||||
|
action = 'HOLD'
|
||||||
|
|
||||||
|
confidence = 0.7
|
||||||
|
if 'confidence' in outputs:
|
||||||
|
conf_tensor = outputs['confidence']
|
||||||
|
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
|
||||||
|
|
||||||
|
prediction = {
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'symbol': symbol,
|
||||||
|
'action': action,
|
||||||
|
'confidence': confidence,
|
||||||
|
'predicted_price': predicted_1m_close,
|
||||||
|
'current_price': current_price,
|
||||||
|
'price_change': price_change,
|
||||||
|
'predicted_candle': predicted_candle, # This is what frontend needs!
|
||||||
|
'type': 'transformer_prediction'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store for tracking
|
||||||
|
self.orchestrator.store_transformer_prediction(symbol, prediction)
|
||||||
|
|
||||||
|
logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles")
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating live transformer prediction: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
||||||
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1962,6 +1962,16 @@ class CleanTradingDashboard:
|
|||||||
def update_price_chart(n, pivots_value, relayout_data):
|
def update_price_chart(n, pivots_value, relayout_data):
|
||||||
"""Update price chart every second, persisting user zoom/pan"""
|
"""Update price chart every second, persisting user zoom/pan"""
|
||||||
try:
|
try:
|
||||||
|
# Log transformer status on first update
|
||||||
|
if n == 1:
|
||||||
|
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
|
||||||
|
if self.orchestrator.primary_transformer:
|
||||||
|
logger.info("TRANSFORMER MODEL LOADED - Ghost candles should be visible")
|
||||||
|
else:
|
||||||
|
logger.warning("TRANSFORMER MODEL IS NONE - Enable training to load model")
|
||||||
|
else:
|
||||||
|
logger.warning("NO TRANSFORMER AVAILABLE - Enable training first")
|
||||||
|
|
||||||
# Validate and train on predictions every update (once per second)
|
# Validate and train on predictions every update (once per second)
|
||||||
# This checks if any predictions can be validated against real candles
|
# This checks if any predictions can be validated against real candles
|
||||||
self._validate_and_train_on_predictions('ETH/USDT')
|
self._validate_and_train_on_predictions('ETH/USDT')
|
||||||
@@ -3365,6 +3375,7 @@ class CleanTradingDashboard:
|
|||||||
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
|
self._add_ghost_candles_to_chart(fig, symbol, df_main, row) # Add predicted future candles
|
||||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||||
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
||||||
@@ -3683,6 +3694,112 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
||||||
|
|
||||||
|
def _add_ghost_candles_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||||
|
"""Add predicted future candles (ghost candles) to the chart"""
|
||||||
|
try:
|
||||||
|
# Get latest live prediction with next_candles
|
||||||
|
prediction = self._get_live_transformer_prediction_with_next_candles(symbol)
|
||||||
|
|
||||||
|
if not prediction:
|
||||||
|
logger.debug("No transformer prediction available - is training enabled?")
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'next_candles' not in prediction:
|
||||||
|
logger.debug("Prediction exists but has no next_candles data")
|
||||||
|
return
|
||||||
|
|
||||||
|
next_candles = prediction['next_candles']
|
||||||
|
if not next_candles:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the chart timeframe from the dataframe (assume 1s or 1m based on df_main)
|
||||||
|
if df_main.empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Detect timeframe from dataframe index
|
||||||
|
if len(df_main) >= 2:
|
||||||
|
time_diff = (df_main.index[-1] - df_main.index[-2]).total_seconds()
|
||||||
|
if time_diff <= 1.5:
|
||||||
|
chart_timeframe = '1s'
|
||||||
|
elif time_diff <= 65:
|
||||||
|
chart_timeframe = '1m'
|
||||||
|
elif time_diff <= 3900:
|
||||||
|
chart_timeframe = '1h'
|
||||||
|
else:
|
||||||
|
chart_timeframe = '1d'
|
||||||
|
else:
|
||||||
|
chart_timeframe = '1m' # Default
|
||||||
|
|
||||||
|
# Get prediction for this timeframe
|
||||||
|
if chart_timeframe not in next_candles:
|
||||||
|
logger.debug(f"No prediction for {chart_timeframe} timeframe")
|
||||||
|
return
|
||||||
|
|
||||||
|
predicted_ohlcv = next_candles[chart_timeframe]
|
||||||
|
if not predicted_ohlcv or len(predicted_ohlcv) < 5:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate timestamp for predicted candle
|
||||||
|
last_timestamp = df_main.index[-1]
|
||||||
|
timeframe_seconds = {'1s': 1, '1m': 60, '1h': 3600, '1d': 86400}
|
||||||
|
delta_seconds = timeframe_seconds.get(chart_timeframe, 60)
|
||||||
|
predicted_timestamp = last_timestamp + timedelta(seconds=delta_seconds)
|
||||||
|
|
||||||
|
# Extract OHLCV values
|
||||||
|
pred_open, pred_high, pred_low, pred_close, pred_volume = predicted_ohlcv
|
||||||
|
|
||||||
|
# Determine color based on direction
|
||||||
|
is_bullish = pred_close >= pred_open
|
||||||
|
|
||||||
|
# Add ghost candle as semi-transparent candlestick
|
||||||
|
fig.add_trace(
|
||||||
|
go.Candlestick(
|
||||||
|
x=[predicted_timestamp],
|
||||||
|
open=[pred_open],
|
||||||
|
high=[pred_high],
|
||||||
|
low=[pred_low],
|
||||||
|
close=[pred_close],
|
||||||
|
name='Predicted Candle',
|
||||||
|
increasing=dict(
|
||||||
|
line=dict(color='rgba(0, 255, 0, 0.4)', width=1),
|
||||||
|
fillcolor='rgba(0, 255, 0, 0.2)'
|
||||||
|
),
|
||||||
|
decreasing=dict(
|
||||||
|
line=dict(color='rgba(255, 0, 0, 0.4)', width=1),
|
||||||
|
fillcolor='rgba(255, 0, 0, 0.2)'
|
||||||
|
),
|
||||||
|
showlegend=True,
|
||||||
|
legendgroup='predictions',
|
||||||
|
opacity=0.5
|
||||||
|
),
|
||||||
|
row=row, col=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a marker at the predicted close price
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=[predicted_timestamp],
|
||||||
|
y=[pred_close],
|
||||||
|
mode='markers',
|
||||||
|
marker=dict(
|
||||||
|
size=8,
|
||||||
|
color='rgba(255, 215, 0, 0.8)', # Gold color
|
||||||
|
symbol='star',
|
||||||
|
line=dict(color='white', width=1)
|
||||||
|
),
|
||||||
|
name='Prediction Target',
|
||||||
|
showlegend=True,
|
||||||
|
legendgroup='predictions',
|
||||||
|
hovertemplate=f'Predicted Close: ${pred_close:.2f}<br>Time: {predicted_timestamp.strftime("%H:%M:%S")}<extra></extra>'
|
||||||
|
),
|
||||||
|
row=row, col=1
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Added ghost candle for {chart_timeframe}: O={pred_open:.2f}, H={pred_high:.2f}, L={pred_low:.2f}, C={pred_close:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding ghost candles to chart: {e}", exc_info=True)
|
||||||
|
|
||||||
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||||
"""Add Transformer price predictions as trend lines with confidence bands"""
|
"""Add Transformer price predictions as trend lines with confidence bands"""
|
||||||
try:
|
try:
|
||||||
@@ -4071,9 +4188,20 @@ class CleanTradingDashboard:
|
|||||||
This makes a real-time prediction with the transformer model
|
This makes a real-time prediction with the transformer model
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
if not self.orchestrator:
|
||||||
|
logger.debug("No orchestrator available for predictions")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if not hasattr(self.orchestrator, 'primary_transformer'):
|
||||||
|
logger.debug("Orchestrator has no primary_transformer attribute")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self.orchestrator.primary_transformer:
|
||||||
|
logger.debug("primary_transformer is None - model not loaded")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"Making live transformer prediction for {symbol}...")
|
||||||
|
|
||||||
transformer = self.orchestrator.primary_transformer
|
transformer = self.orchestrator.primary_transformer
|
||||||
transformer.eval()
|
transformer.eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user