improve predictions

This commit is contained in:
Dobromir Popov
2025-10-01 00:30:37 +03:00
parent 0a28cee58d
commit a03b9c5701

View File

@@ -18,6 +18,8 @@ from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
import numpy as np import numpy as np
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,6 +31,7 @@ class RealtimePredictionLoop:
def __init__(self, orchestrator, data_provider): def __init__(self, orchestrator, data_provider):
self.orchestrator = orchestrator self.orchestrator = orchestrator
self.data_provider = data_provider self.data_provider = data_provider
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
self.running = False self.running = False
self.last_prediction_time = {} self.last_prediction_time = {}
@@ -41,7 +44,7 @@ class RealtimePredictionLoop:
# Track last candle to detect new ones # Track last candle to detect new ones
self.last_candle_close_time = {} self.last_candle_close_time = {}
logger.info("Real-time Prediction Loop initialized") logger.info("Real-time Prediction Loop initialized with unified data interface")
async def start(self): async def start(self):
"""Start the continuous prediction loop""" """Start the continuous prediction loop"""
@@ -178,59 +181,75 @@ class RealtimePredictionLoop:
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"): async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
""" """
CRITICAL: Run predictions from ALL models on current market data CRITICAL: Run predictions from ALL models on current market data
This is where model.predict() gets called! This is where model.predict() gets called with correct data format!
""" """
try: try:
# Get current market data
market_data = await self._extract_market_features(symbol)
if not market_data:
logger.warning(f"No market data available for {symbol}")
return
predictions = {} predictions = {}
# 1. CNN Model Prediction # 1. CNN Model Prediction
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try: try:
cnn_features = market_data.get('cnn_features') # Get standardized CNN input data
if cnn_features is not None: cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
# ✅ THIS IS WHERE model.predict() SHOULD BE CALLED! if cnn_input and cnn_input.data_quality_score > 0.5:
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features) cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
if cnn_data is not None:
# ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
predictions['cnn'] = cnn_prediction predictions['cnn'] = cnn_prediction
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}") logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
except Exception as e: except Exception as e:
logger.error(f"CNN prediction error for {symbol}: {e}") logger.error(f"CNN prediction error for {symbol}: {e}")
# 2. DQN Model Prediction # 2. DQN Model Prediction
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent: if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
try: try:
dqn_state = market_data.get('dqn_state') # Get standardized DQN input data
if dqn_state is not None: dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
if dqn_input and dqn_input.data_quality_score > 0.5:
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
if dqn_data is not None:
# DQN uses act() method # DQN uses act() method
action = self.orchestrator.dqn_agent.act(dqn_state, explore=False) action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
predictions['dqn'] = { predictions['dqn'] = {
'action': action, 'action': action,
'action_name': ['SELL', 'HOLD', 'BUY'][action] 'action_name': ['SELL', 'HOLD', 'BUY'][action]
} }
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}") logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
except Exception as e: except Exception as e:
logger.error(f"DQN prediction error for {symbol}: {e}") logger.error(f"DQN prediction error for {symbol}: {e}")
# 3. COB RL Model Prediction # 3. COB RL Model Prediction
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model: if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
try: try:
cob_features = market_data.get('cob_features') # Get standardized COB RL input data
if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'): cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features) if cob_input and cob_input.data_quality_score > 0.5:
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
predictions['cob_rl'] = cob_prediction predictions['cob_rl'] = cob_prediction
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}") logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
except Exception as e: except Exception as e:
logger.error(f"COB RL prediction error for {symbol}: {e}") logger.error(f"COB RL prediction error for {symbol}: {e}")
# 4. Combine predictions into trading signal # 4. Transformer Model Prediction (if available)
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
try:
# Get standardized Transformer input data
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
if transformer_input and transformer_input.data_quality_score > 0.5:
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
predictions['transformer'] = transformer_prediction
logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"Transformer prediction error for {symbol}: {e}")
# 5. Combine predictions into trading signal
if predictions: if predictions:
trading_signal = self._combine_predictions(symbol, predictions, market_data) trading_signal = self._combine_predictions(symbol, predictions)
# Send signal to trading executor # Send signal to trading executor
if trading_signal and hasattr(self.orchestrator, 'process_signal'): if trading_signal and hasattr(self.orchestrator, 'process_signal'):
@@ -341,7 +360,7 @@ class RealtimePredictionLoop:
logger.error(f"Error building COB features: {e}") logger.error(f"Error building COB features: {e}")
return None return None
def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]: def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
"""Combine predictions from multiple models into a trading signal""" """Combine predictions from multiple models into a trading signal"""
try: try:
# Voting system: each model contributes # Voting system: each model contributes