""" Enhanced Reward Calculator for Reinforcement Learning Training This module implements a comprehensive reward calculation system based on mean squared error between predictions and empirical outcomes. It tracks multiple timeframes separately and maintains prediction history for accurate reward computation. Key Features: - MSE-based reward calculation for prediction accuracy - Multi-timeframe prediction tracking (1s, 1m, 1h, 1d) - Separate accuracy tracking for each timeframe - Prediction history tracking (last 6 predictions per timeframe) - Real-time training at each inference - Timeframe-aware inference scheduling """ import time import logging from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Any from collections import deque from datetime import datetime, timedelta import numpy as np import threading from enum import Enum logger = logging.getLogger(__name__) class TimeFrame(Enum): """Supported timeframes for prediction""" SECONDS_1 = "1s" MINUTES_1 = "1m" HOURS_1 = "1h" DAYS_1 = "1d" @dataclass class PredictionRecord: """Individual prediction record with outcome tracking""" timestamp: datetime symbol: str timeframe: TimeFrame predicted_price: float predicted_direction: int # -1: down, 0: neutral, 1: up confidence: float current_price: float model_name: str # Optional state vector used for prediction/training (standardized feature/state) state_vector: Optional[list] = None # Outcome fields (set when outcome is determined) actual_price: Optional[float] = None actual_direction: Optional[int] = None outcome_timestamp: Optional[datetime] = None mse_reward: Optional[float] = None direction_correct: Optional[bool] = None is_evaluated: bool = False @dataclass class TimeframeAccuracy: """Accuracy tracking for a specific timeframe""" timeframe: TimeFrame total_predictions: int = 0 correct_directions: int = 0 total_mse: float = 0.0 prediction_history: deque = field(default_factory=lambda: deque(maxlen=6)) @property def direction_accuracy(self) -> float: """Calculate directional accuracy percentage""" if self.total_predictions == 0: return 0.0 return (self.correct_directions / self.total_predictions) * 100.0 @property def average_mse(self) -> float: """Calculate average MSE""" if self.total_predictions == 0: return 0.0 return self.total_mse / self.total_predictions class EnhancedRewardCalculator: """ Enhanced reward calculator using MSE and multi-timeframe tracking This calculator: 1. Tracks predictions for multiple timeframes separately 2. Calculates MSE-based rewards when outcomes are available 3. Maintains prediction history for last 6 predictions per timeframe 4. Provides separate accuracy metrics for each timeframe 5. Enables real-time training at each inference """ def __init__(self, symbols: List[str] = None): """Initialize the enhanced reward calculator""" self.symbols = symbols or ['ETH/USDT', 'BTC/USDT'] self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1] # Prediction storage: symbol -> timeframe -> deque of PredictionRecord self.predictions: Dict[str, Dict[TimeFrame, deque]] = {} # Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {} # Evaluation timeouts for each timeframe (in seconds) self.evaluation_timeouts = { TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes } # Price data cache for outcome evaluation self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {} self.price_cache_max_size = 1000 # Thread safety self.lock = threading.RLock() # Initialize data structures self._initialize_data_structures() logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}") logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}") logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}") def _initialize_data_structures(self): """Initialize nested data structures""" for symbol in self.symbols: self.predictions[symbol] = {} self.accuracy_tracker[symbol] = {} self.price_cache[symbol] = [] for timeframe in self.timeframes: self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe) def add_prediction(self, symbol: str, timeframe: TimeFrame, predicted_price: float, predicted_return: Optional[float] = None, predicted_direction: int, confidence: float, current_price: float, model_name: str, state_vector: Optional[list] = None) -> str: """ Add a new prediction to track Args: symbol: Trading symbol (e.g., 'ETH/USDT') timeframe: Timeframe for this prediction predicted_price: Model's predicted price predicted_direction: Predicted direction (-1, 0, 1) confidence: Model's confidence (0.0 to 1.0) current_price: Current market price model_name: Name of the model making prediction Returns: Unique prediction ID for later reference """ with self.lock: prediction = PredictionRecord( timestamp=datetime.now(), symbol=symbol, timeframe=timeframe, predicted_price=predicted_price, predicted_direction=predicted_direction, confidence=confidence, current_price=current_price, model_name=model_name, state_vector=state_vector ) # If predicted_return provided, prefer computing implied predicted_price # to avoid synthetic price fabrication try: if predicted_return is not None and current_price > 0: prediction.predicted_price = current_price * (1.0 + float(predicted_return)) except Exception: pass # Store prediction if symbol not in self.predictions: self._initialize_data_structures() self.predictions[symbol][timeframe].append(prediction) # Add to accuracy tracker history self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction) prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}" logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, " f"direction={predicted_direction}, confidence={confidence:.3f}") return prediction_id def update_price(self, symbol: str, price: float, timestamp: datetime = None): """ Update current price for a symbol Args: symbol: Trading symbol price: Current price timestamp: Price timestamp (defaults to now) """ if timestamp is None: timestamp = datetime.now() with self.lock: if symbol not in self.price_cache: self.price_cache[symbol] = [] self.price_cache[symbol].append((timestamp, price)) # Maintain cache size if len(self.price_cache[symbol]) > self.price_cache_max_size: self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:] def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]: """ Evaluate pending predictions and calculate rewards Args: symbol: Specific symbol to evaluate (None for all symbols) Returns: Dictionary mapping symbol to list of (prediction, reward) tuples """ results = {} symbols_to_evaluate = [symbol] if symbol else self.symbols with self.lock: for sym in symbols_to_evaluate: if sym not in self.predictions: continue results[sym] = [] current_time = datetime.now() for timeframe in self.timeframes: predictions_to_evaluate = [] # Find predictions ready for evaluation for prediction in self.predictions[sym][timeframe]: if prediction.is_evaluated: continue time_elapsed = (current_time - prediction.timestamp).total_seconds() timeout = self.evaluation_timeouts[timeframe] if time_elapsed >= timeout: predictions_to_evaluate.append(prediction) # Evaluate predictions for prediction in predictions_to_evaluate: reward = self._calculate_prediction_reward(prediction) if reward is not None: results[sym].append((prediction, reward)) # Update accuracy tracking self._update_accuracy_tracking(sym, timeframe, prediction) return results def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]: """ Calculate MSE-based reward for a prediction Args: prediction: Prediction record to evaluate Returns: Calculated reward or None if outcome cannot be determined """ # Get actual price at evaluation time actual_price = self._get_price_at_time( prediction.symbol, prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe]) ) if actual_price is None: logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}") return None # Calculate price change and direction price_change = actual_price - prediction.current_price actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0) # Calculate MSE reward price_error = actual_price - prediction.predicted_price mse = price_error ** 2 # Normalize MSE to a reasonable reward scale (lower MSE = higher reward) # Use exponential decay to heavily penalize large errors max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error normalized_mse = min(mse / max_mse, 1.0) mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1] # Direction accuracy bonus/penalty direction_correct = (prediction.predicted_direction == actual_direction) direction_bonus = 0.5 if direction_correct else -0.5 # Confidence scaling confidence_weight = prediction.confidence # Final reward calculation base_reward = mse_reward + direction_bonus final_reward = base_reward * confidence_weight # Update prediction record prediction.actual_price = actual_price prediction.actual_direction = actual_direction prediction.outcome_timestamp = datetime.now() prediction.mse_reward = final_reward prediction.direction_correct = direction_correct prediction.is_evaluated = True logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, " f"MSE={mse:.6f}, direction_correct={direction_correct}, " f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}") return final_reward def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]: """ Get price for symbol at a specific time Args: symbol: Trading symbol target_time: Target timestamp Returns: Price at target time or None if not available """ if symbol not in self.price_cache or not self.price_cache[symbol]: return None # Find closest price to target time closest_price = None min_time_diff = float('inf') for timestamp, price in self.price_cache[symbol]: time_diff = abs((timestamp - target_time).total_seconds()) if time_diff < min_time_diff: min_time_diff = time_diff closest_price = price # Only return price if it's within reasonable time window (30 seconds) if min_time_diff <= 30: return closest_price return None def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord): """Update accuracy tracking for a timeframe""" tracker = self.accuracy_tracker[symbol][timeframe] tracker.total_predictions += 1 if prediction.direction_correct: tracker.correct_directions += 1 if prediction.mse_reward is not None: # Convert reward back to MSE for tracking # Since reward = exp(-5 * normalized_mse), we can reverse it normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5 max_mse = (prediction.current_price * 0.1) ** 2 mse = normalized_mse * max_mse tracker.total_mse += mse def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]: """ Get accuracy summary for all or specific symbol Args: symbol: Specific symbol (None for all) Returns: Nested dictionary with accuracy metrics """ summary = {} symbols_to_summarize = [symbol] if symbol else self.symbols with self.lock: for sym in symbols_to_summarize: if sym not in self.accuracy_tracker: continue summary[sym] = {} for timeframe in self.timeframes: tracker = self.accuracy_tracker[sym][timeframe] summary[sym][timeframe.value] = { 'total_predictions': tracker.total_predictions, 'direction_accuracy': tracker.direction_accuracy, 'average_mse': tracker.average_mse, 'recent_predictions': len(tracker.prediction_history) } return summary def get_training_data(self, symbol: str, timeframe: TimeFrame, max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]: """ Get recent evaluated predictions for training Args: symbol: Trading symbol timeframe: Specific timeframe max_samples: Maximum number of samples to return Returns: List of (prediction, reward) tuples ready for training """ training_data = [] with self.lock: if symbol not in self.predictions or timeframe not in self.predictions[symbol]: return training_data evaluated_predictions = [ p for p in self.predictions[symbol][timeframe] if p.is_evaluated and p.mse_reward is not None ] # Get most recent evaluated predictions recent_predictions = list(evaluated_predictions)[-max_samples:] for prediction in recent_predictions: training_data.append((prediction, prediction.mse_reward)) return training_data def cleanup_old_predictions(self, days_to_keep: int = 7): """ Clean up old predictions to manage memory Args: days_to_keep: Number of days of predictions to keep """ cutoff_time = datetime.now() - timedelta(days=days_to_keep) with self.lock: for symbol in self.predictions: for timeframe in self.timeframes: # Filter out old predictions old_count = len(self.predictions[symbol][timeframe]) self.predictions[symbol][timeframe] = deque( [p for p in self.predictions[symbol][timeframe] if p.timestamp > cutoff_time], maxlen=100 ) new_count = len(self.predictions[symbol][timeframe]) removed_count = old_count - new_count if removed_count > 0: logger.info(f"Cleaned up {removed_count} old predictions for " f"{symbol} {timeframe.value}") def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]: """ Force evaluation of all pending predictions for a specific timeframe Useful for immediate training needs Args: symbol: Trading symbol timeframe: Specific timeframe to evaluate Returns: List of (prediction, reward) tuples """ results = [] with self.lock: if symbol not in self.predictions or timeframe not in self.predictions[symbol]: return results # Evaluate all non-evaluated predictions for prediction in self.predictions[symbol][timeframe]: if not prediction.is_evaluated: reward = self._calculate_prediction_reward(prediction) if reward is not None: results.append((prediction, reward)) self._update_accuracy_tracking(symbol, timeframe, prediction) return results