improve predictions
This commit is contained in:
@@ -18,6 +18,8 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -29,6 +31,7 @@ class RealtimePredictionLoop:
|
|||||||
def __init__(self, orchestrator, data_provider):
|
def __init__(self, orchestrator, data_provider):
|
||||||
self.orchestrator = orchestrator
|
self.orchestrator = orchestrator
|
||||||
self.data_provider = data_provider
|
self.data_provider = data_provider
|
||||||
|
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self.last_prediction_time = {}
|
self.last_prediction_time = {}
|
||||||
@@ -41,7 +44,7 @@ class RealtimePredictionLoop:
|
|||||||
# Track last candle to detect new ones
|
# Track last candle to detect new ones
|
||||||
self.last_candle_close_time = {}
|
self.last_candle_close_time = {}
|
||||||
|
|
||||||
logger.info("Real-time Prediction Loop initialized")
|
logger.info("Real-time Prediction Loop initialized with unified data interface")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Start the continuous prediction loop"""
|
"""Start the continuous prediction loop"""
|
||||||
@@ -178,59 +181,75 @@ class RealtimePredictionLoop:
|
|||||||
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
|
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
|
||||||
"""
|
"""
|
||||||
CRITICAL: Run predictions from ALL models on current market data
|
CRITICAL: Run predictions from ALL models on current market data
|
||||||
This is where model.predict() gets called!
|
This is where model.predict() gets called with correct data format!
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get current market data
|
|
||||||
market_data = await self._extract_market_features(symbol)
|
|
||||||
|
|
||||||
if not market_data:
|
|
||||||
logger.warning(f"No market data available for {symbol}")
|
|
||||||
return
|
|
||||||
|
|
||||||
predictions = {}
|
predictions = {}
|
||||||
|
|
||||||
# 1. CNN Model Prediction
|
# 1. CNN Model Prediction
|
||||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||||
try:
|
try:
|
||||||
cnn_features = market_data.get('cnn_features')
|
# Get standardized CNN input data
|
||||||
if cnn_features is not None:
|
cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
|
||||||
# ✅ THIS IS WHERE model.predict() SHOULD BE CALLED!
|
if cnn_input and cnn_input.data_quality_score > 0.5:
|
||||||
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_features)
|
cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
|
||||||
|
if cnn_data is not None:
|
||||||
|
# ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
|
||||||
|
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
|
||||||
predictions['cnn'] = cnn_prediction
|
predictions['cnn'] = cnn_prediction
|
||||||
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction}")
|
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"CNN prediction error for {symbol}: {e}")
|
logger.error(f"CNN prediction error for {symbol}: {e}")
|
||||||
|
|
||||||
# 2. DQN Model Prediction
|
# 2. DQN Model Prediction
|
||||||
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
|
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
|
||||||
try:
|
try:
|
||||||
dqn_state = market_data.get('dqn_state')
|
# Get standardized DQN input data
|
||||||
if dqn_state is not None:
|
dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
|
||||||
|
if dqn_input and dqn_input.data_quality_score > 0.5:
|
||||||
|
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
|
||||||
|
if dqn_data is not None:
|
||||||
# DQN uses act() method
|
# DQN uses act() method
|
||||||
action = self.orchestrator.dqn_agent.act(dqn_state, explore=False)
|
action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
|
||||||
predictions['dqn'] = {
|
predictions['dqn'] = {
|
||||||
'action': action,
|
'action': action,
|
||||||
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
||||||
}
|
}
|
||||||
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']}")
|
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"DQN prediction error for {symbol}: {e}")
|
logger.error(f"DQN prediction error for {symbol}: {e}")
|
||||||
|
|
||||||
# 3. COB RL Model Prediction
|
# 3. COB RL Model Prediction
|
||||||
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
|
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
|
||||||
try:
|
try:
|
||||||
cob_features = market_data.get('cob_features')
|
# Get standardized COB RL input data
|
||||||
if cob_features is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
|
||||||
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_features)
|
if cob_input and cob_input.data_quality_score > 0.5:
|
||||||
|
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
|
||||||
|
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
||||||
|
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
|
||||||
predictions['cob_rl'] = cob_prediction
|
predictions['cob_rl'] = cob_prediction
|
||||||
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction}")
|
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"COB RL prediction error for {symbol}: {e}")
|
logger.error(f"COB RL prediction error for {symbol}: {e}")
|
||||||
|
|
||||||
# 4. Combine predictions into trading signal
|
# 4. Transformer Model Prediction (if available)
|
||||||
|
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||||
|
try:
|
||||||
|
# Get standardized Transformer input data
|
||||||
|
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
|
||||||
|
if transformer_input and transformer_input.data_quality_score > 0.5:
|
||||||
|
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
|
||||||
|
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
|
||||||
|
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
|
||||||
|
predictions['transformer'] = transformer_prediction
|
||||||
|
logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Transformer prediction error for {symbol}: {e}")
|
||||||
|
|
||||||
|
# 5. Combine predictions into trading signal
|
||||||
if predictions:
|
if predictions:
|
||||||
trading_signal = self._combine_predictions(symbol, predictions, market_data)
|
trading_signal = self._combine_predictions(symbol, predictions)
|
||||||
|
|
||||||
# Send signal to trading executor
|
# Send signal to trading executor
|
||||||
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
|
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
|
||||||
@@ -341,7 +360,7 @@ class RealtimePredictionLoop:
|
|||||||
logger.error(f"Error building COB features: {e}")
|
logger.error(f"Error building COB features: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _combine_predictions(self, symbol: str, predictions: Dict, market_data: Dict) -> Optional[Dict]:
|
def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
|
||||||
"""Combine predictions from multiple models into a trading signal"""
|
"""Combine predictions from multiple models into a trading signal"""
|
||||||
try:
|
try:
|
||||||
# Voting system: each model contributes
|
# Voting system: each model contributes
|
||||||
|
|||||||
Reference in New Issue
Block a user