""" Real-Time Prediction Loop CRITICAL: This is the MISSING PIECE - continuous model inference on incoming market data This module monitors market data and triggers model predictions on: - New 1s candles - New 1m candles - Pivot points detected - Significant price movements NO SYNTHETIC DATA - Only real market data triggers predictions """ import logging import asyncio from datetime import datetime, timedelta from typing import Dict, List, Optional, Any import numpy as np from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData logger = logging.getLogger(__name__) class RealtimePredictionLoop: """ Continuously monitors market data and triggers model predictions """ def __init__(self, orchestrator, data_provider): self.orchestrator = orchestrator self.data_provider = data_provider self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config) self.running = False self.last_prediction_time = {} self.prediction_interval_seconds = { '1s': 1, # Predict every second '1m': 60, # Predict every minute 'pivot': 0 # Predict immediately on pivot } # Track last candle to detect new ones self.last_candle_close_time = {} logger.info("Real-time Prediction Loop initialized with unified data interface") async def start(self): """Start the continuous prediction loop""" self.running = True logger.info("🔄 Starting Real-Time Prediction Loop") # Start prediction tasks for each symbol symbols = self.orchestrator.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) tasks = [] for symbol in symbols: tasks.append(asyncio.create_task(self._prediction_loop_for_symbol(symbol))) await asyncio.gather(*tasks) def stop(self): """Stop the prediction loop""" self.running = False logger.info("Real-Time Prediction Loop stopped") async def _prediction_loop_for_symbol(self, symbol: str): """Run prediction loop for a specific symbol""" logger.info(f"🔄 Prediction loop started for {symbol}") while self.running: try: # 1. Check for new candle (1s or 1m) new_candle_detected, timeframe = await self._detect_new_candle(symbol) if new_candle_detected: logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions") await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle") # 2. Check for pivot point pivot_detected = await self._detect_pivot_point(symbol) if pivot_detected: logger.info(f"📍 Pivot point detected for {symbol} - running predictions") await self._run_all_model_predictions(symbol, trigger="pivot_point") # 3. Periodic prediction (every N seconds based on timeframe) if self._should_run_periodic_prediction(symbol): logger.debug(f"⏰ Periodic prediction for {symbol}") await self._run_all_model_predictions(symbol, trigger="periodic") # Sleep briefly to avoid CPU overuse await asyncio.sleep(0.1) # Check every 100ms except Exception as e: logger.error(f"Error in prediction loop for {symbol}: {e}") await asyncio.sleep(1) async def _detect_new_candle(self, symbol: str) -> tuple: """Detect if a new candle has closed""" try: # Get latest candles candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=2) candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=2) # Check 1s candle if candles_1s and len(candles_1s) >= 2: latest_1s_time = candles_1s[-1].get('timestamp') or candles_1s[-1].get('time') if symbol not in self.last_candle_close_time: self.last_candle_close_time[symbol] = {} last_1s = self.last_candle_close_time[symbol].get('1s') if latest_1s_time and latest_1s_time != last_1s: self.last_candle_close_time[symbol]['1s'] = latest_1s_time return True, '1s' # Check 1m candle if candles_1m and len(candles_1m) >= 2: latest_1m_time = candles_1m[-1].get('timestamp') or candles_1m[-1].get('time') last_1m = self.last_candle_close_time[symbol].get('1m') if latest_1m_time and latest_1m_time != last_1m: self.last_candle_close_time[symbol]['1m'] = latest_1m_time return True, '1m' return False, None except Exception as e: logger.error(f"Error detecting new candle for {symbol}: {e}") return False, None async def _detect_pivot_point(self, symbol: str) -> bool: """Detect if a pivot point has formed""" try: # Use Williams Market Structure or simple pivot detection recent_candles = await self.data_provider.get_latest_candles(symbol, '1m', limit=5) if not recent_candles or len(recent_candles) < 5: return False # Simple pivot: middle candle is local high or low highs = [c.get('high', 0) for c in recent_candles] lows = [c.get('low', 0) for c in recent_candles] # Pivot high: middle is highest if highs[2] == max(highs): logger.debug(f"Pivot HIGH detected for {symbol}") return True # Pivot low: middle is lowest if lows[2] == min(lows): logger.debug(f"Pivot LOW detected for {symbol}") return True return False except Exception as e: logger.error(f"Error detecting pivot for {symbol}: {e}") return False def _should_run_periodic_prediction(self, symbol: str) -> bool: """Check if enough time has passed for periodic prediction""" current_time = datetime.now() last_time = self.last_prediction_time.get(symbol) if not last_time: self.last_prediction_time[symbol] = current_time return True # Run periodic prediction every 5 seconds if (current_time - last_time).total_seconds() >= 5: self.last_prediction_time[symbol] = current_time return True return False async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"): """ CRITICAL: Run predictions from ALL models on current market data This is where model.predict() gets called with correct data format! """ try: predictions = {} # 1. CNN Model Prediction if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: try: # Get standardized CNN input data cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60) if cnn_input and cnn_input.data_quality_score > 0.5: cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn') if cnn_data is not None: # ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA! cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data) predictions['cnn'] = cnn_prediction logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})") except Exception as e: logger.error(f"CNN prediction error for {symbol}: {e}") # 2. DQN Model Prediction if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent: try: # Get standardized DQN input data dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100) if dqn_input and dqn_input.data_quality_score > 0.5: dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn') if dqn_data is not None: # DQN uses act() method action = self.orchestrator.dqn_agent.act(dqn_data, explore=False) predictions['dqn'] = { 'action': action, 'action_name': ['SELL', 'HOLD', 'BUY'][action] } logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})") except Exception as e: logger.error(f"DQN prediction error for {symbol}: {e}") # 3. COB RL Model Prediction if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model: try: # Get standardized COB RL input data cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50) if cob_input and cob_input.data_quality_score > 0.5: cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl') if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'): cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data) predictions['cob_rl'] = cob_prediction logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})") except Exception as e: logger.error(f"COB RL prediction error for {symbol}: {e}") # 4. Transformer Model Prediction (if available) if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model: try: # Get standardized Transformer input data transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150) if transformer_input and transformer_input.data_quality_score > 0.5: transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer') if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'): transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data) predictions['transformer'] = transformer_prediction logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})") except Exception as e: logger.error(f"Transformer prediction error for {symbol}: {e}") # 5. Combine predictions into trading signal if predictions: trading_signal = self._combine_predictions(symbol, predictions) # Send signal to trading executor if trading_signal and hasattr(self.orchestrator, 'process_signal'): await self.orchestrator.process_signal(trading_signal) logger.info(f"📤 Trading signal sent for {symbol}: {trading_signal['action']} " f"(confidence: {trading_signal['confidence']:.2f}, trigger: {trigger})") except Exception as e: logger.error(f"Error running model predictions for {symbol}: {e}") async def _extract_market_features(self, symbol: str) -> Optional[Dict]: """Extract features for all models from current market data""" try: # Get recent candles candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=100) candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=100) if not candles_1m: return None current_price = candles_1m[-1].get('close', 0) # Build features for each model features = { 'symbol': symbol, 'current_price': current_price, 'timestamp': datetime.now(), # CNN features (100-dim feature vector) 'cnn_features': self._build_cnn_features(candles_1m, candles_1s), # DQN state (state vector for RL) 'dqn_state': self._build_dqn_state(candles_1m), # COB features (order book data) 'cob_features': await self._build_cob_features(symbol) } return features except Exception as e: logger.error(f"Error extracting market features for {symbol}: {e}") return None def _build_cnn_features(self, candles_1m: List, candles_1s: List) -> Optional[np.ndarray]: """Build feature vector for CNN model""" try: if not candles_1m or len(candles_1m) < 10: return None # Extract OHLCV data features = [] for candle in candles_1m[-20:]: # Last 20 candles features.extend([ candle.get('open', 0), candle.get('high', 0), candle.get('low', 0), candle.get('close', 0), candle.get('volume', 0) ]) # Pad or truncate to expected size (100 features) feature_array = np.array(features) if len(feature_array) < 100: feature_array = np.pad(feature_array, (0, 100 - len(feature_array))) else: feature_array = feature_array[:100] return feature_array.astype(np.float32) except Exception as e: logger.error(f"Error building CNN features: {e}") return None def _build_dqn_state(self, candles: List) -> Optional[np.ndarray]: """Build state vector for DQN agent""" try: if not candles or len(candles) < 5: return None # Simple state: last 5 close prices normalized closes = [c.get('close', 0) for c in candles[-5:]] if max(closes) == 0: return None # Normalize state = np.array(closes) / max(closes) return state.astype(np.float32) except Exception as e: logger.error(f"Error building DQN state: {e}") return None async def _build_cob_features(self, symbol: str) -> Optional[Dict]: """Build COB (Change of Bid) features""" try: # Get order book data if available if hasattr(self.orchestrator, 'get_cob_data'): cob_data = await self.orchestrator.get_cob_data(symbol) return cob_data return None except Exception as e: logger.error(f"Error building COB features: {e}") return None def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]: """Combine predictions from multiple models into a trading signal""" try: # Voting system: each model contributes votes = {'BUY': 0, 'SELL': 0, 'HOLD': 0} confidences = [] # CNN vote if 'cnn' in predictions: cnn_action = predictions['cnn'].get('action', 'HOLD') cnn_conf = predictions['cnn'].get('confidence', 0.5) votes[cnn_action] += cnn_conf confidences.append(cnn_conf) # DQN vote if 'dqn' in predictions: dqn_action = predictions['dqn'].get('action_name', 'HOLD') votes[dqn_action] += 0.7 # Fixed confidence for DQN confidences.append(0.7) # COB RL vote if 'cob_rl' in predictions: cob_action = predictions['cob_rl'].get('action', 'HOLD') cob_conf = predictions['cob_rl'].get('confidence', 0.5) votes[cob_action] += cob_conf confidences.append(cob_conf) # Determine final action (majority vote) final_action = max(votes, key=votes.get) final_confidence = sum(confidences) / len(confidences) if confidences else 0.5 # Only signal if confidence is high enough if final_confidence < 0.6: logger.debug(f"Low confidence ({final_confidence:.2f}) - no signal") return None return { 'symbol': symbol, 'action': final_action, 'confidence': final_confidence, 'price': market_data.get('current_price'), 'models_used': list(predictions.keys()), 'predictions': predictions, 'timestamp': datetime.now() } except Exception as e: logger.error(f"Error combining predictions: {e}") return None