#!/usr/bin/env python3 """ Multi-Horizon Prediction Manager This module generates predictions for multiple time horizons (1m, 5m, 15m, 60m) every minute, focusing on predicting min/max prices in the next 60 minutes. It stores model input snapshots for future training when outcomes are known. """ import logging import threading import time from datetime import datetime, timedelta from typing import Dict, List, Any, Optional, Tuple from dataclasses import dataclass, field import numpy as np import pandas as pd from collections import deque logger = logging.getLogger(__name__) @dataclass class PredictionSnapshot: """Stores a prediction with model inputs for future training""" prediction_id: str symbol: str prediction_time: datetime target_horizon_minutes: int target_time: datetime current_price: float predicted_min_price: float predicted_max_price: float confidence: float model_inputs: Dict[str, Any] market_state: Dict[str, Any] technical_indicators: Dict[str, Any] pivot_analysis: Dict[str, Any] prediction_metadata: Dict[str, Any] = field(default_factory=dict) actual_min_price: Optional[float] = None actual_max_price: Optional[float] = None outcome_known: bool = False outcome_timestamp: Optional[datetime] = None @dataclass class HorizonPrediction: """Represents a prediction for a specific time horizon""" horizon_minutes: int predicted_min: float predicted_max: float confidence: float prediction_basis: str # 'cnn', 'rl', 'technical', 'ensemble' class MultiHorizonPredictionManager: """Manages multi-timeframe predictions for trading system""" def __init__(self, orchestrator=None, data_provider=None, config: Optional[Dict[str, Any]] = None): """Initialize the multi-horizon prediction manager""" self.orchestrator = orchestrator self.data_provider = data_provider self.config = config or {} # Prediction horizons in minutes self.horizons = [1, 5, 15, 60] # Prediction frequency (every minute) self.prediction_interval_seconds = 60 # Storage for prediction snapshots self.max_snapshots_per_horizon = 1000 self.prediction_snapshots: Dict[int, deque] = {} # {horizon: deque of PredictionSnapshot} # Initialize snapshot storage for each horizon for horizon in self.horizons: self.prediction_snapshots[horizon] = deque(maxlen=self.max_snapshots_per_horizon) # Threading self.prediction_thread = None self.is_running = False self.last_prediction_time = 0.0 # Performance tracking self.prediction_stats = { 'total_predictions': 0, 'predictions_by_horizon': {h: 0 for h in self.horizons}, 'validated_predictions': 0, 'accurate_predictions': 0, 'avg_confidence': 0.0, 'last_prediction_time': None } # Minimum confidence threshold for storing predictions self.min_confidence_threshold = 0.3 logger.info("MultiHorizonPredictionManager initialized") logger.info(f"Prediction horizons: {self.horizons} minutes") logger.info(f"Prediction interval: {self.prediction_interval_seconds} seconds") def start(self): """Start the prediction manager""" if self.is_running: logger.warning("Prediction manager already running") return self.is_running = True self.prediction_thread = threading.Thread( target=self._prediction_loop, daemon=True, name="MultiHorizonPredictor" ) self.prediction_thread.start() logger.info("MultiHorizonPredictionManager started") def stop(self): """Stop the prediction manager""" self.is_running = False if self.prediction_thread and self.prediction_thread.is_alive(): self.prediction_thread.join(timeout=10) logger.info("MultiHorizonPredictionManager stopped") def _prediction_loop(self): """Main prediction loop - runs every minute""" while self.is_running: try: current_time = time.time() # Check if it's time for new predictions if current_time - self.last_prediction_time >= self.prediction_interval_seconds: self._generate_all_horizon_predictions() self.last_prediction_time = current_time # Validate pending predictions self._validate_pending_predictions() # Sleep for 10 seconds before next check time.sleep(10) except Exception as e: logger.error(f"Error in prediction loop: {e}") time.sleep(30) # Longer sleep on error def _generate_all_horizon_predictions(self): """Generate predictions for all horizons""" try: symbols = ['ETH/USDT', 'BTC/USDT'] # Focus on main symbols prediction_time = datetime.now() for symbol in symbols: # Get current market state market_state = self._get_current_market_state(symbol) if not market_state: continue current_price = market_state['current_price'] # Generate predictions for each horizon for horizon_minutes in self.horizons: try: prediction = self._generate_horizon_prediction( symbol, horizon_minutes, prediction_time, market_state ) if prediction and prediction.confidence >= self.min_confidence_threshold: # Create prediction snapshot snapshot = self._create_prediction_snapshot( symbol, horizon_minutes, prediction_time, current_price, prediction, market_state ) # Store snapshot self.prediction_snapshots[horizon_minutes].append(snapshot) # Update stats self.prediction_stats['total_predictions'] += 1 self.prediction_stats['predictions_by_horizon'][horizon_minutes] += 1 logger.info(f"Generated {horizon_minutes}m prediction for {symbol}: " f"min={prediction.predicted_min:.4f}, max={prediction.predicted_max:.4f}, " f"confidence={prediction.confidence:.2f}") except Exception as e: logger.error(f"Error generating {horizon_minutes}m prediction for {symbol}: {e}") self.prediction_stats['last_prediction_time'] = prediction_time except Exception as e: logger.error(f"Error generating all horizon predictions: {e}") def _generate_horizon_prediction(self, symbol: str, horizon_minutes: int, prediction_time: datetime, market_state: Dict[str, Any]) -> Optional[HorizonPrediction]: """Generate prediction for a specific horizon""" try: current_price = market_state['current_price'] # Use ensemble approach: combine CNN, RL, and technical analysis predictions = [] # CNN-based prediction cnn_prediction = self._get_cnn_prediction(symbol, horizon_minutes, market_state) if cnn_prediction: predictions.append(cnn_prediction) # RL-based prediction rl_prediction = self._get_rl_prediction(symbol, horizon_minutes, market_state) if rl_prediction: predictions.append(rl_prediction) # Technical analysis prediction technical_prediction = self._get_technical_prediction(symbol, horizon_minutes, market_state) if technical_prediction: predictions.append(technical_prediction) if not predictions: # Fallback to technical analysis only return self._get_technical_prediction(symbol, horizon_minutes, market_state, fallback=True) # Ensemble prediction return self._ensemble_predictions(predictions, current_price) except Exception as e: logger.error(f"Error generating horizon prediction: {e}") return None def _get_cnn_prediction(self, symbol: str, horizon_minutes: int, market_state: Dict[str, Any]) -> Optional[HorizonPrediction]: """Get CNN-based prediction""" try: if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'): return None # Prepare CNN features based on horizon features = self._prepare_cnn_features_for_horizon(market_state, horizon_minutes) # Get CNN prediction cnn_model = self.orchestrator.cnn_model prediction_output = cnn_model.predict(features) # Interpret CNN output for min/max prediction predicted_min, predicted_max, confidence = self._interpret_cnn_output( prediction_output, market_state['current_price'], horizon_minutes ) return HorizonPrediction( horizon_minutes=horizon_minutes, predicted_min=predicted_min, predicted_max=predicted_max, confidence=confidence, prediction_basis='cnn' ) except Exception as e: logger.debug(f"CNN prediction failed: {e}") return None def _get_rl_prediction(self, symbol: str, horizon_minutes: int, market_state: Dict[str, Any]) -> Optional[HorizonPrediction]: """Get RL-based prediction""" try: if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'): return None # Prepare RL state rl_state = self._prepare_rl_state_for_horizon(market_state, horizon_minutes) # Get RL prediction rl_agent = self.orchestrator.rl_agent action = rl_agent.act(rl_state, explore=False) # Convert action to min/max prediction current_price = market_state['current_price'] predicted_min, predicted_max, confidence = self._convert_rl_action_to_price_prediction( action, current_price, horizon_minutes, rl_agent ) return HorizonPrediction( horizon_minutes=horizon_minutes, predicted_min=predicted_min, predicted_max=predicted_max, confidence=confidence, prediction_basis='rl' ) except Exception as e: logger.debug(f"RL prediction failed: {e}") return None def _get_technical_prediction(self, symbol: str, horizon_minutes: int, market_state: Dict[str, Any], fallback: bool = False) -> Optional[HorizonPrediction]: """Get technical analysis based prediction""" try: current_price = market_state['current_price'] # Use pivot points and technical indicators to predict range pivot_analysis = market_state.get('pivot_analysis', {}) technical_indicators = market_state.get('technical_indicators', {}) # Base prediction on trend strength and pivot levels trend_direction = pivot_analysis.get('trend_direction', 'SIDEWAYS') trend_strength = pivot_analysis.get('trend_strength', 0.0) # Calculate expected range based on volatility and trend volatility = technical_indicators.get('volatility', 0.02) # Default 2% expected_range_percent = volatility * np.sqrt(horizon_minutes / 60.0) # Scale by sqrt(time) if trend_direction == 'UPTREND': # Bias toward higher prices predicted_min = current_price * (1 - expected_range_percent * 0.3) predicted_max = current_price * (1 + expected_range_percent * 1.2) elif trend_direction == 'DOWNTREND': # Bias toward lower prices predicted_min = current_price * (1 - expected_range_percent * 1.2) predicted_max = current_price * (1 + expected_range_percent * 0.3) else: # Symmetric range for sideways range_half = expected_range_percent * current_price predicted_min = current_price - range_half predicted_max = current_price + range_half # Adjust confidence based on trend strength and market conditions base_confidence = 0.4 + (trend_strength * 0.4) # 0.4 to 0.8 # Reduce confidence for longer horizons horizon_factor = max(0.3, 1.0 - (horizon_minutes - 1) / 120.0) # Decrease with horizon confidence = base_confidence * horizon_factor if fallback: confidence = max(confidence, 0.2) # Minimum confidence for fallback return HorizonPrediction( horizon_minutes=horizon_minutes, predicted_min=predicted_min, predicted_max=predicted_max, confidence=confidence, prediction_basis='technical' ) except Exception as e: logger.error(f"Technical prediction failed: {e}") return None def _ensemble_predictions(self, predictions: List[HorizonPrediction], current_price: float) -> HorizonPrediction: """Combine multiple predictions into ensemble prediction""" try: if not predictions: return None # Weight predictions by confidence total_weight = sum(p.confidence for p in predictions) if total_weight == 0: total_weight = len(predictions) # Weighted average of min/max predictions weighted_min = sum(p.predicted_min * p.confidence for p in predictions) / total_weight weighted_max = sum(p.predicted_max * p.confidence for p in predictions) / total_weight # Average confidence avg_confidence = sum(p.confidence for p in predictions) / len(predictions) # Ensure min < max and reasonable bounds if weighted_min >= weighted_max: # Fallback to symmetric range range_half = abs(current_price * 0.02) # 2% range weighted_min = current_price - range_half weighted_max = current_price + range_half return HorizonPrediction( horizon_minutes=predictions[0].horizon_minutes, predicted_min=weighted_min, predicted_max=weighted_max, confidence=min(avg_confidence, 0.95), # Cap at 95% prediction_basis='ensemble' ) except Exception as e: logger.error(f"Ensemble prediction failed: {e}") return None def _get_current_market_state(self, symbol: str) -> Optional[Dict[str, Any]]: """Get comprehensive market state for prediction""" try: if not self.data_provider: return None # Get current price current_price = None if hasattr(self.data_provider, 'current_prices'): current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper()) if current_price is None: logger.debug(f"No current price available for {symbol}") return None # Get recent OHLCV data (last 100 candles for analysis) ohlcv_data = self.data_provider.get_historical_data(symbol, '1m', limit=100) if ohlcv_data is None or len(ohlcv_data) < 20: logger.debug(f"Insufficient OHLCV data for {symbol}") return None # Calculate technical indicators technical_indicators = self._calculate_technical_indicators(ohlcv_data) # Get pivot analysis pivot_analysis = self._get_pivot_analysis(symbol, ohlcv_data) return { 'current_price': current_price, 'ohlcv_data': ohlcv_data, 'technical_indicators': technical_indicators, 'pivot_analysis': pivot_analysis, 'timestamp': datetime.now() } except Exception as e: logger.error(f"Error getting market state for {symbol}: {e}") return None def _calculate_technical_indicators(self, ohlcv_data: np.ndarray) -> Dict[str, Any]: """Calculate technical indicators from OHLCV data""" try: if len(ohlcv_data) < 20: return {} closes = ohlcv_data[:, 4].astype(float) highs = ohlcv_data[:, 2].astype(float) lows = ohlcv_data[:, 3].astype(float) volumes = ohlcv_data[:, 5].astype(float) # Basic indicators sma_5 = np.mean(closes[-5:]) sma_20 = np.mean(closes[-20:]) # RSI def calculate_rsi(prices, period=14): if len(prices) < period + 1: return 50.0 gains = [] losses = [] for i in range(1, min(len(prices), period + 1)): change = prices[-i] - prices[-i-1] if change > 0: gains.append(change) losses.append(0) else: gains.append(0) losses.append(abs(change)) avg_gain = np.mean(gains) if gains else 0 avg_loss = np.mean(losses) if losses else 0 if avg_loss == 0: return 100.0 rs = avg_gain / avg_loss return 100 - (100 / (1 + rs)) rsi = calculate_rsi(closes) # Volatility (standard deviation of returns) returns = np.diff(closes) / closes[:-1] volatility = np.std(returns) if len(returns) > 0 else 0.02 # Volume analysis avg_volume = np.mean(volumes[-20:]) if len(volumes) >= 20 else np.mean(volumes) volume_ratio = volumes[-1] / avg_volume if avg_volume > 0 else 1.0 return { 'sma_5': float(sma_5), 'sma_20': float(sma_20), 'rsi': float(rsi), 'volatility': float(volatility), 'volume_ratio': float(volume_ratio), 'price_change_5m': float((closes[-1] - closes[-5]) / closes[-5]) if len(closes) >= 5 else 0.0, 'price_change_15m': float((closes[-1] - closes[-15]) / closes[-15]) if len(closes) >= 15 else 0.0 } except Exception as e: logger.error(f"Error calculating technical indicators: {e}") return {} def _get_pivot_analysis(self, symbol: str, ohlcv_data: np.ndarray) -> Dict[str, Any]: """Get pivot point analysis""" try: # Use Williams Market Structure if available if hasattr(self.orchestrator, 'williams_structure'): pivot_levels = self.orchestrator.williams_structure.calculate_recursive_pivot_points(ohlcv_data) if pivot_levels: # Get the most recent level latest_level = max(pivot_levels.keys(), key=lambda x: int(x.split('_')[1])) level_data = pivot_levels[latest_level] return { 'trend_direction': level_data.trend_direction, 'trend_strength': level_data.trend_strength, 'support_levels': level_data.support_levels, 'resistance_levels': level_data.resistance_levels } # Fallback to basic pivot analysis if len(ohlcv_data) >= 20: recent_highs = ohlcv_data[-20:, 2].astype(float) recent_lows = ohlcv_data[-20:, 3].astype(float) pivot_high = np.max(recent_highs) pivot_low = np.min(recent_lows) return { 'trend_direction': 'SIDEWAYS', 'trend_strength': 0.5, 'support_levels': [pivot_low], 'resistance_levels': [pivot_high] } return { 'trend_direction': 'SIDEWAYS', 'trend_strength': 0.0, 'support_levels': [], 'resistance_levels': [] } except Exception as e: logger.error(f"Error getting pivot analysis: {e}") return {} def _create_prediction_snapshot(self, symbol: str, horizon_minutes: int, prediction_time: datetime, current_price: float, prediction: HorizonPrediction, market_state: Dict[str, Any]) -> PredictionSnapshot: """Create a prediction snapshot for future training""" prediction_id = f"{symbol.replace('/', '')}_{horizon_minutes}m_{int(prediction_time.timestamp())}" target_time = prediction_time + timedelta(minutes=horizon_minutes) return PredictionSnapshot( prediction_id=prediction_id, symbol=symbol, prediction_time=prediction_time, target_horizon_minutes=horizon_minutes, target_time=target_time, current_price=current_price, predicted_min_price=prediction.predicted_min, predicted_max_price=prediction.predicted_max, confidence=prediction.confidence, model_inputs=self._extract_model_inputs(market_state), market_state=market_state, technical_indicators=market_state.get('technical_indicators', {}), pivot_analysis=market_state.get('pivot_analysis', {}), prediction_metadata={ 'prediction_basis': prediction.prediction_basis, 'ensemble_components': 1 if prediction.prediction_basis != 'ensemble' else 3 } ) def _extract_model_inputs(self, market_state: Dict[str, Any]) -> Dict[str, Any]: """Extract model inputs for future training""" try: model_inputs = {} # CNN features if hasattr(self, '_prepare_cnn_features_for_horizon'): model_inputs['cnn_features'] = self._prepare_cnn_features_for_horizon( market_state, 60 # Use 60m horizon for consistency ) # RL state if hasattr(self, '_prepare_rl_state_for_horizon'): model_inputs['rl_state'] = self._prepare_rl_state_for_horizon( market_state, 60 ) # Raw market data model_inputs['current_price'] = market_state['current_price'] model_inputs['ohlcv_sequence'] = market_state['ohlcv_data'][-50:].tolist() # Last 50 candles return model_inputs except Exception as e: logger.error(f"Error extracting model inputs: {e}") return {} def _validate_pending_predictions(self): """Validate predictions that have reached their target time""" try: current_time = datetime.now() symbols = ['ETH/USDT', 'BTC/USDT'] for symbol in symbols: # Get current price for validation current_price = None if self.data_provider and hasattr(self.data_provider, 'current_prices'): current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper()) if current_price is None: continue # Check each horizon for predictions to validate for horizon_minutes in self.horizons: snapshots_to_validate = [] for snapshot in list(self.prediction_snapshots[horizon_minutes]): if (not snapshot.outcome_known and current_time >= snapshot.target_time): # Prediction has reached target time - validate it snapshot.actual_min_price = current_price # Simplified: current price as proxy for min snapshot.actual_max_price = current_price # In reality, we'd need price range over the period snapshot.outcome_known = True snapshot.outcome_timestamp = current_time snapshots_to_validate.append(snapshot) # Process validated snapshots for snapshot in snapshots_to_validate: self._process_validated_prediction(snapshot) except Exception as e: logger.error(f"Error validating pending predictions: {e}") def _process_validated_prediction(self, snapshot: PredictionSnapshot): """Process a validated prediction for training""" try: self.prediction_stats['validated_predictions'] += 1 # Calculate prediction accuracy if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None: # Simple accuracy check: was the actual price within predicted range? actual_price_range = abs(snapshot.actual_max_price - snapshot.actual_min_price) predicted_range = abs(snapshot.predicted_max_price - snapshot.predicted_min_price) # Check if ranges overlap significantly range_overlap = self._calculate_range_overlap( (snapshot.predicted_min_price, snapshot.predicted_max_price), (snapshot.actual_min_price, snapshot.actual_max_price) ) if range_overlap > 0.5: # 50% overlap threshold self.prediction_stats['accurate_predictions'] += 1 # Here we would trigger training with the snapshot data # For now, just log the result accuracy_rate = (self.prediction_stats['accurate_predictions'] / max(1, self.prediction_stats['validated_predictions'])) logger.info(f"Validated {snapshot.target_horizon_minutes}m prediction for {snapshot.symbol}: " f"confidence={snapshot.confidence:.2f}, accuracy_rate={accuracy_rate:.2f}") except Exception as e: logger.error(f"Error processing validated prediction: {e}") def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float: """Calculate overlap between two price ranges (0.0 to 1.0)""" try: min1, max1 = range1 min2, max2 = range2 # Find overlap overlap_min = max(min1, min2) overlap_max = min(max1, max2) if overlap_max <= overlap_min: return 0.0 overlap_size = overlap_max - overlap_min union_size = max(max1, max2) - min(min1, min2) return overlap_size / union_size if union_size > 0 else 0.0 except Exception: return 0.0 def get_prediction_stats(self) -> Dict[str, Any]: """Get prediction statistics""" stats = self.prediction_stats.copy() # Calculate accuracy rate if stats['validated_predictions'] > 0: stats['accuracy_rate'] = stats['accurate_predictions'] / stats['validated_predictions'] else: stats['accuracy_rate'] = 0.0 # Calculate average confidence if stats['total_predictions'] > 0: # This is approximate since we don't store all confidences stats['avg_confidence'] = 0.5 # Placeholder return stats def get_recent_predictions(self, horizon_minutes: int, limit: int = 10) -> List[PredictionSnapshot]: """Get recent predictions for a specific horizon""" if horizon_minutes not in self.prediction_snapshots: return [] return list(self.prediction_snapshots[horizon_minutes])[-limit:] # Placeholder methods for CNN and RL feature preparation - to be implemented def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray: """Prepare CNN features for specific horizon (placeholder - not yet implemented)""" # This would extract relevant features based on horizon logger.debug(f"CNN feature preparation for horizon {horizon} not yet implemented") return np.array([]) # Return empty array instead of synthetic data def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray: """Prepare RL state for specific horizon (placeholder - not yet implemented)""" # This would create state representation for the horizon logger.debug(f"RL state preparation for horizon {horizon} not yet implemented") return np.array([]) # Return empty array instead of synthetic data def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]: """Interpret CNN output for min/max prediction (placeholder - not yet implemented)""" # This would convert CNN output to price predictions logger.debug(f"CNN output interpretation for horizon {horizon} not yet implemented") return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions def _convert_rl_action_to_price_prediction(self, action: int, current_price: float, horizon: int, rl_agent) -> Tuple[float, float, float]: """Convert RL action to price prediction (placeholder - not yet implemented)""" # This would interpret RL action as price movement expectation logger.debug(f"RL action conversion for horizon {horizon} not yet implemented") return (0.0, 0.0, 0.0) # Return zeros instead of synthetic predictions