Files
gogo2/core/multi_horizon_prediction_manager.py
Dobromir Popov 270ba2e52b fix broken merge
2025-10-08 20:02:41 +03:00

714 lines
30 KiB
Python

#!/usr/bin/env python3
"""
Multi-Horizon Prediction Manager
This module generates predictions for multiple time horizons (1m, 5m, 15m, 60m)
every minute, focusing on predicting min/max prices in the next 60 minutes.
It stores model input snapshots for future training when outcomes are known.
"""
import logging
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class PredictionSnapshot:
"""Stores a prediction with model inputs for future training"""
prediction_id: str
symbol: str
prediction_time: datetime
target_horizon_minutes: int
target_time: datetime
current_price: float
predicted_min_price: float
predicted_max_price: float
confidence: float
model_inputs: Dict[str, Any]
market_state: Dict[str, Any]
technical_indicators: Dict[str, Any]
pivot_analysis: Dict[str, Any]
prediction_metadata: Dict[str, Any] = field(default_factory=dict)
actual_min_price: Optional[float] = None
actual_max_price: Optional[float] = None
outcome_known: bool = False
outcome_timestamp: Optional[datetime] = None
@dataclass
class HorizonPrediction:
"""Represents a prediction for a specific time horizon"""
horizon_minutes: int
predicted_min: float
predicted_max: float
confidence: float
prediction_basis: str # 'cnn', 'rl', 'technical', 'ensemble'
class MultiHorizonPredictionManager:
"""Manages multi-timeframe predictions for trading system"""
def __init__(self, orchestrator=None, data_provider=None, config: Optional[Dict[str, Any]] = None):
"""Initialize the multi-horizon prediction manager"""
self.orchestrator = orchestrator
self.data_provider = data_provider
self.config = config or {}
# Prediction horizons in minutes
self.horizons = [1, 5, 15, 60]
# Prediction frequency (every minute)
self.prediction_interval_seconds = 60
# Storage for prediction snapshots
self.max_snapshots_per_horizon = 1000
self.prediction_snapshots: Dict[int, deque] = {} # {horizon: deque of PredictionSnapshot}
# Initialize snapshot storage for each horizon
for horizon in self.horizons:
self.prediction_snapshots[horizon] = deque(maxlen=self.max_snapshots_per_horizon)
# Threading
self.prediction_thread = None
self.is_running = False
self.last_prediction_time = 0.0
# Performance tracking
self.prediction_stats = {
'total_predictions': 0,
'predictions_by_horizon': {h: 0 for h in self.horizons},
'validated_predictions': 0,
'accurate_predictions': 0,
'avg_confidence': 0.0,
'last_prediction_time': None
}
# Minimum confidence threshold for storing predictions
self.min_confidence_threshold = 0.3
logger.info("MultiHorizonPredictionManager initialized")
logger.info(f"Prediction horizons: {self.horizons} minutes")
logger.info(f"Prediction interval: {self.prediction_interval_seconds} seconds")
def start(self):
"""Start the prediction manager"""
if self.is_running:
logger.warning("Prediction manager already running")
return
self.is_running = True
self.prediction_thread = threading.Thread(
target=self._prediction_loop,
daemon=True,
name="MultiHorizonPredictor"
)
self.prediction_thread.start()
logger.info("MultiHorizonPredictionManager started")
def stop(self):
"""Stop the prediction manager"""
self.is_running = False
if self.prediction_thread and self.prediction_thread.is_alive():
self.prediction_thread.join(timeout=10)
logger.info("MultiHorizonPredictionManager stopped")
def _prediction_loop(self):
"""Main prediction loop - runs every minute"""
while self.is_running:
try:
current_time = time.time()
# Check if it's time for new predictions
if current_time - self.last_prediction_time >= self.prediction_interval_seconds:
self._generate_all_horizon_predictions()
self.last_prediction_time = current_time
# Validate pending predictions
self._validate_pending_predictions()
# Sleep for 10 seconds before next check
time.sleep(10)
except Exception as e:
logger.error(f"Error in prediction loop: {e}")
time.sleep(30) # Longer sleep on error
def _generate_all_horizon_predictions(self):
"""Generate predictions for all horizons"""
try:
symbols = ['ETH/USDT', 'BTC/USDT'] # Focus on main symbols
prediction_time = datetime.now()
for symbol in symbols:
# Get current market state
market_state = self._get_current_market_state(symbol)
if not market_state:
continue
current_price = market_state['current_price']
# Generate predictions for each horizon
for horizon_minutes in self.horizons:
try:
prediction = self._generate_horizon_prediction(
symbol, horizon_minutes, prediction_time, market_state
)
if prediction and prediction.confidence >= self.min_confidence_threshold:
# Create prediction snapshot
snapshot = self._create_prediction_snapshot(
symbol, horizon_minutes, prediction_time, current_price,
prediction, market_state
)
# Store snapshot
self.prediction_snapshots[horizon_minutes].append(snapshot)
# Update stats
self.prediction_stats['total_predictions'] += 1
self.prediction_stats['predictions_by_horizon'][horizon_minutes] += 1
logger.info(f"Generated {horizon_minutes}m prediction for {symbol}: "
f"min={prediction.predicted_min:.4f}, max={prediction.predicted_max:.4f}, "
f"confidence={prediction.confidence:.2f}")
except Exception as e:
logger.error(f"Error generating {horizon_minutes}m prediction for {symbol}: {e}")
self.prediction_stats['last_prediction_time'] = prediction_time
except Exception as e:
logger.error(f"Error generating all horizon predictions: {e}")
def _generate_horizon_prediction(self, symbol: str, horizon_minutes: int,
prediction_time: datetime, market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
"""Generate prediction for a specific horizon"""
try:
current_price = market_state['current_price']
# Use ensemble approach: combine CNN, RL, and technical analysis
predictions = []
# CNN-based prediction
cnn_prediction = self._get_cnn_prediction(symbol, horizon_minutes, market_state)
if cnn_prediction:
predictions.append(cnn_prediction)
# RL-based prediction
rl_prediction = self._get_rl_prediction(symbol, horizon_minutes, market_state)
if rl_prediction:
predictions.append(rl_prediction)
# Technical analysis prediction
technical_prediction = self._get_technical_prediction(symbol, horizon_minutes, market_state)
if technical_prediction:
predictions.append(technical_prediction)
if not predictions:
# Fallback to technical analysis only
return self._get_technical_prediction(symbol, horizon_minutes, market_state, fallback=True)
# Ensemble prediction
return self._ensemble_predictions(predictions, current_price)
except Exception as e:
logger.error(f"Error generating horizon prediction: {e}")
return None
def _get_cnn_prediction(self, symbol: str, horizon_minutes: int,
market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
"""Get CNN-based prediction"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'):
return None
# Prepare CNN features based on horizon
features = self._prepare_cnn_features_for_horizon(market_state, horizon_minutes)
# Get CNN prediction
cnn_model = self.orchestrator.cnn_model
prediction_output = cnn_model.predict(features)
# Interpret CNN output for min/max prediction
predicted_min, predicted_max, confidence = self._interpret_cnn_output(
prediction_output, market_state['current_price'], horizon_minutes
)
return HorizonPrediction(
horizon_minutes=horizon_minutes,
predicted_min=predicted_min,
predicted_max=predicted_max,
confidence=confidence,
prediction_basis='cnn'
)
except Exception as e:
logger.debug(f"CNN prediction failed: {e}")
return None
def _get_rl_prediction(self, symbol: str, horizon_minutes: int,
market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
"""Get RL-based prediction"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'):
return None
# Prepare RL state
rl_state = self._prepare_rl_state_for_horizon(market_state, horizon_minutes)
# Get RL prediction
rl_agent = self.orchestrator.rl_agent
action = rl_agent.act(rl_state, explore=False)
# Convert action to min/max prediction
current_price = market_state['current_price']
predicted_min, predicted_max, confidence = self._convert_rl_action_to_price_prediction(
action, current_price, horizon_minutes, rl_agent
)
return HorizonPrediction(
horizon_minutes=horizon_minutes,
predicted_min=predicted_min,
predicted_max=predicted_max,
confidence=confidence,
prediction_basis='rl'
)
except Exception as e:
logger.debug(f"RL prediction failed: {e}")
return None
def _get_technical_prediction(self, symbol: str, horizon_minutes: int,
market_state: Dict[str, Any], fallback: bool = False) -> Optional[HorizonPrediction]:
"""Get technical analysis based prediction"""
try:
current_price = market_state['current_price']
# Use pivot points and technical indicators to predict range
pivot_analysis = market_state.get('pivot_analysis', {})
technical_indicators = market_state.get('technical_indicators', {})
# Base prediction on trend strength and pivot levels
trend_direction = pivot_analysis.get('trend_direction', 'SIDEWAYS')
trend_strength = pivot_analysis.get('trend_strength', 0.0)
# Calculate expected range based on volatility and trend
volatility = technical_indicators.get('volatility', 0.02) # Default 2%
expected_range_percent = volatility * np.sqrt(horizon_minutes / 60.0) # Scale by sqrt(time)
if trend_direction == 'UPTREND':
# Bias toward higher prices
predicted_min = current_price * (1 - expected_range_percent * 0.3)
predicted_max = current_price * (1 + expected_range_percent * 1.2)
elif trend_direction == 'DOWNTREND':
# Bias toward lower prices
predicted_min = current_price * (1 - expected_range_percent * 1.2)
predicted_max = current_price * (1 + expected_range_percent * 0.3)
else:
# Symmetric range for sideways
range_half = expected_range_percent * current_price
predicted_min = current_price - range_half
predicted_max = current_price + range_half
# Adjust confidence based on trend strength and market conditions
base_confidence = 0.4 + (trend_strength * 0.4) # 0.4 to 0.8
# Reduce confidence for longer horizons
horizon_factor = max(0.3, 1.0 - (horizon_minutes - 1) / 120.0) # Decrease with horizon
confidence = base_confidence * horizon_factor
if fallback:
confidence = max(confidence, 0.2) # Minimum confidence for fallback
return HorizonPrediction(
horizon_minutes=horizon_minutes,
predicted_min=predicted_min,
predicted_max=predicted_max,
confidence=confidence,
prediction_basis='technical'
)
except Exception as e:
logger.error(f"Technical prediction failed: {e}")
return None
def _ensemble_predictions(self, predictions: List[HorizonPrediction], current_price: float) -> HorizonPrediction:
"""Combine multiple predictions into ensemble prediction"""
try:
if not predictions:
return None
# Weight predictions by confidence
total_weight = sum(p.confidence for p in predictions)
if total_weight == 0:
total_weight = len(predictions)
# Weighted average of min/max predictions
weighted_min = sum(p.predicted_min * p.confidence for p in predictions) / total_weight
weighted_max = sum(p.predicted_max * p.confidence for p in predictions) / total_weight
# Average confidence
avg_confidence = sum(p.confidence for p in predictions) / len(predictions)
# Ensure min < max and reasonable bounds
if weighted_min >= weighted_max:
# Fallback to symmetric range
range_half = abs(current_price * 0.02) # 2% range
weighted_min = current_price - range_half
weighted_max = current_price + range_half
return HorizonPrediction(
horizon_minutes=predictions[0].horizon_minutes,
predicted_min=weighted_min,
predicted_max=weighted_max,
confidence=min(avg_confidence, 0.95), # Cap at 95%
prediction_basis='ensemble'
)
except Exception as e:
logger.error(f"Ensemble prediction failed: {e}")
return None
def _get_current_market_state(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get comprehensive market state for prediction"""
try:
if not self.data_provider:
return None
# Get current price
current_price = None
if hasattr(self.data_provider, 'current_prices'):
current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper())
if current_price is None:
logger.debug(f"No current price available for {symbol}")
return None
# Get recent OHLCV data (last 100 candles for analysis)
ohlcv_data = self.data_provider.get_historical_data(symbol, '1m', limit=100)
if ohlcv_data is None or len(ohlcv_data) < 20:
logger.debug(f"Insufficient OHLCV data for {symbol}")
return None
# Calculate technical indicators
technical_indicators = self._calculate_technical_indicators(ohlcv_data)
# Get pivot analysis
pivot_analysis = self._get_pivot_analysis(symbol, ohlcv_data)
return {
'current_price': current_price,
'ohlcv_data': ohlcv_data,
'technical_indicators': technical_indicators,
'pivot_analysis': pivot_analysis,
'timestamp': datetime.now()
}
except Exception as e:
logger.error(f"Error getting market state for {symbol}: {e}")
return None
def _calculate_technical_indicators(self, ohlcv_data: np.ndarray) -> Dict[str, Any]:
"""Calculate technical indicators from OHLCV data"""
try:
if len(ohlcv_data) < 20:
return {}
closes = ohlcv_data[:, 4].astype(float)
highs = ohlcv_data[:, 2].astype(float)
lows = ohlcv_data[:, 3].astype(float)
volumes = ohlcv_data[:, 5].astype(float)
# Basic indicators
sma_5 = np.mean(closes[-5:])
sma_20 = np.mean(closes[-20:])
# RSI
def calculate_rsi(prices, period=14):
if len(prices) < period + 1:
return 50.0
gains = []
losses = []
for i in range(1, min(len(prices), period + 1)):
change = prices[-i] - prices[-i-1]
if change > 0:
gains.append(change)
losses.append(0)
else:
gains.append(0)
losses.append(abs(change))
avg_gain = np.mean(gains) if gains else 0
avg_loss = np.mean(losses) if losses else 0
if avg_loss == 0:
return 100.0
rs = avg_gain / avg_loss
return 100 - (100 / (1 + rs))
rsi = calculate_rsi(closes)
# Volatility (standard deviation of returns)
returns = np.diff(closes) / closes[:-1]
volatility = np.std(returns) if len(returns) > 0 else 0.02
# Volume analysis
avg_volume = np.mean(volumes[-20:]) if len(volumes) >= 20 else np.mean(volumes)
volume_ratio = volumes[-1] / avg_volume if avg_volume > 0 else 1.0
return {
'sma_5': float(sma_5),
'sma_20': float(sma_20),
'rsi': float(rsi),
'volatility': float(volatility),
'volume_ratio': float(volume_ratio),
'price_change_5m': float((closes[-1] - closes[-5]) / closes[-5]) if len(closes) >= 5 else 0.0,
'price_change_15m': float((closes[-1] - closes[-15]) / closes[-15]) if len(closes) >= 15 else 0.0
}
except Exception as e:
logger.error(f"Error calculating technical indicators: {e}")
return {}
def _get_pivot_analysis(self, symbol: str, ohlcv_data: np.ndarray) -> Dict[str, Any]:
"""Get pivot point analysis"""
try:
# Use Williams Market Structure if available
if hasattr(self.orchestrator, 'williams_structure'):
pivot_levels = self.orchestrator.williams_structure.calculate_recursive_pivot_points(ohlcv_data)
if pivot_levels:
# Get the most recent level
latest_level = max(pivot_levels.keys(), key=lambda x: int(x.split('_')[1]))
level_data = pivot_levels[latest_level]
return {
'trend_direction': level_data.trend_direction,
'trend_strength': level_data.trend_strength,
'support_levels': level_data.support_levels,
'resistance_levels': level_data.resistance_levels
}
# Fallback to basic pivot analysis
if len(ohlcv_data) >= 20:
recent_highs = ohlcv_data[-20:, 2].astype(float)
recent_lows = ohlcv_data[-20:, 3].astype(float)
pivot_high = np.max(recent_highs)
pivot_low = np.min(recent_lows)
return {
'trend_direction': 'SIDEWAYS',
'trend_strength': 0.5,
'support_levels': [pivot_low],
'resistance_levels': [pivot_high]
}
return {
'trend_direction': 'SIDEWAYS',
'trend_strength': 0.0,
'support_levels': [],
'resistance_levels': []
}
except Exception as e:
logger.error(f"Error getting pivot analysis: {e}")
return {}
def _create_prediction_snapshot(self, symbol: str, horizon_minutes: int,
prediction_time: datetime, current_price: float,
prediction: HorizonPrediction, market_state: Dict[str, Any]) -> PredictionSnapshot:
"""Create a prediction snapshot for future training"""
prediction_id = f"{symbol.replace('/', '')}_{horizon_minutes}m_{int(prediction_time.timestamp())}"
target_time = prediction_time + timedelta(minutes=horizon_minutes)
return PredictionSnapshot(
prediction_id=prediction_id,
symbol=symbol,
prediction_time=prediction_time,
target_horizon_minutes=horizon_minutes,
target_time=target_time,
current_price=current_price,
predicted_min_price=prediction.predicted_min,
predicted_max_price=prediction.predicted_max,
confidence=prediction.confidence,
model_inputs=self._extract_model_inputs(market_state),
market_state=market_state,
technical_indicators=market_state.get('technical_indicators', {}),
pivot_analysis=market_state.get('pivot_analysis', {}),
prediction_metadata={
'prediction_basis': prediction.prediction_basis,
'ensemble_components': 1 if prediction.prediction_basis != 'ensemble' else 3
}
)
def _extract_model_inputs(self, market_state: Dict[str, Any]) -> Dict[str, Any]:
"""Extract model inputs for future training"""
try:
model_inputs = {}
# CNN features
if hasattr(self, '_prepare_cnn_features_for_horizon'):
model_inputs['cnn_features'] = self._prepare_cnn_features_for_horizon(
market_state, 60 # Use 60m horizon for consistency
)
# RL state
if hasattr(self, '_prepare_rl_state_for_horizon'):
model_inputs['rl_state'] = self._prepare_rl_state_for_horizon(
market_state, 60
)
# Raw market data
model_inputs['current_price'] = market_state['current_price']
model_inputs['ohlcv_sequence'] = market_state['ohlcv_data'][-50:].tolist() # Last 50 candles
return model_inputs
except Exception as e:
logger.error(f"Error extracting model inputs: {e}")
return {}
def _validate_pending_predictions(self):
"""Validate predictions that have reached their target time"""
try:
current_time = datetime.now()
symbols = ['ETH/USDT', 'BTC/USDT']
for symbol in symbols:
# Get current price for validation
current_price = None
if self.data_provider and hasattr(self.data_provider, 'current_prices'):
current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper())
if current_price is None:
continue
# Check each horizon for predictions to validate
for horizon_minutes in self.horizons:
snapshots_to_validate = []
for snapshot in list(self.prediction_snapshots[horizon_minutes]):
if (not snapshot.outcome_known and
current_time >= snapshot.target_time):
# Prediction has reached target time - validate it
snapshot.actual_min_price = current_price # Simplified: current price as proxy for min
snapshot.actual_max_price = current_price # In reality, we'd need price range over the period
snapshot.outcome_known = True
snapshot.outcome_timestamp = current_time
snapshots_to_validate.append(snapshot)
# Process validated snapshots
for snapshot in snapshots_to_validate:
self._process_validated_prediction(snapshot)
except Exception as e:
logger.error(f"Error validating pending predictions: {e}")
def _process_validated_prediction(self, snapshot: PredictionSnapshot):
"""Process a validated prediction for training"""
try:
self.prediction_stats['validated_predictions'] += 1
# Calculate prediction accuracy
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
# Simple accuracy check: was the actual price within predicted range?
actual_price_range = abs(snapshot.actual_max_price - snapshot.actual_min_price)
predicted_range = abs(snapshot.predicted_max_price - snapshot.predicted_min_price)
# Check if ranges overlap significantly
range_overlap = self._calculate_range_overlap(
(snapshot.predicted_min_price, snapshot.predicted_max_price),
(snapshot.actual_min_price, snapshot.actual_max_price)
)
if range_overlap > 0.5: # 50% overlap threshold
self.prediction_stats['accurate_predictions'] += 1
# Here we would trigger training with the snapshot data
# For now, just log the result
accuracy_rate = (self.prediction_stats['accurate_predictions'] /
max(1, self.prediction_stats['validated_predictions']))
logger.info(f"Validated {snapshot.target_horizon_minutes}m prediction for {snapshot.symbol}: "
f"confidence={snapshot.confidence:.2f}, accuracy_rate={accuracy_rate:.2f}")
except Exception as e:
logger.error(f"Error processing validated prediction: {e}")
def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
"""Calculate overlap between two price ranges (0.0 to 1.0)"""
try:
min1, max1 = range1
min2, max2 = range2
# Find overlap
overlap_min = max(min1, min2)
overlap_max = min(max1, max2)
if overlap_max <= overlap_min:
return 0.0
overlap_size = overlap_max - overlap_min
union_size = max(max1, max2) - min(min1, min2)
return overlap_size / union_size if union_size > 0 else 0.0
except Exception:
return 0.0
def get_prediction_stats(self) -> Dict[str, Any]:
"""Get prediction statistics"""
stats = self.prediction_stats.copy()
# Calculate accuracy rate
if stats['validated_predictions'] > 0:
stats['accuracy_rate'] = stats['accurate_predictions'] / stats['validated_predictions']
else:
stats['accuracy_rate'] = 0.0
# Calculate average confidence
if stats['total_predictions'] > 0:
# This is approximate since we don't store all confidences
stats['avg_confidence'] = 0.5 # Placeholder
return stats
def get_recent_predictions(self, horizon_minutes: int, limit: int = 10) -> List[PredictionSnapshot]:
"""Get recent predictions for a specific horizon"""
if horizon_minutes not in self.prediction_snapshots:
return []
return list(self.prediction_snapshots[horizon_minutes])[-limit:]
# Placeholder methods for CNN and RL feature preparation - to be implemented
def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
"""Prepare CNN features for specific horizon (placeholder - not yet implemented)"""
# This would extract relevant features based on horizon
logger.debug(f"CNN feature preparation for horizon {horizon} not yet implemented")
return np.array([]) # Return empty array instead of synthetic data
def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
"""Prepare RL state for specific horizon (placeholder - not yet implemented)"""
# This would create state representation for the horizon
logger.debug(f"RL state preparation for horizon {horizon} not yet implemented")
return np.array([]) # Return empty array instead of synthetic data
def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]:
"""Interpret CNN output for min/max prediction (placeholder - not yet implemented)"""
# This would convert CNN output to price predictions
logger.debug(f"CNN output interpretation for horizon {horizon} not yet implemented")
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions
def _convert_rl_action_to_price_prediction(self, action: int, current_price: float,
horizon: int, rl_agent) -> Tuple[float, float, float]:
"""Convert RL action to price prediction (placeholder - not yet implemented)"""
# This would interpret RL action as price movement expectation
logger.debug(f"RL action conversion for horizon {horizon} not yet implemented")
return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions