496 lines
19 KiB
Python
496 lines
19 KiB
Python
"""
|
|
Enhanced Reward Calculator for Reinforcement Learning Training
|
|
|
|
This module implements a comprehensive reward calculation system based on mean squared error
|
|
between predictions and empirical outcomes. It tracks multiple timeframes separately and
|
|
maintains prediction history for accurate reward computation.
|
|
|
|
Key Features:
|
|
- MSE-based reward calculation for prediction accuracy
|
|
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
|
|
- Separate accuracy tracking for each timeframe
|
|
- Prediction history tracking (last 6 predictions per timeframe)
|
|
- Real-time training at each inference
|
|
- Timeframe-aware inference scheduling
|
|
"""
|
|
|
|
import time
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from collections import deque
|
|
from datetime import datetime, timedelta
|
|
import numpy as np
|
|
import threading
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TimeFrame(Enum):
|
|
"""Supported timeframes for prediction"""
|
|
SECONDS_1 = "1s"
|
|
MINUTES_1 = "1m"
|
|
HOURS_1 = "1h"
|
|
DAYS_1 = "1d"
|
|
|
|
|
|
@dataclass
|
|
class PredictionRecord:
|
|
"""Individual prediction record with outcome tracking"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
timeframe: TimeFrame
|
|
predicted_price: float
|
|
predicted_direction: int # -1: down, 0: neutral, 1: up
|
|
confidence: float
|
|
current_price: float
|
|
model_name: str
|
|
# Optional state vector used for prediction/training (standardized feature/state)
|
|
state_vector: Optional[list] = None
|
|
|
|
# Outcome fields (set when outcome is determined)
|
|
actual_price: Optional[float] = None
|
|
actual_direction: Optional[int] = None
|
|
outcome_timestamp: Optional[datetime] = None
|
|
mse_reward: Optional[float] = None
|
|
direction_correct: Optional[bool] = None
|
|
is_evaluated: bool = False
|
|
|
|
|
|
@dataclass
|
|
class TimeframeAccuracy:
|
|
"""Accuracy tracking for a specific timeframe"""
|
|
timeframe: TimeFrame
|
|
total_predictions: int = 0
|
|
correct_directions: int = 0
|
|
total_mse: float = 0.0
|
|
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
|
|
|
|
@property
|
|
def direction_accuracy(self) -> float:
|
|
"""Calculate directional accuracy percentage"""
|
|
if self.total_predictions == 0:
|
|
return 0.0
|
|
return (self.correct_directions / self.total_predictions) * 100.0
|
|
|
|
@property
|
|
def average_mse(self) -> float:
|
|
"""Calculate average MSE"""
|
|
if self.total_predictions == 0:
|
|
return 0.0
|
|
return self.total_mse / self.total_predictions
|
|
|
|
|
|
class EnhancedRewardCalculator:
|
|
"""
|
|
Enhanced reward calculator using MSE and multi-timeframe tracking
|
|
|
|
This calculator:
|
|
1. Tracks predictions for multiple timeframes separately
|
|
2. Calculates MSE-based rewards when outcomes are available
|
|
3. Maintains prediction history for last 6 predictions per timeframe
|
|
4. Provides separate accuracy metrics for each timeframe
|
|
5. Enables real-time training at each inference
|
|
"""
|
|
|
|
def __init__(self, symbols: List[str] = None):
|
|
"""Initialize the enhanced reward calculator"""
|
|
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
|
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
|
|
|
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
|
|
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
|
|
|
|
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
|
|
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
|
|
|
|
# Evaluation timeouts for each timeframe (in seconds)
|
|
self.evaluation_timeouts = {
|
|
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
|
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
|
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
|
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
|
}
|
|
|
|
# Price data cache for outcome evaluation
|
|
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
|
|
self.price_cache_max_size = 1000
|
|
|
|
# Thread safety
|
|
self.lock = threading.RLock()
|
|
|
|
# Initialize data structures
|
|
self._initialize_data_structures()
|
|
|
|
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
|
|
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
|
|
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
|
|
|
|
def _initialize_data_structures(self):
|
|
"""Initialize nested data structures"""
|
|
for symbol in self.symbols:
|
|
self.predictions[symbol] = {}
|
|
self.accuracy_tracker[symbol] = {}
|
|
self.price_cache[symbol] = []
|
|
|
|
for timeframe in self.timeframes:
|
|
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
|
|
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
|
|
|
|
def add_prediction(self,
|
|
symbol: str,
|
|
timeframe: TimeFrame,
|
|
predicted_price: float,
|
|
predicted_return: Optional[float] = None,
|
|
predicted_direction: int,
|
|
confidence: float,
|
|
current_price: float,
|
|
model_name: str,
|
|
state_vector: Optional[list] = None) -> str:
|
|
"""
|
|
Add a new prediction to track
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'ETH/USDT')
|
|
timeframe: Timeframe for this prediction
|
|
predicted_price: Model's predicted price
|
|
predicted_direction: Predicted direction (-1, 0, 1)
|
|
confidence: Model's confidence (0.0 to 1.0)
|
|
current_price: Current market price
|
|
model_name: Name of the model making prediction
|
|
|
|
Returns:
|
|
Unique prediction ID for later reference
|
|
"""
|
|
with self.lock:
|
|
prediction = PredictionRecord(
|
|
timestamp=datetime.now(),
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
predicted_price=predicted_price,
|
|
predicted_direction=predicted_direction,
|
|
confidence=confidence,
|
|
current_price=current_price,
|
|
model_name=model_name,
|
|
state_vector=state_vector
|
|
)
|
|
|
|
# If predicted_return provided, prefer computing implied predicted_price
|
|
# to avoid synthetic price fabrication
|
|
try:
|
|
if predicted_return is not None and current_price > 0:
|
|
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
|
|
except Exception:
|
|
pass
|
|
|
|
# Store prediction
|
|
if symbol not in self.predictions:
|
|
self._initialize_data_structures()
|
|
|
|
self.predictions[symbol][timeframe].append(prediction)
|
|
|
|
# Add to accuracy tracker history
|
|
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
|
|
|
|
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
|
|
|
|
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
|
|
f"direction={predicted_direction}, confidence={confidence:.3f}")
|
|
|
|
return prediction_id
|
|
|
|
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
|
|
"""
|
|
Update current price for a symbol
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
price: Current price
|
|
timestamp: Price timestamp (defaults to now)
|
|
"""
|
|
if timestamp is None:
|
|
timestamp = datetime.now()
|
|
|
|
with self.lock:
|
|
if symbol not in self.price_cache:
|
|
self.price_cache[symbol] = []
|
|
|
|
self.price_cache[symbol].append((timestamp, price))
|
|
|
|
# Maintain cache size
|
|
if len(self.price_cache[symbol]) > self.price_cache_max_size:
|
|
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
|
|
|
|
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
|
|
"""
|
|
Evaluate pending predictions and calculate rewards
|
|
|
|
Args:
|
|
symbol: Specific symbol to evaluate (None for all symbols)
|
|
|
|
Returns:
|
|
Dictionary mapping symbol to list of (prediction, reward) tuples
|
|
"""
|
|
results = {}
|
|
symbols_to_evaluate = [symbol] if symbol else self.symbols
|
|
|
|
with self.lock:
|
|
for sym in symbols_to_evaluate:
|
|
if sym not in self.predictions:
|
|
continue
|
|
|
|
results[sym] = []
|
|
current_time = datetime.now()
|
|
|
|
for timeframe in self.timeframes:
|
|
predictions_to_evaluate = []
|
|
|
|
# Find predictions ready for evaluation
|
|
for prediction in self.predictions[sym][timeframe]:
|
|
if prediction.is_evaluated:
|
|
continue
|
|
|
|
time_elapsed = (current_time - prediction.timestamp).total_seconds()
|
|
timeout = self.evaluation_timeouts[timeframe]
|
|
|
|
if time_elapsed >= timeout:
|
|
predictions_to_evaluate.append(prediction)
|
|
|
|
# Evaluate predictions
|
|
for prediction in predictions_to_evaluate:
|
|
reward = self._calculate_prediction_reward(prediction)
|
|
if reward is not None:
|
|
results[sym].append((prediction, reward))
|
|
|
|
# Update accuracy tracking
|
|
self._update_accuracy_tracking(sym, timeframe, prediction)
|
|
|
|
return results
|
|
|
|
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
|
|
"""
|
|
Calculate MSE-based reward for a prediction
|
|
|
|
Args:
|
|
prediction: Prediction record to evaluate
|
|
|
|
Returns:
|
|
Calculated reward or None if outcome cannot be determined
|
|
"""
|
|
# Get actual price at evaluation time
|
|
actual_price = self._get_price_at_time(
|
|
prediction.symbol,
|
|
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
|
|
)
|
|
|
|
if actual_price is None:
|
|
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
|
|
return None
|
|
|
|
# Calculate price change and direction
|
|
price_change = actual_price - prediction.current_price
|
|
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
|
|
|
# Calculate MSE reward
|
|
price_error = actual_price - prediction.predicted_price
|
|
mse = price_error ** 2
|
|
|
|
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
|
|
# Use exponential decay to heavily penalize large errors
|
|
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
|
|
normalized_mse = min(mse / max_mse, 1.0)
|
|
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
|
|
|
|
# Direction accuracy bonus/penalty (stronger punishment for wrong direction)
|
|
direction_correct = (prediction.predicted_direction == actual_direction)
|
|
# Increase wrong-direction penalty; reduce correct-direction bonus slightly
|
|
direction_bonus = 0.25 if direction_correct else -1.0
|
|
|
|
# Confidence scaling (apply floor to avoid near-zero scaling)
|
|
confidence_weight = max(prediction.confidence, 0.2)
|
|
|
|
# Final reward calculation
|
|
base_reward = mse_reward + direction_bonus
|
|
final_reward = base_reward * confidence_weight
|
|
|
|
# Update prediction record
|
|
prediction.actual_price = actual_price
|
|
prediction.actual_direction = actual_direction
|
|
prediction.outcome_timestamp = datetime.now()
|
|
prediction.mse_reward = final_reward
|
|
prediction.direction_correct = direction_correct
|
|
prediction.is_evaluated = True
|
|
|
|
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
|
|
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
|
|
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
|
|
|
|
return final_reward
|
|
|
|
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
|
|
"""
|
|
Get price for symbol at a specific time
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
target_time: Target timestamp
|
|
|
|
Returns:
|
|
Price at target time or None if not available
|
|
"""
|
|
if symbol not in self.price_cache or not self.price_cache[symbol]:
|
|
return None
|
|
|
|
# Find closest price to target time
|
|
closest_price = None
|
|
min_time_diff = float('inf')
|
|
|
|
for timestamp, price in self.price_cache[symbol]:
|
|
time_diff = abs((timestamp - target_time).total_seconds())
|
|
if time_diff < min_time_diff:
|
|
min_time_diff = time_diff
|
|
closest_price = price
|
|
|
|
# Only return price if it's within reasonable time window (30 seconds)
|
|
if min_time_diff <= 30:
|
|
return closest_price
|
|
|
|
return None
|
|
|
|
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
|
|
"""Update accuracy tracking for a timeframe"""
|
|
tracker = self.accuracy_tracker[symbol][timeframe]
|
|
|
|
tracker.total_predictions += 1
|
|
if prediction.direction_correct:
|
|
tracker.correct_directions += 1
|
|
|
|
if prediction.mse_reward is not None:
|
|
# Convert reward back to MSE for tracking
|
|
# Since reward = exp(-5 * normalized_mse), we can reverse it
|
|
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
|
|
max_mse = (prediction.current_price * 0.1) ** 2
|
|
mse = normalized_mse * max_mse
|
|
tracker.total_mse += mse
|
|
|
|
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
|
|
"""
|
|
Get accuracy summary for all or specific symbol
|
|
|
|
Args:
|
|
symbol: Specific symbol (None for all)
|
|
|
|
Returns:
|
|
Nested dictionary with accuracy metrics
|
|
"""
|
|
summary = {}
|
|
symbols_to_summarize = [symbol] if symbol else self.symbols
|
|
|
|
with self.lock:
|
|
for sym in symbols_to_summarize:
|
|
if sym not in self.accuracy_tracker:
|
|
continue
|
|
|
|
summary[sym] = {}
|
|
|
|
for timeframe in self.timeframes:
|
|
tracker = self.accuracy_tracker[sym][timeframe]
|
|
summary[sym][timeframe.value] = {
|
|
'total_predictions': tracker.total_predictions,
|
|
'direction_accuracy': tracker.direction_accuracy,
|
|
'average_mse': tracker.average_mse,
|
|
'recent_predictions': len(tracker.prediction_history)
|
|
}
|
|
|
|
return summary
|
|
|
|
def get_training_data(self, symbol: str, timeframe: TimeFrame,
|
|
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
|
|
"""
|
|
Get recent evaluated predictions for training
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Specific timeframe
|
|
max_samples: Maximum number of samples to return
|
|
|
|
Returns:
|
|
List of (prediction, reward) tuples ready for training
|
|
"""
|
|
training_data = []
|
|
|
|
with self.lock:
|
|
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
|
return training_data
|
|
|
|
evaluated_predictions = [
|
|
p for p in self.predictions[symbol][timeframe]
|
|
if p.is_evaluated and p.mse_reward is not None
|
|
]
|
|
|
|
# Get most recent evaluated predictions
|
|
recent_predictions = list(evaluated_predictions)[-max_samples:]
|
|
|
|
for prediction in recent_predictions:
|
|
training_data.append((prediction, prediction.mse_reward))
|
|
|
|
return training_data
|
|
|
|
def cleanup_old_predictions(self, days_to_keep: int = 7):
|
|
"""
|
|
Clean up old predictions to manage memory
|
|
|
|
Args:
|
|
days_to_keep: Number of days of predictions to keep
|
|
"""
|
|
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
|
|
|
|
with self.lock:
|
|
for symbol in self.predictions:
|
|
for timeframe in self.timeframes:
|
|
# Filter out old predictions
|
|
old_count = len(self.predictions[symbol][timeframe])
|
|
|
|
self.predictions[symbol][timeframe] = deque(
|
|
[p for p in self.predictions[symbol][timeframe]
|
|
if p.timestamp > cutoff_time],
|
|
maxlen=100
|
|
)
|
|
|
|
new_count = len(self.predictions[symbol][timeframe])
|
|
removed_count = old_count - new_count
|
|
|
|
if removed_count > 0:
|
|
logger.info(f"Cleaned up {removed_count} old predictions for "
|
|
f"{symbol} {timeframe.value}")
|
|
|
|
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
|
|
"""
|
|
Force evaluation of all pending predictions for a specific timeframe
|
|
Useful for immediate training needs
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Specific timeframe to evaluate
|
|
|
|
Returns:
|
|
List of (prediction, reward) tuples
|
|
"""
|
|
results = []
|
|
|
|
with self.lock:
|
|
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
|
return results
|
|
|
|
# Evaluate all non-evaluated predictions
|
|
for prediction in self.predictions[symbol][timeframe]:
|
|
if not prediction.is_evaluated:
|
|
reward = self._calculate_prediction_reward(prediction)
|
|
if reward is not None:
|
|
results.append((prediction, reward))
|
|
self._update_accuracy_tracking(symbol, timeframe, prediction)
|
|
|
|
return results
|
|
|