""" Timeframe-Aware Inference Coordinator This module coordinates model inference across multiple timeframes with proper scheduling. It ensures that models know which timeframe they are predicting on and handles the complex scheduling requirements for multi-timeframe predictions. Key Features: - Timeframe-aware model inference - Hourly multi-timeframe inference (4 predictions per hour) - Frequent inference at 1-5 second intervals - Prediction context management - Integration with enhanced reward calculator """ import time import asyncio import logging from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Callable from dataclasses import dataclass import threading from enum import Enum from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame logger = logging.getLogger(__name__) @dataclass class InferenceContext: """Context information for a model inference""" symbol: str timeframe: TimeFrame timestamp: datetime target_timeframe: TimeFrame # Which timeframe we're predicting for is_hourly_inference: bool = False inference_type: str = "regular" # "regular", "hourly", "continuous" @dataclass class InferenceSchedule: """Schedule configuration for different inference types""" continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference def __post_init__(self): if self.hourly_timeframes is None: self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1] class TimeframeInferenceCoordinator: """ Coordinates timeframe-aware model inference with proper scheduling This coordinator: 1. Manages continuous inference every 1-5 seconds on main timeframe 2. Schedules hourly multi-timeframe inference (4 predictions per hour) 3. Ensures models know which timeframe they're predicting on 4. Integrates with enhanced reward calculator for training 5. Handles prediction context and metadata """ def __init__(self, reward_calculator: EnhancedRewardCalculator, data_provider: Any = None, symbols: List[str] = None): """ Initialize the timeframe inference coordinator Args: reward_calculator: Enhanced reward calculator instance data_provider: Data provider for market data symbols: List of symbols to coordinate inference for """ self.reward_calculator = reward_calculator self.data_provider = data_provider self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] # Inference schedule configuration self.schedule = InferenceSchedule() # Model registry - stores inference functions for different models self.model_inference_functions: Dict[str, Callable] = {} # Tracking inference state self.last_continuous_inference: Dict[str, datetime] = {} self.last_hourly_inference: Dict[str, datetime] = {} self.next_hourly_inference: Dict[str, datetime] = {} # Active inference tasks self.inference_tasks: List[asyncio.Task] = [] self.running = False # Thread safety self.lock = threading.RLock() # Performance metrics self.inference_stats = { 'continuous_inferences': 0, 'hourly_inferences': 0, 'failed_inferences': 0, 'average_inference_time_ms': 0.0 } self._initialize_schedules() logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}") logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s") logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}") def _initialize_schedules(self): """Initialize inference schedules for all symbols""" current_time = datetime.now() for symbol in self.symbols: self.last_continuous_inference[symbol] = current_time self.last_hourly_inference[symbol] = current_time # Schedule next hourly inference at the top of the next hour next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) self.next_hourly_inference[symbol] = next_hour def register_model_inference_function(self, model_name: str, inference_func: Callable): """ Register a model's inference function Args: model_name: Name of the model inference_func: Async function that takes InferenceContext and returns prediction """ self.model_inference_functions[model_name] = inference_func logger.info(f"Registered inference function for model: {model_name}") async def start_coordination(self): """Start the inference coordination system""" if self.running: logger.warning("Inference coordination already running") return self.running = True logger.info("Starting timeframe inference coordination") # Start continuous inference tasks for each symbol for symbol in self.symbols: task = asyncio.create_task(self._continuous_inference_loop(symbol)) self.inference_tasks.append(task) # Start hourly inference scheduler task = asyncio.create_task(self._hourly_inference_scheduler()) self.inference_tasks.append(task) # Start reward evaluation loop task = asyncio.create_task(self._reward_evaluation_loop()) self.inference_tasks.append(task) logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks") async def stop_coordination(self): """Stop the inference coordination system""" if not self.running: return self.running = False logger.info("Stopping timeframe inference coordination") # Cancel all tasks for task in self.inference_tasks: task.cancel() # Wait for tasks to complete await asyncio.gather(*self.inference_tasks, return_exceptions=True) self.inference_tasks.clear() logger.info("Inference coordination stopped") async def _continuous_inference_loop(self, symbol: str): """ Continuous inference loop for a specific symbol Args: symbol: Trading symbol to run inference for """ logger.info(f"Starting continuous inference loop for {symbol}") while self.running: try: current_time = datetime.now() # Check if it's time for continuous inference last_inference = self.last_continuous_inference[symbol] time_since_last = (current_time - last_inference).total_seconds() if time_since_last >= self.schedule.continuous_interval_seconds: # Run continuous inference on primary timeframe (1s) context = InferenceContext( symbol=symbol, timeframe=TimeFrame.SECONDS_1, timestamp=current_time, target_timeframe=TimeFrame.SECONDS_1, is_hourly_inference=False, inference_type="continuous" ) await self._execute_inference(context) self.last_continuous_inference[symbol] = current_time self.inference_stats['continuous_inferences'] += 1 # Sleep for a short interval to avoid busy waiting await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error in continuous inference loop for {symbol}: {e}") await asyncio.sleep(1.0) # Wait longer on error async def _hourly_inference_scheduler(self): """Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers""" logger.info("Starting hourly inference scheduler") while self.running: try: current_time = datetime.now() # Check if any symbol needs hourly inference for symbol in self.symbols: if current_time >= self.next_hourly_inference[symbol]: await self._execute_hourly_inference(symbol, current_time) # Schedule next hourly inference next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) self.next_hourly_inference[symbol] = next_hour self.last_hourly_inference[symbol] = current_time # Trigger at each new timeframe boundary: 1m, 1h, 1d if current_time.second == 0: # New minute await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1) if current_time.minute == 0 and current_time.second == 0: # New hour await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1) if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0: # New day await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1) # Sleep for 30 seconds between checks await asyncio.sleep(30) except Exception as e: logger.error(f"Error in hourly inference scheduler: {e}") await asyncio.sleep(60) # Wait longer on error async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame): """Execute an inference exactly at timeframe boundary""" try: context = InferenceContext( symbol=symbol, timeframe=timeframe, timestamp=timestamp, target_timeframe=timeframe, is_hourly_inference=False, inference_type="boundary" ) await self._execute_inference(context) except Exception as e: logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}") async def _execute_hourly_inference(self, symbol: str, timestamp: datetime): """ Execute hourly multi-timeframe inference for a symbol Args: symbol: Trading symbol timestamp: Current timestamp """ logger.info(f"Executing hourly multi-timeframe inference for {symbol}") # Run inference for each timeframe for timeframe in self.schedule.hourly_timeframes: context = InferenceContext( symbol=symbol, timeframe=timeframe, timestamp=timestamp, target_timeframe=timeframe, is_hourly_inference=True, inference_type="hourly" ) await self._execute_inference(context) self.inference_stats['hourly_inferences'] += 1 # Small delay between timeframe inferences await asyncio.sleep(0.5) async def _execute_inference(self, context: InferenceContext): """ Execute inference for a specific context Args: context: Inference context containing all necessary information """ start_time = time.time() try: # Run inference for all registered models for model_name, inference_func in self.model_inference_functions.items(): try: # Execute model inference prediction = await inference_func(context) if prediction is not None: # Add prediction to reward calculator prediction_id = self.reward_calculator.add_prediction( symbol=context.symbol, timeframe=context.target_timeframe, predicted_price=prediction.get('predicted_price', 0.0), predicted_direction=prediction.get('direction', 0), confidence=prediction.get('confidence', 0.0), current_price=prediction.get('current_price', 0.0), model_name=model_name, predicted_return=prediction.get('predicted_return'), state_vector=prediction.get('model_state') or prediction.get('model_features') ) logger.debug(f"Added prediction {prediction_id} from {model_name} " f"for {context.symbol} {context.target_timeframe.value}") except Exception as e: logger.error(f"Error running inference for model {model_name}: {e}") self.inference_stats['failed_inferences'] += 1 # Update inference timing stats inference_time_ms = (time.time() - start_time) * 1000 self._update_inference_timing(inference_time_ms) except Exception as e: logger.error(f"Error executing inference for context {context}: {e}") self.inference_stats['failed_inferences'] += 1 def _update_inference_timing(self, inference_time_ms: float): """Update inference timing statistics""" total_inferences = (self.inference_stats['continuous_inferences'] + self.inference_stats['hourly_inferences']) if total_inferences > 0: current_avg = self.inference_stats['average_inference_time_ms'] new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences self.inference_stats['average_inference_time_ms'] = new_avg async def _reward_evaluation_loop(self): """Continuous loop for evaluating prediction rewards""" logger.info("Starting reward evaluation loop") while self.running: try: # Update price cache if data provider available if self.data_provider: # DataProvider.get_current_price is synchronous; do not await await self._update_price_cache() # Evaluate predictions and get training data for symbol in self.symbols: evaluation_results = self.reward_calculator.evaluate_predictions(symbol) if symbol in evaluation_results and evaluation_results[symbol]: logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}") # Here you could trigger training for models that have new evaluated predictions await self._trigger_model_training(symbol, evaluation_results[symbol]) # Sleep for evaluation interval await asyncio.sleep(10) # Evaluate every 10 seconds except Exception as e: logger.error(f"Error in reward evaluation loop: {e}") await asyncio.sleep(30) # Wait longer on error async def _update_price_cache(self): """Update price cache with current market prices""" try: for symbol in self.symbols: # Get current price from data provider if hasattr(self.data_provider, 'get_current_price'): current_price = self.data_provider.get_current_price(symbol) if current_price: self.reward_calculator.update_price(symbol, current_price) except Exception as e: logger.debug(f"Error updating price cache: {e}") async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]): """ Trigger model training based on evaluation results Args: symbol: Trading symbol evaluation_results: List of (prediction, reward) tuples """ try: # Group by model and timeframe for targeted training training_groups = {} for prediction_record, reward in evaluation_results: model_name = prediction_record.model_name timeframe = prediction_record.timeframe key = f"{model_name}_{timeframe.value}" if key not in training_groups: training_groups[key] = [] training_groups[key].append((prediction_record, reward)) # Trigger training for each group for group_key, training_data in training_groups.items(): model_name, timeframe_str = group_key.split('_', 1) timeframe = TimeFrame(timeframe_str) logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} " f"with {len(training_data)} samples") # Here you would call the specific model's training function # This is a placeholder - you'll need to implement the actual training calls await self._call_model_training(model_name, symbol, timeframe, training_data) except Exception as e: logger.error(f"Error triggering model training: {e}") async def _call_model_training(self, model_name: str, symbol: str, timeframe: TimeFrame, training_data: List[Any]): """ Call model-specific training function Args: model_name: Name of the model to train symbol: Trading symbol timeframe: Timeframe for training training_data: List of (prediction, reward) tuples """ # This is a placeholder for model-specific training calls # You'll need to implement this based on your specific model interfaces logger.debug(f"Training call for {model_name}: {len(training_data)} samples") def get_inference_statistics(self) -> Dict[str, Any]: """Get inference coordination statistics""" with self.lock: stats = self.inference_stats.copy() # Add scheduling information stats['symbols'] = self.symbols stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes] stats['next_hourly_inferences'] = { symbol: timestamp.isoformat() for symbol, timestamp in self.next_hourly_inference.items() } # Add accuracy summary from reward calculator stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary() return stats def force_hourly_inference(self, symbol: str = None): """ Force immediate hourly inference for symbol(s) Args: symbol: Specific symbol (None for all symbols) """ symbols_to_process = [symbol] if symbol else self.symbols async def _force_inference(): current_time = datetime.now() for sym in symbols_to_process: await self._execute_hourly_inference(sym, current_time) # Schedule the inference if self.running: asyncio.create_task(_force_inference()) else: logger.warning("Cannot force inference - coordinator not running") def get_prediction_history(self, symbol: str, timeframe: TimeFrame, max_samples: int = 50) -> List[Any]: """ Get prediction history for training Args: symbol: Trading symbol timeframe: Specific timeframe max_samples: Maximum samples to return Returns: List of training samples """ return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)