ehanced training and reward - wip
This commit is contained in:
481
core/enhanced_reward_calculator.py
Normal file
481
core/enhanced_reward_calculator.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
Enhanced Reward Calculator for Reinforcement Learning Training
|
||||
|
||||
This module implements a comprehensive reward calculation system based on mean squared error
|
||||
between predictions and empirical outcomes. It tracks multiple timeframes separately and
|
||||
maintains prediction history for accurate reward computation.
|
||||
|
||||
Key Features:
|
||||
- MSE-based reward calculation for prediction accuracy
|
||||
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
|
||||
- Separate accuracy tracking for each timeframe
|
||||
- Prediction history tracking (last 6 predictions per timeframe)
|
||||
- Real-time training at each inference
|
||||
- Timeframe-aware inference scheduling
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeFrame(Enum):
|
||||
"""Supported timeframes for prediction"""
|
||||
SECONDS_1 = "1s"
|
||||
MINUTES_1 = "1m"
|
||||
HOURS_1 = "1h"
|
||||
DAYS_1 = "1d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionRecord:
|
||||
"""Individual prediction record with outcome tracking"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
predicted_price: float
|
||||
predicted_direction: int # -1: down, 0: neutral, 1: up
|
||||
confidence: float
|
||||
current_price: float
|
||||
model_name: str
|
||||
|
||||
# Outcome fields (set when outcome is determined)
|
||||
actual_price: Optional[float] = None
|
||||
actual_direction: Optional[int] = None
|
||||
outcome_timestamp: Optional[datetime] = None
|
||||
mse_reward: Optional[float] = None
|
||||
direction_correct: Optional[bool] = None
|
||||
is_evaluated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeframeAccuracy:
|
||||
"""Accuracy tracking for a specific timeframe"""
|
||||
timeframe: TimeFrame
|
||||
total_predictions: int = 0
|
||||
correct_directions: int = 0
|
||||
total_mse: float = 0.0
|
||||
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
|
||||
|
||||
@property
|
||||
def direction_accuracy(self) -> float:
|
||||
"""Calculate directional accuracy percentage"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return (self.correct_directions / self.total_predictions) * 100.0
|
||||
|
||||
@property
|
||||
def average_mse(self) -> float:
|
||||
"""Calculate average MSE"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return self.total_mse / self.total_predictions
|
||||
|
||||
|
||||
class EnhancedRewardCalculator:
|
||||
"""
|
||||
Enhanced reward calculator using MSE and multi-timeframe tracking
|
||||
|
||||
This calculator:
|
||||
1. Tracks predictions for multiple timeframes separately
|
||||
2. Calculates MSE-based rewards when outcomes are available
|
||||
3. Maintains prediction history for last 6 predictions per timeframe
|
||||
4. Provides separate accuracy metrics for each timeframe
|
||||
5. Enables real-time training at each inference
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
"""Initialize the enhanced reward calculator"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
|
||||
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
|
||||
|
||||
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
|
||||
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
|
||||
|
||||
# Evaluation timeouts for each timeframe (in seconds)
|
||||
self.evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
|
||||
# Price data cache for outcome evaluation
|
||||
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
|
||||
self.price_cache_max_size = 1000
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Initialize data structures
|
||||
self._initialize_data_structures()
|
||||
|
||||
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
|
||||
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
|
||||
|
||||
def _initialize_data_structures(self):
|
||||
"""Initialize nested data structures"""
|
||||
for symbol in self.symbols:
|
||||
self.predictions[symbol] = {}
|
||||
self.accuracy_tracker[symbol] = {}
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
|
||||
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
|
||||
|
||||
def add_prediction(self,
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Add a new prediction to track
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe for this prediction
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID for later reference
|
||||
"""
|
||||
with self.lock:
|
||||
prediction = PredictionRecord(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
||||
self.predictions[symbol][timeframe].append(prediction)
|
||||
|
||||
# Add to accuracy tracker history
|
||||
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
|
||||
|
||||
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
|
||||
|
||||
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
|
||||
f"direction={predicted_direction}, confidence={confidence:.3f}")
|
||||
|
||||
return prediction_id
|
||||
|
||||
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
|
||||
"""
|
||||
Update current price for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timestamp: Price timestamp (defaults to now)
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.price_cache:
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
self.price_cache[symbol].append((timestamp, price))
|
||||
|
||||
# Maintain cache size
|
||||
if len(self.price_cache[symbol]) > self.price_cache_max_size:
|
||||
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
|
||||
|
||||
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
|
||||
"""
|
||||
Evaluate pending predictions and calculate rewards
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol to evaluate (None for all symbols)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to list of (prediction, reward) tuples
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_evaluate = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_evaluate:
|
||||
if sym not in self.predictions:
|
||||
continue
|
||||
|
||||
results[sym] = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
predictions_to_evaluate = []
|
||||
|
||||
# Find predictions ready for evaluation
|
||||
for prediction in self.predictions[sym][timeframe]:
|
||||
if prediction.is_evaluated:
|
||||
continue
|
||||
|
||||
time_elapsed = (current_time - prediction.timestamp).total_seconds()
|
||||
timeout = self.evaluation_timeouts[timeframe]
|
||||
|
||||
if time_elapsed >= timeout:
|
||||
predictions_to_evaluate.append(prediction)
|
||||
|
||||
# Evaluate predictions
|
||||
for prediction in predictions_to_evaluate:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results[sym].append((prediction, reward))
|
||||
|
||||
# Update accuracy tracking
|
||||
self._update_accuracy_tracking(sym, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
|
||||
"""
|
||||
Calculate MSE-based reward for a prediction
|
||||
|
||||
Args:
|
||||
prediction: Prediction record to evaluate
|
||||
|
||||
Returns:
|
||||
Calculated reward or None if outcome cannot be determined
|
||||
"""
|
||||
# Get actual price at evaluation time
|
||||
actual_price = self._get_price_at_time(
|
||||
prediction.symbol,
|
||||
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
|
||||
)
|
||||
|
||||
if actual_price is None:
|
||||
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
|
||||
return None
|
||||
|
||||
# Calculate price change and direction
|
||||
price_change = actual_price - prediction.current_price
|
||||
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
||||
|
||||
# Calculate MSE reward
|
||||
price_error = actual_price - prediction.predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
|
||||
# Use exponential decay to heavily penalize large errors
|
||||
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
|
||||
|
||||
# Direction accuracy bonus/penalty
|
||||
direction_correct = (prediction.predicted_direction == actual_direction)
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Confidence scaling
|
||||
confidence_weight = prediction.confidence
|
||||
|
||||
# Final reward calculation
|
||||
base_reward = mse_reward + direction_bonus
|
||||
final_reward = base_reward * confidence_weight
|
||||
|
||||
# Update prediction record
|
||||
prediction.actual_price = actual_price
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.outcome_timestamp = datetime.now()
|
||||
prediction.mse_reward = final_reward
|
||||
prediction.direction_correct = direction_correct
|
||||
prediction.is_evaluated = True
|
||||
|
||||
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
|
||||
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
|
||||
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
|
||||
|
||||
return final_reward
|
||||
|
||||
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
|
||||
"""
|
||||
Get price for symbol at a specific time
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
target_time: Target timestamp
|
||||
|
||||
Returns:
|
||||
Price at target time or None if not available
|
||||
"""
|
||||
if symbol not in self.price_cache or not self.price_cache[symbol]:
|
||||
return None
|
||||
|
||||
# Find closest price to target time
|
||||
closest_price = None
|
||||
min_time_diff = float('inf')
|
||||
|
||||
for timestamp, price in self.price_cache[symbol]:
|
||||
time_diff = abs((timestamp - target_time).total_seconds())
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
closest_price = price
|
||||
|
||||
# Only return price if it's within reasonable time window (30 seconds)
|
||||
if min_time_diff <= 30:
|
||||
return closest_price
|
||||
|
||||
return None
|
||||
|
||||
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
|
||||
"""Update accuracy tracking for a timeframe"""
|
||||
tracker = self.accuracy_tracker[symbol][timeframe]
|
||||
|
||||
tracker.total_predictions += 1
|
||||
if prediction.direction_correct:
|
||||
tracker.correct_directions += 1
|
||||
|
||||
if prediction.mse_reward is not None:
|
||||
# Convert reward back to MSE for tracking
|
||||
# Since reward = exp(-5 * normalized_mse), we can reverse it
|
||||
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
|
||||
max_mse = (prediction.current_price * 0.1) ** 2
|
||||
mse = normalized_mse * max_mse
|
||||
tracker.total_mse += mse
|
||||
|
||||
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Get accuracy summary for all or specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Nested dictionary with accuracy metrics
|
||||
"""
|
||||
summary = {}
|
||||
symbols_to_summarize = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_summarize:
|
||||
if sym not in self.accuracy_tracker:
|
||||
continue
|
||||
|
||||
summary[sym] = {}
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
tracker = self.accuracy_tracker[sym][timeframe]
|
||||
summary[sym][timeframe.value] = {
|
||||
'total_predictions': tracker.total_predictions,
|
||||
'direction_accuracy': tracker.direction_accuracy,
|
||||
'average_mse': tracker.average_mse,
|
||||
'recent_predictions': len(tracker.prediction_history)
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def get_training_data(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Get recent evaluated predictions for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum number of samples to return
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples ready for training
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return training_data
|
||||
|
||||
evaluated_predictions = [
|
||||
p for p in self.predictions[symbol][timeframe]
|
||||
if p.is_evaluated and p.mse_reward is not None
|
||||
]
|
||||
|
||||
# Get most recent evaluated predictions
|
||||
recent_predictions = list(evaluated_predictions)[-max_samples:]
|
||||
|
||||
for prediction in recent_predictions:
|
||||
training_data.append((prediction, prediction.mse_reward))
|
||||
|
||||
return training_data
|
||||
|
||||
def cleanup_old_predictions(self, days_to_keep: int = 7):
|
||||
"""
|
||||
Clean up old predictions to manage memory
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of predictions to keep
|
||||
"""
|
||||
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self.lock:
|
||||
for symbol in self.predictions:
|
||||
for timeframe in self.timeframes:
|
||||
# Filter out old predictions
|
||||
old_count = len(self.predictions[symbol][timeframe])
|
||||
|
||||
self.predictions[symbol][timeframe] = deque(
|
||||
[p for p in self.predictions[symbol][timeframe]
|
||||
if p.timestamp > cutoff_time],
|
||||
maxlen=100
|
||||
)
|
||||
|
||||
new_count = len(self.predictions[symbol][timeframe])
|
||||
removed_count = old_count - new_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old predictions for "
|
||||
f"{symbol} {timeframe.value}")
|
||||
|
||||
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Force evaluation of all pending predictions for a specific timeframe
|
||||
Useful for immediate training needs
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe to evaluate
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples
|
||||
"""
|
||||
results = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return results
|
||||
|
||||
# Evaluate all non-evaluated predictions
|
||||
for prediction in self.predictions[symbol][timeframe]:
|
||||
if not prediction.is_evaluated:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results.append((prediction, reward))
|
||||
self._update_accuracy_tracking(symbol, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
337
core/enhanced_reward_system_integration.py
Normal file
337
core/enhanced_reward_system_integration.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Enhanced Reward System Integration
|
||||
|
||||
This module provides a simple integration point for the new MSE-based reward system
|
||||
with the existing trading orchestrator and training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Easy integration with existing TradingOrchestrator
|
||||
- Minimal changes required to existing code
|
||||
- Backward compatibility maintained
|
||||
- Enhanced performance monitoring
|
||||
- Real-time training with MSE rewards
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
||||
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Complete integration of the enhanced reward system
|
||||
|
||||
This class provides a single integration point that can be easily added
|
||||
to the existing TradingOrchestrator to enable MSE-based rewards and
|
||||
multi-timeframe training.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Initialize the enhanced reward system integration
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Initialize core components
|
||||
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
|
||||
|
||||
self.inference_coordinator = TimeframeInferenceCoordinator(
|
||||
reward_calculator=self.reward_calculator,
|
||||
data_provider=getattr(orchestrator, 'data_provider', None),
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
self.training_adapter = EnhancedRLTrainingAdapter(
|
||||
reward_calculator=self.reward_calculator,
|
||||
inference_coordinator=self.inference_coordinator,
|
||||
orchestrator=orchestrator,
|
||||
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
||||
)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.integration_stats = {
|
||||
'start_time': None,
|
||||
'total_predictions_tracked': 0,
|
||||
'total_rewards_calculated': 0,
|
||||
'total_training_batches': 0
|
||||
}
|
||||
|
||||
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the enhanced reward system integration"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced reward system already running")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting Enhanced Reward System Integration")
|
||||
|
||||
# Start core components
|
||||
await self.inference_coordinator.start_coordination()
|
||||
await self.training_adapter.start_training_loop()
|
||||
|
||||
# Start price monitoring
|
||||
asyncio.create_task(self._price_monitoring_loop())
|
||||
|
||||
self.is_running = True
|
||||
self.integration_stats['start_time'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Enhanced Reward System Integration started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system integration: {e}")
|
||||
await self.stop_integration()
|
||||
|
||||
async def stop_integration(self):
|
||||
"""Stop the enhanced reward system integration"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Stopping Enhanced Reward System Integration")
|
||||
|
||||
# Stop components
|
||||
await self.inference_coordinator.stop_coordination()
|
||||
await self.training_adapter.stop_training_loop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
logger.info("Enhanced Reward System Integration stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced reward system integration: {e}")
|
||||
|
||||
async def _price_monitoring_loop(self):
|
||||
"""Monitor prices and update the reward calculator"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Update current prices for all symbols
|
||||
for symbol in self.symbols:
|
||||
current_price = await self._get_current_price(symbol)
|
||||
if current_price > 0:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
# Sleep for 1 second between updates
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in price monitoring loop: {e}")
|
||||
await asyncio.sleep(5.0) # Wait longer on error
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def add_prediction_manually(self, symbol: str, timeframe_str: str,
|
||||
predicted_price: float, predicted_direction: int,
|
||||
confidence: float, current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Manually add a prediction to the reward calculator
|
||||
|
||||
This method allows existing code to easily integrate with the new reward system
|
||||
without major changes.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID
|
||||
"""
|
||||
try:
|
||||
# Convert timeframe string to enum
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
self.integration_stats['total_predictions_tracked'] += 1
|
||||
|
||||
return prediction_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prediction manually: {e}")
|
||||
return ""
|
||||
|
||||
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy statistics for models
|
||||
|
||||
Args:
|
||||
model_name: Specific model name (None for all)
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with accuracy statistics
|
||||
"""
|
||||
try:
|
||||
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
|
||||
|
||||
if model_name:
|
||||
# Filter by model name in prediction history
|
||||
# This would require enhancing the reward calculator to track by model
|
||||
pass
|
||||
|
||||
return accuracy_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model accuracy: {e}")
|
||||
return {}
|
||||
|
||||
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
|
||||
"""
|
||||
Force immediate evaluation and training for debugging/testing
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
timeframe_str: Specific timeframe (None for all)
|
||||
"""
|
||||
try:
|
||||
if timeframe_str:
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_process:
|
||||
# Force evaluation of predictions
|
||||
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
|
||||
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
|
||||
else:
|
||||
# Evaluate all pending predictions
|
||||
for sym in (self.symbols if not symbol else [symbol]):
|
||||
results = self.reward_calculator.evaluate_predictions(sym)
|
||||
if sym in results:
|
||||
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in force evaluation and training: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
try:
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
|
||||
stats['training_adapter'] = self.training_adapter.get_training_statistics()
|
||||
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
# Add system status
|
||||
stats['is_running'] = self.is_running
|
||||
stats['components_running'] = {
|
||||
'inference_coordinator': self.inference_coordinator.running,
|
||||
'training_adapter': self.training_adapter.running
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting integration statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 7):
|
||||
"""Clean up old prediction data to manage memory"""
|
||||
try:
|
||||
self.reward_calculator.cleanup_old_predictions(days_to_keep)
|
||||
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old data: {e}")
|
||||
|
||||
|
||||
# Utility functions for easy integration
|
||||
|
||||
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Utility function to easily integrate enhanced rewards with an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
|
||||
Returns:
|
||||
EnhancedRewardSystemIntegration instance
|
||||
"""
|
||||
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
|
||||
|
||||
# Add integration as an attribute to the orchestrator for easy access
|
||||
setattr(orchestrator, 'enhanced_reward_system', integration)
|
||||
|
||||
logger.info("Enhanced reward system integrated with orchestrator")
|
||||
return integration
|
||||
|
||||
|
||||
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Start enhanced rewards for an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
"""
|
||||
if not hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
integrate_enhanced_rewards(orchestrator, symbols)
|
||||
|
||||
await orchestrator.enhanced_reward_system.start_integration()
|
||||
|
||||
|
||||
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
|
||||
predicted_price: float, direction: int, confidence: float,
|
||||
current_price: float, model_name: str) -> str:
|
||||
"""
|
||||
Helper function to add predictions to enhanced rewards from existing code
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance with enhanced_reward_system
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe string
|
||||
predicted_price: Predicted price
|
||||
direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model confidence
|
||||
current_price: Current market price
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Prediction ID
|
||||
"""
|
||||
if hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
return orchestrator.enhanced_reward_system.add_prediction_manually(
|
||||
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
|
||||
)
|
||||
|
||||
logger.warning("Enhanced reward system not integrated with orchestrator")
|
||||
return ""
|
||||
|
||||
572
core/enhanced_rl_training_adapter.py
Normal file
572
core/enhanced_rl_training_adapter.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Enhanced RL Training Adapter
|
||||
|
||||
This module integrates the new MSE-based reward system with existing RL training pipelines.
|
||||
It provides a bridge between the timeframe-aware inference coordinator and the existing
|
||||
model training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Integration with EnhancedRewardCalculator
|
||||
- Adaptation of existing RL models to new reward system
|
||||
- Real-time training triggers based on prediction outcomes
|
||||
- Multi-timeframe training coordination
|
||||
- Backward compatibility with existing training infrastructure
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingBatch:
|
||||
"""Training batch for RL models with enhanced reward data"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
states: List[np.ndarray]
|
||||
actions: List[int]
|
||||
rewards: List[float]
|
||||
next_states: List[np.ndarray]
|
||||
dones: List[bool]
|
||||
confidences: List[float]
|
||||
prediction_records: List[PredictionRecord]
|
||||
batch_timestamp: datetime
|
||||
|
||||
|
||||
class EnhancedRLTrainingAdapter:
|
||||
"""
|
||||
Adapter that integrates new reward system with existing RL training infrastructure
|
||||
|
||||
This adapter:
|
||||
1. Bridges new reward calculator with existing RL models
|
||||
2. Converts prediction records to RL training format
|
||||
3. Triggers real-time training based on reward evaluation
|
||||
4. Maintains compatibility with existing training systems
|
||||
5. Coordinates multi-timeframe training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
inference_coordinator: TimeframeInferenceCoordinator,
|
||||
orchestrator: Any = None,
|
||||
training_system: Any = None):
|
||||
"""
|
||||
Initialize the enhanced RL training adapter
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
inference_coordinator: Timeframe inference coordinator
|
||||
orchestrator: Trading orchestrator (optional)
|
||||
training_system: Enhanced realtime training system (optional)
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.inference_coordinator = inference_coordinator
|
||||
self.orchestrator = orchestrator
|
||||
self.training_system = training_system
|
||||
|
||||
# Model registry for training functions
|
||||
self.model_trainers: Dict[str, Any] = {}
|
||||
|
||||
# Training configuration
|
||||
self.min_batch_size = 8 # Minimum samples for training
|
||||
self.max_batch_size = 64 # Maximum samples per training batch
|
||||
self.training_interval_seconds = 5.0 # How often to check for training opportunities
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_training_batches': 0,
|
||||
'successful_training_calls': 0,
|
||||
'failed_training_calls': 0,
|
||||
'last_training_time': None,
|
||||
'training_times_per_model': {},
|
||||
'average_batch_sizes': {}
|
||||
}
|
||||
|
||||
# State conversion helpers
|
||||
self.state_builders: Dict[str, Any] = {}
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Running state
|
||||
self.running = False
|
||||
self.training_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("EnhancedRLTrainingAdapter initialized")
|
||||
self._register_default_model_handlers()
|
||||
|
||||
def _register_default_model_handlers(self):
|
||||
"""Register default model handlers for existing models"""
|
||||
# Register inference functions with the coordinator
|
||||
if self.inference_coordinator:
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'dqn_agent', self._dqn_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'cob_rl', self._cob_rl_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'enhanced_cnn', self._cnn_inference_wrapper
|
||||
)
|
||||
|
||||
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for DQN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
# Get base data for the symbol
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
return None
|
||||
|
||||
# Convert to DQN state format
|
||||
state = self._convert_to_dqn_state(base_data, context)
|
||||
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for COB RL model inference"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get COB features
|
||||
features = await self._get_cob_features(context.symbol)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run COB RL prediction
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': prediction.get('direction', 0),
|
||||
'confidence': prediction.get('confidence', 0.0),
|
||||
'change': prediction.get('change', 0.0),
|
||||
'model_features': features,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB RL inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for CNN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
|
||||
# Find CNN models in registry
|
||||
for model_name, model in self.orchestrator.model_registry.models.items():
|
||||
if 'cnn' in model_name.lower():
|
||||
# Get base data
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
continue
|
||||
|
||||
# Run CNN prediction
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
||||
"""Get base data for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
# Use orchestrator's data provider
|
||||
return await self.orchestrator._build_base_data(symbol)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting base data for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get COB features for a symbol"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get latest features from COB trader
|
||||
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
|
||||
if symbol in feature_buffers and feature_buffers[symbol]:
|
||||
latest_data = feature_buffers[symbol][-1]
|
||||
return latest_data.get('features')
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting COB features for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
"""Convert base data to DQN state format"""
|
||||
try:
|
||||
# Use existing state building logic if available
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
|
||||
|
||||
return self.orchestrator.enhanced_training_system._build_dqn_state(
|
||||
base_data, context.symbol
|
||||
)
|
||||
|
||||
# Fallback: create simple state representation
|
||||
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
||||
if feature_vector:
|
||||
return np.array(feature_vector, dtype=np.float32)
|
||||
|
||||
# Last resort: create minimal state
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to DQN state: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
async def start_training_loop(self):
|
||||
"""Start the enhanced training loop"""
|
||||
if self.running:
|
||||
logger.warning("Training loop already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.training_task = asyncio.create_task(self._training_loop())
|
||||
logger.info("Enhanced RL training loop started")
|
||||
|
||||
async def stop_training_loop(self):
|
||||
"""Stop the enhanced training loop"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.training_task:
|
||||
self.training_task.cancel()
|
||||
try:
|
||||
await self.training_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Enhanced RL training loop stopped")
|
||||
|
||||
async def _training_loop(self):
|
||||
"""Main training loop that processes evaluated predictions"""
|
||||
logger.info("Starting enhanced RL training loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process training for each symbol and timeframe
|
||||
for symbol in self.reward_calculator.symbols:
|
||||
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
|
||||
|
||||
# Get training data for this symbol/timeframe
|
||||
training_data = self.reward_calculator.get_training_data(
|
||||
symbol, timeframe, self.max_batch_size
|
||||
)
|
||||
|
||||
if len(training_data) >= self.min_batch_size:
|
||||
await self._process_training_batch(symbol, timeframe, training_data)
|
||||
|
||||
# Sleep between training checks
|
||||
await asyncio.sleep(self.training_interval_seconds)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
await asyncio.sleep(10) # Wait longer on error
|
||||
|
||||
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Process a training batch for a specific symbol/timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group training data by model
|
||||
model_batches = {}
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
model_name = prediction_record.model_name
|
||||
if model_name not in model_batches:
|
||||
model_batches[model_name] = []
|
||||
model_batches[model_name].append((prediction_record, reward))
|
||||
|
||||
# Process each model's batch
|
||||
for model_name, model_data in model_batches.items():
|
||||
if len(model_data) >= self.min_batch_size:
|
||||
await self._train_model_batch(model_name, symbol, timeframe, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Train a specific model with a batch of data
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
training_start = time.time()
|
||||
|
||||
# Convert to training batch format
|
||||
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
|
||||
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
# Call appropriate training function based on model type
|
||||
success = False
|
||||
|
||||
if 'dqn' in model_name.lower():
|
||||
success = await self._train_dqn_model(batch)
|
||||
elif 'cob' in model_name.lower():
|
||||
success = await self._train_cob_rl_model(batch)
|
||||
elif 'cnn' in model_name.lower():
|
||||
success = await self._train_cnn_model(batch)
|
||||
else:
|
||||
logger.warning(f"Unknown model type for training: {model_name}")
|
||||
|
||||
# Update statistics
|
||||
training_time = time.time() - training_start
|
||||
self._update_training_stats(model_name, batch, success, training_time)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples in {training_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model {model_name}: {e}")
|
||||
self._update_training_stats(model_name, None, False, 0)
|
||||
|
||||
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
|
||||
"""Create a training batch from prediction records and rewards"""
|
||||
try:
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
confidences = []
|
||||
prediction_records = []
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
direction = prediction_record.predicted_direction
|
||||
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(True) # Each prediction is treated as terminal
|
||||
confidences.append(prediction_record.confidence)
|
||||
prediction_records.append(prediction_record)
|
||||
|
||||
return TrainingBatch(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
dones=dones,
|
||||
confidences=confidences,
|
||||
prediction_records=prediction_records,
|
||||
batch_timestamp=datetime.now()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training batch: {e}")
|
||||
return None
|
||||
|
||||
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train DQN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Add experiences to memory
|
||||
for i in range(len(batch.states)):
|
||||
if hasattr(rl_agent, 'remember'):
|
||||
rl_agent.remember(
|
||||
state=batch.states[i],
|
||||
action=batch.actions[i],
|
||||
reward=batch.rewards[i],
|
||||
next_state=batch.next_states[i],
|
||||
done=batch.dones[i]
|
||||
)
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
|
||||
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN training loss: {loss:.6f}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train COB RL model with batch data"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Use COB RL trainer if available
|
||||
# This is a placeholder - implement based on actual COB RL training interface
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train CNN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
# Use enhanced training system for CNN training
|
||||
# This is a placeholder - implement based on actual CNN training interface
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return False
|
||||
|
||||
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
|
||||
success: bool, training_time: float):
|
||||
"""Update training statistics"""
|
||||
with self.lock:
|
||||
self.training_stats['total_training_batches'] += 1
|
||||
|
||||
if success:
|
||||
self.training_stats['successful_training_calls'] += 1
|
||||
else:
|
||||
self.training_stats['failed_training_calls'] += 1
|
||||
|
||||
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
||||
|
||||
# Model-specific stats
|
||||
if model_name not in self.training_stats['training_times_per_model']:
|
||||
self.training_stats['training_times_per_model'][model_name] = []
|
||||
self.training_stats['average_batch_sizes'][model_name] = []
|
||||
|
||||
self.training_stats['training_times_per_model'][model_name].append(training_time)
|
||||
|
||||
if batch:
|
||||
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
with self.lock:
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Calculate averages
|
||||
for model_name in stats['training_times_per_model']:
|
||||
times = stats['training_times_per_model'][model_name]
|
||||
if times:
|
||||
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
|
||||
|
||||
sizes = stats['average_batch_sizes'][model_name]
|
||||
if sizes:
|
||||
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
|
||||
|
||||
return stats
|
||||
|
||||
467
core/timeframe_inference_coordinator.py
Normal file
467
core/timeframe_inference_coordinator.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Timeframe-Aware Inference Coordinator
|
||||
|
||||
This module coordinates model inference across multiple timeframes with proper scheduling.
|
||||
It ensures that models know which timeframe they are predicting on and handles the
|
||||
complex scheduling requirements for multi-timeframe predictions.
|
||||
|
||||
Key Features:
|
||||
- Timeframe-aware model inference
|
||||
- Hourly multi-timeframe inference (4 predictions per hour)
|
||||
- Frequent inference at 1-5 second intervals
|
||||
- Prediction context management
|
||||
- Integration with enhanced reward calculator
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceContext:
|
||||
"""Context information for a model inference"""
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
timestamp: datetime
|
||||
target_timeframe: TimeFrame # Which timeframe we're predicting for
|
||||
is_hourly_inference: bool = False
|
||||
inference_type: str = "regular" # "regular", "hourly", "continuous"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceSchedule:
|
||||
"""Schedule configuration for different inference types"""
|
||||
continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds
|
||||
hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference
|
||||
|
||||
def __post_init__(self):
|
||||
if self.hourly_timeframes is None:
|
||||
self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
|
||||
class TimeframeInferenceCoordinator:
|
||||
"""
|
||||
Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
This coordinator:
|
||||
1. Manages continuous inference every 1-5 seconds on main timeframe
|
||||
2. Schedules hourly multi-timeframe inference (4 predictions per hour)
|
||||
3. Ensures models know which timeframe they're predicting on
|
||||
4. Integrates with enhanced reward calculator for training
|
||||
5. Handles prediction context and metadata
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
data_provider: Any = None,
|
||||
symbols: List[str] = None):
|
||||
"""
|
||||
Initialize the timeframe inference coordinator
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
data_provider: Data provider for market data
|
||||
symbols: List of symbols to coordinate inference for
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Inference schedule configuration
|
||||
self.schedule = InferenceSchedule()
|
||||
|
||||
# Model registry - stores inference functions for different models
|
||||
self.model_inference_functions: Dict[str, Callable] = {}
|
||||
|
||||
# Tracking inference state
|
||||
self.last_continuous_inference: Dict[str, datetime] = {}
|
||||
self.last_hourly_inference: Dict[str, datetime] = {}
|
||||
self.next_hourly_inference: Dict[str, datetime] = {}
|
||||
|
||||
# Active inference tasks
|
||||
self.inference_tasks: List[asyncio.Task] = []
|
||||
self.running = False
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Performance metrics
|
||||
self.inference_stats = {
|
||||
'continuous_inferences': 0,
|
||||
'hourly_inferences': 0,
|
||||
'failed_inferences': 0,
|
||||
'average_inference_time_ms': 0.0
|
||||
}
|
||||
|
||||
self._initialize_schedules()
|
||||
|
||||
logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s")
|
||||
logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}")
|
||||
|
||||
def _initialize_schedules(self):
|
||||
"""Initialize inference schedules for all symbols"""
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol in self.symbols:
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Schedule next hourly inference at the top of the next hour
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
|
||||
def register_model_inference_function(self, model_name: str, inference_func: Callable):
|
||||
"""
|
||||
Register a model's inference function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
inference_func: Async function that takes InferenceContext and returns prediction
|
||||
"""
|
||||
self.model_inference_functions[model_name] = inference_func
|
||||
logger.info(f"Registered inference function for model: {model_name}")
|
||||
|
||||
async def start_coordination(self):
|
||||
"""Start the inference coordination system"""
|
||||
if self.running:
|
||||
logger.warning("Inference coordination already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logger.info("Starting timeframe inference coordination")
|
||||
|
||||
# Start continuous inference tasks for each symbol
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._continuous_inference_loop(symbol))
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start hourly inference scheduler
|
||||
task = asyncio.create_task(self._hourly_inference_scheduler())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start reward evaluation loop
|
||||
task = asyncio.create_task(self._reward_evaluation_loop())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks")
|
||||
|
||||
async def stop_coordination(self):
|
||||
"""Stop the inference coordination system"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
logger.info("Stopping timeframe inference coordination")
|
||||
|
||||
# Cancel all tasks
|
||||
for task in self.inference_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(*self.inference_tasks, return_exceptions=True)
|
||||
self.inference_tasks.clear()
|
||||
|
||||
logger.info("Inference coordination stopped")
|
||||
|
||||
async def _continuous_inference_loop(self, symbol: str):
|
||||
"""
|
||||
Continuous inference loop for a specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to run inference for
|
||||
"""
|
||||
logger.info(f"Starting continuous inference loop for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if it's time for continuous inference
|
||||
last_inference = self.last_continuous_inference[symbol]
|
||||
time_since_last = (current_time - last_inference).total_seconds()
|
||||
|
||||
if time_since_last >= self.schedule.continuous_interval_seconds:
|
||||
# Run continuous inference on primary timeframe (1s)
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=TimeFrame.SECONDS_1,
|
||||
timestamp=current_time,
|
||||
target_timeframe=TimeFrame.SECONDS_1,
|
||||
is_hourly_inference=False,
|
||||
inference_type="continuous"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.inference_stats['continuous_inferences'] += 1
|
||||
|
||||
# Sleep for a short interval to avoid busy waiting
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous inference loop for {symbol}: {e}")
|
||||
await asyncio.sleep(1.0) # Wait longer on error
|
||||
|
||||
async def _hourly_inference_scheduler(self):
|
||||
"""Scheduler for hourly multi-timeframe inference"""
|
||||
logger.info("Starting hourly inference scheduler")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if any symbol needs hourly inference
|
||||
for symbol in self.symbols:
|
||||
if current_time >= self.next_hourly_inference[symbol]:
|
||||
await self._execute_hourly_inference(symbol, current_time)
|
||||
|
||||
# Schedule next hourly inference
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hourly inference scheduler: {e}")
|
||||
await asyncio.sleep(60) # Wait longer on error
|
||||
|
||||
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
||||
"""
|
||||
Execute hourly multi-timeframe inference for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Current timestamp
|
||||
"""
|
||||
logger.info(f"Executing hourly multi-timeframe inference for {symbol}")
|
||||
|
||||
# Run inference for each timeframe
|
||||
for timeframe in self.schedule.hourly_timeframes:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=True,
|
||||
inference_type="hourly"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.inference_stats['hourly_inferences'] += 1
|
||||
|
||||
# Small delay between timeframe inferences
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def _execute_inference(self, context: InferenceContext):
|
||||
"""
|
||||
Execute inference for a specific context
|
||||
|
||||
Args:
|
||||
context: Inference context containing all necessary information
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference for all registered models
|
||||
for model_name, inference_func in self.model_inference_functions.items():
|
||||
try:
|
||||
# Execute model inference
|
||||
prediction = await inference_func(context)
|
||||
|
||||
if prediction is not None:
|
||||
# Add prediction to reward calculator
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=context.symbol,
|
||||
timeframe=context.target_timeframe,
|
||||
predicted_price=prediction.get('predicted_price', 0.0),
|
||||
predicted_direction=prediction.get('direction', 0),
|
||||
confidence=prediction.get('confidence', 0.0),
|
||||
current_price=prediction.get('current_price', 0.0),
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
||||
f"for {context.symbol} {context.target_timeframe.value}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for model {model_name}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
# Update inference timing stats
|
||||
inference_time_ms = (time.time() - start_time) * 1000
|
||||
self._update_inference_timing(inference_time_ms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing inference for context {context}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
def _update_inference_timing(self, inference_time_ms: float):
|
||||
"""Update inference timing statistics"""
|
||||
total_inferences = (self.inference_stats['continuous_inferences'] +
|
||||
self.inference_stats['hourly_inferences'])
|
||||
|
||||
if total_inferences > 0:
|
||||
current_avg = self.inference_stats['average_inference_time_ms']
|
||||
new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences
|
||||
self.inference_stats['average_inference_time_ms'] = new_avg
|
||||
|
||||
async def _reward_evaluation_loop(self):
|
||||
"""Continuous loop for evaluating prediction rewards"""
|
||||
logger.info("Starting reward evaluation loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Update price cache if data provider available
|
||||
if self.data_provider:
|
||||
await self._update_price_cache()
|
||||
|
||||
# Evaluate predictions and get training data
|
||||
for symbol in self.symbols:
|
||||
evaluation_results = self.reward_calculator.evaluate_predictions(symbol)
|
||||
|
||||
if symbol in evaluation_results and evaluation_results[symbol]:
|
||||
logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}")
|
||||
|
||||
# Here you could trigger training for models that have new evaluated predictions
|
||||
await self._trigger_model_training(symbol, evaluation_results[symbol])
|
||||
|
||||
# Sleep for evaluation interval
|
||||
await asyncio.sleep(10) # Evaluate every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reward evaluation loop: {e}")
|
||||
await asyncio.sleep(30) # Wait longer on error
|
||||
|
||||
async def _update_price_cache(self):
|
||||
"""Update price cache with current market prices"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
# Get current price from data provider
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
current_price = await self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating price cache: {e}")
|
||||
|
||||
async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]):
|
||||
"""
|
||||
Trigger model training based on evaluation results
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
evaluation_results: List of (prediction, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group by model and timeframe for targeted training
|
||||
training_groups = {}
|
||||
|
||||
for prediction_record, reward in evaluation_results:
|
||||
model_name = prediction_record.model_name
|
||||
timeframe = prediction_record.timeframe
|
||||
|
||||
key = f"{model_name}_{timeframe.value}"
|
||||
if key not in training_groups:
|
||||
training_groups[key] = []
|
||||
|
||||
training_groups[key].append((prediction_record, reward))
|
||||
|
||||
# Trigger training for each group
|
||||
for group_key, training_data in training_groups.items():
|
||||
model_name, timeframe_str = group_key.split('_', 1)
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples")
|
||||
|
||||
# Here you would call the specific model's training function
|
||||
# This is a placeholder - you'll need to implement the actual training calls
|
||||
await self._call_model_training(model_name, symbol, timeframe, training_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training: {e}")
|
||||
|
||||
async def _call_model_training(self, model_name: str, symbol: str,
|
||||
timeframe: TimeFrame, training_data: List[Any]):
|
||||
"""
|
||||
Call model-specific training function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction, reward) tuples
|
||||
"""
|
||||
# This is a placeholder for model-specific training calls
|
||||
# You'll need to implement this based on your specific model interfaces
|
||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
||||
|
||||
def get_inference_statistics(self) -> Dict[str, Any]:
|
||||
"""Get inference coordination statistics"""
|
||||
with self.lock:
|
||||
stats = self.inference_stats.copy()
|
||||
|
||||
# Add scheduling information
|
||||
stats['symbols'] = self.symbols
|
||||
stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds
|
||||
stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes]
|
||||
stats['next_hourly_inferences'] = {
|
||||
symbol: timestamp.isoformat()
|
||||
for symbol, timestamp in self.next_hourly_inference.items()
|
||||
}
|
||||
|
||||
# Add accuracy summary from reward calculator
|
||||
stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
return stats
|
||||
|
||||
def force_hourly_inference(self, symbol: str = None):
|
||||
"""
|
||||
Force immediate hourly inference for symbol(s)
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all symbols)
|
||||
"""
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
async def _force_inference():
|
||||
current_time = datetime.now()
|
||||
for sym in symbols_to_process:
|
||||
await self._execute_hourly_inference(sym, current_time)
|
||||
|
||||
# Schedule the inference
|
||||
if self.running:
|
||||
asyncio.create_task(_force_inference())
|
||||
else:
|
||||
logger.warning("Cannot force inference - coordinator not running")
|
||||
|
||||
def get_prediction_history(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Any]:
|
||||
"""
|
||||
Get prediction history for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum samples to return
|
||||
|
||||
Returns:
|
||||
List of training samples
|
||||
"""
|
||||
return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)
|
||||
|
||||
Reference in New Issue
Block a user