ehanced training and reward - wip
This commit is contained in:
481
core/enhanced_reward_calculator.py
Normal file
481
core/enhanced_reward_calculator.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
Enhanced Reward Calculator for Reinforcement Learning Training
|
||||
|
||||
This module implements a comprehensive reward calculation system based on mean squared error
|
||||
between predictions and empirical outcomes. It tracks multiple timeframes separately and
|
||||
maintains prediction history for accurate reward computation.
|
||||
|
||||
Key Features:
|
||||
- MSE-based reward calculation for prediction accuracy
|
||||
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
|
||||
- Separate accuracy tracking for each timeframe
|
||||
- Prediction history tracking (last 6 predictions per timeframe)
|
||||
- Real-time training at each inference
|
||||
- Timeframe-aware inference scheduling
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeFrame(Enum):
|
||||
"""Supported timeframes for prediction"""
|
||||
SECONDS_1 = "1s"
|
||||
MINUTES_1 = "1m"
|
||||
HOURS_1 = "1h"
|
||||
DAYS_1 = "1d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionRecord:
|
||||
"""Individual prediction record with outcome tracking"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
predicted_price: float
|
||||
predicted_direction: int # -1: down, 0: neutral, 1: up
|
||||
confidence: float
|
||||
current_price: float
|
||||
model_name: str
|
||||
|
||||
# Outcome fields (set when outcome is determined)
|
||||
actual_price: Optional[float] = None
|
||||
actual_direction: Optional[int] = None
|
||||
outcome_timestamp: Optional[datetime] = None
|
||||
mse_reward: Optional[float] = None
|
||||
direction_correct: Optional[bool] = None
|
||||
is_evaluated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeframeAccuracy:
|
||||
"""Accuracy tracking for a specific timeframe"""
|
||||
timeframe: TimeFrame
|
||||
total_predictions: int = 0
|
||||
correct_directions: int = 0
|
||||
total_mse: float = 0.0
|
||||
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
|
||||
|
||||
@property
|
||||
def direction_accuracy(self) -> float:
|
||||
"""Calculate directional accuracy percentage"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return (self.correct_directions / self.total_predictions) * 100.0
|
||||
|
||||
@property
|
||||
def average_mse(self) -> float:
|
||||
"""Calculate average MSE"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return self.total_mse / self.total_predictions
|
||||
|
||||
|
||||
class EnhancedRewardCalculator:
|
||||
"""
|
||||
Enhanced reward calculator using MSE and multi-timeframe tracking
|
||||
|
||||
This calculator:
|
||||
1. Tracks predictions for multiple timeframes separately
|
||||
2. Calculates MSE-based rewards when outcomes are available
|
||||
3. Maintains prediction history for last 6 predictions per timeframe
|
||||
4. Provides separate accuracy metrics for each timeframe
|
||||
5. Enables real-time training at each inference
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
"""Initialize the enhanced reward calculator"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
|
||||
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
|
||||
|
||||
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
|
||||
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
|
||||
|
||||
# Evaluation timeouts for each timeframe (in seconds)
|
||||
self.evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
|
||||
# Price data cache for outcome evaluation
|
||||
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
|
||||
self.price_cache_max_size = 1000
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Initialize data structures
|
||||
self._initialize_data_structures()
|
||||
|
||||
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
|
||||
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
|
||||
|
||||
def _initialize_data_structures(self):
|
||||
"""Initialize nested data structures"""
|
||||
for symbol in self.symbols:
|
||||
self.predictions[symbol] = {}
|
||||
self.accuracy_tracker[symbol] = {}
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
|
||||
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
|
||||
|
||||
def add_prediction(self,
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Add a new prediction to track
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe for this prediction
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID for later reference
|
||||
"""
|
||||
with self.lock:
|
||||
prediction = PredictionRecord(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
||||
self.predictions[symbol][timeframe].append(prediction)
|
||||
|
||||
# Add to accuracy tracker history
|
||||
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
|
||||
|
||||
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
|
||||
|
||||
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
|
||||
f"direction={predicted_direction}, confidence={confidence:.3f}")
|
||||
|
||||
return prediction_id
|
||||
|
||||
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
|
||||
"""
|
||||
Update current price for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timestamp: Price timestamp (defaults to now)
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.price_cache:
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
self.price_cache[symbol].append((timestamp, price))
|
||||
|
||||
# Maintain cache size
|
||||
if len(self.price_cache[symbol]) > self.price_cache_max_size:
|
||||
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
|
||||
|
||||
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
|
||||
"""
|
||||
Evaluate pending predictions and calculate rewards
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol to evaluate (None for all symbols)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to list of (prediction, reward) tuples
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_evaluate = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_evaluate:
|
||||
if sym not in self.predictions:
|
||||
continue
|
||||
|
||||
results[sym] = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
predictions_to_evaluate = []
|
||||
|
||||
# Find predictions ready for evaluation
|
||||
for prediction in self.predictions[sym][timeframe]:
|
||||
if prediction.is_evaluated:
|
||||
continue
|
||||
|
||||
time_elapsed = (current_time - prediction.timestamp).total_seconds()
|
||||
timeout = self.evaluation_timeouts[timeframe]
|
||||
|
||||
if time_elapsed >= timeout:
|
||||
predictions_to_evaluate.append(prediction)
|
||||
|
||||
# Evaluate predictions
|
||||
for prediction in predictions_to_evaluate:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results[sym].append((prediction, reward))
|
||||
|
||||
# Update accuracy tracking
|
||||
self._update_accuracy_tracking(sym, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
|
||||
"""
|
||||
Calculate MSE-based reward for a prediction
|
||||
|
||||
Args:
|
||||
prediction: Prediction record to evaluate
|
||||
|
||||
Returns:
|
||||
Calculated reward or None if outcome cannot be determined
|
||||
"""
|
||||
# Get actual price at evaluation time
|
||||
actual_price = self._get_price_at_time(
|
||||
prediction.symbol,
|
||||
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
|
||||
)
|
||||
|
||||
if actual_price is None:
|
||||
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
|
||||
return None
|
||||
|
||||
# Calculate price change and direction
|
||||
price_change = actual_price - prediction.current_price
|
||||
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
||||
|
||||
# Calculate MSE reward
|
||||
price_error = actual_price - prediction.predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
|
||||
# Use exponential decay to heavily penalize large errors
|
||||
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
|
||||
|
||||
# Direction accuracy bonus/penalty
|
||||
direction_correct = (prediction.predicted_direction == actual_direction)
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Confidence scaling
|
||||
confidence_weight = prediction.confidence
|
||||
|
||||
# Final reward calculation
|
||||
base_reward = mse_reward + direction_bonus
|
||||
final_reward = base_reward * confidence_weight
|
||||
|
||||
# Update prediction record
|
||||
prediction.actual_price = actual_price
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.outcome_timestamp = datetime.now()
|
||||
prediction.mse_reward = final_reward
|
||||
prediction.direction_correct = direction_correct
|
||||
prediction.is_evaluated = True
|
||||
|
||||
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
|
||||
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
|
||||
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
|
||||
|
||||
return final_reward
|
||||
|
||||
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
|
||||
"""
|
||||
Get price for symbol at a specific time
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
target_time: Target timestamp
|
||||
|
||||
Returns:
|
||||
Price at target time or None if not available
|
||||
"""
|
||||
if symbol not in self.price_cache or not self.price_cache[symbol]:
|
||||
return None
|
||||
|
||||
# Find closest price to target time
|
||||
closest_price = None
|
||||
min_time_diff = float('inf')
|
||||
|
||||
for timestamp, price in self.price_cache[symbol]:
|
||||
time_diff = abs((timestamp - target_time).total_seconds())
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
closest_price = price
|
||||
|
||||
# Only return price if it's within reasonable time window (30 seconds)
|
||||
if min_time_diff <= 30:
|
||||
return closest_price
|
||||
|
||||
return None
|
||||
|
||||
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
|
||||
"""Update accuracy tracking for a timeframe"""
|
||||
tracker = self.accuracy_tracker[symbol][timeframe]
|
||||
|
||||
tracker.total_predictions += 1
|
||||
if prediction.direction_correct:
|
||||
tracker.correct_directions += 1
|
||||
|
||||
if prediction.mse_reward is not None:
|
||||
# Convert reward back to MSE for tracking
|
||||
# Since reward = exp(-5 * normalized_mse), we can reverse it
|
||||
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
|
||||
max_mse = (prediction.current_price * 0.1) ** 2
|
||||
mse = normalized_mse * max_mse
|
||||
tracker.total_mse += mse
|
||||
|
||||
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Get accuracy summary for all or specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Nested dictionary with accuracy metrics
|
||||
"""
|
||||
summary = {}
|
||||
symbols_to_summarize = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_summarize:
|
||||
if sym not in self.accuracy_tracker:
|
||||
continue
|
||||
|
||||
summary[sym] = {}
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
tracker = self.accuracy_tracker[sym][timeframe]
|
||||
summary[sym][timeframe.value] = {
|
||||
'total_predictions': tracker.total_predictions,
|
||||
'direction_accuracy': tracker.direction_accuracy,
|
||||
'average_mse': tracker.average_mse,
|
||||
'recent_predictions': len(tracker.prediction_history)
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def get_training_data(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Get recent evaluated predictions for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum number of samples to return
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples ready for training
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return training_data
|
||||
|
||||
evaluated_predictions = [
|
||||
p for p in self.predictions[symbol][timeframe]
|
||||
if p.is_evaluated and p.mse_reward is not None
|
||||
]
|
||||
|
||||
# Get most recent evaluated predictions
|
||||
recent_predictions = list(evaluated_predictions)[-max_samples:]
|
||||
|
||||
for prediction in recent_predictions:
|
||||
training_data.append((prediction, prediction.mse_reward))
|
||||
|
||||
return training_data
|
||||
|
||||
def cleanup_old_predictions(self, days_to_keep: int = 7):
|
||||
"""
|
||||
Clean up old predictions to manage memory
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of predictions to keep
|
||||
"""
|
||||
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self.lock:
|
||||
for symbol in self.predictions:
|
||||
for timeframe in self.timeframes:
|
||||
# Filter out old predictions
|
||||
old_count = len(self.predictions[symbol][timeframe])
|
||||
|
||||
self.predictions[symbol][timeframe] = deque(
|
||||
[p for p in self.predictions[symbol][timeframe]
|
||||
if p.timestamp > cutoff_time],
|
||||
maxlen=100
|
||||
)
|
||||
|
||||
new_count = len(self.predictions[symbol][timeframe])
|
||||
removed_count = old_count - new_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old predictions for "
|
||||
f"{symbol} {timeframe.value}")
|
||||
|
||||
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Force evaluation of all pending predictions for a specific timeframe
|
||||
Useful for immediate training needs
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe to evaluate
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples
|
||||
"""
|
||||
results = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return results
|
||||
|
||||
# Evaluate all non-evaluated predictions
|
||||
for prediction in self.predictions[symbol][timeframe]:
|
||||
if not prediction.is_evaluated:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results.append((prediction, reward))
|
||||
self._update_accuracy_tracking(symbol, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user