ehanced training and reward - wip
This commit is contained in:
572
core/enhanced_rl_training_adapter.py
Normal file
572
core/enhanced_rl_training_adapter.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Enhanced RL Training Adapter
|
||||
|
||||
This module integrates the new MSE-based reward system with existing RL training pipelines.
|
||||
It provides a bridge between the timeframe-aware inference coordinator and the existing
|
||||
model training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Integration with EnhancedRewardCalculator
|
||||
- Adaptation of existing RL models to new reward system
|
||||
- Real-time training triggers based on prediction outcomes
|
||||
- Multi-timeframe training coordination
|
||||
- Backward compatibility with existing training infrastructure
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingBatch:
|
||||
"""Training batch for RL models with enhanced reward data"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
states: List[np.ndarray]
|
||||
actions: List[int]
|
||||
rewards: List[float]
|
||||
next_states: List[np.ndarray]
|
||||
dones: List[bool]
|
||||
confidences: List[float]
|
||||
prediction_records: List[PredictionRecord]
|
||||
batch_timestamp: datetime
|
||||
|
||||
|
||||
class EnhancedRLTrainingAdapter:
|
||||
"""
|
||||
Adapter that integrates new reward system with existing RL training infrastructure
|
||||
|
||||
This adapter:
|
||||
1. Bridges new reward calculator with existing RL models
|
||||
2. Converts prediction records to RL training format
|
||||
3. Triggers real-time training based on reward evaluation
|
||||
4. Maintains compatibility with existing training systems
|
||||
5. Coordinates multi-timeframe training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
inference_coordinator: TimeframeInferenceCoordinator,
|
||||
orchestrator: Any = None,
|
||||
training_system: Any = None):
|
||||
"""
|
||||
Initialize the enhanced RL training adapter
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
inference_coordinator: Timeframe inference coordinator
|
||||
orchestrator: Trading orchestrator (optional)
|
||||
training_system: Enhanced realtime training system (optional)
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.inference_coordinator = inference_coordinator
|
||||
self.orchestrator = orchestrator
|
||||
self.training_system = training_system
|
||||
|
||||
# Model registry for training functions
|
||||
self.model_trainers: Dict[str, Any] = {}
|
||||
|
||||
# Training configuration
|
||||
self.min_batch_size = 8 # Minimum samples for training
|
||||
self.max_batch_size = 64 # Maximum samples per training batch
|
||||
self.training_interval_seconds = 5.0 # How often to check for training opportunities
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_training_batches': 0,
|
||||
'successful_training_calls': 0,
|
||||
'failed_training_calls': 0,
|
||||
'last_training_time': None,
|
||||
'training_times_per_model': {},
|
||||
'average_batch_sizes': {}
|
||||
}
|
||||
|
||||
# State conversion helpers
|
||||
self.state_builders: Dict[str, Any] = {}
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Running state
|
||||
self.running = False
|
||||
self.training_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("EnhancedRLTrainingAdapter initialized")
|
||||
self._register_default_model_handlers()
|
||||
|
||||
def _register_default_model_handlers(self):
|
||||
"""Register default model handlers for existing models"""
|
||||
# Register inference functions with the coordinator
|
||||
if self.inference_coordinator:
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'dqn_agent', self._dqn_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'cob_rl', self._cob_rl_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'enhanced_cnn', self._cnn_inference_wrapper
|
||||
)
|
||||
|
||||
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for DQN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
# Get base data for the symbol
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
return None
|
||||
|
||||
# Convert to DQN state format
|
||||
state = self._convert_to_dqn_state(base_data, context)
|
||||
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for COB RL model inference"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get COB features
|
||||
features = await self._get_cob_features(context.symbol)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run COB RL prediction
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': prediction.get('direction', 0),
|
||||
'confidence': prediction.get('confidence', 0.0),
|
||||
'change': prediction.get('change', 0.0),
|
||||
'model_features': features,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB RL inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for CNN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
|
||||
# Find CNN models in registry
|
||||
for model_name, model in self.orchestrator.model_registry.models.items():
|
||||
if 'cnn' in model_name.lower():
|
||||
# Get base data
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
continue
|
||||
|
||||
# Run CNN prediction
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
||||
"""Get base data for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
# Use orchestrator's data provider
|
||||
return await self.orchestrator._build_base_data(symbol)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting base data for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get COB features for a symbol"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get latest features from COB trader
|
||||
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
|
||||
if symbol in feature_buffers and feature_buffers[symbol]:
|
||||
latest_data = feature_buffers[symbol][-1]
|
||||
return latest_data.get('features')
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting COB features for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
"""Convert base data to DQN state format"""
|
||||
try:
|
||||
# Use existing state building logic if available
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
|
||||
|
||||
return self.orchestrator.enhanced_training_system._build_dqn_state(
|
||||
base_data, context.symbol
|
||||
)
|
||||
|
||||
# Fallback: create simple state representation
|
||||
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
||||
if feature_vector:
|
||||
return np.array(feature_vector, dtype=np.float32)
|
||||
|
||||
# Last resort: create minimal state
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to DQN state: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
async def start_training_loop(self):
|
||||
"""Start the enhanced training loop"""
|
||||
if self.running:
|
||||
logger.warning("Training loop already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.training_task = asyncio.create_task(self._training_loop())
|
||||
logger.info("Enhanced RL training loop started")
|
||||
|
||||
async def stop_training_loop(self):
|
||||
"""Stop the enhanced training loop"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.training_task:
|
||||
self.training_task.cancel()
|
||||
try:
|
||||
await self.training_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Enhanced RL training loop stopped")
|
||||
|
||||
async def _training_loop(self):
|
||||
"""Main training loop that processes evaluated predictions"""
|
||||
logger.info("Starting enhanced RL training loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process training for each symbol and timeframe
|
||||
for symbol in self.reward_calculator.symbols:
|
||||
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
|
||||
|
||||
# Get training data for this symbol/timeframe
|
||||
training_data = self.reward_calculator.get_training_data(
|
||||
symbol, timeframe, self.max_batch_size
|
||||
)
|
||||
|
||||
if len(training_data) >= self.min_batch_size:
|
||||
await self._process_training_batch(symbol, timeframe, training_data)
|
||||
|
||||
# Sleep between training checks
|
||||
await asyncio.sleep(self.training_interval_seconds)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
await asyncio.sleep(10) # Wait longer on error
|
||||
|
||||
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Process a training batch for a specific symbol/timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group training data by model
|
||||
model_batches = {}
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
model_name = prediction_record.model_name
|
||||
if model_name not in model_batches:
|
||||
model_batches[model_name] = []
|
||||
model_batches[model_name].append((prediction_record, reward))
|
||||
|
||||
# Process each model's batch
|
||||
for model_name, model_data in model_batches.items():
|
||||
if len(model_data) >= self.min_batch_size:
|
||||
await self._train_model_batch(model_name, symbol, timeframe, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Train a specific model with a batch of data
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
training_start = time.time()
|
||||
|
||||
# Convert to training batch format
|
||||
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
|
||||
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
# Call appropriate training function based on model type
|
||||
success = False
|
||||
|
||||
if 'dqn' in model_name.lower():
|
||||
success = await self._train_dqn_model(batch)
|
||||
elif 'cob' in model_name.lower():
|
||||
success = await self._train_cob_rl_model(batch)
|
||||
elif 'cnn' in model_name.lower():
|
||||
success = await self._train_cnn_model(batch)
|
||||
else:
|
||||
logger.warning(f"Unknown model type for training: {model_name}")
|
||||
|
||||
# Update statistics
|
||||
training_time = time.time() - training_start
|
||||
self._update_training_stats(model_name, batch, success, training_time)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples in {training_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model {model_name}: {e}")
|
||||
self._update_training_stats(model_name, None, False, 0)
|
||||
|
||||
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
|
||||
"""Create a training batch from prediction records and rewards"""
|
||||
try:
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
confidences = []
|
||||
prediction_records = []
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
direction = prediction_record.predicted_direction
|
||||
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(True) # Each prediction is treated as terminal
|
||||
confidences.append(prediction_record.confidence)
|
||||
prediction_records.append(prediction_record)
|
||||
|
||||
return TrainingBatch(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
dones=dones,
|
||||
confidences=confidences,
|
||||
prediction_records=prediction_records,
|
||||
batch_timestamp=datetime.now()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training batch: {e}")
|
||||
return None
|
||||
|
||||
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train DQN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Add experiences to memory
|
||||
for i in range(len(batch.states)):
|
||||
if hasattr(rl_agent, 'remember'):
|
||||
rl_agent.remember(
|
||||
state=batch.states[i],
|
||||
action=batch.actions[i],
|
||||
reward=batch.rewards[i],
|
||||
next_state=batch.next_states[i],
|
||||
done=batch.dones[i]
|
||||
)
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
|
||||
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN training loss: {loss:.6f}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train COB RL model with batch data"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Use COB RL trainer if available
|
||||
# This is a placeholder - implement based on actual COB RL training interface
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train CNN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
# Use enhanced training system for CNN training
|
||||
# This is a placeholder - implement based on actual CNN training interface
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return False
|
||||
|
||||
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
|
||||
success: bool, training_time: float):
|
||||
"""Update training statistics"""
|
||||
with self.lock:
|
||||
self.training_stats['total_training_batches'] += 1
|
||||
|
||||
if success:
|
||||
self.training_stats['successful_training_calls'] += 1
|
||||
else:
|
||||
self.training_stats['failed_training_calls'] += 1
|
||||
|
||||
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
||||
|
||||
# Model-specific stats
|
||||
if model_name not in self.training_stats['training_times_per_model']:
|
||||
self.training_stats['training_times_per_model'][model_name] = []
|
||||
self.training_stats['average_batch_sizes'][model_name] = []
|
||||
|
||||
self.training_stats['training_times_per_model'][model_name].append(training_time)
|
||||
|
||||
if batch:
|
||||
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
with self.lock:
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Calculate averages
|
||||
for model_name in stats['training_times_per_model']:
|
||||
times = stats['training_times_per_model'][model_name]
|
||||
if times:
|
||||
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
|
||||
|
||||
sizes = stats['average_batch_sizes'][model_name]
|
||||
if sizes:
|
||||
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
|
||||
|
||||
return stats
|
||||
|
||||
Reference in New Issue
Block a user