improve predictions

This commit is contained in:
Dobromir Popov
2025-10-01 00:30:37 +03:00
parent 0a28cee58d
commit a03b9c5701

View File

@@ -18,6 +18,8 @@ from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import numpy as np
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
logger = logging.getLogger(__name__)
@@ -29,6 +31,7 @@ class RealtimePredictionLoop:
def __init__(self, orchestrator, data_provider):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
self.running = False
self.last_prediction_time = {}
@@ -41,7 +44,7 @@ class RealtimePredictionLoop:
# Track last candle to detect new ones
self.last_candle_close_time = {}
logger.info("Real-time Prediction Loop initialized")
logger.info("Real-time Prediction Loop initialized with unified data interface")
async def start(self):
"""Start the continuous prediction loop"""
@@ -178,59 +181,75 @@ class RealtimePredictionLoop:
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
"""
CRITICAL: Run predictions from ALL models on current market data
This is where model.predict() gets called!
This is where model.predict() gets called with correct data format!
"""
try:
# Get current market data
market_data = await self._extract_market_features(symbol)
if not market_data:
logger.warning(f"No market data available for {symbol}")
return
predictions = {}
# 1. CNN Model Prediction
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try:
cnn_features = market_data.get('cnn_features')
if cnn_features is not None:
# ✅ THIS IS WHERE model.predict() SHOULD BE CALLED!
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features)
predictions['cnn'] = cnn_prediction
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}")
# Get standardized CNN input data
cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
if cnn_input and cnn_input.data_quality_score > 0.5:
cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
if cnn_data is not None:
# ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
predictions['cnn'] = cnn_prediction
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"CNN prediction error for {symbol}: {e}")
# 2. DQN Model Prediction
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
try:
dqn_state = market_data.get('dqn_state')
if dqn_state is not None:
# DQN uses act() method
action = self.orchestrator.dqn_agent.act(dqn_state, explore=False)
predictions['dqn'] = {
'action': action,
'action_name': ['SELL', 'HOLD', 'BUY'][action]
}
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}")
# Get standardized DQN input data
dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
if dqn_input and dqn_input.data_quality_score > 0.5:
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
if dqn_data is not None:
# DQN uses act() method
action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
predictions['dqn'] = {
'action': action,
'action_name': ['SELL', 'HOLD', 'BUY'][action]
}
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"DQN prediction error for {symbol}: {e}")
# 3. COB RL Model Prediction
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
try:
cob_features = market_data.get('cob_features')
if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features)
predictions['cob_rl'] = cob_prediction
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}")
# Get standardized COB RL input data
cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
if cob_input and cob_input.data_quality_score > 0.5:
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
predictions['cob_rl'] = cob_prediction
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"COB RL prediction error for {symbol}: {e}")
# 4. Combine predictions into trading signal
# 4. Transformer Model Prediction (if available)
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
try:
# Get standardized Transformer input data
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
if transformer_input and transformer_input.data_quality_score > 0.5:
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
predictions['transformer'] = transformer_prediction
logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"Transformer prediction error for {symbol}: {e}")
# 5. Combine predictions into trading signal
if predictions:
trading_signal = self._combine_predictions(symbol, predictions, market_data)
trading_signal = self._combine_predictions(symbol, predictions)
# Send signal to trading executor
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
@@ -341,7 +360,7 @@ class RealtimePredictionLoop:
logger.error(f"Error building COB features: {e}")
return None
def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]:
def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
"""Combine predictions from multiple models into a trading signal"""
try:
# Voting system: each model contributes