From a03b9c57014540b02e2e0719eb415bedeb4e8d94 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 1 Oct 2025 00:30:37 +0300 Subject: [PATCH] improve predictions --- core/realtime_prediction_loop.py | 83 ++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/core/realtime_prediction_loop.py b/core/realtime_prediction_loop.py index dfae22e..29411fe 100644 --- a/core/realtime_prediction_loop.py +++ b/core/realtime_prediction_loop.py @@ -18,6 +18,8 @@ from datetime import datetime, timedelta from typing import Dict, List, Optional, Any import numpy as np +from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData + logger = logging.getLogger(__name__) @@ -29,6 +31,7 @@ class RealtimePredictionLoop: def __init__(self, orchestrator, data_provider): self.orchestrator = orchestrator self.data_provider = data_provider + self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config) self.running = False self.last_prediction_time = {} @@ -41,7 +44,7 @@ class RealtimePredictionLoop: # Track last candle to detect new ones self.last_candle_close_time = {} - logger.info("Real-time Prediction Loop initialized") + logger.info("Real-time Prediction Loop initialized with unified data interface") async def start(self): """Start the continuous prediction loop""" @@ -178,59 +181,75 @@ class RealtimePredictionLoop: async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"): """ CRITICAL: Run predictions from ALL models on current market data - This is where model.predict() gets called! + This is where model.predict() gets called with correct data format! """ try: - # Get current market data - market_data = await self._extract_market_features(symbol) - - if not market_data: - logger.warning(f"No market data available for {symbol}") - return - predictions = {} # 1. CNN Model Prediction if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: try: - cnn_features = market_data.get('cnn_features') - if cnn_features is not None: - # ✅ THIS IS WHERE model.predict() SHOULD BE CALLED! - cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features) - predictions['cnn'] = cnn_prediction - logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}") + # Get standardized CNN input data + cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60) + if cnn_input and cnn_input.data_quality_score > 0.5: + cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn') + if cnn_data is not None: + # ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA! + cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data) + predictions['cnn'] = cnn_prediction + logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})") except Exception as e: logger.error(f"CNN prediction error for {symbol}: {e}") # 2. DQN Model Prediction if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent: try: - dqn_state = market_data.get('dqn_state') - if dqn_state is not None: - # DQN uses act() method - action = self.orchestrator.dqn_agent.act(dqn_state, explore=False) - predictions['dqn'] = { - 'action': action, - 'action_name': ['SELL', 'HOLD', 'BUY'][action] - } - logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}") + # Get standardized DQN input data + dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100) + if dqn_input and dqn_input.data_quality_score > 0.5: + dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn') + if dqn_data is not None: + # DQN uses act() method + action = self.orchestrator.dqn_agent.act(dqn_data, explore=False) + predictions['dqn'] = { + 'action': action, + 'action_name': ['SELL', 'HOLD', 'BUY'][action] + } + logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})") except Exception as e: logger.error(f"DQN prediction error for {symbol}: {e}") # 3. COB RL Model Prediction if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model: try: - cob_features = market_data.get('cob_features') - if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'): - cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features) - predictions['cob_rl'] = cob_prediction - logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}") + # Get standardized COB RL input data + cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50) + if cob_input and cob_input.data_quality_score > 0.5: + cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl') + if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'): + cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data) + predictions['cob_rl'] = cob_prediction + logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})") except Exception as e: logger.error(f"COB RL prediction error for {symbol}: {e}") - # 4. Combine predictions into trading signal + # 4. Transformer Model Prediction (if available) + if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model: + try: + # Get standardized Transformer input data + transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150) + if transformer_input and transformer_input.data_quality_score > 0.5: + transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer') + if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'): + transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data) + predictions['transformer'] = transformer_prediction + logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})") + except Exception as e: + logger.error(f"Transformer prediction error for {symbol}: {e}") + + # 5. Combine predictions into trading signal if predictions: - trading_signal = self._combine_predictions(symbol, predictions, market_data) + trading_signal = self._combine_predictions(symbol, predictions) # Send signal to trading executor if trading_signal and hasattr(self.orchestrator, 'process_signal'): @@ -341,7 +360,7 @@ class RealtimePredictionLoop: logger.error(f"Error building COB features: {e}") return None - def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]: + def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]: """Combine predictions from multiple models into a trading signal""" try: # Voting system: each model contributes