diff --git a/core/realtime_prediction_loop.py b/core/realtime_prediction_loop.py new file mode 100644 index 0000000..dfae22e --- /dev/null +++ b/core/realtime_prediction_loop.py @@ -0,0 +1,392 @@ +""" +Real-Time Prediction Loop + +CRITICAL: This is the MISSING PIECE - continuous model inference on incoming market data + +This module monitors market data and triggers model predictions on: +- New 1s candles +- New 1m candles +- Pivot points detected +- Significant price movements + +NO SYNTHETIC DATA - Only real market data triggers predictions +""" + +import logging +import asyncio +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any +import numpy as np + +logger = logging.getLogger(__name__) + + +class RealtimePredictionLoop: + """ + Continuously monitors market data and triggers model predictions + """ + + def __init__(self, orchestrator, data_provider): + self.orchestrator = orchestrator + self.data_provider = data_provider + + self.running = False + self.last_prediction_time = {} + self.prediction_interval_seconds = { + '1s': 1, # Predict every second + '1m': 60, # Predict every minute + 'pivot': 0 # Predict immediately on pivot + } + + # Track last candle to detect new ones + self.last_candle_close_time = {} + + logger.info("Real-time Prediction Loop initialized") + + async def start(self): + """Start the continuous prediction loop""" + self.running = True + logger.info("🔄 Starting Real-Time Prediction Loop") + + # Start prediction tasks for each symbol + symbols = self.orchestrator.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) + + tasks = [] + for symbol in symbols: + tasks.append(asyncio.create_task(self._prediction_loop_for_symbol(symbol))) + + await asyncio.gather(*tasks) + + def stop(self): + """Stop the prediction loop""" + self.running = False + logger.info("Real-Time Prediction Loop stopped") + + async def _prediction_loop_for_symbol(self, symbol: str): + """Run prediction loop for a specific symbol""" + logger.info(f"🔄 Prediction loop started for {symbol}") + + while self.running: + try: + # 1. Check for new candle (1s or 1m) + new_candle_detected, timeframe = await self._detect_new_candle(symbol) + + if new_candle_detected: + logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions") + await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle") + + # 2. Check for pivot point + pivot_detected = await self._detect_pivot_point(symbol) + + if pivot_detected: + logger.info(f"📍 Pivot point detected for {symbol} - running predictions") + await self._run_all_model_predictions(symbol, trigger="pivot_point") + + # 3. Periodic prediction (every N seconds based on timeframe) + if self._should_run_periodic_prediction(symbol): + logger.debug(f"⏰ Periodic prediction for {symbol}") + await self._run_all_model_predictions(symbol, trigger="periodic") + + # Sleep briefly to avoid CPU overuse + await asyncio.sleep(0.1) # Check every 100ms + + except Exception as e: + logger.error(f"Error in prediction loop for {symbol}: {e}") + await asyncio.sleep(1) + + async def _detect_new_candle(self, symbol: str) -> tuple: + """Detect if a new candle has closed""" + try: + # Get latest candles + candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=2) + candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=2) + + # Check 1s candle + if candles_1s and len(candles_1s) >= 2: + latest_1s_time = candles_1s[-1].get('timestamp') or candles_1s[-1].get('time') + + if symbol not in self.last_candle_close_time: + self.last_candle_close_time[symbol] = {} + + last_1s = self.last_candle_close_time[symbol].get('1s') + + if latest_1s_time and latest_1s_time != last_1s: + self.last_candle_close_time[symbol]['1s'] = latest_1s_time + return True, '1s' + + # Check 1m candle + if candles_1m and len(candles_1m) >= 2: + latest_1m_time = candles_1m[-1].get('timestamp') or candles_1m[-1].get('time') + + last_1m = self.last_candle_close_time[symbol].get('1m') + + if latest_1m_time and latest_1m_time != last_1m: + self.last_candle_close_time[symbol]['1m'] = latest_1m_time + return True, '1m' + + return False, None + + except Exception as e: + logger.error(f"Error detecting new candle for {symbol}: {e}") + return False, None + + async def _detect_pivot_point(self, symbol: str) -> bool: + """Detect if a pivot point has formed""" + try: + # Use Williams Market Structure or simple pivot detection + recent_candles = await self.data_provider.get_latest_candles(symbol, '1m', limit=5) + + if not recent_candles or len(recent_candles) < 5: + return False + + # Simple pivot: middle candle is local high or low + highs = [c.get('high', 0) for c in recent_candles] + lows = [c.get('low', 0) for c in recent_candles] + + # Pivot high: middle is highest + if highs[2] == max(highs): + logger.debug(f"Pivot HIGH detected for {symbol}") + return True + + # Pivot low: middle is lowest + if lows[2] == min(lows): + logger.debug(f"Pivot LOW detected for {symbol}") + return True + + return False + + except Exception as e: + logger.error(f"Error detecting pivot for {symbol}: {e}") + return False + + def _should_run_periodic_prediction(self, symbol: str) -> bool: + """Check if enough time has passed for periodic prediction""" + current_time = datetime.now() + + last_time = self.last_prediction_time.get(symbol) + if not last_time: + self.last_prediction_time[symbol] = current_time + return True + + # Run periodic prediction every 5 seconds + if (current_time - last_time).total_seconds() >= 5: + self.last_prediction_time[symbol] = current_time + return True + + return False + + async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"): + """ + CRITICAL: Run predictions from ALL models on current market data + This is where model.predict() gets called! + """ + try: + # Get current market data + market_data = await self._extract_market_features(symbol) + + if not market_data: + logger.warning(f"No market data available for {symbol}") + return + + predictions = {} + + # 1. CNN Model Prediction + if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: + try: + cnn_features = market_data.get('cnn_features') + if cnn_features is not None: + # ✅ THIS IS WHERE model.predict() SHOULD BE CALLED! + cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features) + predictions['cnn'] = cnn_prediction + logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}") + except Exception as e: + logger.error(f"CNN prediction error for {symbol}: {e}") + + # 2. DQN Model Prediction + if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent: + try: + dqn_state = market_data.get('dqn_state') + if dqn_state is not None: + # DQN uses act() method + action = self.orchestrator.dqn_agent.act(dqn_state, explore=False) + predictions['dqn'] = { + 'action': action, + 'action_name': ['SELL', 'HOLD', 'BUY'][action] + } + logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}") + except Exception as e: + logger.error(f"DQN prediction error for {symbol}: {e}") + + # 3. COB RL Model Prediction + if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model: + try: + cob_features = market_data.get('cob_features') + if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'): + cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features) + predictions['cob_rl'] = cob_prediction + logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}") + except Exception as e: + logger.error(f"COB RL prediction error for {symbol}: {e}") + + # 4. Combine predictions into trading signal + if predictions: + trading_signal = self._combine_predictions(symbol, predictions, market_data) + + # Send signal to trading executor + if trading_signal and hasattr(self.orchestrator, 'process_signal'): + await self.orchestrator.process_signal(trading_signal) + logger.info(f"📤 Trading signal sent for {symbol}: {trading_signal['action']} " + f"(confidence: {trading_signal['confidence']:.2f}, trigger: {trigger})") + + except Exception as e: + logger.error(f"Error running model predictions for {symbol}: {e}") + + async def _extract_market_features(self, symbol: str) -> Optional[Dict]: + """Extract features for all models from current market data""" + try: + # Get recent candles + candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=100) + candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=100) + + if not candles_1m: + return None + + current_price = candles_1m[-1].get('close', 0) + + # Build features for each model + features = { + 'symbol': symbol, + 'current_price': current_price, + 'timestamp': datetime.now(), + + # CNN features (100-dim feature vector) + 'cnn_features': self._build_cnn_features(candles_1m, candles_1s), + + # DQN state (state vector for RL) + 'dqn_state': self._build_dqn_state(candles_1m), + + # COB features (order book data) + 'cob_features': await self._build_cob_features(symbol) + } + + return features + + except Exception as e: + logger.error(f"Error extracting market features for {symbol}: {e}") + return None + + def _build_cnn_features(self, candles_1m: List, candles_1s: List) -> Optional[np.ndarray]: + """Build feature vector for CNN model""" + try: + if not candles_1m or len(candles_1m) < 10: + return None + + # Extract OHLCV data + features = [] + + for candle in candles_1m[-20:]: # Last 20 candles + features.extend([ + candle.get('open', 0), + candle.get('high', 0), + candle.get('low', 0), + candle.get('close', 0), + candle.get('volume', 0) + ]) + + # Pad or truncate to expected size (100 features) + feature_array = np.array(features) + if len(feature_array) < 100: + feature_array = np.pad(feature_array, (0, 100 - len(feature_array))) + else: + feature_array = feature_array[:100] + + return feature_array.astype(np.float32) + + except Exception as e: + logger.error(f"Error building CNN features: {e}") + return None + + def _build_dqn_state(self, candles: List) -> Optional[np.ndarray]: + """Build state vector for DQN agent""" + try: + if not candles or len(candles) < 5: + return None + + # Simple state: last 5 close prices normalized + closes = [c.get('close', 0) for c in candles[-5:]] + + if max(closes) == 0: + return None + + # Normalize + state = np.array(closes) / max(closes) + + return state.astype(np.float32) + + except Exception as e: + logger.error(f"Error building DQN state: {e}") + return None + + async def _build_cob_features(self, symbol: str) -> Optional[Dict]: + """Build COB (Change of Bid) features""" + try: + # Get order book data if available + if hasattr(self.orchestrator, 'get_cob_data'): + cob_data = await self.orchestrator.get_cob_data(symbol) + return cob_data + + return None + + except Exception as e: + logger.error(f"Error building COB features: {e}") + return None + + def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]: + """Combine predictions from multiple models into a trading signal""" + try: + # Voting system: each model contributes + votes = {'BUY': 0, 'SELL': 0, 'HOLD': 0} + confidences = [] + + # CNN vote + if 'cnn' in predictions: + cnn_action = predictions['cnn'].get('action', 'HOLD') + cnn_conf = predictions['cnn'].get('confidence', 0.5) + votes[cnn_action] += cnn_conf + confidences.append(cnn_conf) + + # DQN vote + if 'dqn' in predictions: + dqn_action = predictions['dqn'].get('action_name', 'HOLD') + votes[dqn_action] += 0.7 # Fixed confidence for DQN + confidences.append(0.7) + + # COB RL vote + if 'cob_rl' in predictions: + cob_action = predictions['cob_rl'].get('action', 'HOLD') + cob_conf = predictions['cob_rl'].get('confidence', 0.5) + votes[cob_action] += cob_conf + confidences.append(cob_conf) + + # Determine final action (majority vote) + final_action = max(votes, key=votes.get) + final_confidence = sum(confidences) / len(confidences) if confidences else 0.5 + + # Only signal if confidence is high enough + if final_confidence < 0.6: + logger.debug(f"Low confidence ({final_confidence:.2f}) - no signal") + return None + + return { + 'symbol': symbol, + 'action': final_action, + 'confidence': final_confidence, + 'price': market_data.get('current_price'), + 'models_used': list(predictions.keys()), + 'predictions': predictions, + 'timestamp': datetime.now() + } + + except Exception as e: + logger.error(f"Error combining predictions: {e}") + return None