improve predictions
This commit is contained in:
@@ -18,6 +18,8 @@ from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
import numpy as np
|
||||
|
||||
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,6 +31,7 @@ class RealtimePredictionLoop:
|
||||
def __init__(self, orchestrator, data_provider):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
|
||||
|
||||
self.running = False
|
||||
self.last_prediction_time = {}
|
||||
@@ -41,7 +44,7 @@ class RealtimePredictionLoop:
|
||||
# Track last candle to detect new ones
|
||||
self.last_candle_close_time = {}
|
||||
|
||||
logger.info("Real-time Prediction Loop initialized")
|
||||
logger.info("Real-time Prediction Loop initialized with unified data interface")
|
||||
|
||||
async def start(self):
|
||||
"""Start the continuous prediction loop"""
|
||||
@@ -178,59 +181,75 @@ class RealtimePredictionLoop:
|
||||
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
|
||||
"""
|
||||
CRITICAL: Run predictions from ALL models on current market data
|
||||
This is where model.predict() gets called!
|
||||
This is where model.predict() gets called with correct data format!
|
||||
"""
|
||||
try:
|
||||
# Get current market data
|
||||
market_data = await self._extract_market_features(symbol)
|
||||
|
||||
if not market_data:
|
||||
logger.warning(f"No market data available for {symbol}")
|
||||
return
|
||||
|
||||
predictions = {}
|
||||
|
||||
# 1. CNN Model Prediction
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
try:
|
||||
cnn_features = market_data.get('cnn_features')
|
||||
if cnn_features is not None:
|
||||
# ✅ THIS IS WHERE model.predict() SHOULD BE CALLED!
|
||||
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features)
|
||||
predictions['cnn'] = cnn_prediction
|
||||
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}")
|
||||
# Get standardized CNN input data
|
||||
cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
|
||||
if cnn_input and cnn_input.data_quality_score > 0.5:
|
||||
cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
|
||||
if cnn_data is not None:
|
||||
# ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
|
||||
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
|
||||
predictions['cnn'] = cnn_prediction
|
||||
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
|
||||
except Exception as e:
|
||||
logger.error(f"CNN prediction error for {symbol}: {e}")
|
||||
|
||||
# 2. DQN Model Prediction
|
||||
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
|
||||
try:
|
||||
dqn_state = market_data.get('dqn_state')
|
||||
if dqn_state is not None:
|
||||
# DQN uses act() method
|
||||
action = self.orchestrator.dqn_agent.act(dqn_state, explore=False)
|
||||
predictions['dqn'] = {
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
||||
}
|
||||
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}")
|
||||
# Get standardized DQN input data
|
||||
dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
|
||||
if dqn_input and dqn_input.data_quality_score > 0.5:
|
||||
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
|
||||
if dqn_data is not None:
|
||||
# DQN uses act() method
|
||||
action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
|
||||
predictions['dqn'] = {
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
||||
}
|
||||
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
|
||||
except Exception as e:
|
||||
logger.error(f"DQN prediction error for {symbol}: {e}")
|
||||
|
||||
# 3. COB RL Model Prediction
|
||||
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
|
||||
try:
|
||||
cob_features = market_data.get('cob_features')
|
||||
if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
||||
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features)
|
||||
predictions['cob_rl'] = cob_prediction
|
||||
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}")
|
||||
# Get standardized COB RL input data
|
||||
cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
|
||||
if cob_input and cob_input.data_quality_score > 0.5:
|
||||
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
|
||||
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
||||
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
|
||||
predictions['cob_rl'] = cob_prediction
|
||||
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
|
||||
except Exception as e:
|
||||
logger.error(f"COB RL prediction error for {symbol}: {e}")
|
||||
|
||||
# 4. Combine predictions into trading signal
|
||||
# 4. Transformer Model Prediction (if available)
|
||||
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||
try:
|
||||
# Get standardized Transformer input data
|
||||
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
|
||||
if transformer_input and transformer_input.data_quality_score > 0.5:
|
||||
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
|
||||
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
|
||||
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
|
||||
predictions['transformer'] = transformer_prediction
|
||||
logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Transformer prediction error for {symbol}: {e}")
|
||||
|
||||
# 5. Combine predictions into trading signal
|
||||
if predictions:
|
||||
trading_signal = self._combine_predictions(symbol, predictions, market_data)
|
||||
trading_signal = self._combine_predictions(symbol, predictions)
|
||||
|
||||
# Send signal to trading executor
|
||||
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
|
||||
@@ -341,7 +360,7 @@ class RealtimePredictionLoop:
|
||||
logger.error(f"Error building COB features: {e}")
|
||||
return None
|
||||
|
||||
def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]:
|
||||
def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
|
||||
"""Combine predictions from multiple models into a trading signal"""
|
||||
try:
|
||||
# Voting system: each model contributes
|
||||
|
||||
Reference in New Issue
Block a user