wip
This commit is contained in:
@@ -16,7 +16,7 @@ sys.path.insert(0, str(parent_dir))
|
||||
from flask import Flask, render_template, request, jsonify, send_file
|
||||
from dash import Dash, html
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Dict, List, Any
|
||||
import json
|
||||
import pandas as pd
|
||||
@@ -2254,11 +2254,19 @@ class AnnotationDashboard:
|
||||
if cnn_preds:
|
||||
predictions['cnn'] = cnn_preds[-1]
|
||||
|
||||
# Transformer predictions
|
||||
# Transformer predictions with next_candles for ghost candles
|
||||
# First check if there are stored predictions from the inference loop
|
||||
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
||||
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
||||
if transformer_preds:
|
||||
# Use the most recent stored prediction (from inference loop)
|
||||
predictions['transformer'] = transformer_preds[-1]
|
||||
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
|
||||
else:
|
||||
# Fallback: generate new prediction if no stored predictions
|
||||
transformer_pred = self._get_live_transformer_prediction(symbol)
|
||||
if transformer_pred:
|
||||
predictions['transformer'] = transformer_pred
|
||||
|
||||
if predictions:
|
||||
response['prediction'] = predictions
|
||||
@@ -2497,6 +2505,110 @@ class AnnotationDashboard:
|
||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||
self._live_update_thread.start()
|
||||
|
||||
def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'):
|
||||
"""
|
||||
Generate live transformer prediction with next_candles for ghost candle display
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
logger.debug("No orchestrator - cannot generate predictions")
|
||||
return None
|
||||
|
||||
if not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.debug("Orchestrator has no primary_transformer - enable training first")
|
||||
return None
|
||||
|
||||
transformer = self.orchestrator.primary_transformer
|
||||
if not transformer:
|
||||
logger.debug("primary_transformer is None - model not loaded yet")
|
||||
return None
|
||||
|
||||
transformer.eval()
|
||||
|
||||
# Get recent market data
|
||||
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None
|
||||
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None
|
||||
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None
|
||||
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None
|
||||
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None
|
||||
|
||||
if not price_data_1m or len(price_data_1m) < 10:
|
||||
return None
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
device = next(transformer.parameters()).device
|
||||
|
||||
def ohlcv_to_tensor(data, limit=None):
|
||||
if not data:
|
||||
return None
|
||||
data = data[-limit:] if limit and len(data) > limit else data
|
||||
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
||||
return torch.from_numpy(arr).unsqueeze(0).to(device)
|
||||
|
||||
inputs = {
|
||||
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
||||
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
|
||||
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
||||
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
||||
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
||||
}
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = transformer(**inputs)
|
||||
|
||||
# Extract next_candles
|
||||
next_candles = outputs.get('next_candles', {})
|
||||
if not next_candles:
|
||||
return None
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
predicted_candle = {}
|
||||
for tf, candle_tensor in next_candles.items():
|
||||
if candle_tensor is not None:
|
||||
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
|
||||
predicted_candle[tf] = candle_values
|
||||
|
||||
current_price = price_data_1m[-1]['close']
|
||||
predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3]
|
||||
price_change = (predicted_1m_close - current_price) / current_price
|
||||
|
||||
if price_change > 0.001:
|
||||
action = 'BUY'
|
||||
elif price_change < -0.001:
|
||||
action = 'SELL'
|
||||
else:
|
||||
action = 'HOLD'
|
||||
|
||||
confidence = 0.7
|
||||
if 'confidence' in outputs:
|
||||
conf_tensor = outputs['confidence']
|
||||
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
|
||||
|
||||
prediction = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_1m_close,
|
||||
'current_price': current_price,
|
||||
'price_change': price_change,
|
||||
'predicted_candle': predicted_candle, # This is what frontend needs!
|
||||
'type': 'transformer_prediction'
|
||||
}
|
||||
|
||||
# Store for tracking
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction)
|
||||
|
||||
logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles")
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating live transformer prediction: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
||||
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user