412 lines
18 KiB
Python
412 lines
18 KiB
Python
"""
|
|
Real-Time Prediction Loop
|
|
|
|
CRITICAL: This is the MISSING PIECE - continuous model inference on incoming market data
|
|
|
|
This module monitors market data and triggers model predictions on:
|
|
- New 1s candles
|
|
- New 1m candles
|
|
- Pivot points detected
|
|
- Significant price movements
|
|
|
|
NO SYNTHETIC DATA - Only real market data triggers predictions
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any
|
|
import numpy as np
|
|
|
|
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RealtimePredictionLoop:
|
|
"""
|
|
Continuously monitors market data and triggers model predictions
|
|
"""
|
|
|
|
def __init__(self, orchestrator, data_provider):
|
|
self.orchestrator = orchestrator
|
|
self.data_provider = data_provider
|
|
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
|
|
|
|
self.running = False
|
|
self.last_prediction_time = {}
|
|
self.prediction_interval_seconds = {
|
|
'1s': 1, # Predict every second
|
|
'1m': 60, # Predict every minute
|
|
'pivot': 0 # Predict immediately on pivot
|
|
}
|
|
|
|
# Track last candle to detect new ones
|
|
self.last_candle_close_time = {}
|
|
|
|
logger.info("Real-time Prediction Loop initialized with unified data interface")
|
|
|
|
async def start(self):
|
|
"""Start the continuous prediction loop"""
|
|
self.running = True
|
|
logger.info(" Starting Real-Time Prediction Loop")
|
|
|
|
# Start prediction tasks for each symbol
|
|
symbols = self.orchestrator.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
|
|
|
tasks = []
|
|
for symbol in symbols:
|
|
tasks.append(asyncio.create_task(self._prediction_loop_for_symbol(symbol)))
|
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
def stop(self):
|
|
"""Stop the prediction loop"""
|
|
self.running = False
|
|
logger.info("Real-Time Prediction Loop stopped")
|
|
|
|
async def _prediction_loop_for_symbol(self, symbol: str):
|
|
"""Run prediction loop for a specific symbol"""
|
|
logger.info(f" Prediction loop started for {symbol}")
|
|
|
|
while self.running:
|
|
try:
|
|
# 1. Check for new candle (1s or 1m)
|
|
new_candle_detected, timeframe = await self._detect_new_candle(symbol)
|
|
|
|
if new_candle_detected:
|
|
logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions")
|
|
await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle")
|
|
|
|
# 2. Check for pivot point
|
|
pivot_detected = await self._detect_pivot_point(symbol)
|
|
|
|
if pivot_detected:
|
|
logger.info(f"📍 Pivot point detected for {symbol} - running predictions")
|
|
await self._run_all_model_predictions(symbol, trigger="pivot_point")
|
|
|
|
# 3. Periodic prediction (every N seconds based on timeframe)
|
|
if self._should_run_periodic_prediction(symbol):
|
|
logger.debug(f"⏰ Periodic prediction for {symbol}")
|
|
await self._run_all_model_predictions(symbol, trigger="periodic")
|
|
|
|
# Sleep briefly to avoid CPU overuse
|
|
await asyncio.sleep(0.1) # Check every 100ms
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction loop for {symbol}: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
async def _detect_new_candle(self, symbol: str) -> tuple:
|
|
"""Detect if a new candle has closed"""
|
|
try:
|
|
# Get latest candles
|
|
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=2)
|
|
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=2)
|
|
|
|
# Check 1s candle
|
|
if candles_1s and len(candles_1s) >= 2:
|
|
latest_1s_time = candles_1s[-1].get('timestamp') or candles_1s[-1].get('time')
|
|
|
|
if symbol not in self.last_candle_close_time:
|
|
self.last_candle_close_time[symbol] = {}
|
|
|
|
last_1s = self.last_candle_close_time[symbol].get('1s')
|
|
|
|
if latest_1s_time and latest_1s_time != last_1s:
|
|
self.last_candle_close_time[symbol]['1s'] = latest_1s_time
|
|
return True, '1s'
|
|
|
|
# Check 1m candle
|
|
if candles_1m and len(candles_1m) >= 2:
|
|
latest_1m_time = candles_1m[-1].get('timestamp') or candles_1m[-1].get('time')
|
|
|
|
last_1m = self.last_candle_close_time[symbol].get('1m')
|
|
|
|
if latest_1m_time and latest_1m_time != last_1m:
|
|
self.last_candle_close_time[symbol]['1m'] = latest_1m_time
|
|
return True, '1m'
|
|
|
|
return False, None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting new candle for {symbol}: {e}")
|
|
return False, None
|
|
|
|
async def _detect_pivot_point(self, symbol: str) -> bool:
|
|
"""Detect if a pivot point has formed"""
|
|
try:
|
|
# Use Williams Market Structure or simple pivot detection
|
|
recent_candles = await self.data_provider.get_latest_candles(symbol, '1m', limit=5)
|
|
|
|
if not recent_candles or len(recent_candles) < 5:
|
|
return False
|
|
|
|
# Simple pivot: middle candle is local high or low
|
|
highs = [c.get('high', 0) for c in recent_candles]
|
|
lows = [c.get('low', 0) for c in recent_candles]
|
|
|
|
# Pivot high: middle is highest
|
|
if highs[2] == max(highs):
|
|
logger.debug(f"Pivot HIGH detected for {symbol}")
|
|
return True
|
|
|
|
# Pivot low: middle is lowest
|
|
if lows[2] == min(lows):
|
|
logger.debug(f"Pivot LOW detected for {symbol}")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting pivot for {symbol}: {e}")
|
|
return False
|
|
|
|
def _should_run_periodic_prediction(self, symbol: str) -> bool:
|
|
"""Check if enough time has passed for periodic prediction"""
|
|
current_time = datetime.now()
|
|
|
|
last_time = self.last_prediction_time.get(symbol)
|
|
if not last_time:
|
|
self.last_prediction_time[symbol] = current_time
|
|
return True
|
|
|
|
# Run periodic prediction every 5 seconds
|
|
if (current_time - last_time).total_seconds() >= 5:
|
|
self.last_prediction_time[symbol] = current_time
|
|
return True
|
|
|
|
return False
|
|
|
|
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
|
|
"""
|
|
CRITICAL: Run predictions from ALL models on current market data
|
|
This is where model.predict() gets called with correct data format!
|
|
"""
|
|
try:
|
|
predictions = {}
|
|
|
|
# 1. CNN Model Prediction
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
try:
|
|
# Get standardized CNN input data
|
|
cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
|
|
if cnn_input and cnn_input.data_quality_score > 0.5:
|
|
cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
|
|
if cnn_data is not None:
|
|
# THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
|
|
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
|
|
predictions['cnn'] = cnn_prediction
|
|
logger.info(f" CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
|
|
except Exception as e:
|
|
logger.error(f"CNN prediction error for {symbol}: {e}")
|
|
|
|
# 2. DQN Model Prediction
|
|
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
|
|
try:
|
|
# Get standardized DQN input data
|
|
dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
|
|
if dqn_input and dqn_input.data_quality_score > 0.5:
|
|
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
|
|
if dqn_data is not None:
|
|
# DQN uses act() method
|
|
action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
|
|
predictions['dqn'] = {
|
|
'action': action,
|
|
'action_name': ['SELL', 'HOLD', 'BUY'][action]
|
|
}
|
|
logger.info(f" DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
|
|
except Exception as e:
|
|
logger.error(f"DQN prediction error for {symbol}: {e}")
|
|
|
|
# 3. COB RL Model Prediction
|
|
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
|
|
try:
|
|
# Get standardized COB RL input data
|
|
cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
|
|
if cob_input and cob_input.data_quality_score > 0.5:
|
|
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
|
|
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
|
|
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
|
|
predictions['cob_rl'] = cob_prediction
|
|
logger.info(f" COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
|
|
except Exception as e:
|
|
logger.error(f"COB RL prediction error for {symbol}: {e}")
|
|
|
|
# 4. Transformer Model Prediction (if available)
|
|
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
|
try:
|
|
# Get standardized Transformer input data
|
|
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
|
|
if transformer_input and transformer_input.data_quality_score > 0.5:
|
|
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
|
|
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
|
|
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
|
|
predictions['transformer'] = transformer_prediction
|
|
logger.info(f" Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
|
|
except Exception as e:
|
|
logger.error(f"Transformer prediction error for {symbol}: {e}")
|
|
|
|
# 5. Combine predictions into trading signal
|
|
if predictions:
|
|
trading_signal = self._combine_predictions(symbol, predictions)
|
|
|
|
# Send signal to trading executor
|
|
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
|
|
await self.orchestrator.process_signal(trading_signal)
|
|
logger.info(f"📤 Trading signal sent for {symbol}: {trading_signal['action']} "
|
|
f"(confidence: {trading_signal['confidence']:.2f}, trigger: {trigger})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running model predictions for {symbol}: {e}")
|
|
|
|
async def _extract_market_features(self, symbol: str) -> Optional[Dict]:
|
|
"""Extract features for all models from current market data"""
|
|
try:
|
|
# Get recent candles
|
|
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=100)
|
|
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=100)
|
|
|
|
if not candles_1m:
|
|
return None
|
|
|
|
current_price = candles_1m[-1].get('close', 0)
|
|
|
|
# Build features for each model
|
|
features = {
|
|
'symbol': symbol,
|
|
'current_price': current_price,
|
|
'timestamp': datetime.now(),
|
|
|
|
# CNN features (100-dim feature vector)
|
|
'cnn_features': self._build_cnn_features(candles_1m, candles_1s),
|
|
|
|
# DQN state (state vector for RL)
|
|
'dqn_state': self._build_dqn_state(candles_1m),
|
|
|
|
# COB features (order book data)
|
|
'cob_features': await self._build_cob_features(symbol)
|
|
}
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting market features for {symbol}: {e}")
|
|
return None
|
|
|
|
def _build_cnn_features(self, candles_1m: List, candles_1s: List) -> Optional[np.ndarray]:
|
|
"""Build feature vector for CNN model"""
|
|
try:
|
|
if not candles_1m or len(candles_1m) < 10:
|
|
return None
|
|
|
|
# Extract OHLCV data
|
|
features = []
|
|
|
|
for candle in candles_1m[-20:]: # Last 20 candles
|
|
features.extend([
|
|
candle.get('open', 0),
|
|
candle.get('high', 0),
|
|
candle.get('low', 0),
|
|
candle.get('close', 0),
|
|
candle.get('volume', 0)
|
|
])
|
|
|
|
# Pad or truncate to expected size (100 features)
|
|
feature_array = np.array(features)
|
|
if len(feature_array) < 100:
|
|
feature_array = np.pad(feature_array, (0, 100 - len(feature_array)))
|
|
else:
|
|
feature_array = feature_array[:100]
|
|
|
|
return feature_array.astype(np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building CNN features: {e}")
|
|
return None
|
|
|
|
def _build_dqn_state(self, candles: List) -> Optional[np.ndarray]:
|
|
"""Build state vector for DQN agent"""
|
|
try:
|
|
if not candles or len(candles) < 5:
|
|
return None
|
|
|
|
# Simple state: last 5 close prices normalized
|
|
closes = [c.get('close', 0) for c in candles[-5:]]
|
|
|
|
if max(closes) == 0:
|
|
return None
|
|
|
|
# Normalize
|
|
state = np.array(closes) / max(closes)
|
|
|
|
return state.astype(np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building DQN state: {e}")
|
|
return None
|
|
|
|
async def _build_cob_features(self, symbol: str) -> Optional[Dict]:
|
|
"""Build COB (Change of Bid) features"""
|
|
try:
|
|
# Get order book data if available
|
|
if hasattr(self.orchestrator, 'get_cob_data'):
|
|
cob_data = await self.orchestrator.get_cob_data(symbol)
|
|
return cob_data
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building COB features: {e}")
|
|
return None
|
|
|
|
def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
|
|
"""Combine predictions from multiple models into a trading signal"""
|
|
try:
|
|
# Voting system: each model contributes
|
|
votes = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
|
confidences = []
|
|
|
|
# CNN vote
|
|
if 'cnn' in predictions:
|
|
cnn_action = predictions['cnn'].get('action', 'HOLD')
|
|
cnn_conf = predictions['cnn'].get('confidence', 0.5)
|
|
votes[cnn_action] += cnn_conf
|
|
confidences.append(cnn_conf)
|
|
|
|
# DQN vote
|
|
if 'dqn' in predictions:
|
|
dqn_action = predictions['dqn'].get('action_name', 'HOLD')
|
|
votes[dqn_action] += 0.7 # Fixed confidence for DQN
|
|
confidences.append(0.7)
|
|
|
|
# COB RL vote
|
|
if 'cob_rl' in predictions:
|
|
cob_action = predictions['cob_rl'].get('action', 'HOLD')
|
|
cob_conf = predictions['cob_rl'].get('confidence', 0.5)
|
|
votes[cob_action] += cob_conf
|
|
confidences.append(cob_conf)
|
|
|
|
# Determine final action (majority vote)
|
|
final_action = max(votes, key=votes.get)
|
|
final_confidence = sum(confidences) / len(confidences) if confidences else 0.5
|
|
|
|
# Only signal if confidence is high enough
|
|
if final_confidence < 0.6:
|
|
logger.debug(f"Low confidence ({final_confidence:.2f}) - no signal")
|
|
return None
|
|
|
|
return {
|
|
'symbol': symbol,
|
|
'action': final_action,
|
|
'confidence': final_confidence,
|
|
'price': market_data.get('current_price'),
|
|
'models_used': list(predictions.keys()),
|
|
'predictions': predictions,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error combining predictions: {e}")
|
|
return None
|