diff --git a/ANNOTATE/core/live_pivot_trainer.py b/ANNOTATE/core/live_pivot_trainer.py
index e9e70ad..79ce19a 100644
--- a/ANNOTATE/core/live_pivot_trainer.py
+++ b/ANNOTATE/core/live_pivot_trainer.py
@@ -16,6 +16,8 @@ import time
from typing import Dict, List, Optional, Tuple
from datetime import datetime, timezone
from collections import deque
+import numpy as np
+import pandas as pd
logger = logging.getLogger(__name__)
@@ -146,20 +148,50 @@ class LivePivotTrainer:
if williams is None:
return
- pivots = williams.calculate_pivots(candles)
+ # Prepare data for Williams Market Structure
+ # Convert DataFrame to numpy array format
+ df = candles.copy()
+ ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
- if not pivots or 'L2' not in pivots:
+ # Handle timestamp conversion based on index type
+ if isinstance(df.index, pd.DatetimeIndex):
+ # Convert ns to ms
+ timestamps = df.index.astype(np.int64) // 10**6
+ else:
+ # Assume it's already timestamp or handle accordingly
+ timestamps = df.index
+
+ ohlcv_array.insert(0, 'timestamp', timestamps)
+ ohlcv_array = ohlcv_array.to_numpy()
+
+ # Calculate pivots
+ pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
+
+ if not pivot_levels or 2 not in pivot_levels:
return
- l2_pivots = pivots['L2']
+ # Get Level 2 pivots
+ l2_trend_level = pivot_levels[2]
+ l2_pivots_objs = l2_trend_level.pivot_points
+ if not l2_pivots_objs:
+ return
+
# Check for new L2 pivots (not in history)
new_pivots = []
- for pivot in l2_pivots:
- pivot_id = f"{symbol}_{timeframe}_{pivot['timestamp']}_{pivot['type']}"
+ for p in l2_pivots_objs:
+ # Convert pivot object to dict for compatibility
+ pivot_dict = {
+ 'timestamp': p.timestamp, # Keep as datetime object for compatibility
+ 'price': p.price,
+ 'type': p.pivot_type,
+ 'strength': p.strength
+ }
+
+ pivot_id = f"{symbol}_{timeframe}_{pivot_dict['timestamp']}_{pivot_dict['type']}"
if pivot_id not in self.trained_pivots:
- new_pivots.append(pivot)
+ new_pivots.append(pivot_dict)
self.trained_pivots.append(pivot_id)
if new_pivots:
diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py
index b24d66d..971c505 100644
--- a/ANNOTATE/core/real_training_adapter.py
+++ b/ANNOTATE/core/real_training_adapter.py
@@ -2361,7 +2361,10 @@ class RealTrainingAdapter:
# Real-time inference support
- def start_realtime_inference(self, model_name: str, symbol: str, data_provider, enable_live_training: bool = True) -> str:
+ def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
+ enable_live_training: bool = True,
+ train_every_candle: bool = False,
+ timeframe: str = '1m') -> str:
"""
Start real-time inference using orchestrator's REAL prediction methods
@@ -2370,6 +2373,8 @@ class RealTrainingAdapter:
symbol: Trading symbol
data_provider: Data provider for market data
enable_live_training: If True, automatically train on L2 pivots
+ train_every_candle: If True, train on every new candle (computationally expensive)
+ timeframe: Timeframe for candle-based training (default: 1m)
Returns:
inference_id: Unique ID for this inference session
@@ -2391,10 +2396,15 @@ class RealTrainingAdapter:
'start_time': time.time(),
'signals': [],
'stop_flag': False,
- 'live_training_enabled': enable_live_training
+ 'live_training_enabled': enable_live_training,
+ 'train_every_candle': train_every_candle,
+ 'timeframe': timeframe,
+ 'data_provider': data_provider,
+ 'last_candle_time': None
}
- logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol}")
+ training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
+ logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol} ({training_mode})")
# Start live pivot training if enabled
if enable_live_training:
@@ -2462,6 +2472,173 @@ class RealTrainingAdapter:
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit]
+ def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Dict:
+ """Make a prediction using the specified model"""
+ try:
+ if model_name == 'Transformer' and self.orchestrator:
+ trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
+ if trainer and trainer.model:
+ # Get recent market data
+ market_data = self._get_realtime_market_data(symbol, data_provider)
+ if not market_data:
+ return None
+
+ # Make prediction
+ import torch
+ with torch.no_grad():
+ trainer.model.eval()
+ outputs = trainer.model(**market_data)
+
+ # Extract action
+ action_probs = outputs.get('action_probs')
+ if action_probs is not None:
+ action_idx = torch.argmax(action_probs, dim=-1).item()
+ confidence = action_probs[0, action_idx].item()
+
+ actions = ['BUY', 'SELL', 'HOLD']
+ action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
+
+ return {
+ 'action': action,
+ 'confidence': confidence
+ }
+
+ return None
+ except Exception as e:
+ logger.debug(f"Error making realtime prediction: {e}")
+ return None
+
+ def _get_realtime_market_data(self, symbol: str, data_provider) -> Dict:
+ """Get current market data for prediction"""
+ try:
+ # Get recent candles for all timeframes
+ data = {}
+ for tf in ['1s', '1m', '1h', '1d']:
+ df = data_provider.get_historical_data(symbol, tf, limit=200)
+ if df is not None and not df.empty:
+ # Convert to tensor format (simplified)
+ import torch
+ import numpy as np
+
+ candles = df[['open', 'high', 'low', 'close', 'volume']].values
+ candles_tensor = torch.tensor(candles, dtype=torch.float32).unsqueeze(0)
+ data[f'price_data_{tf}'] = candles_tensor
+
+ # Add placeholder data for other inputs
+ import torch
+ data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32)
+ data['market_data'] = torch.zeros(1, 30, dtype=torch.float32)
+
+ return data if data else None
+ except Exception as e:
+ logger.debug(f"Error getting realtime market data: {e}")
+ return None
+
+ def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider):
+ """Train model on new candle when it closes"""
+ try:
+ # Get latest candle
+ df = data_provider.get_historical_data(symbol, timeframe, limit=2)
+ if df is None or len(df) < 2:
+ return
+
+ # Check if we have a new candle
+ latest_candle_time = df.index[-1]
+ if session['last_candle_time'] == latest_candle_time:
+ return # Same candle, no training needed
+
+ session['last_candle_time'] = latest_candle_time
+
+ # Get the completed candle (second to last)
+ completed_candle = df.iloc[-2]
+ next_candle = df.iloc[-1]
+
+ # Calculate if the prediction would have been correct
+ price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
+
+ # Create training sample
+ training_sample = {
+ 'symbol': symbol,
+ 'timestamp': completed_candle.name,
+ 'market_state': self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider),
+ 'action': 'BUY' if price_change > 0.001 else ('SELL' if price_change < -0.001 else 'HOLD'),
+ 'entry_price': float(completed_candle['close']),
+ 'exit_price': float(next_candle['close']),
+ 'profit_loss_pct': price_change * 100,
+ 'direction': 'LONG' if price_change > 0 else 'SHORT'
+ }
+
+ # Train on this sample
+ model_name = session['model_name']
+ if model_name == 'Transformer':
+ self._train_transformer_on_sample(training_sample)
+ logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
+
+ except Exception as e:
+ logger.warning(f"Error training on new candle: {e}")
+
+ def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
+ """Fetch market state at a specific candle time"""
+ try:
+ # Simplified version - get recent data
+ market_state = {'timeframes': {}, 'secondary_timeframes': {}}
+
+ for tf in ['1s', '1m', '1h', '1d']:
+ df = data_provider.get_historical_data(symbol, tf, limit=200)
+ if df is not None and not df.empty:
+ market_state['timeframes'][tf] = {
+ 'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
+ 'open': df['open'].tolist(),
+ 'high': df['high'].tolist(),
+ 'low': df['low'].tolist(),
+ 'close': df['close'].tolist(),
+ 'volume': df['volume'].tolist()
+ }
+
+ return market_state
+ except Exception as e:
+ logger.warning(f"Error fetching market state for candle: {e}")
+ return {}
+
+ def _train_transformer_on_sample(self, training_sample: Dict):
+ """Train transformer on a single sample"""
+ try:
+ if not self.orchestrator:
+ return
+
+ trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
+ if not trainer:
+ return
+
+ # Convert to batch format
+ batch = self._convert_annotation_to_transformer_batch(training_sample)
+ if not batch:
+ return
+
+ # Train on this batch
+ import torch
+ with torch.enable_grad():
+ trainer.model.train()
+ result = trainer.train_step(batch, accumulate_gradients=False)
+ if result:
+ logger.info(f"Per-candle training: Loss={result.get('total_loss', 0):.4f}")
+
+ except Exception as e:
+ logger.warning(f"Error training transformer on sample: {e}")
+
+ def _get_sleep_time_for_timeframe(self, timeframe: str) -> float:
+ """Get appropriate sleep time based on timeframe"""
+ timeframe_seconds = {
+ '1s': 1,
+ '1m': 5, # Check every 5 seconds for new 1m candle
+ '5m': 30,
+ '15m': 60,
+ '1h': 300,
+ '4h': 600,
+ '1d': 3600
+ }
+ return timeframe_seconds.get(timeframe, 5)
+
def _store_training_prediction(self, batch: Dict, trainer, symbol: str):
"""Store a prediction from training batch for visualization"""
try:
@@ -2517,50 +2694,74 @@ class RealTrainingAdapter:
def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider):
"""
- Real-time inference loop using orchestrator's REAL prediction methods
+ Real-time inference loop with optional per-candle training
- This runs in a background thread and continuously makes predictions
- using the actual model inference methods from the orchestrator.
+ This runs in a background thread and continuously makes predictions.
+ Can optionally train on every new candle.
"""
session = self.inference_sessions[inference_id]
+ train_every_candle = session.get('train_every_candle', False)
+ timeframe = session.get('timeframe', '1m')
try:
while not session['stop_flag']:
try:
- # Use orchestrator's REAL prediction method
- if hasattr(self.orchestrator, 'make_decision'):
- # Get real prediction from orchestrator
- decision = self.orchestrator.make_decision(symbol)
-
- if decision:
- # Store signal
- signal = {
- 'timestamp': datetime.now().isoformat(),
- 'symbol': symbol,
- 'model': model_name,
- 'action': decision.action,
- 'confidence': decision.confidence,
- 'price': decision.price
- }
-
- session['signals'].append(signal)
-
- # Keep only last 100 signals
- if len(session['signals']) > 100:
- session['signals'] = session['signals'][-100:]
-
- logger.info(f"REAL Signal: {signal['action']} @ {signal['price']} (confidence: {signal['confidence']:.2f})")
+ # Get current market data
+ current_price = data_provider.get_current_price(symbol)
+ if not current_price:
+ time.sleep(1)
+ continue
- # Sleep for 1 second before next inference
- time.sleep(1)
+ # Make prediction using the model
+ prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
+
+ if prediction:
+ # Store signal
+ signal = {
+ 'timestamp': datetime.now().isoformat(),
+ 'symbol': symbol,
+ 'model': model_name,
+ 'action': prediction['action'],
+ 'confidence': prediction['confidence'],
+ 'price': current_price
+ }
+
+ session['signals'].append(signal)
+
+ # Keep only last 100 signals
+ if len(session['signals']) > 100:
+ session['signals'] = session['signals'][-100:]
+
+ logger.info(f"Live Signal: {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f})")
+
+ # Store prediction for visualization
+ if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
+ self.orchestrator.store_transformer_prediction(symbol, {
+ 'timestamp': datetime.now(),
+ 'current_price': current_price,
+ 'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
+ 'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
+ 'confidence': prediction['confidence'],
+ 'action': prediction['action'],
+ 'horizon_minutes': 10,
+ 'source': 'live_inference'
+ })
+
+ # Per-candle training mode
+ if train_every_candle:
+ self._train_on_new_candle(session, symbol, timeframe, data_provider)
+
+ # Sleep based on timeframe
+ sleep_time = self._get_sleep_time_for_timeframe(timeframe)
+ time.sleep(sleep_time)
except Exception as e:
- logger.error(f"Error in REAL inference loop: {e}")
+ logger.error(f"Error in inference loop: {e}")
time.sleep(5)
- logger.info(f"REAL inference loop stopped: {inference_id}")
+ logger.info(f"Inference loop stopped: {inference_id}")
except Exception as e:
- logger.error(f"Fatal error in REAL inference loop: {e}")
+ logger.error(f"Fatal error in inference loop: {e}")
session['status'] = 'error'
session['error'] = str(e)
diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py
index 6745c29..67b7ed3 100644
--- a/ANNOTATE/web/app.py
+++ b/ANNOTATE/web/app.py
@@ -224,6 +224,7 @@ class BacktestRunner:
if orchestrator and hasattr(orchestrator, 'store_transformer_prediction'):
# Determine model type from model class name
model_type = model.__class__.__name__.lower()
+ logger.debug(f"Backtest: Storing prediction for model type: {model_type}")
# Store in appropriate prediction collection
if 'transformer' in model_type:
@@ -236,6 +237,7 @@ class BacktestRunner:
'action': prediction['action'],
'horizon_minutes': 10
})
+ logger.debug(f"Backtest: Stored transformer prediction: {prediction['action']} @ {current_price}")
elif 'cnn' in model_type:
if hasattr(orchestrator, 'recent_cnn_predictions'):
if symbol not in orchestrator.recent_cnn_predictions:
@@ -2006,12 +2008,14 @@ class AnnotationDashboard:
@self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference():
- """Start real-time inference mode with optional live training on L2 pivots"""
+ """Start real-time inference mode with optional training modes"""
try:
data = request.get_json()
model_name = data.get('model_name')
symbol = data.get('symbol', 'ETH/USDT')
- enable_live_training = data.get('enable_live_training', True) # Default: enabled
+ timeframe = data.get('timeframe', '1m')
+ enable_live_training = data.get('enable_live_training', False) # Pivot-based training
+ train_every_candle = data.get('train_every_candle', False) # Per-candle training
if not self.training_adapter:
return jsonify({
@@ -2022,18 +2026,23 @@ class AnnotationDashboard:
}
})
- # Start real-time inference with optional live training
+ # Start real-time inference with optional training modes
inference_id = self.training_adapter.start_realtime_inference(
model_name=model_name,
symbol=symbol,
data_provider=self.data_provider,
- enable_live_training=enable_live_training
+ enable_live_training=enable_live_training,
+ train_every_candle=train_every_candle,
+ timeframe=timeframe
)
+ training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
+
return jsonify({
'success': True,
'inference_id': inference_id,
- 'live_training_enabled': enable_live_training
+ 'training_mode': training_mode,
+ 'timeframe': timeframe
})
except Exception as e:
diff --git a/ANNOTATE/web/templates/components/training_panel.html b/ANNOTATE/web/templates/components/training_panel.html
index 6dc09cb..6eb0d61 100644
--- a/ANNOTATE/web/templates/components/training_panel.html
+++ b/ANNOTATE/web/templates/components/training_panel.html
@@ -80,7 +80,15 @@
+
+