ehanced training and reward - wip
This commit is contained in:
481
core/enhanced_reward_calculator.py
Normal file
481
core/enhanced_reward_calculator.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
Enhanced Reward Calculator for Reinforcement Learning Training
|
||||
|
||||
This module implements a comprehensive reward calculation system based on mean squared error
|
||||
between predictions and empirical outcomes. It tracks multiple timeframes separately and
|
||||
maintains prediction history for accurate reward computation.
|
||||
|
||||
Key Features:
|
||||
- MSE-based reward calculation for prediction accuracy
|
||||
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
|
||||
- Separate accuracy tracking for each timeframe
|
||||
- Prediction history tracking (last 6 predictions per timeframe)
|
||||
- Real-time training at each inference
|
||||
- Timeframe-aware inference scheduling
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeFrame(Enum):
|
||||
"""Supported timeframes for prediction"""
|
||||
SECONDS_1 = "1s"
|
||||
MINUTES_1 = "1m"
|
||||
HOURS_1 = "1h"
|
||||
DAYS_1 = "1d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionRecord:
|
||||
"""Individual prediction record with outcome tracking"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
predicted_price: float
|
||||
predicted_direction: int # -1: down, 0: neutral, 1: up
|
||||
confidence: float
|
||||
current_price: float
|
||||
model_name: str
|
||||
|
||||
# Outcome fields (set when outcome is determined)
|
||||
actual_price: Optional[float] = None
|
||||
actual_direction: Optional[int] = None
|
||||
outcome_timestamp: Optional[datetime] = None
|
||||
mse_reward: Optional[float] = None
|
||||
direction_correct: Optional[bool] = None
|
||||
is_evaluated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeframeAccuracy:
|
||||
"""Accuracy tracking for a specific timeframe"""
|
||||
timeframe: TimeFrame
|
||||
total_predictions: int = 0
|
||||
correct_directions: int = 0
|
||||
total_mse: float = 0.0
|
||||
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
|
||||
|
||||
@property
|
||||
def direction_accuracy(self) -> float:
|
||||
"""Calculate directional accuracy percentage"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return (self.correct_directions / self.total_predictions) * 100.0
|
||||
|
||||
@property
|
||||
def average_mse(self) -> float:
|
||||
"""Calculate average MSE"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return self.total_mse / self.total_predictions
|
||||
|
||||
|
||||
class EnhancedRewardCalculator:
|
||||
"""
|
||||
Enhanced reward calculator using MSE and multi-timeframe tracking
|
||||
|
||||
This calculator:
|
||||
1. Tracks predictions for multiple timeframes separately
|
||||
2. Calculates MSE-based rewards when outcomes are available
|
||||
3. Maintains prediction history for last 6 predictions per timeframe
|
||||
4. Provides separate accuracy metrics for each timeframe
|
||||
5. Enables real-time training at each inference
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
"""Initialize the enhanced reward calculator"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
|
||||
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
|
||||
|
||||
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
|
||||
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
|
||||
|
||||
# Evaluation timeouts for each timeframe (in seconds)
|
||||
self.evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
|
||||
# Price data cache for outcome evaluation
|
||||
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
|
||||
self.price_cache_max_size = 1000
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Initialize data structures
|
||||
self._initialize_data_structures()
|
||||
|
||||
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
|
||||
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
|
||||
|
||||
def _initialize_data_structures(self):
|
||||
"""Initialize nested data structures"""
|
||||
for symbol in self.symbols:
|
||||
self.predictions[symbol] = {}
|
||||
self.accuracy_tracker[symbol] = {}
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
|
||||
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
|
||||
|
||||
def add_prediction(self,
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Add a new prediction to track
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe for this prediction
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID for later reference
|
||||
"""
|
||||
with self.lock:
|
||||
prediction = PredictionRecord(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
||||
self.predictions[symbol][timeframe].append(prediction)
|
||||
|
||||
# Add to accuracy tracker history
|
||||
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
|
||||
|
||||
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
|
||||
|
||||
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
|
||||
f"direction={predicted_direction}, confidence={confidence:.3f}")
|
||||
|
||||
return prediction_id
|
||||
|
||||
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
|
||||
"""
|
||||
Update current price for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timestamp: Price timestamp (defaults to now)
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.price_cache:
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
self.price_cache[symbol].append((timestamp, price))
|
||||
|
||||
# Maintain cache size
|
||||
if len(self.price_cache[symbol]) > self.price_cache_max_size:
|
||||
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
|
||||
|
||||
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
|
||||
"""
|
||||
Evaluate pending predictions and calculate rewards
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol to evaluate (None for all symbols)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to list of (prediction, reward) tuples
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_evaluate = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_evaluate:
|
||||
if sym not in self.predictions:
|
||||
continue
|
||||
|
||||
results[sym] = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
predictions_to_evaluate = []
|
||||
|
||||
# Find predictions ready for evaluation
|
||||
for prediction in self.predictions[sym][timeframe]:
|
||||
if prediction.is_evaluated:
|
||||
continue
|
||||
|
||||
time_elapsed = (current_time - prediction.timestamp).total_seconds()
|
||||
timeout = self.evaluation_timeouts[timeframe]
|
||||
|
||||
if time_elapsed >= timeout:
|
||||
predictions_to_evaluate.append(prediction)
|
||||
|
||||
# Evaluate predictions
|
||||
for prediction in predictions_to_evaluate:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results[sym].append((prediction, reward))
|
||||
|
||||
# Update accuracy tracking
|
||||
self._update_accuracy_tracking(sym, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
|
||||
"""
|
||||
Calculate MSE-based reward for a prediction
|
||||
|
||||
Args:
|
||||
prediction: Prediction record to evaluate
|
||||
|
||||
Returns:
|
||||
Calculated reward or None if outcome cannot be determined
|
||||
"""
|
||||
# Get actual price at evaluation time
|
||||
actual_price = self._get_price_at_time(
|
||||
prediction.symbol,
|
||||
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
|
||||
)
|
||||
|
||||
if actual_price is None:
|
||||
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
|
||||
return None
|
||||
|
||||
# Calculate price change and direction
|
||||
price_change = actual_price - prediction.current_price
|
||||
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
||||
|
||||
# Calculate MSE reward
|
||||
price_error = actual_price - prediction.predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
|
||||
# Use exponential decay to heavily penalize large errors
|
||||
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
|
||||
|
||||
# Direction accuracy bonus/penalty
|
||||
direction_correct = (prediction.predicted_direction == actual_direction)
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Confidence scaling
|
||||
confidence_weight = prediction.confidence
|
||||
|
||||
# Final reward calculation
|
||||
base_reward = mse_reward + direction_bonus
|
||||
final_reward = base_reward * confidence_weight
|
||||
|
||||
# Update prediction record
|
||||
prediction.actual_price = actual_price
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.outcome_timestamp = datetime.now()
|
||||
prediction.mse_reward = final_reward
|
||||
prediction.direction_correct = direction_correct
|
||||
prediction.is_evaluated = True
|
||||
|
||||
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
|
||||
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
|
||||
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
|
||||
|
||||
return final_reward
|
||||
|
||||
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
|
||||
"""
|
||||
Get price for symbol at a specific time
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
target_time: Target timestamp
|
||||
|
||||
Returns:
|
||||
Price at target time or None if not available
|
||||
"""
|
||||
if symbol not in self.price_cache or not self.price_cache[symbol]:
|
||||
return None
|
||||
|
||||
# Find closest price to target time
|
||||
closest_price = None
|
||||
min_time_diff = float('inf')
|
||||
|
||||
for timestamp, price in self.price_cache[symbol]:
|
||||
time_diff = abs((timestamp - target_time).total_seconds())
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
closest_price = price
|
||||
|
||||
# Only return price if it's within reasonable time window (30 seconds)
|
||||
if min_time_diff <= 30:
|
||||
return closest_price
|
||||
|
||||
return None
|
||||
|
||||
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
|
||||
"""Update accuracy tracking for a timeframe"""
|
||||
tracker = self.accuracy_tracker[symbol][timeframe]
|
||||
|
||||
tracker.total_predictions += 1
|
||||
if prediction.direction_correct:
|
||||
tracker.correct_directions += 1
|
||||
|
||||
if prediction.mse_reward is not None:
|
||||
# Convert reward back to MSE for tracking
|
||||
# Since reward = exp(-5 * normalized_mse), we can reverse it
|
||||
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
|
||||
max_mse = (prediction.current_price * 0.1) ** 2
|
||||
mse = normalized_mse * max_mse
|
||||
tracker.total_mse += mse
|
||||
|
||||
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Get accuracy summary for all or specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Nested dictionary with accuracy metrics
|
||||
"""
|
||||
summary = {}
|
||||
symbols_to_summarize = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_summarize:
|
||||
if sym not in self.accuracy_tracker:
|
||||
continue
|
||||
|
||||
summary[sym] = {}
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
tracker = self.accuracy_tracker[sym][timeframe]
|
||||
summary[sym][timeframe.value] = {
|
||||
'total_predictions': tracker.total_predictions,
|
||||
'direction_accuracy': tracker.direction_accuracy,
|
||||
'average_mse': tracker.average_mse,
|
||||
'recent_predictions': len(tracker.prediction_history)
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def get_training_data(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Get recent evaluated predictions for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum number of samples to return
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples ready for training
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return training_data
|
||||
|
||||
evaluated_predictions = [
|
||||
p for p in self.predictions[symbol][timeframe]
|
||||
if p.is_evaluated and p.mse_reward is not None
|
||||
]
|
||||
|
||||
# Get most recent evaluated predictions
|
||||
recent_predictions = list(evaluated_predictions)[-max_samples:]
|
||||
|
||||
for prediction in recent_predictions:
|
||||
training_data.append((prediction, prediction.mse_reward))
|
||||
|
||||
return training_data
|
||||
|
||||
def cleanup_old_predictions(self, days_to_keep: int = 7):
|
||||
"""
|
||||
Clean up old predictions to manage memory
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of predictions to keep
|
||||
"""
|
||||
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self.lock:
|
||||
for symbol in self.predictions:
|
||||
for timeframe in self.timeframes:
|
||||
# Filter out old predictions
|
||||
old_count = len(self.predictions[symbol][timeframe])
|
||||
|
||||
self.predictions[symbol][timeframe] = deque(
|
||||
[p for p in self.predictions[symbol][timeframe]
|
||||
if p.timestamp > cutoff_time],
|
||||
maxlen=100
|
||||
)
|
||||
|
||||
new_count = len(self.predictions[symbol][timeframe])
|
||||
removed_count = old_count - new_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old predictions for "
|
||||
f"{symbol} {timeframe.value}")
|
||||
|
||||
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Force evaluation of all pending predictions for a specific timeframe
|
||||
Useful for immediate training needs
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe to evaluate
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples
|
||||
"""
|
||||
results = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return results
|
||||
|
||||
# Evaluate all non-evaluated predictions
|
||||
for prediction in self.predictions[symbol][timeframe]:
|
||||
if not prediction.is_evaluated:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results.append((prediction, reward))
|
||||
self._update_accuracy_tracking(symbol, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
337
core/enhanced_reward_system_integration.py
Normal file
337
core/enhanced_reward_system_integration.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Enhanced Reward System Integration
|
||||
|
||||
This module provides a simple integration point for the new MSE-based reward system
|
||||
with the existing trading orchestrator and training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Easy integration with existing TradingOrchestrator
|
||||
- Minimal changes required to existing code
|
||||
- Backward compatibility maintained
|
||||
- Enhanced performance monitoring
|
||||
- Real-time training with MSE rewards
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
||||
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Complete integration of the enhanced reward system
|
||||
|
||||
This class provides a single integration point that can be easily added
|
||||
to the existing TradingOrchestrator to enable MSE-based rewards and
|
||||
multi-timeframe training.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Initialize the enhanced reward system integration
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Initialize core components
|
||||
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
|
||||
|
||||
self.inference_coordinator = TimeframeInferenceCoordinator(
|
||||
reward_calculator=self.reward_calculator,
|
||||
data_provider=getattr(orchestrator, 'data_provider', None),
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
self.training_adapter = EnhancedRLTrainingAdapter(
|
||||
reward_calculator=self.reward_calculator,
|
||||
inference_coordinator=self.inference_coordinator,
|
||||
orchestrator=orchestrator,
|
||||
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
||||
)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.integration_stats = {
|
||||
'start_time': None,
|
||||
'total_predictions_tracked': 0,
|
||||
'total_rewards_calculated': 0,
|
||||
'total_training_batches': 0
|
||||
}
|
||||
|
||||
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the enhanced reward system integration"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced reward system already running")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting Enhanced Reward System Integration")
|
||||
|
||||
# Start core components
|
||||
await self.inference_coordinator.start_coordination()
|
||||
await self.training_adapter.start_training_loop()
|
||||
|
||||
# Start price monitoring
|
||||
asyncio.create_task(self._price_monitoring_loop())
|
||||
|
||||
self.is_running = True
|
||||
self.integration_stats['start_time'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Enhanced Reward System Integration started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system integration: {e}")
|
||||
await self.stop_integration()
|
||||
|
||||
async def stop_integration(self):
|
||||
"""Stop the enhanced reward system integration"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Stopping Enhanced Reward System Integration")
|
||||
|
||||
# Stop components
|
||||
await self.inference_coordinator.stop_coordination()
|
||||
await self.training_adapter.stop_training_loop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
logger.info("Enhanced Reward System Integration stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced reward system integration: {e}")
|
||||
|
||||
async def _price_monitoring_loop(self):
|
||||
"""Monitor prices and update the reward calculator"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Update current prices for all symbols
|
||||
for symbol in self.symbols:
|
||||
current_price = await self._get_current_price(symbol)
|
||||
if current_price > 0:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
# Sleep for 1 second between updates
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in price monitoring loop: {e}")
|
||||
await asyncio.sleep(5.0) # Wait longer on error
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def add_prediction_manually(self, symbol: str, timeframe_str: str,
|
||||
predicted_price: float, predicted_direction: int,
|
||||
confidence: float, current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Manually add a prediction to the reward calculator
|
||||
|
||||
This method allows existing code to easily integrate with the new reward system
|
||||
without major changes.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID
|
||||
"""
|
||||
try:
|
||||
# Convert timeframe string to enum
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
self.integration_stats['total_predictions_tracked'] += 1
|
||||
|
||||
return prediction_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prediction manually: {e}")
|
||||
return ""
|
||||
|
||||
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy statistics for models
|
||||
|
||||
Args:
|
||||
model_name: Specific model name (None for all)
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with accuracy statistics
|
||||
"""
|
||||
try:
|
||||
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
|
||||
|
||||
if model_name:
|
||||
# Filter by model name in prediction history
|
||||
# This would require enhancing the reward calculator to track by model
|
||||
pass
|
||||
|
||||
return accuracy_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model accuracy: {e}")
|
||||
return {}
|
||||
|
||||
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
|
||||
"""
|
||||
Force immediate evaluation and training for debugging/testing
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
timeframe_str: Specific timeframe (None for all)
|
||||
"""
|
||||
try:
|
||||
if timeframe_str:
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_process:
|
||||
# Force evaluation of predictions
|
||||
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
|
||||
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
|
||||
else:
|
||||
# Evaluate all pending predictions
|
||||
for sym in (self.symbols if not symbol else [symbol]):
|
||||
results = self.reward_calculator.evaluate_predictions(sym)
|
||||
if sym in results:
|
||||
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in force evaluation and training: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
try:
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
|
||||
stats['training_adapter'] = self.training_adapter.get_training_statistics()
|
||||
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
# Add system status
|
||||
stats['is_running'] = self.is_running
|
||||
stats['components_running'] = {
|
||||
'inference_coordinator': self.inference_coordinator.running,
|
||||
'training_adapter': self.training_adapter.running
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting integration statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 7):
|
||||
"""Clean up old prediction data to manage memory"""
|
||||
try:
|
||||
self.reward_calculator.cleanup_old_predictions(days_to_keep)
|
||||
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old data: {e}")
|
||||
|
||||
|
||||
# Utility functions for easy integration
|
||||
|
||||
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Utility function to easily integrate enhanced rewards with an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
|
||||
Returns:
|
||||
EnhancedRewardSystemIntegration instance
|
||||
"""
|
||||
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
|
||||
|
||||
# Add integration as an attribute to the orchestrator for easy access
|
||||
setattr(orchestrator, 'enhanced_reward_system', integration)
|
||||
|
||||
logger.info("Enhanced reward system integrated with orchestrator")
|
||||
return integration
|
||||
|
||||
|
||||
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Start enhanced rewards for an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
"""
|
||||
if not hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
integrate_enhanced_rewards(orchestrator, symbols)
|
||||
|
||||
await orchestrator.enhanced_reward_system.start_integration()
|
||||
|
||||
|
||||
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
|
||||
predicted_price: float, direction: int, confidence: float,
|
||||
current_price: float, model_name: str) -> str:
|
||||
"""
|
||||
Helper function to add predictions to enhanced rewards from existing code
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance with enhanced_reward_system
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe string
|
||||
predicted_price: Predicted price
|
||||
direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model confidence
|
||||
current_price: Current market price
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Prediction ID
|
||||
"""
|
||||
if hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
return orchestrator.enhanced_reward_system.add_prediction_manually(
|
||||
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
|
||||
)
|
||||
|
||||
logger.warning("Enhanced reward system not integrated with orchestrator")
|
||||
return ""
|
||||
|
572
core/enhanced_rl_training_adapter.py
Normal file
572
core/enhanced_rl_training_adapter.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Enhanced RL Training Adapter
|
||||
|
||||
This module integrates the new MSE-based reward system with existing RL training pipelines.
|
||||
It provides a bridge between the timeframe-aware inference coordinator and the existing
|
||||
model training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Integration with EnhancedRewardCalculator
|
||||
- Adaptation of existing RL models to new reward system
|
||||
- Real-time training triggers based on prediction outcomes
|
||||
- Multi-timeframe training coordination
|
||||
- Backward compatibility with existing training infrastructure
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingBatch:
|
||||
"""Training batch for RL models with enhanced reward data"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
states: List[np.ndarray]
|
||||
actions: List[int]
|
||||
rewards: List[float]
|
||||
next_states: List[np.ndarray]
|
||||
dones: List[bool]
|
||||
confidences: List[float]
|
||||
prediction_records: List[PredictionRecord]
|
||||
batch_timestamp: datetime
|
||||
|
||||
|
||||
class EnhancedRLTrainingAdapter:
|
||||
"""
|
||||
Adapter that integrates new reward system with existing RL training infrastructure
|
||||
|
||||
This adapter:
|
||||
1. Bridges new reward calculator with existing RL models
|
||||
2. Converts prediction records to RL training format
|
||||
3. Triggers real-time training based on reward evaluation
|
||||
4. Maintains compatibility with existing training systems
|
||||
5. Coordinates multi-timeframe training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
inference_coordinator: TimeframeInferenceCoordinator,
|
||||
orchestrator: Any = None,
|
||||
training_system: Any = None):
|
||||
"""
|
||||
Initialize the enhanced RL training adapter
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
inference_coordinator: Timeframe inference coordinator
|
||||
orchestrator: Trading orchestrator (optional)
|
||||
training_system: Enhanced realtime training system (optional)
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.inference_coordinator = inference_coordinator
|
||||
self.orchestrator = orchestrator
|
||||
self.training_system = training_system
|
||||
|
||||
# Model registry for training functions
|
||||
self.model_trainers: Dict[str, Any] = {}
|
||||
|
||||
# Training configuration
|
||||
self.min_batch_size = 8 # Minimum samples for training
|
||||
self.max_batch_size = 64 # Maximum samples per training batch
|
||||
self.training_interval_seconds = 5.0 # How often to check for training opportunities
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_training_batches': 0,
|
||||
'successful_training_calls': 0,
|
||||
'failed_training_calls': 0,
|
||||
'last_training_time': None,
|
||||
'training_times_per_model': {},
|
||||
'average_batch_sizes': {}
|
||||
}
|
||||
|
||||
# State conversion helpers
|
||||
self.state_builders: Dict[str, Any] = {}
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Running state
|
||||
self.running = False
|
||||
self.training_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("EnhancedRLTrainingAdapter initialized")
|
||||
self._register_default_model_handlers()
|
||||
|
||||
def _register_default_model_handlers(self):
|
||||
"""Register default model handlers for existing models"""
|
||||
# Register inference functions with the coordinator
|
||||
if self.inference_coordinator:
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'dqn_agent', self._dqn_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'cob_rl', self._cob_rl_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'enhanced_cnn', self._cnn_inference_wrapper
|
||||
)
|
||||
|
||||
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for DQN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
# Get base data for the symbol
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
return None
|
||||
|
||||
# Convert to DQN state format
|
||||
state = self._convert_to_dqn_state(base_data, context)
|
||||
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for COB RL model inference"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get COB features
|
||||
features = await self._get_cob_features(context.symbol)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run COB RL prediction
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': prediction.get('direction', 0),
|
||||
'confidence': prediction.get('confidence', 0.0),
|
||||
'change': prediction.get('change', 0.0),
|
||||
'model_features': features,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB RL inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for CNN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
|
||||
# Find CNN models in registry
|
||||
for model_name, model in self.orchestrator.model_registry.models.items():
|
||||
if 'cnn' in model_name.lower():
|
||||
# Get base data
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
continue
|
||||
|
||||
# Run CNN prediction
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
||||
"""Get base data for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
# Use orchestrator's data provider
|
||||
return await self.orchestrator._build_base_data(symbol)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting base data for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get COB features for a symbol"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get latest features from COB trader
|
||||
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
|
||||
if symbol in feature_buffers and feature_buffers[symbol]:
|
||||
latest_data = feature_buffers[symbol][-1]
|
||||
return latest_data.get('features')
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting COB features for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
"""Convert base data to DQN state format"""
|
||||
try:
|
||||
# Use existing state building logic if available
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
|
||||
|
||||
return self.orchestrator.enhanced_training_system._build_dqn_state(
|
||||
base_data, context.symbol
|
||||
)
|
||||
|
||||
# Fallback: create simple state representation
|
||||
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
||||
if feature_vector:
|
||||
return np.array(feature_vector, dtype=np.float32)
|
||||
|
||||
# Last resort: create minimal state
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to DQN state: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
async def start_training_loop(self):
|
||||
"""Start the enhanced training loop"""
|
||||
if self.running:
|
||||
logger.warning("Training loop already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.training_task = asyncio.create_task(self._training_loop())
|
||||
logger.info("Enhanced RL training loop started")
|
||||
|
||||
async def stop_training_loop(self):
|
||||
"""Stop the enhanced training loop"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.training_task:
|
||||
self.training_task.cancel()
|
||||
try:
|
||||
await self.training_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Enhanced RL training loop stopped")
|
||||
|
||||
async def _training_loop(self):
|
||||
"""Main training loop that processes evaluated predictions"""
|
||||
logger.info("Starting enhanced RL training loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process training for each symbol and timeframe
|
||||
for symbol in self.reward_calculator.symbols:
|
||||
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
|
||||
|
||||
# Get training data for this symbol/timeframe
|
||||
training_data = self.reward_calculator.get_training_data(
|
||||
symbol, timeframe, self.max_batch_size
|
||||
)
|
||||
|
||||
if len(training_data) >= self.min_batch_size:
|
||||
await self._process_training_batch(symbol, timeframe, training_data)
|
||||
|
||||
# Sleep between training checks
|
||||
await asyncio.sleep(self.training_interval_seconds)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
await asyncio.sleep(10) # Wait longer on error
|
||||
|
||||
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Process a training batch for a specific symbol/timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group training data by model
|
||||
model_batches = {}
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
model_name = prediction_record.model_name
|
||||
if model_name not in model_batches:
|
||||
model_batches[model_name] = []
|
||||
model_batches[model_name].append((prediction_record, reward))
|
||||
|
||||
# Process each model's batch
|
||||
for model_name, model_data in model_batches.items():
|
||||
if len(model_data) >= self.min_batch_size:
|
||||
await self._train_model_batch(model_name, symbol, timeframe, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Train a specific model with a batch of data
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
training_start = time.time()
|
||||
|
||||
# Convert to training batch format
|
||||
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
|
||||
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
# Call appropriate training function based on model type
|
||||
success = False
|
||||
|
||||
if 'dqn' in model_name.lower():
|
||||
success = await self._train_dqn_model(batch)
|
||||
elif 'cob' in model_name.lower():
|
||||
success = await self._train_cob_rl_model(batch)
|
||||
elif 'cnn' in model_name.lower():
|
||||
success = await self._train_cnn_model(batch)
|
||||
else:
|
||||
logger.warning(f"Unknown model type for training: {model_name}")
|
||||
|
||||
# Update statistics
|
||||
training_time = time.time() - training_start
|
||||
self._update_training_stats(model_name, batch, success, training_time)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples in {training_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model {model_name}: {e}")
|
||||
self._update_training_stats(model_name, None, False, 0)
|
||||
|
||||
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
|
||||
"""Create a training batch from prediction records and rewards"""
|
||||
try:
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
confidences = []
|
||||
prediction_records = []
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
direction = prediction_record.predicted_direction
|
||||
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(True) # Each prediction is treated as terminal
|
||||
confidences.append(prediction_record.confidence)
|
||||
prediction_records.append(prediction_record)
|
||||
|
||||
return TrainingBatch(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
dones=dones,
|
||||
confidences=confidences,
|
||||
prediction_records=prediction_records,
|
||||
batch_timestamp=datetime.now()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training batch: {e}")
|
||||
return None
|
||||
|
||||
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train DQN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Add experiences to memory
|
||||
for i in range(len(batch.states)):
|
||||
if hasattr(rl_agent, 'remember'):
|
||||
rl_agent.remember(
|
||||
state=batch.states[i],
|
||||
action=batch.actions[i],
|
||||
reward=batch.rewards[i],
|
||||
next_state=batch.next_states[i],
|
||||
done=batch.dones[i]
|
||||
)
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
|
||||
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN training loss: {loss:.6f}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train COB RL model with batch data"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Use COB RL trainer if available
|
||||
# This is a placeholder - implement based on actual COB RL training interface
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train CNN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
# Use enhanced training system for CNN training
|
||||
# This is a placeholder - implement based on actual CNN training interface
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return False
|
||||
|
||||
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
|
||||
success: bool, training_time: float):
|
||||
"""Update training statistics"""
|
||||
with self.lock:
|
||||
self.training_stats['total_training_batches'] += 1
|
||||
|
||||
if success:
|
||||
self.training_stats['successful_training_calls'] += 1
|
||||
else:
|
||||
self.training_stats['failed_training_calls'] += 1
|
||||
|
||||
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
||||
|
||||
# Model-specific stats
|
||||
if model_name not in self.training_stats['training_times_per_model']:
|
||||
self.training_stats['training_times_per_model'][model_name] = []
|
||||
self.training_stats['average_batch_sizes'][model_name] = []
|
||||
|
||||
self.training_stats['training_times_per_model'][model_name].append(training_time)
|
||||
|
||||
if batch:
|
||||
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
with self.lock:
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Calculate averages
|
||||
for model_name in stats['training_times_per_model']:
|
||||
times = stats['training_times_per_model'][model_name]
|
||||
if times:
|
||||
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
|
||||
|
||||
sizes = stats['average_batch_sizes'][model_name]
|
||||
if sizes:
|
||||
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
|
||||
|
||||
return stats
|
||||
|
467
core/timeframe_inference_coordinator.py
Normal file
467
core/timeframe_inference_coordinator.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Timeframe-Aware Inference Coordinator
|
||||
|
||||
This module coordinates model inference across multiple timeframes with proper scheduling.
|
||||
It ensures that models know which timeframe they are predicting on and handles the
|
||||
complex scheduling requirements for multi-timeframe predictions.
|
||||
|
||||
Key Features:
|
||||
- Timeframe-aware model inference
|
||||
- Hourly multi-timeframe inference (4 predictions per hour)
|
||||
- Frequent inference at 1-5 second intervals
|
||||
- Prediction context management
|
||||
- Integration with enhanced reward calculator
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceContext:
|
||||
"""Context information for a model inference"""
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
timestamp: datetime
|
||||
target_timeframe: TimeFrame # Which timeframe we're predicting for
|
||||
is_hourly_inference: bool = False
|
||||
inference_type: str = "regular" # "regular", "hourly", "continuous"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceSchedule:
|
||||
"""Schedule configuration for different inference types"""
|
||||
continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds
|
||||
hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference
|
||||
|
||||
def __post_init__(self):
|
||||
if self.hourly_timeframes is None:
|
||||
self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
|
||||
class TimeframeInferenceCoordinator:
|
||||
"""
|
||||
Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
This coordinator:
|
||||
1. Manages continuous inference every 1-5 seconds on main timeframe
|
||||
2. Schedules hourly multi-timeframe inference (4 predictions per hour)
|
||||
3. Ensures models know which timeframe they're predicting on
|
||||
4. Integrates with enhanced reward calculator for training
|
||||
5. Handles prediction context and metadata
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
data_provider: Any = None,
|
||||
symbols: List[str] = None):
|
||||
"""
|
||||
Initialize the timeframe inference coordinator
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
data_provider: Data provider for market data
|
||||
symbols: List of symbols to coordinate inference for
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Inference schedule configuration
|
||||
self.schedule = InferenceSchedule()
|
||||
|
||||
# Model registry - stores inference functions for different models
|
||||
self.model_inference_functions: Dict[str, Callable] = {}
|
||||
|
||||
# Tracking inference state
|
||||
self.last_continuous_inference: Dict[str, datetime] = {}
|
||||
self.last_hourly_inference: Dict[str, datetime] = {}
|
||||
self.next_hourly_inference: Dict[str, datetime] = {}
|
||||
|
||||
# Active inference tasks
|
||||
self.inference_tasks: List[asyncio.Task] = []
|
||||
self.running = False
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Performance metrics
|
||||
self.inference_stats = {
|
||||
'continuous_inferences': 0,
|
||||
'hourly_inferences': 0,
|
||||
'failed_inferences': 0,
|
||||
'average_inference_time_ms': 0.0
|
||||
}
|
||||
|
||||
self._initialize_schedules()
|
||||
|
||||
logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s")
|
||||
logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}")
|
||||
|
||||
def _initialize_schedules(self):
|
||||
"""Initialize inference schedules for all symbols"""
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol in self.symbols:
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Schedule next hourly inference at the top of the next hour
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
|
||||
def register_model_inference_function(self, model_name: str, inference_func: Callable):
|
||||
"""
|
||||
Register a model's inference function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
inference_func: Async function that takes InferenceContext and returns prediction
|
||||
"""
|
||||
self.model_inference_functions[model_name] = inference_func
|
||||
logger.info(f"Registered inference function for model: {model_name}")
|
||||
|
||||
async def start_coordination(self):
|
||||
"""Start the inference coordination system"""
|
||||
if self.running:
|
||||
logger.warning("Inference coordination already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logger.info("Starting timeframe inference coordination")
|
||||
|
||||
# Start continuous inference tasks for each symbol
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._continuous_inference_loop(symbol))
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start hourly inference scheduler
|
||||
task = asyncio.create_task(self._hourly_inference_scheduler())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start reward evaluation loop
|
||||
task = asyncio.create_task(self._reward_evaluation_loop())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks")
|
||||
|
||||
async def stop_coordination(self):
|
||||
"""Stop the inference coordination system"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
logger.info("Stopping timeframe inference coordination")
|
||||
|
||||
# Cancel all tasks
|
||||
for task in self.inference_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(*self.inference_tasks, return_exceptions=True)
|
||||
self.inference_tasks.clear()
|
||||
|
||||
logger.info("Inference coordination stopped")
|
||||
|
||||
async def _continuous_inference_loop(self, symbol: str):
|
||||
"""
|
||||
Continuous inference loop for a specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to run inference for
|
||||
"""
|
||||
logger.info(f"Starting continuous inference loop for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if it's time for continuous inference
|
||||
last_inference = self.last_continuous_inference[symbol]
|
||||
time_since_last = (current_time - last_inference).total_seconds()
|
||||
|
||||
if time_since_last >= self.schedule.continuous_interval_seconds:
|
||||
# Run continuous inference on primary timeframe (1s)
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=TimeFrame.SECONDS_1,
|
||||
timestamp=current_time,
|
||||
target_timeframe=TimeFrame.SECONDS_1,
|
||||
is_hourly_inference=False,
|
||||
inference_type="continuous"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.inference_stats['continuous_inferences'] += 1
|
||||
|
||||
# Sleep for a short interval to avoid busy waiting
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous inference loop for {symbol}: {e}")
|
||||
await asyncio.sleep(1.0) # Wait longer on error
|
||||
|
||||
async def _hourly_inference_scheduler(self):
|
||||
"""Scheduler for hourly multi-timeframe inference"""
|
||||
logger.info("Starting hourly inference scheduler")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if any symbol needs hourly inference
|
||||
for symbol in self.symbols:
|
||||
if current_time >= self.next_hourly_inference[symbol]:
|
||||
await self._execute_hourly_inference(symbol, current_time)
|
||||
|
||||
# Schedule next hourly inference
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hourly inference scheduler: {e}")
|
||||
await asyncio.sleep(60) # Wait longer on error
|
||||
|
||||
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
||||
"""
|
||||
Execute hourly multi-timeframe inference for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Current timestamp
|
||||
"""
|
||||
logger.info(f"Executing hourly multi-timeframe inference for {symbol}")
|
||||
|
||||
# Run inference for each timeframe
|
||||
for timeframe in self.schedule.hourly_timeframes:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=True,
|
||||
inference_type="hourly"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.inference_stats['hourly_inferences'] += 1
|
||||
|
||||
# Small delay between timeframe inferences
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def _execute_inference(self, context: InferenceContext):
|
||||
"""
|
||||
Execute inference for a specific context
|
||||
|
||||
Args:
|
||||
context: Inference context containing all necessary information
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference for all registered models
|
||||
for model_name, inference_func in self.model_inference_functions.items():
|
||||
try:
|
||||
# Execute model inference
|
||||
prediction = await inference_func(context)
|
||||
|
||||
if prediction is not None:
|
||||
# Add prediction to reward calculator
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=context.symbol,
|
||||
timeframe=context.target_timeframe,
|
||||
predicted_price=prediction.get('predicted_price', 0.0),
|
||||
predicted_direction=prediction.get('direction', 0),
|
||||
confidence=prediction.get('confidence', 0.0),
|
||||
current_price=prediction.get('current_price', 0.0),
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
||||
f"for {context.symbol} {context.target_timeframe.value}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for model {model_name}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
# Update inference timing stats
|
||||
inference_time_ms = (time.time() - start_time) * 1000
|
||||
self._update_inference_timing(inference_time_ms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing inference for context {context}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
def _update_inference_timing(self, inference_time_ms: float):
|
||||
"""Update inference timing statistics"""
|
||||
total_inferences = (self.inference_stats['continuous_inferences'] +
|
||||
self.inference_stats['hourly_inferences'])
|
||||
|
||||
if total_inferences > 0:
|
||||
current_avg = self.inference_stats['average_inference_time_ms']
|
||||
new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences
|
||||
self.inference_stats['average_inference_time_ms'] = new_avg
|
||||
|
||||
async def _reward_evaluation_loop(self):
|
||||
"""Continuous loop for evaluating prediction rewards"""
|
||||
logger.info("Starting reward evaluation loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Update price cache if data provider available
|
||||
if self.data_provider:
|
||||
await self._update_price_cache()
|
||||
|
||||
# Evaluate predictions and get training data
|
||||
for symbol in self.symbols:
|
||||
evaluation_results = self.reward_calculator.evaluate_predictions(symbol)
|
||||
|
||||
if symbol in evaluation_results and evaluation_results[symbol]:
|
||||
logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}")
|
||||
|
||||
# Here you could trigger training for models that have new evaluated predictions
|
||||
await self._trigger_model_training(symbol, evaluation_results[symbol])
|
||||
|
||||
# Sleep for evaluation interval
|
||||
await asyncio.sleep(10) # Evaluate every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reward evaluation loop: {e}")
|
||||
await asyncio.sleep(30) # Wait longer on error
|
||||
|
||||
async def _update_price_cache(self):
|
||||
"""Update price cache with current market prices"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
# Get current price from data provider
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
current_price = await self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating price cache: {e}")
|
||||
|
||||
async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]):
|
||||
"""
|
||||
Trigger model training based on evaluation results
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
evaluation_results: List of (prediction, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group by model and timeframe for targeted training
|
||||
training_groups = {}
|
||||
|
||||
for prediction_record, reward in evaluation_results:
|
||||
model_name = prediction_record.model_name
|
||||
timeframe = prediction_record.timeframe
|
||||
|
||||
key = f"{model_name}_{timeframe.value}"
|
||||
if key not in training_groups:
|
||||
training_groups[key] = []
|
||||
|
||||
training_groups[key].append((prediction_record, reward))
|
||||
|
||||
# Trigger training for each group
|
||||
for group_key, training_data in training_groups.items():
|
||||
model_name, timeframe_str = group_key.split('_', 1)
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples")
|
||||
|
||||
# Here you would call the specific model's training function
|
||||
# This is a placeholder - you'll need to implement the actual training calls
|
||||
await self._call_model_training(model_name, symbol, timeframe, training_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training: {e}")
|
||||
|
||||
async def _call_model_training(self, model_name: str, symbol: str,
|
||||
timeframe: TimeFrame, training_data: List[Any]):
|
||||
"""
|
||||
Call model-specific training function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction, reward) tuples
|
||||
"""
|
||||
# This is a placeholder for model-specific training calls
|
||||
# You'll need to implement this based on your specific model interfaces
|
||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
||||
|
||||
def get_inference_statistics(self) -> Dict[str, Any]:
|
||||
"""Get inference coordination statistics"""
|
||||
with self.lock:
|
||||
stats = self.inference_stats.copy()
|
||||
|
||||
# Add scheduling information
|
||||
stats['symbols'] = self.symbols
|
||||
stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds
|
||||
stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes]
|
||||
stats['next_hourly_inferences'] = {
|
||||
symbol: timestamp.isoformat()
|
||||
for symbol, timestamp in self.next_hourly_inference.items()
|
||||
}
|
||||
|
||||
# Add accuracy summary from reward calculator
|
||||
stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
return stats
|
||||
|
||||
def force_hourly_inference(self, symbol: str = None):
|
||||
"""
|
||||
Force immediate hourly inference for symbol(s)
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all symbols)
|
||||
"""
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
async def _force_inference():
|
||||
current_time = datetime.now()
|
||||
for sym in symbols_to_process:
|
||||
await self._execute_hourly_inference(sym, current_time)
|
||||
|
||||
# Schedule the inference
|
||||
if self.running:
|
||||
asyncio.create_task(_force_inference())
|
||||
else:
|
||||
logger.warning("Cannot force inference - coordinator not running")
|
||||
|
||||
def get_prediction_history(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Any]:
|
||||
"""
|
||||
Get prediction history for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum samples to return
|
||||
|
||||
Returns:
|
||||
List of training samples
|
||||
"""
|
||||
return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)
|
||||
|
349
docs/ENHANCED_REWARD_SYSTEM.md
Normal file
349
docs/ENHANCED_REWARD_SYSTEM.md
Normal file
@@ -0,0 +1,349 @@
|
||||
# Enhanced Reward System for Reinforcement Learning Training
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of an enhanced reward system for your reinforcement learning trading models. The system uses **mean squared error (MSE) between predictions and empirical outcomes** as the primary reward mechanism, with support for multiple timeframes and comprehensive accuracy tracking.
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ MSE-Based Reward Calculation
|
||||
- Uses mean squared difference between predicted and actual prices
|
||||
- Exponential decay function heavily penalizes large prediction errors
|
||||
- Direction accuracy bonus/penalty system
|
||||
- Confidence-weighted final rewards
|
||||
|
||||
### ✅ Multi-Timeframe Support
|
||||
- Separate tracking for **1s, 1m, 1h, 1d** timeframes
|
||||
- Independent accuracy metrics for each timeframe
|
||||
- Timeframe-specific evaluation timeouts
|
||||
- Models know which timeframe they're predicting on
|
||||
|
||||
### ✅ Prediction History Tracking
|
||||
- Maintains last **6 predictions per timeframe** per symbol
|
||||
- Comprehensive prediction records with outcomes
|
||||
- Historical accuracy analysis
|
||||
- Memory-efficient with automatic cleanup
|
||||
|
||||
### ✅ Real-Time Training
|
||||
- Training triggered at each inference when outcomes are available
|
||||
- Separate training batches for each model and timeframe
|
||||
- Automatic evaluation of predictions after appropriate timeouts
|
||||
- Integration with existing RL training infrastructure
|
||||
|
||||
### ✅ Enhanced Inference Scheduling
|
||||
- **Continuous inference** every 1-5 seconds on primary timeframe
|
||||
- **Hourly multi-timeframe inference** (4 predictions per hour - one for each timeframe)
|
||||
- Timeframe-aware inference context
|
||||
- Proper scheduling and coordination
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Market Data] --> B[Timeframe Inference Coordinator]
|
||||
B --> C[Model Inference]
|
||||
C --> D[Enhanced Reward Calculator]
|
||||
D --> E[Prediction Tracking]
|
||||
E --> F[Outcome Evaluation]
|
||||
F --> G[MSE Reward Calculation]
|
||||
G --> H[Enhanced RL Training Adapter]
|
||||
H --> I[Model Training]
|
||||
I --> J[Performance Monitoring]
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. EnhancedRewardCalculator (`core/enhanced_reward_calculator.py`)
|
||||
|
||||
**Purpose**: Central reward calculation engine using MSE methodology
|
||||
|
||||
**Key Methods**:
|
||||
- `add_prediction()` - Track new predictions
|
||||
- `evaluate_predictions()` - Calculate rewards when outcomes available
|
||||
- `get_accuracy_summary()` - Comprehensive accuracy metrics
|
||||
- `get_training_data()` - Extract training samples for models
|
||||
|
||||
**Reward Formula**:
|
||||
```python
|
||||
# MSE calculation
|
||||
price_error = actual_price - predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize to reasonable scale
|
||||
max_mse = (current_price * 0.1) ** 2 # 10% as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
|
||||
# Exponential decay (heavily penalize large errors)
|
||||
mse_reward = exp(-5 * normalized_mse) # Range: [exp(-5), 1]
|
||||
|
||||
# Direction bonus/penalty
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Final reward (confidence weighted)
|
||||
final_reward = (mse_reward + direction_bonus) * confidence
|
||||
```
|
||||
|
||||
### 2. TimeframeInferenceCoordinator (`core/timeframe_inference_coordinator.py`)
|
||||
|
||||
**Purpose**: Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
**Key Features**:
|
||||
- **Continuous inference loop** for each symbol (every 5 seconds)
|
||||
- **Hourly multi-timeframe scheduler** (4 predictions per hour)
|
||||
- **Inference context management** (models know target timeframe)
|
||||
- **Automatic reward evaluation** and training triggers
|
||||
|
||||
**Scheduling**:
|
||||
- **Every 5 seconds**: Inference on primary timeframe (1s)
|
||||
- **Every hour**: One inference for each timeframe (1s, 1m, 1h, 1d)
|
||||
- **Evaluation timeouts**: 5s for 1s predictions, 60s for 1m, 300s for 1h, 900s for 1d
|
||||
|
||||
### 3. EnhancedRLTrainingAdapter (`core/enhanced_rl_training_adapter.py`)
|
||||
|
||||
**Purpose**: Bridge between new reward system and existing RL training infrastructure
|
||||
|
||||
**Key Features**:
|
||||
- **Model inference wrappers** for DQN, COB RL, and CNN models
|
||||
- **Training batch creation** from prediction records and rewards
|
||||
- **Real-time training triggers** based on evaluation results
|
||||
- **Backward compatibility** with existing training systems
|
||||
|
||||
### 4. EnhancedRewardSystemIntegration (`core/enhanced_reward_system_integration.py`)
|
||||
|
||||
**Purpose**: Simple integration point for existing systems
|
||||
|
||||
**Key Features**:
|
||||
- **One-line integration** with existing TradingOrchestrator
|
||||
- **Helper functions** for easy prediction tracking
|
||||
- **Comprehensive monitoring** and statistics
|
||||
- **Minimal code changes** required
|
||||
|
||||
## Integration Guide
|
||||
|
||||
### Step 1: Import Required Components
|
||||
|
||||
Add to your `orchestrator.py`:
|
||||
|
||||
```python
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
```
|
||||
|
||||
### Step 2: Initialize in TradingOrchestrator
|
||||
|
||||
In your `TradingOrchestrator.__init__()`:
|
||||
|
||||
```python
|
||||
# Add this line after existing initialization
|
||||
integrate_enhanced_rewards(self, symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
```
|
||||
|
||||
### Step 3: Start the System
|
||||
|
||||
In your `TradingOrchestrator.run()` method:
|
||||
|
||||
```python
|
||||
# Add this line after initialization
|
||||
await self.enhanced_reward_system.start_integration()
|
||||
```
|
||||
|
||||
### Step 4: Track Predictions
|
||||
|
||||
In your model inference methods (CNN, DQN, COB RL):
|
||||
|
||||
```python
|
||||
# Example in CNN inference
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
self, # orchestrator instance
|
||||
symbol, # 'ETH/USDT'
|
||||
timeframe, # '1s', '1m', '1h', '1d'
|
||||
predicted_price, # model's price prediction
|
||||
direction, # -1 (down), 0 (neutral), 1 (up)
|
||||
confidence, # 0.0 to 1.0
|
||||
current_price, # current market price
|
||||
'enhanced_cnn' # model name
|
||||
)
|
||||
```
|
||||
|
||||
### Step 5: Monitor Performance
|
||||
|
||||
```python
|
||||
# Get comprehensive statistics
|
||||
stats = self.enhanced_reward_system.get_integration_statistics()
|
||||
accuracy = self.enhanced_reward_system.get_model_accuracy()
|
||||
|
||||
# Force evaluation for testing
|
||||
self.enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
See `examples/enhanced_reward_system_example.py` for a complete demonstration.
|
||||
|
||||
```bash
|
||||
python examples/enhanced_reward_system_example.py
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
### 🎯 Better Accuracy Measurement
|
||||
- **MSE rewards** provide nuanced feedback vs. simple directional accuracy
|
||||
- **Price prediction accuracy** measured alongside direction accuracy
|
||||
- **Confidence-weighted rewards** encourage well-calibrated predictions
|
||||
|
||||
### 📊 Multi-Timeframe Intelligence
|
||||
- **Separate tracking** prevents timeframe confusion
|
||||
- **Timeframe-specific evaluation** accounts for different market dynamics
|
||||
- **Comprehensive accuracy picture** across all prediction horizons
|
||||
|
||||
### ⚡ Real-Time Learning
|
||||
- **Immediate training** when prediction outcomes available
|
||||
- **No batch delays** - models learn from every prediction
|
||||
- **Adaptive training frequency** based on prediction evaluation
|
||||
|
||||
### 🔄 Enhanced Inference Scheduling
|
||||
- **Optimal prediction frequency** balances real-time response with computational efficiency
|
||||
- **Hourly multi-timeframe predictions** provide comprehensive market coverage
|
||||
- **Context-aware models** make better predictions knowing their target timeframe
|
||||
|
||||
## Configuration
|
||||
|
||||
### Evaluation Timeouts (Configurable in EnhancedRewardCalculator)
|
||||
|
||||
```python
|
||||
evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
```
|
||||
|
||||
### Inference Scheduling (Configurable in TimeframeInferenceCoordinator)
|
||||
|
||||
```python
|
||||
schedule = InferenceSchedule(
|
||||
continuous_interval_seconds=5.0, # Continuous inference every 5 seconds
|
||||
hourly_timeframes=[TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
)
|
||||
```
|
||||
|
||||
### Training Configuration (Configurable in EnhancedRLTrainingAdapter)
|
||||
|
||||
```python
|
||||
min_batch_size = 8 # Minimum samples for training
|
||||
max_batch_size = 64 # Maximum samples per training batch
|
||||
training_interval_seconds = 5.0 # Training check frequency
|
||||
```
|
||||
|
||||
## Monitoring and Statistics
|
||||
|
||||
### Integration Statistics
|
||||
|
||||
```python
|
||||
stats = enhanced_reward_system.get_integration_statistics()
|
||||
```
|
||||
|
||||
Returns:
|
||||
- System running status
|
||||
- Total predictions tracked
|
||||
- Component status
|
||||
- Inference and training statistics
|
||||
- Performance metrics
|
||||
|
||||
### Model Accuracy
|
||||
|
||||
```python
|
||||
accuracy = enhanced_reward_system.get_model_accuracy()
|
||||
```
|
||||
|
||||
Returns for each symbol and timeframe:
|
||||
- Total predictions made
|
||||
- Direction accuracy percentage
|
||||
- Average MSE
|
||||
- Recent prediction count
|
||||
|
||||
### Real-Time Monitoring
|
||||
|
||||
The system provides comprehensive logging at different levels:
|
||||
- `INFO`: Major system events, training results
|
||||
- `DEBUG`: Detailed prediction tracking, reward calculations
|
||||
- `ERROR`: System errors and recovery actions
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The enhanced reward system is designed to be **fully backward compatible**:
|
||||
|
||||
✅ **Existing models continue to work** without modification
|
||||
✅ **Existing training systems** remain functional
|
||||
✅ **Existing reward calculations** can run in parallel
|
||||
✅ **Gradual migration** - enable for specific models incrementally
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Force Evaluation for Testing
|
||||
|
||||
```python
|
||||
# Force immediate evaluation of all predictions
|
||||
enhanced_reward_system.force_evaluation_and_training()
|
||||
|
||||
# Force evaluation for specific symbol/timeframe
|
||||
enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
### Manual Prediction Addition
|
||||
|
||||
```python
|
||||
# Add predictions manually for testing
|
||||
prediction_id = enhanced_reward_system.add_prediction_manually(
|
||||
symbol='ETH/USDT',
|
||||
timeframe_str='1s',
|
||||
predicted_price=3150.50,
|
||||
predicted_direction=1,
|
||||
confidence=0.85,
|
||||
current_price=3150.00,
|
||||
model_name='test_model'
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
The system includes automatic memory management:
|
||||
|
||||
- **Automatic prediction cleanup** (configurable retention period)
|
||||
- **Circular buffers** for prediction history (max 100 per timeframe)
|
||||
- **Price cache management** (max 1000 price points per symbol)
|
||||
- **Efficient storage** using deques and compressed data structures
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
The architecture supports easy extension for:
|
||||
|
||||
1. **Additional timeframes** (30s, 5m, 15m, etc.)
|
||||
2. **Custom reward functions** (Sharpe ratio, maximum drawdown, etc.)
|
||||
3. **Multi-symbol correlation** rewards
|
||||
4. **Advanced statistical metrics** (Sortino ratio, Calmar ratio)
|
||||
5. **Model ensemble** reward aggregation
|
||||
6. **A/B testing** framework for reward functions
|
||||
|
||||
## Conclusion
|
||||
|
||||
The Enhanced Reward System provides a comprehensive foundation for improving RL model training through:
|
||||
|
||||
- **Precise MSE-based rewards** that accurately measure prediction quality
|
||||
- **Multi-timeframe intelligence** that prevents confusion between different prediction horizons
|
||||
- **Real-time learning** that maximizes training opportunities
|
||||
- **Easy integration** that requires minimal changes to existing code
|
||||
- **Comprehensive monitoring** that provides insights into model performance
|
||||
|
||||
This system addresses the specific requirements you outlined:
|
||||
✅ MSE-based accuracy calculation
|
||||
✅ Training at each inference using last prediction vs. current outcome
|
||||
✅ Separate accuracy tracking for up to 6 last predictions per timeframe
|
||||
✅ Models know which timeframe they're predicting on
|
||||
✅ Hourly multi-timeframe inference (4 predictions per hour)
|
||||
✅ Integration with existing 1-5 second inference frequency
|
||||
|
265
examples/enhanced_reward_system_example.py
Normal file
265
examples/enhanced_reward_system_example.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Enhanced Reward System Integration Example
|
||||
|
||||
This example demonstrates how to integrate the new MSE-based reward system
|
||||
with the existing trading orchestrator and models.
|
||||
|
||||
Usage:
|
||||
python examples/enhanced_reward_system_example.py
|
||||
|
||||
This example shows:
|
||||
1. How to integrate the enhanced reward system with TradingOrchestrator
|
||||
2. How to add predictions from existing models
|
||||
3. How to monitor accuracy and training statistics
|
||||
4. How the system handles multi-timeframe predictions and training
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Import the integration components
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
start_enhanced_rewards_for_orchestrator,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def demonstrate_enhanced_reward_integration():
|
||||
"""Demonstrate the enhanced reward system integration"""
|
||||
|
||||
print("=" * 80)
|
||||
print("ENHANCED REWARD SYSTEM INTEGRATION DEMONSTRATION")
|
||||
print("=" * 80)
|
||||
|
||||
# Note: This is a demonstration - in real usage, you would use your actual orchestrator
|
||||
# For this example, we'll create a mock orchestrator
|
||||
|
||||
print("\n1. Setting up mock orchestrator...")
|
||||
mock_orchestrator = create_mock_orchestrator()
|
||||
|
||||
print("\n2. Integrating enhanced reward system...")
|
||||
# This is the main integration step - just one line!
|
||||
enhanced_rewards = integrate_enhanced_rewards(mock_orchestrator, ['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
print("\n3. Starting enhanced reward system...")
|
||||
await start_enhanced_rewards_for_orchestrator(mock_orchestrator)
|
||||
|
||||
print("\n4. System is now running with enhanced rewards!")
|
||||
print(" - CNN predictions every 10 seconds (current rate)")
|
||||
print(" - Continuous inference every 5 seconds")
|
||||
print(" - Hourly multi-timeframe inference (4 predictions per hour)")
|
||||
print(" - Real-time MSE-based reward calculation")
|
||||
print(" - Automatic training when predictions are evaluated")
|
||||
|
||||
# Demonstrate adding predictions from existing models
|
||||
await demonstrate_prediction_tracking(mock_orchestrator)
|
||||
|
||||
# Demonstrate monitoring and statistics
|
||||
await demonstrate_monitoring(mock_orchestrator)
|
||||
|
||||
# Demonstrate force evaluation for testing
|
||||
await demonstrate_force_evaluation(mock_orchestrator)
|
||||
|
||||
print("\n8. Stopping enhanced reward system...")
|
||||
await mock_orchestrator.enhanced_reward_system.stop_integration()
|
||||
|
||||
print("\n✅ Enhanced Reward System demonstration completed successfully!")
|
||||
print("\nTo integrate with your actual system:")
|
||||
print("1. Add these imports to your orchestrator file")
|
||||
print("2. Call integrate_enhanced_rewards(your_orchestrator) in __init__")
|
||||
print("3. Call await start_enhanced_rewards_for_orchestrator(your_orchestrator) in run()")
|
||||
print("4. Use add_prediction_to_enhanced_rewards() in your model inference code")
|
||||
|
||||
|
||||
async def demonstrate_prediction_tracking(orchestrator):
|
||||
"""Demonstrate how to track predictions from existing models"""
|
||||
|
||||
print("\n5. Demonstrating prediction tracking...")
|
||||
|
||||
# Simulate predictions from different models and timeframes
|
||||
predictions = [
|
||||
# CNN predictions for multiple timeframes
|
||||
('ETH/USDT', '1s', 3150.50, 1, 0.85, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1m', 3155.00, 1, 0.78, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1h', 3200.00, 1, 0.72, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1d', 3300.00, 1, 0.65, 3150.00, 'enhanced_cnn'),
|
||||
|
||||
# DQN predictions
|
||||
('ETH/USDT', '1s', 3149.00, -1, 0.70, 3150.00, 'dqn_agent'),
|
||||
('BTC/USDT', '1s', 51200.00, 1, 0.75, 51150.00, 'dqn_agent'),
|
||||
|
||||
# COB RL predictions
|
||||
('ETH/USDT', '1s', 3151.20, 1, 0.88, 3150.00, 'cob_rl'),
|
||||
('BTC/USDT', '1s', 51180.00, 1, 0.82, 51150.00, 'cob_rl'),
|
||||
]
|
||||
|
||||
prediction_ids = []
|
||||
for symbol, timeframe, pred_price, direction, confidence, curr_price, model in predictions:
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
orchestrator, symbol, timeframe, pred_price, direction, confidence, curr_price, model
|
||||
)
|
||||
prediction_ids.append(prediction_id)
|
||||
print(f" ✓ Added prediction: {model} predicts {symbol} {timeframe} "
|
||||
f"direction={direction} confidence={confidence:.2f}")
|
||||
|
||||
print(f" 📊 Total predictions added: {len(prediction_ids)}")
|
||||
|
||||
|
||||
async def demonstrate_monitoring(orchestrator):
|
||||
"""Demonstrate monitoring and statistics"""
|
||||
|
||||
print("\n6. Demonstrating monitoring and statistics...")
|
||||
|
||||
# Wait a bit for some processing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Get integration statistics
|
||||
stats = orchestrator.enhanced_reward_system.get_integration_statistics()
|
||||
|
||||
print(" 📈 Integration Statistics:")
|
||||
print(f" - System running: {stats.get('is_running', False)}")
|
||||
print(f" - Start time: {stats.get('start_time', 'N/A')}")
|
||||
print(f" - Predictions tracked: {stats.get('total_predictions_tracked', 0)}")
|
||||
|
||||
# Get accuracy summary
|
||||
accuracy = orchestrator.enhanced_reward_system.get_model_accuracy()
|
||||
print("\n 🎯 Accuracy Summary by Symbol and Timeframe:")
|
||||
for symbol, timeframes in accuracy.items():
|
||||
print(f" - {symbol}:")
|
||||
for timeframe, metrics in timeframes.items():
|
||||
print(f" - {timeframe}: {metrics['total_predictions']} predictions, "
|
||||
f"{metrics['direction_accuracy']:.1f}% accuracy")
|
||||
|
||||
|
||||
async def demonstrate_force_evaluation(orchestrator):
|
||||
"""Demonstrate force evaluation for testing"""
|
||||
|
||||
print("\n7. Demonstrating force evaluation for testing...")
|
||||
|
||||
# Simulate some price changes by updating prices
|
||||
print(" 💰 Simulating price changes...")
|
||||
orchestrator.enhanced_reward_system.reward_calculator.update_price('ETH/USDT', 3152.50)
|
||||
orchestrator.enhanced_reward_system.reward_calculator.update_price('BTC/USDT', 51175.00)
|
||||
|
||||
# Force evaluation of all predictions
|
||||
print(" ⚡ Force evaluating all predictions...")
|
||||
orchestrator.enhanced_reward_system.force_evaluation_and_training()
|
||||
|
||||
# Get updated statistics
|
||||
await asyncio.sleep(1)
|
||||
stats = orchestrator.enhanced_reward_system.get_integration_statistics()
|
||||
|
||||
print(" 📊 Updated statistics after evaluation:")
|
||||
accuracy = orchestrator.enhanced_reward_system.get_model_accuracy()
|
||||
total_evaluated = sum(
|
||||
sum(tf_data['total_predictions'] for tf_data in symbol_data.values())
|
||||
for symbol_data in accuracy.values()
|
||||
)
|
||||
print(f" - Total predictions evaluated: {total_evaluated}")
|
||||
|
||||
|
||||
def create_mock_orchestrator():
|
||||
"""Create a mock orchestrator for demonstration purposes"""
|
||||
|
||||
class MockDataProvider:
|
||||
def __init__(self):
|
||||
self.current_prices = {
|
||||
'ETH/USDT': 3150.00,
|
||||
'BTC/USDT': 51150.00
|
||||
}
|
||||
|
||||
class MockOrchestrator:
|
||||
def __init__(self):
|
||||
self.data_provider = MockDataProvider()
|
||||
# Add other mock attributes as needed
|
||||
|
||||
return MockOrchestrator()
|
||||
|
||||
|
||||
def show_integration_instructions():
|
||||
"""Show step-by-step integration instructions"""
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("INTEGRATION INSTRUCTIONS FOR YOUR ACTUAL SYSTEM")
|
||||
print("=" * 80)
|
||||
|
||||
print("""
|
||||
To integrate the enhanced reward system with your actual TradingOrchestrator:
|
||||
|
||||
1. ADD IMPORTS to your orchestrator.py:
|
||||
```python
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
```
|
||||
|
||||
2. INTEGRATE in TradingOrchestrator.__init__():
|
||||
```python
|
||||
# Add this line in your __init__ method
|
||||
integrate_enhanced_rewards(self, symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
```
|
||||
|
||||
3. START in TradingOrchestrator.run():
|
||||
```python
|
||||
# Add this line in your run() method, after initialization
|
||||
await self.enhanced_reward_system.start_integration()
|
||||
```
|
||||
|
||||
4. ADD PREDICTIONS in your model inference code:
|
||||
```python
|
||||
# In your CNN/DQN/COB model inference methods, add:
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
self, # orchestrator instance
|
||||
symbol, # e.g., 'ETH/USDT'
|
||||
timeframe, # e.g., '1s', '1m', '1h', '1d'
|
||||
predicted_price, # model's price prediction
|
||||
direction, # -1 (down), 0 (neutral), 1 (up)
|
||||
confidence, # 0.0 to 1.0
|
||||
current_price, # current market price
|
||||
model_name # e.g., 'enhanced_cnn', 'dqn_agent'
|
||||
)
|
||||
```
|
||||
|
||||
5. MONITOR with:
|
||||
```python
|
||||
# Get statistics anytime
|
||||
stats = self.enhanced_reward_system.get_integration_statistics()
|
||||
accuracy = self.enhanced_reward_system.get_model_accuracy()
|
||||
```
|
||||
|
||||
The system will automatically:
|
||||
- Track predictions for multiple timeframes separately
|
||||
- Calculate MSE-based rewards when outcomes are available
|
||||
- Trigger real-time training with enhanced rewards
|
||||
- Maintain accuracy statistics for each model and timeframe
|
||||
- Handle hourly multi-timeframe inference scheduling
|
||||
|
||||
Key Benefits:
|
||||
✅ MSE-based accuracy measurement (better than simple directional accuracy)
|
||||
✅ Separate tracking for up to 6 last predictions per timeframe
|
||||
✅ Real-time training at each inference when outcomes available
|
||||
✅ Multi-timeframe prediction support (1s, 1m, 1h, 1d)
|
||||
✅ Hourly inference on all timeframes (4 predictions per hour)
|
||||
✅ Models know which timeframe they're predicting on
|
||||
✅ Backward compatible with existing code
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the demonstration
|
||||
asyncio.run(demonstrate_enhanced_reward_integration())
|
||||
|
||||
# Show integration instructions
|
||||
show_integration_instructions()
|
||||
|
Reference in New Issue
Block a user