replay system
This commit is contained in:
529
core/rl_training_pipeline.py
Normal file
529
core/rl_training_pipeline.py
Normal file
@ -0,0 +1,529 @@
|
||||
"""
|
||||
RL Training Pipeline with Comprehensive Experience Storage and Replay
|
||||
|
||||
This module implements a robust RL training pipeline that:
|
||||
1. Stores all training experiences with profitability metrics
|
||||
2. Implements profit-weighted experience replay
|
||||
3. Tracks gradient information for each training step
|
||||
4. Enables retraining on most profitable trading sequences
|
||||
5. Maintains comprehensive trading episode analysis
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque
|
||||
import threading
|
||||
import random
|
||||
|
||||
from .training_data_collector import get_training_data_collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class RLExperience:
|
||||
"""Single RL experience with complete state-action-reward information"""
|
||||
experience_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Core RL components
|
||||
state: np.ndarray
|
||||
action: int # 0=SELL, 1=HOLD, 2=BUY
|
||||
reward: float
|
||||
next_state: np.ndarray
|
||||
done: bool
|
||||
|
||||
# Extended state information
|
||||
market_context: Dict[str, Any]
|
||||
cnn_predictions: Optional[Dict[str, Any]] = None
|
||||
confidence_score: float = 0.0
|
||||
|
||||
# Actual trading outcome
|
||||
actual_profit: Optional[float] = None
|
||||
actual_holding_time: Optional[timedelta] = None
|
||||
optimal_action: Optional[int] = None
|
||||
|
||||
# Experience value for replay
|
||||
experience_value: float = 0.0
|
||||
profitability_score: float = 0.0
|
||||
learning_priority: float = 0.0
|
||||
|
||||
# Training metadata
|
||||
times_trained: int = 0
|
||||
last_trained: Optional[datetime] = None
|
||||
|
||||
class ProfitWeightedExperienceBuffer:
|
||||
"""Experience buffer with profit-weighted sampling for replay"""
|
||||
|
||||
def __init__(self, max_size: int = 100000):
|
||||
self.max_size = max_size
|
||||
self.experiences: Dict[str, RLExperience] = {}
|
||||
self.experience_order: deque = deque(maxlen=max_size)
|
||||
self.profitable_experiences: List[str] = []
|
||||
self.total_experiences = 0
|
||||
self.total_profitable = 0
|
||||
|
||||
def add_experience(self, experience: RLExperience):
|
||||
"""Add experience to buffer"""
|
||||
try:
|
||||
self.experiences[experience.experience_id] = experience
|
||||
self.experience_order.append(experience.experience_id)
|
||||
|
||||
if experience.actual_profit is not None and experience.actual_profit > 0:
|
||||
self.profitable_experiences.append(experience.experience_id)
|
||||
self.total_profitable += 1
|
||||
|
||||
# Remove oldest if buffer is full
|
||||
if len(self.experiences) > self.max_size:
|
||||
oldest_id = self.experience_order[0]
|
||||
if oldest_id in self.experiences:
|
||||
del self.experiences[oldest_id]
|
||||
if oldest_id in self.profitable_experiences:
|
||||
self.profitable_experiences.remove(oldest_id)
|
||||
|
||||
self.total_experiences += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience to buffer: {e}")
|
||||
|
||||
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
|
||||
"""Sample batch with profit-weighted prioritization"""
|
||||
try:
|
||||
if len(self.experiences) < batch_size:
|
||||
return list(self.experiences.values())
|
||||
|
||||
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
|
||||
# Sample mix of profitable and all experiences
|
||||
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
|
||||
remaining_sample_size = batch_size - profitable_sample_size
|
||||
|
||||
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
|
||||
all_ids = list(self.experiences.keys())
|
||||
remaining_ids = random.sample(all_ids, remaining_sample_size)
|
||||
|
||||
sampled_ids = profitable_ids + remaining_ids
|
||||
else:
|
||||
# Random sampling from all experiences
|
||||
all_ids = list(self.experiences.keys())
|
||||
sampled_ids = random.sample(all_ids, batch_size)
|
||||
|
||||
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
|
||||
|
||||
# Update training counts
|
||||
for experience in sampled_experiences:
|
||||
experience.times_trained += 1
|
||||
experience.last_trained = datetime.now()
|
||||
|
||||
return sampled_experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sampling batch: {e}")
|
||||
return list(self.experiences.values())[:batch_size]
|
||||
|
||||
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
|
||||
"""Get most profitable experiences for targeted training"""
|
||||
try:
|
||||
profitable_experiences = [
|
||||
self.experiences[exp_id] for exp_id in self.profitable_experiences
|
||||
if exp_id in self.experiences
|
||||
]
|
||||
|
||||
profitable_experiences.sort(
|
||||
key=lambda x: x.actual_profit if x.actual_profit else 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return profitable_experiences[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profitable experiences: {e}")
|
||||
return []
|
||||
|
||||
class RLTradingAgent(nn.Module):
|
||||
"""RL Trading Agent with comprehensive state processing"""
|
||||
|
||||
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
|
||||
super(RLTradingAgent, self).__init__()
|
||||
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# State processing network
|
||||
self.state_processor = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Q-value network
|
||||
self.q_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim)
|
||||
)
|
||||
|
||||
# Policy network
|
||||
self.policy_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, action_dim),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
|
||||
# Value network
|
||||
self.value_network = nn.Sequential(
|
||||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(hidden_dim // 4, 1)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
"""Forward pass through the agent"""
|
||||
processed_state = self.state_processor(state)
|
||||
|
||||
q_values = self.q_network(processed_state)
|
||||
policy_probs = self.policy_network(processed_state)
|
||||
state_value = self.value_network(processed_state)
|
||||
|
||||
return {
|
||||
'q_values': q_values,
|
||||
'policy_probs': policy_probs,
|
||||
'state_value': state_value,
|
||||
'processed_state': processed_state
|
||||
}
|
||||
|
||||
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
|
||||
"""Select action using epsilon-greedy policy"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.from_numpy(state).float().unsqueeze(0)
|
||||
|
||||
outputs = self.forward(state)
|
||||
|
||||
if random.random() < epsilon:
|
||||
action = random.randint(0, self.action_dim - 1)
|
||||
confidence = 0.33
|
||||
else:
|
||||
q_values = outputs['q_values']
|
||||
action = torch.argmax(q_values, dim=1).item()
|
||||
q_softmax = F.softmax(q_values, dim=1)
|
||||
confidence = torch.max(q_softmax).item()
|
||||
|
||||
return action, confidence
|
||||
|
||||
@dataclass
|
||||
class RLTrainingStep:
|
||||
"""Single RL training step with backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
batch_experiences: List[str]
|
||||
|
||||
# Training data
|
||||
total_loss: float
|
||||
q_loss: float
|
||||
policy_loss: float
|
||||
|
||||
# Gradients
|
||||
gradients: Dict[str, torch.Tensor]
|
||||
gradient_norms: Dict[str, float]
|
||||
|
||||
# Metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
|
||||
# Performance
|
||||
batch_profitability: float = 0.0
|
||||
correct_actions: int = 0
|
||||
total_actions: int = 0
|
||||
step_value: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RLTrainingSession:
|
||||
"""Complete RL training session"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
training_mode: str = 'experience_replay'
|
||||
symbol: str = ''
|
||||
|
||||
training_steps: List[RLTrainingStep] = field(default_factory=list)
|
||||
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
|
||||
profitable_actions: int = 0
|
||||
total_actions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
session_value: float = 0.0
|
||||
|
||||
class RLTrainer:
|
||||
"""RL trainer with comprehensive experience storage and replay"""
|
||||
|
||||
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
|
||||
self.agent = agent.to(device)
|
||||
self.device = device
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
|
||||
self.experience_buffer = ProfitWeightedExperienceBuffer()
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
self.training_sessions: List[RLTrainingSession] = []
|
||||
self.current_session: Optional[RLTrainingSession] = None
|
||||
|
||||
self.gamma = 0.99
|
||||
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'total_experiences': 0,
|
||||
'profitable_actions': 0,
|
||||
'total_actions': 0,
|
||||
'average_reward': 0.0
|
||||
}
|
||||
|
||||
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
|
||||
|
||||
def add_experience(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
|
||||
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
|
||||
"""Add experience to the buffer"""
|
||||
try:
|
||||
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
experience = RLExperience(
|
||||
experience_id=experience_id,
|
||||
timestamp=datetime.now(),
|
||||
episode_id=market_context.get('episode_id', 'unknown'),
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
market_context=market_context,
|
||||
cnn_predictions=cnn_predictions,
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
self.experience_buffer.add_experience(experience)
|
||||
self.training_stats['total_experiences'] += 1
|
||||
|
||||
return experience_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience: {e}")
|
||||
return None
|
||||
|
||||
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
|
||||
"""Train on experiences with comprehensive data storage"""
|
||||
try:
|
||||
session = RLTrainingSession(
|
||||
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode='experience_replay'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
self.agent.train()
|
||||
total_loss = 0.0
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
experiences = self.experience_buffer.sample_batch(batch_size, True)
|
||||
|
||||
if len(experiences) < batch_size:
|
||||
continue
|
||||
|
||||
# Prepare batch tensors
|
||||
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
current_outputs = self.agent(states)
|
||||
current_q_values = current_outputs['q_values']
|
||||
|
||||
# Calculate target Q-values
|
||||
with torch.no_grad():
|
||||
next_outputs = self.agent(next_states)
|
||||
next_q_values = next_outputs['q_values']
|
||||
max_next_q_values = torch.max(next_q_values, dim=1)[0]
|
||||
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
|
||||
|
||||
# Calculate loss
|
||||
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
|
||||
|
||||
# Backward pass
|
||||
q_loss.backward()
|
||||
|
||||
# Store gradients
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.agent.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = RLTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
batch_experiences=[exp.experience_id for exp in experiences],
|
||||
total_loss=q_loss.item(),
|
||||
q_loss=q_loss.item(),
|
||||
policy_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
batch_size=len(experiences)
|
||||
)
|
||||
|
||||
session.training_steps.append(step)
|
||||
total_loss += q_loss.item()
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
self._save_training_session(session)
|
||||
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
|
||||
logger.info(f"RL training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
|
||||
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable experiences"""
|
||||
try:
|
||||
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
|
||||
|
||||
filtered_experiences = [
|
||||
exp for exp in profitable_experiences
|
||||
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
|
||||
]
|
||||
|
||||
if len(filtered_experiences) < batch_size:
|
||||
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
|
||||
|
||||
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
|
||||
|
||||
num_batches = len(filtered_experiences) // batch_size
|
||||
|
||||
# Temporarily replace buffer sampling
|
||||
original_sample_method = self.experience_buffer.sample_batch
|
||||
|
||||
def profitable_sample_batch(batch_size, prioritize_profitable=True):
|
||||
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
|
||||
|
||||
self.experience_buffer.sample_batch = profitable_sample_batch
|
||||
|
||||
try:
|
||||
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
|
||||
results['training_mode'] = 'profitable_replay'
|
||||
results['experiences_used'] = len(filtered_experiences)
|
||||
return results
|
||||
finally:
|
||||
self.experience_buffer.sample_batch = original_sample_method
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable experiences: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _save_training_session(self, session: RLTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
return stats
|
||||
|
||||
# Global instance
|
||||
rl_trainer = None
|
||||
|
||||
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
|
||||
"""Get global RL trainer instance"""
|
||||
global rl_trainer
|
||||
if rl_trainer is None:
|
||||
if agent is None:
|
||||
agent = RLTradingAgent()
|
||||
rl_trainer = RLTrainer(agent)
|
||||
return rl_trainer
|
Reference in New Issue
Block a user