Files
gogo2/core/enhanced_reward_calculator.py
2025-08-26 22:26:42 +03:00

496 lines
19 KiB
Python

"""
Enhanced Reward Calculator for Reinforcement Learning Training
This module implements a comprehensive reward calculation system based on mean squared error
between predictions and empirical outcomes. It tracks multiple timeframes separately and
maintains prediction history for accurate reward computation.
Key Features:
- MSE-based reward calculation for prediction accuracy
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
- Separate accuracy tracking for each timeframe
- Prediction history tracking (last 6 predictions per timeframe)
- Real-time training at each inference
- Timeframe-aware inference scheduling
"""
import time
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from collections import deque
from datetime import datetime, timedelta
import numpy as np
import threading
from enum import Enum
logger = logging.getLogger(__name__)
class TimeFrame(Enum):
"""Supported timeframes for prediction"""
SECONDS_1 = "1s"
MINUTES_1 = "1m"
HOURS_1 = "1h"
DAYS_1 = "1d"
@dataclass
class PredictionRecord:
"""Individual prediction record with outcome tracking"""
timestamp: datetime
symbol: str
timeframe: TimeFrame
predicted_price: float
predicted_direction: int # -1: down, 0: neutral, 1: up
confidence: float
current_price: float
model_name: str
# Optional state vector used for prediction/training (standardized feature/state)
state_vector: Optional[list] = None
# Outcome fields (set when outcome is determined)
actual_price: Optional[float] = None
actual_direction: Optional[int] = None
outcome_timestamp: Optional[datetime] = None
mse_reward: Optional[float] = None
direction_correct: Optional[bool] = None
is_evaluated: bool = False
@dataclass
class TimeframeAccuracy:
"""Accuracy tracking for a specific timeframe"""
timeframe: TimeFrame
total_predictions: int = 0
correct_directions: int = 0
total_mse: float = 0.0
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
@property
def direction_accuracy(self) -> float:
"""Calculate directional accuracy percentage"""
if self.total_predictions == 0:
return 0.0
return (self.correct_directions / self.total_predictions) * 100.0
@property
def average_mse(self) -> float:
"""Calculate average MSE"""
if self.total_predictions == 0:
return 0.0
return self.total_mse / self.total_predictions
class EnhancedRewardCalculator:
"""
Enhanced reward calculator using MSE and multi-timeframe tracking
This calculator:
1. Tracks predictions for multiple timeframes separately
2. Calculates MSE-based rewards when outcomes are available
3. Maintains prediction history for last 6 predictions per timeframe
4. Provides separate accuracy metrics for each timeframe
5. Enables real-time training at each inference
"""
def __init__(self, symbols: List[str] = None):
"""Initialize the enhanced reward calculator"""
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
# Evaluation timeouts for each timeframe (in seconds)
self.evaluation_timeouts = {
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
}
# Price data cache for outcome evaluation
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
self.price_cache_max_size = 1000
# Thread safety
self.lock = threading.RLock()
# Initialize data structures
self._initialize_data_structures()
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
def _initialize_data_structures(self):
"""Initialize nested data structures"""
for symbol in self.symbols:
self.predictions[symbol] = {}
self.accuracy_tracker[symbol] = {}
self.price_cache[symbol] = []
for timeframe in self.timeframes:
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
def add_prediction(self,
symbol: str,
timeframe: TimeFrame,
predicted_price: float,
predicted_return: Optional[float] = None,
predicted_direction: int,
confidence: float,
current_price: float,
model_name: str,
state_vector: Optional[list] = None) -> str:
"""
Add a new prediction to track
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timeframe: Timeframe for this prediction
predicted_price: Model's predicted price
predicted_direction: Predicted direction (-1, 0, 1)
confidence: Model's confidence (0.0 to 1.0)
current_price: Current market price
model_name: Name of the model making prediction
Returns:
Unique prediction ID for later reference
"""
with self.lock:
prediction = PredictionRecord(
timestamp=datetime.now(),
symbol=symbol,
timeframe=timeframe,
predicted_price=predicted_price,
predicted_direction=predicted_direction,
confidence=confidence,
current_price=current_price,
model_name=model_name,
state_vector=state_vector
)
# If predicted_return provided, prefer computing implied predicted_price
# to avoid synthetic price fabrication
try:
if predicted_return is not None and current_price > 0:
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
except Exception:
pass
# Store prediction
if symbol not in self.predictions:
self._initialize_data_structures()
self.predictions[symbol][timeframe].append(prediction)
# Add to accuracy tracker history
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
f"direction={predicted_direction}, confidence={confidence:.3f}")
return prediction_id
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
"""
Update current price for a symbol
Args:
symbol: Trading symbol
price: Current price
timestamp: Price timestamp (defaults to now)
"""
if timestamp is None:
timestamp = datetime.now()
with self.lock:
if symbol not in self.price_cache:
self.price_cache[symbol] = []
self.price_cache[symbol].append((timestamp, price))
# Maintain cache size
if len(self.price_cache[symbol]) > self.price_cache_max_size:
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
"""
Evaluate pending predictions and calculate rewards
Args:
symbol: Specific symbol to evaluate (None for all symbols)
Returns:
Dictionary mapping symbol to list of (prediction, reward) tuples
"""
results = {}
symbols_to_evaluate = [symbol] if symbol else self.symbols
with self.lock:
for sym in symbols_to_evaluate:
if sym not in self.predictions:
continue
results[sym] = []
current_time = datetime.now()
for timeframe in self.timeframes:
predictions_to_evaluate = []
# Find predictions ready for evaluation
for prediction in self.predictions[sym][timeframe]:
if prediction.is_evaluated:
continue
time_elapsed = (current_time - prediction.timestamp).total_seconds()
timeout = self.evaluation_timeouts[timeframe]
if time_elapsed >= timeout:
predictions_to_evaluate.append(prediction)
# Evaluate predictions
for prediction in predictions_to_evaluate:
reward = self._calculate_prediction_reward(prediction)
if reward is not None:
results[sym].append((prediction, reward))
# Update accuracy tracking
self._update_accuracy_tracking(sym, timeframe, prediction)
return results
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
"""
Calculate MSE-based reward for a prediction
Args:
prediction: Prediction record to evaluate
Returns:
Calculated reward or None if outcome cannot be determined
"""
# Get actual price at evaluation time
actual_price = self._get_price_at_time(
prediction.symbol,
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
)
if actual_price is None:
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
return None
# Calculate price change and direction
price_change = actual_price - prediction.current_price
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
# Calculate MSE reward
price_error = actual_price - prediction.predicted_price
mse = price_error ** 2
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
# Use exponential decay to heavily penalize large errors
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
normalized_mse = min(mse / max_mse, 1.0)
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
# Direction accuracy bonus/penalty (stronger punishment for wrong direction)
direction_correct = (prediction.predicted_direction == actual_direction)
# Increase wrong-direction penalty; reduce correct-direction bonus slightly
direction_bonus = 0.25 if direction_correct else -1.0
# Confidence scaling (apply floor to avoid near-zero scaling)
confidence_weight = max(prediction.confidence, 0.2)
# Final reward calculation
base_reward = mse_reward + direction_bonus
final_reward = base_reward * confidence_weight
# Update prediction record
prediction.actual_price = actual_price
prediction.actual_direction = actual_direction
prediction.outcome_timestamp = datetime.now()
prediction.mse_reward = final_reward
prediction.direction_correct = direction_correct
prediction.is_evaluated = True
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
return final_reward
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
"""
Get price for symbol at a specific time
Args:
symbol: Trading symbol
target_time: Target timestamp
Returns:
Price at target time or None if not available
"""
if symbol not in self.price_cache or not self.price_cache[symbol]:
return None
# Find closest price to target time
closest_price = None
min_time_diff = float('inf')
for timestamp, price in self.price_cache[symbol]:
time_diff = abs((timestamp - target_time).total_seconds())
if time_diff < min_time_diff:
min_time_diff = time_diff
closest_price = price
# Only return price if it's within reasonable time window (30 seconds)
if min_time_diff <= 30:
return closest_price
return None
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
"""Update accuracy tracking for a timeframe"""
tracker = self.accuracy_tracker[symbol][timeframe]
tracker.total_predictions += 1
if prediction.direction_correct:
tracker.correct_directions += 1
if prediction.mse_reward is not None:
# Convert reward back to MSE for tracking
# Since reward = exp(-5 * normalized_mse), we can reverse it
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
max_mse = (prediction.current_price * 0.1) ** 2
mse = normalized_mse * max_mse
tracker.total_mse += mse
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
"""
Get accuracy summary for all or specific symbol
Args:
symbol: Specific symbol (None for all)
Returns:
Nested dictionary with accuracy metrics
"""
summary = {}
symbols_to_summarize = [symbol] if symbol else self.symbols
with self.lock:
for sym in symbols_to_summarize:
if sym not in self.accuracy_tracker:
continue
summary[sym] = {}
for timeframe in self.timeframes:
tracker = self.accuracy_tracker[sym][timeframe]
summary[sym][timeframe.value] = {
'total_predictions': tracker.total_predictions,
'direction_accuracy': tracker.direction_accuracy,
'average_mse': tracker.average_mse,
'recent_predictions': len(tracker.prediction_history)
}
return summary
def get_training_data(self, symbol: str, timeframe: TimeFrame,
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
"""
Get recent evaluated predictions for training
Args:
symbol: Trading symbol
timeframe: Specific timeframe
max_samples: Maximum number of samples to return
Returns:
List of (prediction, reward) tuples ready for training
"""
training_data = []
with self.lock:
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
return training_data
evaluated_predictions = [
p for p in self.predictions[symbol][timeframe]
if p.is_evaluated and p.mse_reward is not None
]
# Get most recent evaluated predictions
recent_predictions = list(evaluated_predictions)[-max_samples:]
for prediction in recent_predictions:
training_data.append((prediction, prediction.mse_reward))
return training_data
def cleanup_old_predictions(self, days_to_keep: int = 7):
"""
Clean up old predictions to manage memory
Args:
days_to_keep: Number of days of predictions to keep
"""
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
with self.lock:
for symbol in self.predictions:
for timeframe in self.timeframes:
# Filter out old predictions
old_count = len(self.predictions[symbol][timeframe])
self.predictions[symbol][timeframe] = deque(
[p for p in self.predictions[symbol][timeframe]
if p.timestamp > cutoff_time],
maxlen=100
)
new_count = len(self.predictions[symbol][timeframe])
removed_count = old_count - new_count
if removed_count > 0:
logger.info(f"Cleaned up {removed_count} old predictions for "
f"{symbol} {timeframe.value}")
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
"""
Force evaluation of all pending predictions for a specific timeframe
Useful for immediate training needs
Args:
symbol: Trading symbol
timeframe: Specific timeframe to evaluate
Returns:
List of (prediction, reward) tuples
"""
results = []
with self.lock:
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
return results
# Evaluate all non-evaluated predictions
for prediction in self.predictions[symbol][timeframe]:
if not prediction.is_evaluated:
reward = self._calculate_prediction_reward(prediction)
if reward is not None:
results.append((prediction, reward))
self._update_accuracy_tracking(symbol, timeframe, prediction)
return results