593 lines
25 KiB
Python
593 lines
25 KiB
Python
"""
|
|
Enhanced RL Training Adapter
|
|
|
|
This module integrates the new MSE-based reward system with existing RL training pipelines.
|
|
It provides a bridge between the timeframe-aware inference coordinator and the existing
|
|
model training infrastructure.
|
|
|
|
Key Features:
|
|
- Integration with EnhancedRewardCalculator
|
|
- Adaptation of existing RL models to new reward system
|
|
- Real-time training triggers based on prediction outcomes
|
|
- Multi-timeframe training coordination
|
|
- Backward compatibility with existing training infrastructure
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
import threading
|
|
|
|
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
|
|
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TrainingBatch:
|
|
"""Training batch for RL models with enhanced reward data"""
|
|
model_name: str
|
|
symbol: str
|
|
timeframe: TimeFrame
|
|
states: List[np.ndarray]
|
|
actions: List[int]
|
|
rewards: List[float]
|
|
next_states: List[np.ndarray]
|
|
dones: List[bool]
|
|
confidences: List[float]
|
|
prediction_records: List[PredictionRecord]
|
|
batch_timestamp: datetime
|
|
|
|
|
|
class EnhancedRLTrainingAdapter:
|
|
"""
|
|
Adapter that integrates new reward system with existing RL training infrastructure
|
|
|
|
This adapter:
|
|
1. Bridges new reward calculator with existing RL models
|
|
2. Converts prediction records to RL training format
|
|
3. Triggers real-time training based on reward evaluation
|
|
4. Maintains compatibility with existing training systems
|
|
5. Coordinates multi-timeframe training
|
|
"""
|
|
|
|
def __init__(self,
|
|
reward_calculator: EnhancedRewardCalculator,
|
|
inference_coordinator: TimeframeInferenceCoordinator,
|
|
orchestrator: Any = None,
|
|
training_system: Any = None):
|
|
"""
|
|
Initialize the enhanced RL training adapter
|
|
|
|
Args:
|
|
reward_calculator: Enhanced reward calculator instance
|
|
inference_coordinator: Timeframe inference coordinator
|
|
orchestrator: Trading orchestrator (optional)
|
|
training_system: Enhanced realtime training system (optional)
|
|
"""
|
|
self.reward_calculator = reward_calculator
|
|
self.inference_coordinator = inference_coordinator
|
|
self.orchestrator = orchestrator
|
|
self.training_system = training_system
|
|
|
|
# Model registry for training functions
|
|
self.model_trainers: Dict[str, Any] = {}
|
|
|
|
# Training configuration
|
|
self.min_batch_size = 8 # Minimum samples for training
|
|
self.max_batch_size = 64 # Maximum samples per training batch
|
|
self.training_interval_seconds = 5.0 # How often to check for training opportunities
|
|
|
|
# Training statistics
|
|
self.training_stats = {
|
|
'total_training_batches': 0,
|
|
'successful_training_calls': 0,
|
|
'failed_training_calls': 0,
|
|
'last_training_time': None,
|
|
'training_times_per_model': {},
|
|
'average_batch_sizes': {}
|
|
}
|
|
|
|
# State conversion helpers
|
|
self.state_builders: Dict[str, Any] = {}
|
|
|
|
# Thread safety
|
|
self.lock = threading.RLock()
|
|
|
|
# Running state
|
|
self.running = False
|
|
self.training_task: Optional[asyncio.Task] = None
|
|
|
|
logger.info("EnhancedRLTrainingAdapter initialized")
|
|
self._register_default_model_handlers()
|
|
|
|
def _register_default_model_handlers(self):
|
|
"""Register default model handlers for existing models"""
|
|
# Register inference functions with the coordinator
|
|
if self.inference_coordinator:
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'dqn_agent', self._dqn_inference_wrapper
|
|
)
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'cob_rl', self._cob_rl_inference_wrapper
|
|
)
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'enhanced_cnn', self._cnn_inference_wrapper
|
|
)
|
|
|
|
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for DQN model inference"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
|
# Get base data for the symbol
|
|
base_data = await self._get_base_data(context.symbol)
|
|
if base_data is None:
|
|
return None
|
|
|
|
# Convert to DQN state format
|
|
state = self._convert_to_dqn_state(base_data, context)
|
|
|
|
# Run DQN prediction
|
|
if hasattr(self.orchestrator.rl_agent, 'act'):
|
|
action_idx = self.orchestrator.rl_agent.act(state)
|
|
# Try to extract confidence from agent if available
|
|
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
|
|
if confidence is None:
|
|
confidence = 0.5
|
|
|
|
# Convert action to prediction format
|
|
action_names = ['SELL', 'HOLD', 'BUY']
|
|
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
|
|
|
# Use real current price
|
|
current_price = self._safe_get_current_price(context.symbol)
|
|
|
|
# Do not fabricate price; set predicted_price only if model provides numeric target later
|
|
return {
|
|
'predicted_price': current_price, # same as current when no numeric target available
|
|
'current_price': current_price,
|
|
'direction': direction,
|
|
'confidence': float(confidence),
|
|
'action': action_names[action_idx],
|
|
'model_state': state,
|
|
'context': context
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in DQN inference wrapper: {e}")
|
|
|
|
return None
|
|
|
|
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for COB RL model inference"""
|
|
try:
|
|
if (self.orchestrator and
|
|
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
|
self.orchestrator.realtime_rl_trader):
|
|
|
|
# Get COB features
|
|
features = await self._get_cob_features(context.symbol)
|
|
if features is None:
|
|
return None
|
|
|
|
# Run COB RL prediction
|
|
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
|
|
|
if prediction:
|
|
current_price = self._safe_get_current_price(context.symbol)
|
|
# If 'change' is available assume it is a fractional return
|
|
change = prediction.get('change', None)
|
|
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
|
|
|
|
return {
|
|
'predicted_price': predicted_price,
|
|
'current_price': current_price,
|
|
'direction': prediction.get('direction', 0),
|
|
'confidence': prediction.get('confidence', 0.0),
|
|
'change': prediction.get('change', 0.0),
|
|
'model_features': features,
|
|
'context': context
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in COB RL inference wrapper: {e}")
|
|
|
|
return None
|
|
|
|
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for CNN model inference"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
|
|
# Find CNN models in registry
|
|
for model_name, model in self.orchestrator.model_registry.models.items():
|
|
if 'cnn' in model_name.lower():
|
|
# Get base data
|
|
base_data = await self._get_base_data(context.symbol)
|
|
if base_data is None:
|
|
continue
|
|
|
|
# Run CNN prediction
|
|
if hasattr(model, 'predict_from_base_input'):
|
|
model_output = model.predict_from_base_input(base_data)
|
|
|
|
# Extract current price from data provider
|
|
current_price = self._safe_get_current_price(context.symbol)
|
|
|
|
# Extract prediction data
|
|
predictions = model_output.predictions
|
|
action = predictions.get('action', 'HOLD')
|
|
confidence = predictions.get('confidence', 0.0)
|
|
|
|
# Convert action to direction only for classification signal
|
|
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
|
|
|
# Use numeric predicted return if provided (no synthetic fabrication)
|
|
pr_map = {
|
|
TimeFrame.SECONDS_1: 'predicted_return_1s',
|
|
TimeFrame.MINUTES_1: 'predicted_return_1m',
|
|
TimeFrame.HOURS_1: 'predicted_return_1h',
|
|
TimeFrame.DAYS_1: 'predicted_return_1d',
|
|
}
|
|
ret_key = pr_map.get(context.target_timeframe)
|
|
predicted_return = None
|
|
if ret_key and ret_key in predictions:
|
|
predicted_return = float(predictions.get(ret_key))
|
|
|
|
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
|
|
|
return {
|
|
'predicted_price': predicted_price,
|
|
'current_price': current_price,
|
|
'direction': direction,
|
|
'confidence': confidence,
|
|
'predicted_return': predicted_return,
|
|
'action': action,
|
|
'model_output': model_output,
|
|
'context': context
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN inference wrapper: {e}")
|
|
|
|
return None
|
|
|
|
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
|
"""Get base data for a symbol"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
|
# Use orchestrator's data provider
|
|
return await self.orchestrator._build_base_data(symbol)
|
|
except Exception as e:
|
|
logger.debug(f"Error getting base data for {symbol}: {e}")
|
|
|
|
return None
|
|
|
|
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
|
"""Get COB features for a symbol"""
|
|
try:
|
|
if (self.orchestrator and
|
|
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
|
self.orchestrator.realtime_rl_trader):
|
|
|
|
# Get latest features from COB trader
|
|
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
|
|
if symbol in feature_buffers and feature_buffers[symbol]:
|
|
latest_data = feature_buffers[symbol][-1]
|
|
return latest_data.get('features')
|
|
except Exception as e:
|
|
logger.debug(f"Error getting COB features for {symbol}: {e}")
|
|
|
|
return None
|
|
|
|
def _safe_get_current_price(self, symbol: str) -> float:
|
|
"""Get current price for a symbol via DataProvider API"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
|
price = self.orchestrator.data_provider.get_current_price(symbol)
|
|
return float(price) if price is not None else 0.0
|
|
except Exception as e:
|
|
logger.debug(f"Error getting current price for {symbol}: {e}")
|
|
return 0.0
|
|
|
|
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
|
"""Convert base data to DQN state format"""
|
|
try:
|
|
# Use existing state building logic if available
|
|
if (self.orchestrator and
|
|
hasattr(self.orchestrator, 'enhanced_training_system') and
|
|
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
|
|
|
|
return self.orchestrator.enhanced_training_system._build_dqn_state(
|
|
base_data, context.symbol
|
|
)
|
|
|
|
# Fallback: create simple state representation
|
|
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
|
if feature_vector:
|
|
return np.array(feature_vector, dtype=np.float32)
|
|
|
|
# Last resort: create minimal state
|
|
return np.zeros(100, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting to DQN state: {e}")
|
|
return np.zeros(100, dtype=np.float32)
|
|
|
|
async def start_training_loop(self):
|
|
"""Start the enhanced training loop"""
|
|
if self.running:
|
|
logger.warning("Training loop already running")
|
|
return
|
|
|
|
self.running = True
|
|
self.training_task = asyncio.create_task(self._training_loop())
|
|
logger.info("Enhanced RL training loop started")
|
|
|
|
async def stop_training_loop(self):
|
|
"""Stop the enhanced training loop"""
|
|
if not self.running:
|
|
return
|
|
|
|
self.running = False
|
|
if self.training_task:
|
|
self.training_task.cancel()
|
|
try:
|
|
await self.training_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
logger.info("Enhanced RL training loop stopped")
|
|
|
|
async def _training_loop(self):
|
|
"""Main training loop that processes evaluated predictions"""
|
|
logger.info("Starting enhanced RL training loop")
|
|
|
|
while self.running:
|
|
try:
|
|
# Process training for each symbol and timeframe
|
|
for symbol in self.reward_calculator.symbols:
|
|
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
|
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
|
|
|
|
# Get training data for this symbol/timeframe
|
|
training_data = self.reward_calculator.get_training_data(
|
|
symbol, timeframe, self.max_batch_size
|
|
)
|
|
|
|
if len(training_data) >= self.min_batch_size:
|
|
await self._process_training_batch(symbol, timeframe, training_data)
|
|
|
|
# Sleep between training checks
|
|
await asyncio.sleep(self.training_interval_seconds)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training loop: {e}")
|
|
await asyncio.sleep(10) # Wait longer on error
|
|
|
|
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
|
|
training_data: List[Tuple[PredictionRecord, float]]):
|
|
"""
|
|
Process a training batch for a specific symbol/timeframe
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe for training
|
|
training_data: List of (prediction_record, reward) tuples
|
|
"""
|
|
try:
|
|
# Group training data by model
|
|
model_batches = {}
|
|
|
|
for prediction_record, reward in training_data:
|
|
model_name = prediction_record.model_name
|
|
if model_name not in model_batches:
|
|
model_batches[model_name] = []
|
|
model_batches[model_name].append((prediction_record, reward))
|
|
|
|
# Process each model's batch
|
|
for model_name, model_data in model_batches.items():
|
|
if len(model_data) >= self.min_batch_size:
|
|
await self._train_model_batch(model_name, symbol, timeframe, model_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
|
|
|
|
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
|
training_data: List[Tuple[PredictionRecord, float]]):
|
|
"""
|
|
Train a specific model with a batch of data
|
|
|
|
Args:
|
|
model_name: Name of the model to train
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe for training
|
|
training_data: List of (prediction_record, reward) tuples
|
|
"""
|
|
try:
|
|
training_start = time.time()
|
|
|
|
# Convert to training batch format
|
|
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
|
|
|
|
if batch is None:
|
|
return
|
|
|
|
# Call appropriate training function based on model type
|
|
success = False
|
|
|
|
if 'dqn' in model_name.lower():
|
|
success = await self._train_dqn_model(batch)
|
|
elif 'cob' in model_name.lower():
|
|
success = await self._train_cob_rl_model(batch)
|
|
elif 'cnn' in model_name.lower():
|
|
success = await self._train_cnn_model(batch)
|
|
else:
|
|
logger.warning(f"Unknown model type for training: {model_name}")
|
|
|
|
# Update statistics
|
|
training_time = time.time() - training_start
|
|
self._update_training_stats(model_name, batch, success, training_time)
|
|
|
|
if success:
|
|
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
|
|
f"with {len(training_data)} samples in {training_time:.3f}s")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training model {model_name}: {e}")
|
|
self._update_training_stats(model_name, None, False, 0)
|
|
|
|
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
|
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
|
|
"""Create a training batch from prediction records and rewards"""
|
|
try:
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
confidences = []
|
|
prediction_records = []
|
|
|
|
for prediction_record, reward in training_data:
|
|
# Extract state information
|
|
# This would need to be adapted based on how states are stored
|
|
state = np.zeros(100)
|
|
next_state = state.copy() # Simplified next state
|
|
|
|
# Convert direction to action
|
|
direction = prediction_record.predicted_direction
|
|
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
|
|
|
states.append(state)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_states.append(next_state)
|
|
dones.append(True) # Each prediction is treated as terminal
|
|
confidences.append(prediction_record.confidence)
|
|
prediction_records.append(prediction_record)
|
|
|
|
return TrainingBatch(
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
states=states,
|
|
actions=actions,
|
|
rewards=rewards,
|
|
next_states=next_states,
|
|
dones=dones,
|
|
confidences=confidences,
|
|
prediction_records=prediction_records,
|
|
batch_timestamp=datetime.now()
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating training batch: {e}")
|
|
return None
|
|
|
|
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
|
|
"""Train DQN model with batch data"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
|
rl_agent = self.orchestrator.rl_agent
|
|
|
|
# Add experiences to memory
|
|
for i in range(len(batch.states)):
|
|
if hasattr(rl_agent, 'remember'):
|
|
rl_agent.remember(
|
|
state=batch.states[i],
|
|
action=batch.actions[i],
|
|
reward=batch.rewards[i],
|
|
next_state=batch.next_states[i],
|
|
done=batch.dones[i]
|
|
)
|
|
|
|
# Trigger training if enough experiences
|
|
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
|
|
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
|
|
loss = rl_agent.replay()
|
|
if loss is not None:
|
|
logger.debug(f"DQN training loss: {loss:.6f}")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training DQN model: {e}")
|
|
return False
|
|
|
|
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
|
"""Train COB RL model with batch data"""
|
|
try:
|
|
if (self.orchestrator and
|
|
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
|
self.orchestrator.realtime_rl_trader):
|
|
|
|
# Use COB RL trainer if available
|
|
# This is a placeholder - implement based on actual COB RL training interface
|
|
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training COB RL model: {e}")
|
|
return False
|
|
|
|
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
|
"""Train CNN model with batch data"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
|
# Use enhanced training system for CNN training
|
|
# This is a placeholder - implement based on actual CNN training interface
|
|
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN model: {e}")
|
|
return False
|
|
|
|
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
|
|
success: bool, training_time: float):
|
|
"""Update training statistics"""
|
|
with self.lock:
|
|
self.training_stats['total_training_batches'] += 1
|
|
|
|
if success:
|
|
self.training_stats['successful_training_calls'] += 1
|
|
else:
|
|
self.training_stats['failed_training_calls'] += 1
|
|
|
|
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
|
|
|
# Model-specific stats
|
|
if model_name not in self.training_stats['training_times_per_model']:
|
|
self.training_stats['training_times_per_model'][model_name] = []
|
|
self.training_stats['average_batch_sizes'][model_name] = []
|
|
|
|
self.training_stats['training_times_per_model'][model_name].append(training_time)
|
|
|
|
if batch:
|
|
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""Get training statistics"""
|
|
with self.lock:
|
|
stats = self.training_stats.copy()
|
|
|
|
# Calculate averages
|
|
for model_name in stats['training_times_per_model']:
|
|
times = stats['training_times_per_model'][model_name]
|
|
if times:
|
|
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
|
|
|
|
sizes = stats['average_batch_sizes'][model_name]
|
|
if sizes:
|
|
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
|
|
|
|
return stats
|
|
|