Files
gogo2/core/enhanced_rl_training_adapter.py
Dobromir Popov c39b70f6fa MISC
2025-08-26 18:11:34 +03:00

596 lines
26 KiB
Python

"""
Enhanced RL Training Adapter
This module integrates the new MSE-based reward system with existing RL training pipelines.
It provides a bridge between the timeframe-aware inference coordinator and the existing
model training infrastructure.
Key Features:
- Integration with EnhancedRewardCalculator
- Adaptation of existing RL models to new reward system
- Real-time training triggers based on prediction outcomes
- Multi-timeframe training coordination
- Backward compatibility with existing training infrastructure
"""
import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
import numpy as np
import threading
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
logger = logging.getLogger(__name__)
@dataclass
class TrainingBatch:
"""Training batch for RL models with enhanced reward data"""
model_name: str
symbol: str
timeframe: TimeFrame
states: List[np.ndarray]
actions: List[int]
rewards: List[float]
next_states: List[np.ndarray]
dones: List[bool]
confidences: List[float]
prediction_records: List[PredictionRecord]
batch_timestamp: datetime
class EnhancedRLTrainingAdapter:
"""
Adapter that integrates new reward system with existing RL training infrastructure
This adapter:
1. Bridges new reward calculator with existing RL models
2. Converts prediction records to RL training format
3. Triggers real-time training based on reward evaluation
4. Maintains compatibility with existing training systems
5. Coordinates multi-timeframe training
"""
def __init__(self,
reward_calculator: EnhancedRewardCalculator,
inference_coordinator: TimeframeInferenceCoordinator,
orchestrator: Any = None,
training_system: Any = None):
"""
Initialize the enhanced RL training adapter
Args:
reward_calculator: Enhanced reward calculator instance
inference_coordinator: Timeframe inference coordinator
orchestrator: Trading orchestrator (optional)
training_system: Enhanced realtime training system (optional)
"""
self.reward_calculator = reward_calculator
self.inference_coordinator = inference_coordinator
self.orchestrator = orchestrator
self.training_system = training_system
# Model registry for training functions
self.model_trainers: Dict[str, Any] = {}
# Training configuration
self.min_batch_size = 8 # Minimum samples for training
self.max_batch_size = 64 # Maximum samples per training batch
self.training_interval_seconds = 5.0 # How often to check for training opportunities
# Training statistics
self.training_stats = {
'total_training_batches': 0,
'successful_training_calls': 0,
'failed_training_calls': 0,
'last_training_time': None,
'training_times_per_model': {},
'average_batch_sizes': {}
}
# State conversion helpers
self.state_builders: Dict[str, Any] = {}
# Thread safety
self.lock = threading.RLock()
# Running state
self.running = False
self.training_task: Optional[asyncio.Task] = None
logger.info("EnhancedRLTrainingAdapter initialized")
self._register_default_model_handlers()
def _register_default_model_handlers(self):
"""Register default model handlers for existing models"""
# Register inference functions with the coordinator
if self.inference_coordinator:
self.inference_coordinator.register_model_inference_function(
'dqn_agent', self._dqn_inference_wrapper
)
self.inference_coordinator.register_model_inference_function(
'cob_rl', self._cob_rl_inference_wrapper
)
self.inference_coordinator.register_model_inference_function(
'enhanced_cnn', self._cnn_inference_wrapper
)
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
"""Wrapper for DQN model inference"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
# Get base data for the symbol
base_data = await self._get_base_data(context.symbol)
if base_data is None:
return None
# Convert to DQN state format
state = self._convert_to_dqn_state(base_data, context)
# Run DQN prediction
if hasattr(self.orchestrator.rl_agent, 'act'):
action_idx = self.orchestrator.rl_agent.act(state)
# Try to extract confidence from agent if available
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
if confidence is None:
confidence = 0.5
# Convert action to prediction format
action_names = ['SELL', 'HOLD', 'BUY']
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
# Use real current price
current_price = self._safe_get_current_price(context.symbol)
# Do not fabricate price; set predicted_price only if model provides numeric target later
return {
'predicted_price': current_price, # same as current when no numeric target available
'current_price': current_price,
'direction': direction,
'confidence': float(confidence),
'action': action_names[action_idx],
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
'context': context
}
except Exception as e:
logger.error(f"Error in DQN inference wrapper: {e}")
return None
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
"""Wrapper for COB RL model inference"""
try:
if (self.orchestrator and
hasattr(self.orchestrator, 'realtime_rl_trader') and
self.orchestrator.realtime_rl_trader):
# Get COB features
features = await self._get_cob_features(context.symbol)
if features is None:
return None
# Run COB RL prediction
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
if prediction:
current_price = self._safe_get_current_price(context.symbol)
# If 'change' is available assume it is a fractional return
change = prediction.get('change', None)
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
return {
'predicted_price': predicted_price,
'current_price': current_price,
'direction': prediction.get('direction', 0),
'confidence': prediction.get('confidence', 0.0),
'change': prediction.get('change', 0.0),
'model_features': features,
'context': context
}
except Exception as e:
logger.error(f"Error in COB RL inference wrapper: {e}")
return None
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
"""Wrapper for CNN model inference"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
# Find CNN models in registry
for model_name, model in self.orchestrator.model_registry.models.items():
if 'cnn' in model_name.lower():
# Get base data
base_data = await self._get_base_data(context.symbol)
if base_data is None:
continue
# Run CNN prediction
if hasattr(model, 'predict_from_base_input'):
model_output = model.predict_from_base_input(base_data)
# Extract current price from data provider
current_price = self._safe_get_current_price(context.symbol)
# Extract prediction data
predictions = model_output.predictions
action = predictions.get('action', 'HOLD')
confidence = predictions.get('confidence', 0.0)
# Convert action to direction only for classification signal
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
# Use numeric predicted return if provided (no synthetic fabrication)
pr_map = {
TimeFrame.SECONDS_1: 'predicted_return_1s',
TimeFrame.MINUTES_1: 'predicted_return_1m',
TimeFrame.HOURS_1: 'predicted_return_1h',
TimeFrame.DAYS_1: 'predicted_return_1d',
}
ret_key = pr_map.get(context.target_timeframe)
predicted_return = None
if ret_key and ret_key in predictions:
predicted_return = float(predictions.get(ret_key))
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
# Also attach DQN-formatted state if available for training consumption
dqn_state = self._convert_to_dqn_state(base_data, context)
return {
'predicted_price': predicted_price,
'current_price': current_price,
'direction': direction,
'confidence': confidence,
'predicted_return': predicted_return,
'action': action,
'model_output': model_output,
'model_state': (dqn_state.tolist() if hasattr(dqn_state, 'tolist') else dqn_state),
'context': context
}
except Exception as e:
logger.error(f"Error in CNN inference wrapper: {e}")
return None
async def _get_base_data(self, symbol: str) -> Optional[Any]:
"""Get base data for a symbol"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
# Use orchestrator's data provider
return await self.orchestrator._build_base_data(symbol)
except Exception as e:
logger.debug(f"Error getting base data for {symbol}: {e}")
return None
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
"""Get COB features for a symbol"""
try:
if (self.orchestrator and
hasattr(self.orchestrator, 'realtime_rl_trader') and
self.orchestrator.realtime_rl_trader):
# Get latest features from COB trader
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
if symbol in feature_buffers and feature_buffers[symbol]:
latest_data = feature_buffers[symbol][-1]
return latest_data.get('features')
except Exception as e:
logger.debug(f"Error getting COB features for {symbol}: {e}")
return None
def _safe_get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol via DataProvider API"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
"""Convert base data to DQN state format"""
try:
# Use existing state building logic if available
if (self.orchestrator and
hasattr(self.orchestrator, 'enhanced_training_system') and
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
return self.orchestrator.enhanced_training_system._build_dqn_state(
base_data, context.symbol
)
# Fallback: create simple state representation
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
if feature_vector:
return np.array(feature_vector, dtype=np.float32)
# Last resort: create minimal state
return np.zeros(100, dtype=np.float32)
except Exception as e:
logger.error(f"Error converting to DQN state: {e}")
return np.zeros(100, dtype=np.float32)
async def start_training_loop(self):
"""Start the enhanced training loop"""
if self.running:
logger.warning("Training loop already running")
return
self.running = True
self.training_task = asyncio.create_task(self._training_loop())
logger.info("Enhanced RL training loop started")
async def stop_training_loop(self):
"""Stop the enhanced training loop"""
if not self.running:
return
self.running = False
if self.training_task:
self.training_task.cancel()
try:
await self.training_task
except asyncio.CancelledError:
pass
logger.info("Enhanced RL training loop stopped")
async def _training_loop(self):
"""Main training loop that processes evaluated predictions"""
logger.info("Starting enhanced RL training loop")
while self.running:
try:
# Process training for each symbol and timeframe
for symbol in self.reward_calculator.symbols:
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
# Get training data for this symbol/timeframe
training_data = self.reward_calculator.get_training_data(
symbol, timeframe, self.max_batch_size
)
if len(training_data) >= self.min_batch_size:
await self._process_training_batch(symbol, timeframe, training_data)
# Sleep between training checks
await asyncio.sleep(self.training_interval_seconds)
except Exception as e:
logger.error(f"Error in training loop: {e}")
await asyncio.sleep(10) # Wait longer on error
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
training_data: List[Tuple[PredictionRecord, float]]):
"""
Process a training batch for a specific symbol/timeframe
Args:
symbol: Trading symbol
timeframe: Timeframe for training
training_data: List of (prediction_record, reward) tuples
"""
try:
# Group training data by model
model_batches = {}
for prediction_record, reward in training_data:
model_name = prediction_record.model_name
if model_name not in model_batches:
model_batches[model_name] = []
model_batches[model_name].append((prediction_record, reward))
# Process each model's batch
for model_name, model_data in model_batches.items():
if len(model_data) >= self.min_batch_size:
await self._train_model_batch(model_name, symbol, timeframe, model_data)
except Exception as e:
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
training_data: List[Tuple[PredictionRecord, float]]):
"""
Train a specific model with a batch of data
Args:
model_name: Name of the model to train
symbol: Trading symbol
timeframe: Timeframe for training
training_data: List of (prediction_record, reward) tuples
"""
try:
training_start = time.time()
# Convert to training batch format
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
if batch is None:
return
# Call appropriate training function based on model type
success = False
if 'dqn' in model_name.lower():
success = await self._train_dqn_model(batch)
elif 'cob' in model_name.lower():
success = await self._train_cob_rl_model(batch)
elif 'cnn' in model_name.lower():
success = await self._train_cnn_model(batch)
else:
logger.warning(f"Unknown model type for training: {model_name}")
# Update statistics
training_time = time.time() - training_start
self._update_training_stats(model_name, batch, success, training_time)
if success:
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
f"with {len(training_data)} samples in {training_time:.3f}s")
except Exception as e:
logger.error(f"Error training model {model_name}: {e}")
self._update_training_stats(model_name, None, False, 0)
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
"""Create a training batch from prediction records and rewards"""
try:
states = []
actions = []
rewards = []
next_states = []
dones = []
confidences = []
prediction_records = []
for prediction_record, reward in training_data:
# Extract state information
# This would need to be adapted based on how states are stored
state = np.zeros(100)
next_state = state.copy() # Simplified next state
# Convert direction to action
direction = prediction_record.predicted_direction
action = direction + 1 # Convert -1,0,1 to 0,1,2
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(True) # Each prediction is treated as terminal
confidences.append(prediction_record.confidence)
prediction_records.append(prediction_record)
return TrainingBatch(
model_name=model_name,
symbol=symbol,
timeframe=timeframe,
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
dones=dones,
confidences=confidences,
prediction_records=prediction_records,
batch_timestamp=datetime.now()
)
except Exception as e:
logger.error(f"Error creating training batch: {e}")
return None
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
"""Train DQN model with batch data"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
rl_agent = self.orchestrator.rl_agent
# Add experiences to memory
for i in range(len(batch.states)):
if hasattr(rl_agent, 'remember'):
rl_agent.remember(
state=batch.states[i],
action=batch.actions[i],
reward=batch.rewards[i],
next_state=batch.next_states[i],
done=batch.dones[i]
)
# Trigger training if enough experiences
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
loss = rl_agent.replay()
if loss is not None:
logger.debug(f"DQN training loss: {loss:.6f}")
return True
return False
except Exception as e:
logger.error(f"Error training DQN model: {e}")
return False
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
"""Train COB RL model with batch data"""
try:
if (self.orchestrator and
hasattr(self.orchestrator, 'realtime_rl_trader') and
self.orchestrator.realtime_rl_trader):
# Use COB RL trainer if available
# This is a placeholder - implement based on actual COB RL training interface
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
return True
return False
except Exception as e:
logger.error(f"Error training COB RL model: {e}")
return False
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
"""Train CNN model with batch data"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
# Use enhanced training system for CNN training
# This is a placeholder - implement based on actual CNN training interface
logger.debug(f"CNN training batch: {len(batch.states)} samples")
return True
return False
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return False
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
success: bool, training_time: float):
"""Update training statistics"""
with self.lock:
self.training_stats['total_training_batches'] += 1
if success:
self.training_stats['successful_training_calls'] += 1
else:
self.training_stats['failed_training_calls'] += 1
self.training_stats['last_training_time'] = datetime.now().isoformat()
# Model-specific stats
if model_name not in self.training_stats['training_times_per_model']:
self.training_stats['training_times_per_model'][model_name] = []
self.training_stats['average_batch_sizes'][model_name] = []
self.training_stats['training_times_per_model'][model_name].append(training_time)
if batch:
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
def get_training_statistics(self) -> Dict[str, Any]:
"""Get training statistics"""
with self.lock:
stats = self.training_stats.copy()
# Calculate averages
for model_name in stats['training_times_per_model']:
times = stats['training_times_per_model'][model_name]
if times:
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
sizes = stats['average_batch_sizes'][model_name]
if sizes:
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
return stats