feat: Add real-time prediction loop - the MISSING piece
Create RealtimePredictionLoop to continuously call model.predict() on incoming data Features: - Detects new 1s/1m candles and triggers predictions - Detects pivot points for prediction triggers - Calls ALL models (CNN, DQN, COB-RL) continuously - Combines predictions into trading signals - NO SYNTHETIC DATA - only real market data This is why model.predict() was only called once at startup - there was NO continuous prediction loop running on new market data!
This commit is contained in:
392
core/realtime_prediction_loop.py
Normal file
392
core/realtime_prediction_loop.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
Real-Time Prediction Loop
|
||||
|
||||
CRITICAL: This is the MISSING PIECE - continuous model inference on incoming market data
|
||||
|
||||
This module monitors market data and triggers model predictions on:
|
||||
- New 1s candles
|
||||
- New 1m candles
|
||||
- Pivot points detected
|
||||
- Significant price movements
|
||||
|
||||
NO SYNTHETIC DATA - Only real market data triggers predictions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RealtimePredictionLoop:
|
||||
"""
|
||||
Continuously monitors market data and triggers model predictions
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
|
||||
self.running = False
|
||||
self.last_prediction_time = {}
|
||||
self.prediction_interval_seconds = {
|
||||
'1s': 1, # Predict every second
|
||||
'1m': 60, # Predict every minute
|
||||
'pivot': 0 # Predict immediately on pivot
|
||||
}
|
||||
|
||||
# Track last candle to detect new ones
|
||||
self.last_candle_close_time = {}
|
||||
|
||||
logger.info("Real-time Prediction Loop initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start the continuous prediction loop"""
|
||||
self.running = True
|
||||
logger.info("🔄 Starting Real-Time Prediction Loop")
|
||||
|
||||
# Start prediction tasks for each symbol
|
||||
symbols = self.orchestrator.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
tasks = []
|
||||
for symbol in symbols:
|
||||
tasks.append(asyncio.create_task(self._prediction_loop_for_symbol(symbol)))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the prediction loop"""
|
||||
self.running = False
|
||||
logger.info("Real-Time Prediction Loop stopped")
|
||||
|
||||
async def _prediction_loop_for_symbol(self, symbol: str):
|
||||
"""Run prediction loop for a specific symbol"""
|
||||
logger.info(f"🔄 Prediction loop started for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 1. Check for new candle (1s or 1m)
|
||||
new_candle_detected, timeframe = await self._detect_new_candle(symbol)
|
||||
|
||||
if new_candle_detected:
|
||||
logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions")
|
||||
await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle")
|
||||
|
||||
# 2. Check for pivot point
|
||||
pivot_detected = await self._detect_pivot_point(symbol)
|
||||
|
||||
if pivot_detected:
|
||||
logger.info(f"📍 Pivot point detected for {symbol} - running predictions")
|
||||
await self._run_all_model_predictions(symbol, trigger="pivot_point")
|
||||
|
||||
# 3. Periodic prediction (every N seconds based on timeframe)
|
||||
if self._should_run_periodic_prediction(symbol):
|
||||
logger.debug(f"⏰ Periodic prediction for {symbol}")
|
||||
await self._run_all_model_predictions(symbol, trigger="periodic")
|
||||
|
||||
# Sleep briefly to avoid CPU overuse
|
||||
await asyncio.sleep(0.1) # Check every 100ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction loop for {symbol}: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _detect_new_candle(self, symbol: str) -> tuple:
|
||||
"""Detect if a new candle has closed"""
|
||||
try:
|
||||
# Get latest candles
|
||||
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=2)
|
||||
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=2)
|
||||
|
||||
# Check 1s candle
|
||||
if candles_1s and len(candles_1s) >= 2:
|
||||
latest_1s_time = candles_1s[-1].get('timestamp') or candles_1s[-1].get('time')
|
||||
|
||||
if symbol not in self.last_candle_close_time:
|
||||
self.last_candle_close_time[symbol] = {}
|
||||
|
||||
last_1s = self.last_candle_close_time[symbol].get('1s')
|
||||
|
||||
if latest_1s_time and latest_1s_time != last_1s:
|
||||
self.last_candle_close_time[symbol]['1s'] = latest_1s_time
|
||||
return True, '1s'
|
||||
|
||||
# Check 1m candle
|
||||
if candles_1m and len(candles_1m) >= 2:
|
||||
latest_1m_time = candles_1m[-1].get('timestamp') or candles_1m[-1].get('time')
|
||||
|
||||
last_1m = self.last_candle_close_time[symbol].get('1m')
|
||||
|
||||
if latest_1m_time and latest_1m_time != last_1m:
|
||||
self.last_candle_close_time[symbol]['1m'] = latest_1m_time
|
||||
return True, '1m'
|
||||
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting new candle for {symbol}: {e}")
|
||||
return False, None
|
||||
|
||||
async def _detect_pivot_point(self, symbol: str) -> bool:
|
||||
"""Detect if a pivot point has formed"""
|
||||
try:
|
||||
# Use Williams Market Structure or simple pivot detection
|
||||
recent_candles = await self.data_provider.get_latest_candles(symbol, '1m', limit=5)
|
||||
|
||||
if not recent_candles or len(recent_candles) < 5:
|
||||
return False
|
||||
|
||||
# Simple pivot: middle candle is local high or low
|
||||
highs = [c.get('high', 0) for c in recent_candles]
|
||||
lows = [c.get('low', 0) for c in recent_candles]
|
||||
|
||||
# Pivot high: middle is highest
|
||||
if highs[2] == max(highs):
|
||||
logger.debug(f"Pivot HIGH detected for {symbol}")
|
||||
return True
|
||||
|
||||
# Pivot low: middle is lowest
|
||||
if lows[2] == min(lows):
|
||||
logger.debug(f"Pivot LOW detected for {symbol}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting pivot for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def _should_run_periodic_prediction(self, symbol: str) -> bool:
|
||||
"""Check if enough time has passed for periodic prediction"""
|
||||
current_time = datetime.now()
|
||||
|
||||
last_time = self.last_prediction_time.get(symbol)
|
||||
if not last_time:
|
||||
self.last_prediction_time[symbol] = current_time
|
||||
return True
|
||||
|
||||
# Run periodic prediction every 5 seconds
|
||||
if (current_time - last_time).total_seconds() >= 5:
|
||||
self.last_prediction_time[symbol] = current_time
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
|
||||
"""
|
||||
CRITICAL: Run predictions from ALL models on current market data
|
||||
This is where model.predict() gets called!
|
||||
"""
|
||||
try:
|
||||
# Get current market data
|
||||
market_data = await self._extract_market_features(symbol)
|
||||
|
||||
if not market_data:
|
||||
logger.warning(f"No market data available for {symbol}")
|
||||
return
|
||||
|
||||
predictions = {}
|
||||
|
||||
# 1. CNN Model Prediction
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
try:
|
||||
cnn_features = market_data.get('cnn_features')
|
||||
if cnn_features is not None:
|
||||
# ✅ THIS IS WHERE model.predict() SHOULD BE CALLED!
|
||||
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features)
|
||||
predictions['cnn'] = cnn_prediction
|
||||
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}")
|
||||
except Exception as e:
|
||||
logger.error(f"CNN prediction error for {symbol}: {e}")
|
||||
|
||||
# 2. DQN Model Prediction
|
||||
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
|
||||
try:
|
||||
dqn_state = market_data.get('dqn_state')
|
||||
if dqn_state is not None:
|
||||
# DQN uses act() method
|
||||
action = self.orchestrator.dqn_agent.act(dqn_state, explore=False)
|
||||
predictions['dqn'] = {
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
||||
}
|
||||
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}")
|
||||
except Exception as e:
|
||||
logger.error(f"DQN prediction error for {symbol}: {e}")
|
||||
|
||||
# 3. COB RL Model Prediction
|
||||
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
|
||||
try:
|
||||
cob_features = market_data.get('cob_features')
|
||||
if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
||||
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features)
|
||||
predictions['cob_rl'] = cob_prediction
|
||||
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}")
|
||||
except Exception as e:
|
||||
logger.error(f"COB RL prediction error for {symbol}: {e}")
|
||||
|
||||
# 4. Combine predictions into trading signal
|
||||
if predictions:
|
||||
trading_signal = self._combine_predictions(symbol, predictions, market_data)
|
||||
|
||||
# Send signal to trading executor
|
||||
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
|
||||
await self.orchestrator.process_signal(trading_signal)
|
||||
logger.info(f"📤 Trading signal sent for {symbol}: {trading_signal['action']} "
|
||||
f"(confidence: {trading_signal['confidence']:.2f}, trigger: {trigger})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running model predictions for {symbol}: {e}")
|
||||
|
||||
async def _extract_market_features(self, symbol: str) -> Optional[Dict]:
|
||||
"""Extract features for all models from current market data"""
|
||||
try:
|
||||
# Get recent candles
|
||||
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=100)
|
||||
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=100)
|
||||
|
||||
if not candles_1m:
|
||||
return None
|
||||
|
||||
current_price = candles_1m[-1].get('close', 0)
|
||||
|
||||
# Build features for each model
|
||||
features = {
|
||||
'symbol': symbol,
|
||||
'current_price': current_price,
|
||||
'timestamp': datetime.now(),
|
||||
|
||||
# CNN features (100-dim feature vector)
|
||||
'cnn_features': self._build_cnn_features(candles_1m, candles_1s),
|
||||
|
||||
# DQN state (state vector for RL)
|
||||
'dqn_state': self._build_dqn_state(candles_1m),
|
||||
|
||||
# COB features (order book data)
|
||||
'cob_features': await self._build_cob_features(symbol)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting market features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _build_cnn_features(self, candles_1m: List, candles_1s: List) -> Optional[np.ndarray]:
|
||||
"""Build feature vector for CNN model"""
|
||||
try:
|
||||
if not candles_1m or len(candles_1m) < 10:
|
||||
return None
|
||||
|
||||
# Extract OHLCV data
|
||||
features = []
|
||||
|
||||
for candle in candles_1m[-20:]: # Last 20 candles
|
||||
features.extend([
|
||||
candle.get('open', 0),
|
||||
candle.get('high', 0),
|
||||
candle.get('low', 0),
|
||||
candle.get('close', 0),
|
||||
candle.get('volume', 0)
|
||||
])
|
||||
|
||||
# Pad or truncate to expected size (100 features)
|
||||
feature_array = np.array(features)
|
||||
if len(feature_array) < 100:
|
||||
feature_array = np.pad(feature_array, (0, 100 - len(feature_array)))
|
||||
else:
|
||||
feature_array = feature_array[:100]
|
||||
|
||||
return feature_array.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building CNN features: {e}")
|
||||
return None
|
||||
|
||||
def _build_dqn_state(self, candles: List) -> Optional[np.ndarray]:
|
||||
"""Build state vector for DQN agent"""
|
||||
try:
|
||||
if not candles or len(candles) < 5:
|
||||
return None
|
||||
|
||||
# Simple state: last 5 close prices normalized
|
||||
closes = [c.get('close', 0) for c in candles[-5:]]
|
||||
|
||||
if max(closes) == 0:
|
||||
return None
|
||||
|
||||
# Normalize
|
||||
state = np.array(closes) / max(closes)
|
||||
|
||||
return state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building DQN state: {e}")
|
||||
return None
|
||||
|
||||
async def _build_cob_features(self, symbol: str) -> Optional[Dict]:
|
||||
"""Build COB (Change of Bid) features"""
|
||||
try:
|
||||
# Get order book data if available
|
||||
if hasattr(self.orchestrator, 'get_cob_data'):
|
||||
cob_data = await self.orchestrator.get_cob_data(symbol)
|
||||
return cob_data
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building COB features: {e}")
|
||||
return None
|
||||
|
||||
def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]:
|
||||
"""Combine predictions from multiple models into a trading signal"""
|
||||
try:
|
||||
# Voting system: each model contributes
|
||||
votes = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
confidences = []
|
||||
|
||||
# CNN vote
|
||||
if 'cnn' in predictions:
|
||||
cnn_action = predictions['cnn'].get('action', 'HOLD')
|
||||
cnn_conf = predictions['cnn'].get('confidence', 0.5)
|
||||
votes[cnn_action] += cnn_conf
|
||||
confidences.append(cnn_conf)
|
||||
|
||||
# DQN vote
|
||||
if 'dqn' in predictions:
|
||||
dqn_action = predictions['dqn'].get('action_name', 'HOLD')
|
||||
votes[dqn_action] += 0.7 # Fixed confidence for DQN
|
||||
confidences.append(0.7)
|
||||
|
||||
# COB RL vote
|
||||
if 'cob_rl' in predictions:
|
||||
cob_action = predictions['cob_rl'].get('action', 'HOLD')
|
||||
cob_conf = predictions['cob_rl'].get('confidence', 0.5)
|
||||
votes[cob_action] += cob_conf
|
||||
confidences.append(cob_conf)
|
||||
|
||||
# Determine final action (majority vote)
|
||||
final_action = max(votes, key=votes.get)
|
||||
final_confidence = sum(confidences) / len(confidences) if confidences else 0.5
|
||||
|
||||
# Only signal if confidence is high enough
|
||||
if final_confidence < 0.6:
|
||||
logger.debug(f"Low confidence ({final_confidence:.2f}) - no signal")
|
||||
return None
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'action': final_action,
|
||||
'confidence': final_confidence,
|
||||
'price': market_data.get('current_price'),
|
||||
'models_used': list(predictions.keys()),
|
||||
'predictions': predictions,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error combining predictions: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user