Files
gogo2/core/realtime_prediction_loop.py
2025-10-01 00:30:37 +03:00

412 lines
18 KiB
Python

"""
Real-Time Prediction Loop
CRITICAL: This is the MISSING PIECE - continuous model inference on incoming market data
This module monitors market data and triggers model predictions on:
- New 1s candles
- New 1m candles
- Pivot points detected
- Significant price movements
NO SYNTHETIC DATA - Only real market data triggers predictions
"""
import logging
import asyncio
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import numpy as np
from .unified_model_data_interface import UnifiedModelDataInterface, ModelInputData
logger = logging.getLogger(__name__)
class RealtimePredictionLoop:
"""
Continuously monitors market data and triggers model predictions
"""
def __init__(self, orchestrator, data_provider):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.unified_data_interface = UnifiedModelDataInterface(data_provider, orchestrator.config)
self.running = False
self.last_prediction_time = {}
self.prediction_interval_seconds = {
'1s': 1, # Predict every second
'1m': 60, # Predict every minute
'pivot': 0 # Predict immediately on pivot
}
# Track last candle to detect new ones
self.last_candle_close_time = {}
logger.info("Real-time Prediction Loop initialized with unified data interface")
async def start(self):
"""Start the continuous prediction loop"""
self.running = True
logger.info("🔄 Starting Real-Time Prediction Loop")
# Start prediction tasks for each symbol
symbols = self.orchestrator.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
tasks = []
for symbol in symbols:
tasks.append(asyncio.create_task(self._prediction_loop_for_symbol(symbol)))
await asyncio.gather(*tasks)
def stop(self):
"""Stop the prediction loop"""
self.running = False
logger.info("Real-Time Prediction Loop stopped")
async def _prediction_loop_for_symbol(self, symbol: str):
"""Run prediction loop for a specific symbol"""
logger.info(f"🔄 Prediction loop started for {symbol}")
while self.running:
try:
# 1. Check for new candle (1s or 1m)
new_candle_detected, timeframe = await self._detect_new_candle(symbol)
if new_candle_detected:
logger.info(f"📊 New {timeframe} candle detected for {symbol} - running predictions")
await self._run_all_model_predictions(symbol, trigger=f"new_{timeframe}_candle")
# 2. Check for pivot point
pivot_detected = await self._detect_pivot_point(symbol)
if pivot_detected:
logger.info(f"📍 Pivot point detected for {symbol} - running predictions")
await self._run_all_model_predictions(symbol, trigger="pivot_point")
# 3. Periodic prediction (every N seconds based on timeframe)
if self._should_run_periodic_prediction(symbol):
logger.debug(f"⏰ Periodic prediction for {symbol}")
await self._run_all_model_predictions(symbol, trigger="periodic")
# Sleep briefly to avoid CPU overuse
await asyncio.sleep(0.1) # Check every 100ms
except Exception as e:
logger.error(f"Error in prediction loop for {symbol}: {e}")
await asyncio.sleep(1)
async def _detect_new_candle(self, symbol: str) -> tuple:
"""Detect if a new candle has closed"""
try:
# Get latest candles
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=2)
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=2)
# Check 1s candle
if candles_1s and len(candles_1s) >= 2:
latest_1s_time = candles_1s[-1].get('timestamp') or candles_1s[-1].get('time')
if symbol not in self.last_candle_close_time:
self.last_candle_close_time[symbol] = {}
last_1s = self.last_candle_close_time[symbol].get('1s')
if latest_1s_time and latest_1s_time != last_1s:
self.last_candle_close_time[symbol]['1s'] = latest_1s_time
return True, '1s'
# Check 1m candle
if candles_1m and len(candles_1m) >= 2:
latest_1m_time = candles_1m[-1].get('timestamp') or candles_1m[-1].get('time')
last_1m = self.last_candle_close_time[symbol].get('1m')
if latest_1m_time and latest_1m_time != last_1m:
self.last_candle_close_time[symbol]['1m'] = latest_1m_time
return True, '1m'
return False, None
except Exception as e:
logger.error(f"Error detecting new candle for {symbol}: {e}")
return False, None
async def _detect_pivot_point(self, symbol: str) -> bool:
"""Detect if a pivot point has formed"""
try:
# Use Williams Market Structure or simple pivot detection
recent_candles = await self.data_provider.get_latest_candles(symbol, '1m', limit=5)
if not recent_candles or len(recent_candles) < 5:
return False
# Simple pivot: middle candle is local high or low
highs = [c.get('high', 0) for c in recent_candles]
lows = [c.get('low', 0) for c in recent_candles]
# Pivot high: middle is highest
if highs[2] == max(highs):
logger.debug(f"Pivot HIGH detected for {symbol}")
return True
# Pivot low: middle is lowest
if lows[2] == min(lows):
logger.debug(f"Pivot LOW detected for {symbol}")
return True
return False
except Exception as e:
logger.error(f"Error detecting pivot for {symbol}: {e}")
return False
def _should_run_periodic_prediction(self, symbol: str) -> bool:
"""Check if enough time has passed for periodic prediction"""
current_time = datetime.now()
last_time = self.last_prediction_time.get(symbol)
if not last_time:
self.last_prediction_time[symbol] = current_time
return True
# Run periodic prediction every 5 seconds
if (current_time - last_time).total_seconds() >= 5:
self.last_prediction_time[symbol] = current_time
return True
return False
async def _run_all_model_predictions(self, symbol: str, trigger: str = "unknown"):
"""
CRITICAL: Run predictions from ALL models on current market data
This is where model.predict() gets called with correct data format!
"""
try:
predictions = {}
# 1. CNN Model Prediction
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try:
# Get standardized CNN input data
cnn_input = self.unified_data_interface.prepare_model_input(symbol, 'cnn', window_size=60)
if cnn_input and cnn_input.data_quality_score > 0.5:
cnn_data = self.unified_data_interface.get_model_specific_input(cnn_input, 'cnn')
if cnn_data is not None:
# ✅ THIS IS WHERE model.predict() GETS CALLED WITH CORRECT DATA!
cnn_prediction = self.orchestrator.cnn_model.predict(cnn_data)
predictions['cnn'] = cnn_prediction
logger.info(f"✅ CNN prediction for {symbol}: {cnn_prediction} (quality: {cnn_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"CNN prediction error for {symbol}: {e}")
# 2. DQN Model Prediction
if hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent:
try:
# Get standardized DQN input data
dqn_input = self.unified_data_interface.prepare_model_input(symbol, 'dqn', window_size=100)
if dqn_input and dqn_input.data_quality_score > 0.5:
dqn_data = self.unified_data_interface.get_model_specific_input(dqn_input, 'dqn')
if dqn_data is not None:
# DQN uses act() method
action = self.orchestrator.dqn_agent.act(dqn_data, explore=False)
predictions['dqn'] = {
'action': action,
'action_name': ['SELL', 'HOLD', 'BUY'][action]
}
logger.info(f"✅ DQN prediction for {symbol}: {predictions['dqn']['action_name']} (quality: {dqn_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"DQN prediction error for {symbol}: {e}")
# 3. COB RL Model Prediction
if hasattr(self.orchestrator, 'cob_rl_model') and self.orchestrator.cob_rl_model:
try:
# Get standardized COB RL input data
cob_input = self.unified_data_interface.prepare_model_input(symbol, 'cob_rl', window_size=50)
if cob_input and cob_input.data_quality_score > 0.5:
cob_data = self.unified_data_interface.get_model_specific_input(cob_input, 'cob_rl')
if cob_data is not None and hasattr(self.orchestrator.cob_rl_model, 'predict'):
cob_prediction = self.orchestrator.cob_rl_model.predict(cob_data)
predictions['cob_rl'] = cob_prediction
logger.info(f"✅ COB RL prediction for {symbol}: {cob_prediction} (quality: {cob_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"COB RL prediction error for {symbol}: {e}")
# 4. Transformer Model Prediction (if available)
if hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
try:
# Get standardized Transformer input data
transformer_input = self.unified_data_interface.prepare_model_input(symbol, 'transformer', window_size=150)
if transformer_input and transformer_input.data_quality_score > 0.5:
transformer_data = self.unified_data_interface.get_model_specific_input(transformer_input, 'transformer')
if transformer_data is not None and hasattr(self.orchestrator.transformer_model, 'predict'):
transformer_prediction = self.orchestrator.transformer_model.predict(transformer_data)
predictions['transformer'] = transformer_prediction
logger.info(f"✅ Transformer prediction for {symbol}: {transformer_prediction} (quality: {transformer_input.data_quality_score:.2f})")
except Exception as e:
logger.error(f"Transformer prediction error for {symbol}: {e}")
# 5. Combine predictions into trading signal
if predictions:
trading_signal = self._combine_predictions(symbol, predictions)
# Send signal to trading executor
if trading_signal and hasattr(self.orchestrator, 'process_signal'):
await self.orchestrator.process_signal(trading_signal)
logger.info(f"📤 Trading signal sent for {symbol}: {trading_signal['action']} "
f"(confidence: {trading_signal['confidence']:.2f}, trigger: {trigger})")
except Exception as e:
logger.error(f"Error running model predictions for {symbol}: {e}")
async def _extract_market_features(self, symbol: str) -> Optional[Dict]:
"""Extract features for all models from current market data"""
try:
# Get recent candles
candles_1m = await self.data_provider.get_latest_candles(symbol, '1m', limit=100)
candles_1s = await self.data_provider.get_latest_candles(symbol, '1s', limit=100)
if not candles_1m:
return None
current_price = candles_1m[-1].get('close', 0)
# Build features for each model
features = {
'symbol': symbol,
'current_price': current_price,
'timestamp': datetime.now(),
# CNN features (100-dim feature vector)
'cnn_features': self._build_cnn_features(candles_1m, candles_1s),
# DQN state (state vector for RL)
'dqn_state': self._build_dqn_state(candles_1m),
# COB features (order book data)
'cob_features': await self._build_cob_features(symbol)
}
return features
except Exception as e:
logger.error(f"Error extracting market features for {symbol}: {e}")
return None
def _build_cnn_features(self, candles_1m: List, candles_1s: List) -> Optional[np.ndarray]:
"""Build feature vector for CNN model"""
try:
if not candles_1m or len(candles_1m) < 10:
return None
# Extract OHLCV data
features = []
for candle in candles_1m[-20:]: # Last 20 candles
features.extend([
candle.get('open', 0),
candle.get('high', 0),
candle.get('low', 0),
candle.get('close', 0),
candle.get('volume', 0)
])
# Pad or truncate to expected size (100 features)
feature_array = np.array(features)
if len(feature_array) < 100:
feature_array = np.pad(feature_array, (0, 100 - len(feature_array)))
else:
feature_array = feature_array[:100]
return feature_array.astype(np.float32)
except Exception as e:
logger.error(f"Error building CNN features: {e}")
return None
def _build_dqn_state(self, candles: List) -> Optional[np.ndarray]:
"""Build state vector for DQN agent"""
try:
if not candles or len(candles) < 5:
return None
# Simple state: last 5 close prices normalized
closes = [c.get('close', 0) for c in candles[-5:]]
if max(closes) == 0:
return None
# Normalize
state = np.array(closes) / max(closes)
return state.astype(np.float32)
except Exception as e:
logger.error(f"Error building DQN state: {e}")
return None
async def _build_cob_features(self, symbol: str) -> Optional[Dict]:
"""Build COB (Change of Bid) features"""
try:
# Get order book data if available
if hasattr(self.orchestrator, 'get_cob_data'):
cob_data = await self.orchestrator.get_cob_data(symbol)
return cob_data
return None
except Exception as e:
logger.error(f"Error building COB features: {e}")
return None
def _combine_predictions(self, symbol: str, predictions: Dict) -> Optional[Dict]:
"""Combine predictions from multiple models into a trading signal"""
try:
# Voting system: each model contributes
votes = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
confidences = []
# CNN vote
if 'cnn' in predictions:
cnn_action = predictions['cnn'].get('action', 'HOLD')
cnn_conf = predictions['cnn'].get('confidence', 0.5)
votes[cnn_action] += cnn_conf
confidences.append(cnn_conf)
# DQN vote
if 'dqn' in predictions:
dqn_action = predictions['dqn'].get('action_name', 'HOLD')
votes[dqn_action] += 0.7 # Fixed confidence for DQN
confidences.append(0.7)
# COB RL vote
if 'cob_rl' in predictions:
cob_action = predictions['cob_rl'].get('action', 'HOLD')
cob_conf = predictions['cob_rl'].get('confidence', 0.5)
votes[cob_action] += cob_conf
confidences.append(cob_conf)
# Determine final action (majority vote)
final_action = max(votes, key=votes.get)
final_confidence = sum(confidences) / len(confidences) if confidences else 0.5
# Only signal if confidence is high enough
if final_confidence < 0.6:
logger.debug(f"Low confidence ({final_confidence:.2f}) - no signal")
return None
return {
'symbol': symbol,
'action': final_action,
'confidence': final_confidence,
'price': market_data.get('current_price'),
'models_used': list(predictions.keys()),
'predictions': predictions,
'timestamp': datetime.now()
}
except Exception as e:
logger.error(f"Error combining predictions: {e}")
return None