try to fix live RT updates on ANNOTATE

This commit is contained in:
Dobromir Popov
2025-11-22 00:55:37 +02:00
parent feb6cec275
commit a7def3b788
5 changed files with 328 additions and 57 deletions

View File

@@ -16,6 +16,8 @@ import time
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from datetime import datetime, timezone from datetime import datetime, timezone
from collections import deque from collections import deque
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -146,20 +148,50 @@ class LivePivotTrainer:
if williams is None: if williams is None:
return return
pivots = williams.calculate_pivots(candles) # Prepare data for Williams Market Structure
# Convert DataFrame to numpy array format
df = candles.copy()
ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
if not pivots or 'L2' not in pivots: # Handle timestamp conversion based on index type
if isinstance(df.index, pd.DatetimeIndex):
# Convert ns to ms
timestamps = df.index.astype(np.int64) // 10**6
else:
# Assume it's already timestamp or handle accordingly
timestamps = df.index
ohlcv_array.insert(0, 'timestamp', timestamps)
ohlcv_array = ohlcv_array.to_numpy()
# Calculate pivots
pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
if not pivot_levels or 2 not in pivot_levels:
return return
l2_pivots = pivots['L2'] # Get Level 2 pivots
l2_trend_level = pivot_levels[2]
l2_pivots_objs = l2_trend_level.pivot_points
if not l2_pivots_objs:
return
# Check for new L2 pivots (not in history) # Check for new L2 pivots (not in history)
new_pivots = [] new_pivots = []
for pivot in l2_pivots: for p in l2_pivots_objs:
pivot_id = f"{symbol}_{timeframe}_{pivot['timestamp']}_{pivot['type']}" # Convert pivot object to dict for compatibility
pivot_dict = {
'timestamp': p.timestamp, # Keep as datetime object for compatibility
'price': p.price,
'type': p.pivot_type,
'strength': p.strength
}
pivot_id = f"{symbol}_{timeframe}_{pivot_dict['timestamp']}_{pivot_dict['type']}"
if pivot_id not in self.trained_pivots: if pivot_id not in self.trained_pivots:
new_pivots.append(pivot) new_pivots.append(pivot_dict)
self.trained_pivots.append(pivot_id) self.trained_pivots.append(pivot_id)
if new_pivots: if new_pivots:

View File

@@ -2361,7 +2361,10 @@ class RealTrainingAdapter:
# Real-time inference support # Real-time inference support
def start_realtime_inference(self, model_name: str, symbol: str, data_provider, enable_live_training: bool = True) -> str: def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
enable_live_training: bool = True,
train_every_candle: bool = False,
timeframe: str = '1m') -> str:
""" """
Start real-time inference using orchestrator's REAL prediction methods Start real-time inference using orchestrator's REAL prediction methods
@@ -2370,6 +2373,8 @@ class RealTrainingAdapter:
symbol: Trading symbol symbol: Trading symbol
data_provider: Data provider for market data data_provider: Data provider for market data
enable_live_training: If True, automatically train on L2 pivots enable_live_training: If True, automatically train on L2 pivots
train_every_candle: If True, train on every new candle (computationally expensive)
timeframe: Timeframe for candle-based training (default: 1m)
Returns: Returns:
inference_id: Unique ID for this inference session inference_id: Unique ID for this inference session
@@ -2391,10 +2396,15 @@ class RealTrainingAdapter:
'start_time': time.time(), 'start_time': time.time(),
'signals': [], 'signals': [],
'stop_flag': False, 'stop_flag': False,
'live_training_enabled': enable_live_training 'live_training_enabled': enable_live_training,
'train_every_candle': train_every_candle,
'timeframe': timeframe,
'data_provider': data_provider,
'last_candle_time': None
} }
logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol}") training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol} ({training_mode})")
# Start live pivot training if enabled # Start live pivot training if enabled
if enable_live_training: if enable_live_training:
@@ -2462,6 +2472,173 @@ class RealTrainingAdapter:
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True) all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit] return all_signals[:limit]
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Dict:
"""Make a prediction using the specified model"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Get recent market data
market_data = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None
# Make prediction
import torch
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action
action_probs = outputs.get('action_probs')
if action_probs is not None:
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
return {
'action': action,
'confidence': confidence
}
return None
except Exception as e:
logger.debug(f"Error making realtime prediction: {e}")
return None
def _get_realtime_market_data(self, symbol: str, data_provider) -> Dict:
"""Get current market data for prediction"""
try:
# Get recent candles for all timeframes
data = {}
for tf in ['1s', '1m', '1h', '1d']:
df = data_provider.get_historical_data(symbol, tf, limit=200)
if df is not None and not df.empty:
# Convert to tensor format (simplified)
import torch
import numpy as np
candles = df[['open', 'high', 'low', 'close', 'volume']].values
candles_tensor = torch.tensor(candles, dtype=torch.float32).unsqueeze(0)
data[f'price_data_{tf}'] = candles_tensor
# Add placeholder data for other inputs
import torch
data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32)
data['market_data'] = torch.zeros(1, 30, dtype=torch.float32)
return data if data else None
except Exception as e:
logger.debug(f"Error getting realtime market data: {e}")
return None
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider):
"""Train model on new candle when it closes"""
try:
# Get latest candle
df = data_provider.get_historical_data(symbol, timeframe, limit=2)
if df is None or len(df) < 2:
return
# Check if we have a new candle
latest_candle_time = df.index[-1]
if session['last_candle_time'] == latest_candle_time:
return # Same candle, no training needed
session['last_candle_time'] = latest_candle_time
# Get the completed candle (second to last)
completed_candle = df.iloc[-2]
next_candle = df.iloc[-1]
# Calculate if the prediction would have been correct
price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
# Create training sample
training_sample = {
'symbol': symbol,
'timestamp': completed_candle.name,
'market_state': self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider),
'action': 'BUY' if price_change > 0.001 else ('SELL' if price_change < -0.001 else 'HOLD'),
'entry_price': float(completed_candle['close']),
'exit_price': float(next_candle['close']),
'profit_loss_pct': price_change * 100,
'direction': 'LONG' if price_change > 0 else 'SHORT'
}
# Train on this sample
model_name = session['model_name']
if model_name == 'Transformer':
self._train_transformer_on_sample(training_sample)
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
except Exception as e:
logger.warning(f"Error training on new candle: {e}")
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
"""Fetch market state at a specific candle time"""
try:
# Simplified version - get recent data
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h', '1d']:
df = data_provider.get_historical_data(symbol, tf, limit=200)
if df is not None and not df.empty:
market_state['timeframes'][tf] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
return market_state
except Exception as e:
logger.warning(f"Error fetching market state for candle: {e}")
return {}
def _train_transformer_on_sample(self, training_sample: Dict):
"""Train transformer on a single sample"""
try:
if not self.orchestrator:
return
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
return
# Convert to batch format
batch = self._convert_annotation_to_transformer_batch(training_sample)
if not batch:
return
# Train on this batch
import torch
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
logger.info(f"Per-candle training: Loss={result.get('total_loss', 0):.4f}")
except Exception as e:
logger.warning(f"Error training transformer on sample: {e}")
def _get_sleep_time_for_timeframe(self, timeframe: str) -> float:
"""Get appropriate sleep time based on timeframe"""
timeframe_seconds = {
'1s': 1,
'1m': 5, # Check every 5 seconds for new 1m candle
'5m': 30,
'15m': 60,
'1h': 300,
'4h': 600,
'1d': 3600
}
return timeframe_seconds.get(timeframe, 5)
def _store_training_prediction(self, batch: Dict, trainer, symbol: str): def _store_training_prediction(self, batch: Dict, trainer, symbol: str):
"""Store a prediction from training batch for visualization""" """Store a prediction from training batch for visualization"""
try: try:
@@ -2517,50 +2694,74 @@ class RealTrainingAdapter:
def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider): def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider):
""" """
Real-time inference loop using orchestrator's REAL prediction methods Real-time inference loop with optional per-candle training
This runs in a background thread and continuously makes predictions This runs in a background thread and continuously makes predictions.
using the actual model inference methods from the orchestrator. Can optionally train on every new candle.
""" """
session = self.inference_sessions[inference_id] session = self.inference_sessions[inference_id]
train_every_candle = session.get('train_every_candle', False)
timeframe = session.get('timeframe', '1m')
try: try:
while not session['stop_flag']: while not session['stop_flag']:
try: try:
# Use orchestrator's REAL prediction method # Get current market data
if hasattr(self.orchestrator, 'make_decision'): current_price = data_provider.get_current_price(symbol)
# Get real prediction from orchestrator if not current_price:
decision = self.orchestrator.make_decision(symbol) time.sleep(1)
continue
if decision:
# Store signal
signal = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'model': model_name,
'action': decision.action,
'confidence': decision.confidence,
'price': decision.price
}
session['signals'].append(signal)
# Keep only last 100 signals
if len(session['signals']) > 100:
session['signals'] = session['signals'][-100:]
logger.info(f"REAL Signal: {signal['action']} @ {signal['price']} (confidence: {signal['confidence']:.2f})")
# Sleep for 1 second before next inference # Make prediction using the model
time.sleep(1) prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
if prediction:
# Store signal
signal = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'model': model_name,
'action': prediction['action'],
'confidence': prediction['confidence'],
'price': current_price
}
session['signals'].append(signal)
# Keep only last 100 signals
if len(session['signals']) > 100:
session['signals'] = session['signals'][-100:]
logger.info(f"Live Signal: {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f})")
# Store prediction for visualization
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
self.orchestrator.store_transformer_prediction(symbol, {
'timestamp': datetime.now(),
'current_price': current_price,
'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
'confidence': prediction['confidence'],
'action': prediction['action'],
'horizon_minutes': 10,
'source': 'live_inference'
})
# Per-candle training mode
if train_every_candle:
self._train_on_new_candle(session, symbol, timeframe, data_provider)
# Sleep based on timeframe
sleep_time = self._get_sleep_time_for_timeframe(timeframe)
time.sleep(sleep_time)
except Exception as e: except Exception as e:
logger.error(f"Error in REAL inference loop: {e}") logger.error(f"Error in inference loop: {e}")
time.sleep(5) time.sleep(5)
logger.info(f"REAL inference loop stopped: {inference_id}") logger.info(f"Inference loop stopped: {inference_id}")
except Exception as e: except Exception as e:
logger.error(f"Fatal error in REAL inference loop: {e}") logger.error(f"Fatal error in inference loop: {e}")
session['status'] = 'error' session['status'] = 'error'
session['error'] = str(e) session['error'] = str(e)

View File

@@ -224,6 +224,7 @@ class BacktestRunner:
if orchestrator and hasattr(orchestrator, 'store_transformer_prediction'): if orchestrator and hasattr(orchestrator, 'store_transformer_prediction'):
# Determine model type from model class name # Determine model type from model class name
model_type = model.__class__.__name__.lower() model_type = model.__class__.__name__.lower()
logger.debug(f"Backtest: Storing prediction for model type: {model_type}")
# Store in appropriate prediction collection # Store in appropriate prediction collection
if 'transformer' in model_type: if 'transformer' in model_type:
@@ -236,6 +237,7 @@ class BacktestRunner:
'action': prediction['action'], 'action': prediction['action'],
'horizon_minutes': 10 'horizon_minutes': 10
}) })
logger.debug(f"Backtest: Stored transformer prediction: {prediction['action']} @ {current_price}")
elif 'cnn' in model_type: elif 'cnn' in model_type:
if hasattr(orchestrator, 'recent_cnn_predictions'): if hasattr(orchestrator, 'recent_cnn_predictions'):
if symbol not in orchestrator.recent_cnn_predictions: if symbol not in orchestrator.recent_cnn_predictions:
@@ -2006,12 +2008,14 @@ class AnnotationDashboard:
@self.server.route('/api/realtime-inference/start', methods=['POST']) @self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference(): def start_realtime_inference():
"""Start real-time inference mode with optional live training on L2 pivots""" """Start real-time inference mode with optional training modes"""
try: try:
data = request.get_json() data = request.get_json()
model_name = data.get('model_name') model_name = data.get('model_name')
symbol = data.get('symbol', 'ETH/USDT') symbol = data.get('symbol', 'ETH/USDT')
enable_live_training = data.get('enable_live_training', True) # Default: enabled timeframe = data.get('timeframe', '1m')
enable_live_training = data.get('enable_live_training', False) # Pivot-based training
train_every_candle = data.get('train_every_candle', False) # Per-candle training
if not self.training_adapter: if not self.training_adapter:
return jsonify({ return jsonify({
@@ -2022,18 +2026,23 @@ class AnnotationDashboard:
} }
}) })
# Start real-time inference with optional live training # Start real-time inference with optional training modes
inference_id = self.training_adapter.start_realtime_inference( inference_id = self.training_adapter.start_realtime_inference(
model_name=model_name, model_name=model_name,
symbol=symbol, symbol=symbol,
data_provider=self.data_provider, data_provider=self.data_provider,
enable_live_training=enable_live_training enable_live_training=enable_live_training,
train_every_candle=train_every_candle,
timeframe=timeframe
) )
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
return jsonify({ return jsonify({
'success': True, 'success': True,
'inference_id': inference_id, 'inference_id': inference_id,
'live_training_enabled': enable_live_training 'training_mode': training_mode,
'timeframe': timeframe
}) })
except Exception as e: except Exception as e:

View File

@@ -80,7 +80,15 @@
<button class="btn btn-success btn-sm w-100" id="start-inference-btn"> <button class="btn btn-success btn-sm w-100" id="start-inference-btn">
<i class="fas fa-play"></i> <i class="fas fa-play"></i>
Start Live Inference Start Live Inference (No Training)
</button>
<button class="btn btn-info btn-sm w-100 mt-1" id="start-inference-pivot-btn">
<i class="fas fa-chart-line"></i>
Live Inference + Pivot Training
</button>
<button class="btn btn-primary btn-sm w-100 mt-1" id="start-inference-candle-btn">
<i class="fas fa-graduation-cap"></i>
Live Inference + Per-Candle Training
</button> </button>
<button class="btn btn-danger btn-sm w-100 mt-1" id="stop-inference-btn" style="display: none;"> <button class="btn btn-danger btn-sm w-100 mt-1" id="stop-inference-btn" style="display: none;">
<i class="fas fa-stop"></i> <i class="fas fa-stop"></i>
@@ -511,7 +519,8 @@
document.getElementById('active-steps').textContent = steps; document.getElementById('active-steps').textContent = steps;
}); });
document.getElementById('start-inference-btn').addEventListener('click', function () { // Helper function to start inference with different modes
function startInference(enableLiveTraining, trainEveryCandle) {
const modelName = document.getElementById('model-select').value; const modelName = document.getElementById('model-select').value;
if (!modelName) { if (!modelName) {
@@ -519,9 +528,8 @@
return; return;
} }
// Get primary timeframe and prediction steps // Get timeframe
const primaryTimeframe = document.getElementById('primary-timeframe-select').value; const timeframe = document.getElementById('primary-timeframe-select').value;
const predictionSteps = parseInt(document.getElementById('prediction-steps-slider').value);
// Start real-time inference // Start real-time inference
fetch('/api/realtime-inference/start', { fetch('/api/realtime-inference/start', {
@@ -530,8 +538,9 @@
body: JSON.stringify({ body: JSON.stringify({
model_name: modelName, model_name: modelName,
symbol: appState.currentSymbol, symbol: appState.currentSymbol,
primary_timeframe: primaryTimeframe, timeframe: timeframe,
prediction_steps: predictionSteps enable_live_training: enableLiveTraining,
train_every_candle: trainEveryCandle
}) })
}) })
.then(response => response.json()) .then(response => response.json())
@@ -541,6 +550,8 @@
// Update UI // Update UI
document.getElementById('start-inference-btn').style.display = 'none'; document.getElementById('start-inference-btn').style.display = 'none';
document.getElementById('start-inference-pivot-btn').style.display = 'none';
document.getElementById('start-inference-candle-btn').style.display = 'none';
document.getElementById('stop-inference-btn').style.display = 'block'; document.getElementById('stop-inference-btn').style.display = 'block';
document.getElementById('inference-status').style.display = 'block'; document.getElementById('inference-status').style.display = 'block';
document.getElementById('inference-controls').style.display = 'block'; document.getElementById('inference-controls').style.display = 'block';
@@ -558,7 +569,10 @@
// Start polling for signals // Start polling for signals
startSignalPolling(); startSignalPolling();
showSuccess('Real-time inference started - Charts now updating live'); const trainingMode = data.training_mode || 'inference-only';
const modeText = trainingMode === 'per-candle' ? ' with per-candle training' :
(trainingMode === 'pivot-based' ? ' with pivot training' : '');
showSuccess('Real-time inference started' + modeText);
} else { } else {
showError('Failed to start inference: ' + data.error.message); showError('Failed to start inference: ' + data.error.message);
} }
@@ -566,6 +580,19 @@
.catch(error => { .catch(error => {
showError('Network error: ' + error.message); showError('Network error: ' + error.message);
}); });
}
// Button handlers for different inference modes
document.getElementById('start-inference-btn').addEventListener('click', function () {
startInference(false, false); // No training
});
document.getElementById('start-inference-pivot-btn').addEventListener('click', function () {
startInference(true, false); // Pivot-based training
});
document.getElementById('start-inference-candle-btn').addEventListener('click', function () {
startInference(false, true); // Per-candle training
}); });
document.getElementById('stop-inference-btn').addEventListener('click', function () { document.getElementById('stop-inference-btn').addEventListener('click', function () {
@@ -582,6 +609,8 @@
if (data.success) { if (data.success) {
// Update UI // Update UI
document.getElementById('start-inference-btn').style.display = 'block'; document.getElementById('start-inference-btn').style.display = 'block';
document.getElementById('start-inference-pivot-btn').style.display = 'block';
document.getElementById('start-inference-candle-btn').style.display = 'block';
document.getElementById('stop-inference-btn').style.display = 'none'; document.getElementById('stop-inference-btn').style.display = 'none';
document.getElementById('inference-status').style.display = 'none'; document.getElementById('inference-status').style.display = 'none';
document.getElementById('inference-controls').style.display = 'none'; document.getElementById('inference-controls').style.display = 'none';

View File

@@ -121,9 +121,9 @@ class WilliamsMarketStructure:
# Restore original pivot distance # Restore original pivot distance
self.min_pivot_distance = original_distance self.min_pivot_distance = original_distance
logger.info(f"Calculated {len(self.pivot_levels)} pivot levels from {len(ohlcv_data)} candles") logger.debug(f"Calculated {len(self.pivot_levels)} pivot levels from {len(ohlcv_data)} candles")
for level_num, level_data in self.pivot_levels.items(): for level_num, level_data in self.pivot_levels.items():
logger.info(f" L{level_num}: {len(level_data.pivot_points)} pivots") logger.debug(f" L{level_num}: {len(level_data.pivot_points)} pivots")
return self.pivot_levels return self.pivot_levels
@@ -185,7 +185,7 @@ class WilliamsMarketStructure:
) )
pivots.append(pivot) pivots.append(pivot)
logger.info(f"Level 1: Found {len(pivots)} pivot points from {len(df)} candles") logger.debug(f"Level 1: Found {len(pivots)} pivot points from {len(df)} candles")
return pivots return pivots
except Exception as e: except Exception as e:
@@ -272,7 +272,7 @@ class WilliamsMarketStructure:
# Sort pivots by timestamp # Sort pivots by timestamp
pivots.sort(key=lambda x: x.timestamp) pivots.sort(key=lambda x: x.timestamp)
logger.info(f"Level {level}: Found {len(pivots)} pivot points (from {len(highs)} highs, {len(lows)} lows)") logger.debug(f"Level {level}: Found {len(pivots)} pivot points (from {len(highs)} highs, {len(lows)} lows)")
return pivots return pivots
except Exception as e: except Exception as e: