wip
This commit is contained in:
@@ -3264,18 +3264,25 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
||||
|
||||
# Store prediction for visualization
|
||||
# Store prediction for visualization WITH predicted_candle data for ghost candles
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
self.orchestrator.store_transformer_prediction(symbol, {
|
||||
stored_prediction = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'current_price': current_price,
|
||||
'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
|
||||
'predicted_price': prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99)),
|
||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||
'confidence': prediction['confidence'],
|
||||
'action': prediction['action'],
|
||||
'horizon_minutes': 10,
|
||||
'source': 'live_inference'
|
||||
})
|
||||
}
|
||||
# Include predicted_candle for ghost candle visualization
|
||||
if 'predicted_candle' in prediction and prediction['predicted_candle']:
|
||||
stored_prediction['predicted_candle'] = prediction['predicted_candle']
|
||||
stored_prediction['next_candles'] = prediction['predicted_candle'] # Alias for compatibility
|
||||
logger.debug(f"Stored prediction with {len(prediction['predicted_candle'])} timeframe candles")
|
||||
|
||||
self.orchestrator.store_transformer_prediction(symbol, stored_prediction)
|
||||
|
||||
# Per-candle training mode
|
||||
if train_every_candle:
|
||||
|
||||
@@ -16,7 +16,7 @@ sys.path.insert(0, str(parent_dir))
|
||||
from flask import Flask, render_template, request, jsonify, send_file
|
||||
from dash import Dash, html
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Dict, List, Any
|
||||
import json
|
||||
import pandas as pd
|
||||
@@ -2254,11 +2254,19 @@ class AnnotationDashboard:
|
||||
if cnn_preds:
|
||||
predictions['cnn'] = cnn_preds[-1]
|
||||
|
||||
# Transformer predictions
|
||||
# Transformer predictions with next_candles for ghost candles
|
||||
# First check if there are stored predictions from the inference loop
|
||||
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
||||
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
||||
if transformer_preds:
|
||||
# Use the most recent stored prediction (from inference loop)
|
||||
predictions['transformer'] = transformer_preds[-1]
|
||||
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
|
||||
else:
|
||||
# Fallback: generate new prediction if no stored predictions
|
||||
transformer_pred = self._get_live_transformer_prediction(symbol)
|
||||
if transformer_pred:
|
||||
predictions['transformer'] = transformer_pred
|
||||
|
||||
if predictions:
|
||||
response['prediction'] = predictions
|
||||
@@ -2497,6 +2505,110 @@ class AnnotationDashboard:
|
||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||
self._live_update_thread.start()
|
||||
|
||||
def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'):
|
||||
"""
|
||||
Generate live transformer prediction with next_candles for ghost candle display
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
logger.debug("No orchestrator - cannot generate predictions")
|
||||
return None
|
||||
|
||||
if not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.debug("Orchestrator has no primary_transformer - enable training first")
|
||||
return None
|
||||
|
||||
transformer = self.orchestrator.primary_transformer
|
||||
if not transformer:
|
||||
logger.debug("primary_transformer is None - model not loaded yet")
|
||||
return None
|
||||
|
||||
transformer.eval()
|
||||
|
||||
# Get recent market data
|
||||
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None
|
||||
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None
|
||||
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None
|
||||
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None
|
||||
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None
|
||||
|
||||
if not price_data_1m or len(price_data_1m) < 10:
|
||||
return None
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
device = next(transformer.parameters()).device
|
||||
|
||||
def ohlcv_to_tensor(data, limit=None):
|
||||
if not data:
|
||||
return None
|
||||
data = data[-limit:] if limit and len(data) > limit else data
|
||||
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
||||
return torch.from_numpy(arr).unsqueeze(0).to(device)
|
||||
|
||||
inputs = {
|
||||
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
||||
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
|
||||
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
||||
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
||||
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = transformer(**inputs)
|
||||
|
||||
# Extract next_candles
|
||||
next_candles = outputs.get('next_candles', {})
|
||||
if not next_candles:
|
||||
return None
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
predicted_candle = {}
|
||||
for tf, candle_tensor in next_candles.items():
|
||||
if candle_tensor is not None:
|
||||
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
|
||||
predicted_candle[tf] = candle_values
|
||||
|
||||
current_price = price_data_1m[-1]['close']
|
||||
predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3]
|
||||
price_change = (predicted_1m_close - current_price) / current_price
|
||||
|
||||
if price_change > 0.001:
|
||||
action = 'BUY'
|
||||
elif price_change < -0.001:
|
||||
action = 'SELL'
|
||||
else:
|
||||
action = 'HOLD'
|
||||
|
||||
confidence = 0.7
|
||||
if 'confidence' in outputs:
|
||||
conf_tensor = outputs['confidence']
|
||||
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
|
||||
|
||||
prediction = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_1m_close,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'predicted_candle': predicted_candle, # This is what frontend needs!
|
||||
'type': 'transformer_prediction'
|
||||
}
|
||||
|
||||
# Store for tracking
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction)
|
||||
|
||||
logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles")
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating live transformer prediction: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
||||
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
||||
"""
|
||||
|
||||
@@ -1962,6 +1962,16 @@ class CleanTradingDashboard:
|
||||
def update_price_chart(n, pivots_value, relayout_data):
|
||||
"""Update price chart every second, persisting user zoom/pan"""
|
||||
try:
|
||||
# Log transformer status on first update
|
||||
if n == 1:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
|
||||
if self.orchestrator.primary_transformer:
|
||||
logger.info("TRANSFORMER MODEL LOADED - Ghost candles should be visible")
|
||||
else:
|
||||
logger.warning("TRANSFORMER MODEL IS NONE - Enable training to load model")
|
||||
else:
|
||||
logger.warning("NO TRANSFORMER AVAILABLE - Enable training first")
|
||||
|
||||
# Validate and train on predictions every update (once per second)
|
||||
# This checks if any predictions can be validated against real candles
|
||||
self._validate_and_train_on_predictions('ETH/USDT')
|
||||
@@ -3365,6 +3375,7 @@ class CleanTradingDashboard:
|
||||
self._add_dqn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_cnn_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_transformer_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_ghost_candles_to_chart(fig, symbol, df_main, row) # Add predicted future candles
|
||||
self._add_cob_rl_predictions_to_chart(fig, symbol, df_main, row)
|
||||
self._add_prediction_accuracy_feedback(fig, symbol, df_main, row)
|
||||
self._add_williams_pivots_to_chart(fig, symbol, df_main, row)
|
||||
@@ -3683,6 +3694,112 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding CNN predictions to chart: {e}")
|
||||
|
||||
def _add_ghost_candles_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add predicted future candles (ghost candles) to the chart"""
|
||||
try:
|
||||
# Get latest live prediction with next_candles
|
||||
prediction = self._get_live_transformer_prediction_with_next_candles(symbol)
|
||||
|
||||
if not prediction:
|
||||
logger.debug("No transformer prediction available - is training enabled?")
|
||||
return
|
||||
|
||||
if 'next_candles' not in prediction:
|
||||
logger.debug("Prediction exists but has no next_candles data")
|
||||
return
|
||||
|
||||
next_candles = prediction['next_candles']
|
||||
if not next_candles:
|
||||
return
|
||||
|
||||
# Get the chart timeframe from the dataframe (assume 1s or 1m based on df_main)
|
||||
if df_main.empty:
|
||||
return
|
||||
|
||||
# Detect timeframe from dataframe index
|
||||
if len(df_main) >= 2:
|
||||
time_diff = (df_main.index[-1] - df_main.index[-2]).total_seconds()
|
||||
if time_diff <= 1.5:
|
||||
chart_timeframe = '1s'
|
||||
elif time_diff <= 65:
|
||||
chart_timeframe = '1m'
|
||||
elif time_diff <= 3900:
|
||||
chart_timeframe = '1h'
|
||||
else:
|
||||
chart_timeframe = '1d'
|
||||
else:
|
||||
chart_timeframe = '1m' # Default
|
||||
|
||||
# Get prediction for this timeframe
|
||||
if chart_timeframe not in next_candles:
|
||||
logger.debug(f"No prediction for {chart_timeframe} timeframe")
|
||||
return
|
||||
|
||||
predicted_ohlcv = next_candles[chart_timeframe]
|
||||
if not predicted_ohlcv or len(predicted_ohlcv) < 5:
|
||||
return
|
||||
|
||||
# Calculate timestamp for predicted candle
|
||||
last_timestamp = df_main.index[-1]
|
||||
timeframe_seconds = {'1s': 1, '1m': 60, '1h': 3600, '1d': 86400}
|
||||
delta_seconds = timeframe_seconds.get(chart_timeframe, 60)
|
||||
predicted_timestamp = last_timestamp + timedelta(seconds=delta_seconds)
|
||||
|
||||
# Extract OHLCV values
|
||||
pred_open, pred_high, pred_low, pred_close, pred_volume = predicted_ohlcv
|
||||
|
||||
# Determine color based on direction
|
||||
is_bullish = pred_close >= pred_open
|
||||
|
||||
# Add ghost candle as semi-transparent candlestick
|
||||
fig.add_trace(
|
||||
go.Candlestick(
|
||||
x=[predicted_timestamp],
|
||||
open=[pred_open],
|
||||
high=[pred_high],
|
||||
low=[pred_low],
|
||||
close=[pred_close],
|
||||
name='Predicted Candle',
|
||||
increasing=dict(
|
||||
line=dict(color='rgba(0, 255, 0, 0.4)', width=1),
|
||||
fillcolor='rgba(0, 255, 0, 0.2)'
|
||||
),
|
||||
decreasing=dict(
|
||||
line=dict(color='rgba(255, 0, 0, 0.4)', width=1),
|
||||
fillcolor='rgba(255, 0, 0, 0.2)'
|
||||
),
|
||||
showlegend=True,
|
||||
legendgroup='predictions',
|
||||
opacity=0.5
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
# Add a marker at the predicted close price
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[predicted_timestamp],
|
||||
y=[pred_close],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
size=8,
|
||||
color='rgba(255, 215, 0, 0.8)', # Gold color
|
||||
symbol='star',
|
||||
line=dict(color='white', width=1)
|
||||
),
|
||||
name='Prediction Target',
|
||||
showlegend=True,
|
||||
legendgroup='predictions',
|
||||
hovertemplate=f'Predicted Close: ${pred_close:.2f}<br>Time: {predicted_timestamp.strftime("%H:%M:%S")}<extra></extra>'
|
||||
),
|
||||
row=row, col=1
|
||||
)
|
||||
|
||||
logger.info(f"Added ghost candle for {chart_timeframe}: O={pred_open:.2f}, H={pred_high:.2f}, L={pred_low:.2f}, C={pred_close:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding ghost candles to chart: {e}", exc_info=True)
|
||||
|
||||
def _add_transformer_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
|
||||
"""Add Transformer price predictions as trend lines with confidence bands"""
|
||||
try:
|
||||
@@ -4071,9 +4188,20 @@ class CleanTradingDashboard:
|
||||
This makes a real-time prediction with the transformer model
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||
if not self.orchestrator:
|
||||
logger.debug("No orchestrator available for predictions")
|
||||
return None
|
||||
|
||||
if not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.debug("Orchestrator has no primary_transformer attribute")
|
||||
return None
|
||||
|
||||
if not self.orchestrator.primary_transformer:
|
||||
logger.debug("primary_transformer is None - model not loaded")
|
||||
return None
|
||||
|
||||
logger.debug(f"Making live transformer prediction for {symbol}...")
|
||||
|
||||
transformer = self.orchestrator.primary_transformer
|
||||
transformer.eval()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user