529 lines
20 KiB
Python
529 lines
20 KiB
Python
"""
|
|
RL Training Pipeline with Comprehensive Experience Storage and Replay
|
|
|
|
This module implements a robust RL training pipeline that:
|
|
1. Stores all training experiences with profitability metrics
|
|
2. Implements profit-weighted experience replay
|
|
3. Tracks gradient information for each training step
|
|
4. Enables retraining on most profitable trading sequences
|
|
5. Maintains comprehensive trading episode analysis
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
import pickle
|
|
from collections import deque
|
|
import threading
|
|
import random
|
|
|
|
from .training_data_collector import get_training_data_collector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class RLExperience:
|
|
"""Single RL experience with complete state-action-reward information"""
|
|
experience_id: str
|
|
timestamp: datetime
|
|
episode_id: str
|
|
|
|
# Core RL components
|
|
state: np.ndarray
|
|
action: int # 0=SELL, 1=HOLD, 2=BUY
|
|
reward: float
|
|
next_state: np.ndarray
|
|
done: bool
|
|
|
|
# Extended state information
|
|
market_context: Dict[str, Any]
|
|
cnn_predictions: Optional[Dict[str, Any]] = None
|
|
confidence_score: float = 0.0
|
|
|
|
# Actual trading outcome
|
|
actual_profit: Optional[float] = None
|
|
actual_holding_time: Optional[timedelta] = None
|
|
optimal_action: Optional[int] = None
|
|
|
|
# Experience value for replay
|
|
experience_value: float = 0.0
|
|
profitability_score: float = 0.0
|
|
learning_priority: float = 0.0
|
|
|
|
# Training metadata
|
|
times_trained: int = 0
|
|
last_trained: Optional[datetime] = None
|
|
|
|
class ProfitWeightedExperienceBuffer:
|
|
"""Experience buffer with profit-weighted sampling for replay"""
|
|
|
|
def __init__(self, max_size: int = 100000):
|
|
self.max_size = max_size
|
|
self.experiences: Dict[str, RLExperience] = {}
|
|
self.experience_order: deque = deque(maxlen=max_size)
|
|
self.profitable_experiences: List[str] = []
|
|
self.total_experiences = 0
|
|
self.total_profitable = 0
|
|
|
|
def add_experience(self, experience: RLExperience):
|
|
"""Add experience to buffer"""
|
|
try:
|
|
self.experiences[experience.experience_id] = experience
|
|
self.experience_order.append(experience.experience_id)
|
|
|
|
if experience.actual_profit is not None and experience.actual_profit > 0:
|
|
self.profitable_experiences.append(experience.experience_id)
|
|
self.total_profitable += 1
|
|
|
|
# Remove oldest if buffer is full
|
|
if len(self.experiences) > self.max_size:
|
|
oldest_id = self.experience_order[0]
|
|
if oldest_id in self.experiences:
|
|
del self.experiences[oldest_id]
|
|
if oldest_id in self.profitable_experiences:
|
|
self.profitable_experiences.remove(oldest_id)
|
|
|
|
self.total_experiences += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding experience to buffer: {e}")
|
|
|
|
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
|
|
"""Sample batch with profit-weighted prioritization"""
|
|
try:
|
|
if len(self.experiences) < batch_size:
|
|
return list(self.experiences.values())
|
|
|
|
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
|
|
# Sample mix of profitable and all experiences
|
|
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
|
|
remaining_sample_size = batch_size - profitable_sample_size
|
|
|
|
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
|
|
all_ids = list(self.experiences.keys())
|
|
remaining_ids = random.sample(all_ids, remaining_sample_size)
|
|
|
|
sampled_ids = profitable_ids + remaining_ids
|
|
else:
|
|
# Random sampling from all experiences
|
|
all_ids = list(self.experiences.keys())
|
|
sampled_ids = random.sample(all_ids, batch_size)
|
|
|
|
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
|
|
|
|
# Update training counts
|
|
for experience in sampled_experiences:
|
|
experience.times_trained += 1
|
|
experience.last_trained = datetime.now()
|
|
|
|
return sampled_experiences
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sampling batch: {e}")
|
|
return list(self.experiences.values())[:batch_size]
|
|
|
|
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
|
|
"""Get most profitable experiences for targeted training"""
|
|
try:
|
|
profitable_experiences = [
|
|
self.experiences[exp_id] for exp_id in self.profitable_experiences
|
|
if exp_id in self.experiences
|
|
]
|
|
|
|
profitable_experiences.sort(
|
|
key=lambda x: x.actual_profit if x.actual_profit else 0,
|
|
reverse=True
|
|
)
|
|
|
|
return profitable_experiences[:limit]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting profitable experiences: {e}")
|
|
return []
|
|
|
|
class RLTradingAgent(nn.Module):
|
|
"""RL Trading Agent with comprehensive state processing"""
|
|
|
|
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
|
|
super(RLTradingAgent, self).__init__()
|
|
|
|
self.state_dim = state_dim
|
|
self.action_dim = action_dim
|
|
self.hidden_dim = hidden_dim
|
|
|
|
# State processing network
|
|
self.state_processor = nn.Sequential(
|
|
nn.Linear(state_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.LayerNorm(hidden_dim // 2),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Q-value network
|
|
self.q_network = nn.Sequential(
|
|
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(hidden_dim // 4, action_dim)
|
|
)
|
|
|
|
# Policy network
|
|
self.policy_network = nn.Sequential(
|
|
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(hidden_dim // 4, action_dim),
|
|
nn.Softmax(dim=-1)
|
|
)
|
|
|
|
# Value network
|
|
self.value_network = nn.Sequential(
|
|
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(hidden_dim // 4, 1)
|
|
)
|
|
|
|
def forward(self, state):
|
|
"""Forward pass through the agent"""
|
|
processed_state = self.state_processor(state)
|
|
|
|
q_values = self.q_network(processed_state)
|
|
policy_probs = self.policy_network(processed_state)
|
|
state_value = self.value_network(processed_state)
|
|
|
|
return {
|
|
'q_values': q_values,
|
|
'policy_probs': policy_probs,
|
|
'state_value': state_value,
|
|
'processed_state': processed_state
|
|
}
|
|
|
|
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
|
|
"""Select action using epsilon-greedy policy"""
|
|
self.eval()
|
|
with torch.no_grad():
|
|
if isinstance(state, np.ndarray):
|
|
state = torch.from_numpy(state).float().unsqueeze(0)
|
|
|
|
outputs = self.forward(state)
|
|
|
|
if random.random() < epsilon:
|
|
action = random.randint(0, self.action_dim - 1)
|
|
confidence = 0.33
|
|
else:
|
|
q_values = outputs['q_values']
|
|
action = torch.argmax(q_values, dim=1).item()
|
|
q_softmax = F.softmax(q_values, dim=1)
|
|
confidence = torch.max(q_softmax).item()
|
|
|
|
return action, confidence
|
|
|
|
@dataclass
|
|
class RLTrainingStep:
|
|
"""Single RL training step with backpropagation data"""
|
|
step_id: str
|
|
timestamp: datetime
|
|
batch_experiences: List[str]
|
|
|
|
# Training data
|
|
total_loss: float
|
|
q_loss: float
|
|
policy_loss: float
|
|
|
|
# Gradients
|
|
gradients: Dict[str, torch.Tensor]
|
|
gradient_norms: Dict[str, float]
|
|
|
|
# Metadata
|
|
learning_rate: float = 0.001
|
|
batch_size: int = 32
|
|
|
|
# Performance
|
|
batch_profitability: float = 0.0
|
|
correct_actions: int = 0
|
|
total_actions: int = 0
|
|
step_value: float = 0.0
|
|
|
|
@dataclass
|
|
class RLTrainingSession:
|
|
"""Complete RL training session"""
|
|
session_id: str
|
|
start_timestamp: datetime
|
|
end_timestamp: Optional[datetime] = None
|
|
|
|
training_mode: str = 'experience_replay'
|
|
symbol: str = ''
|
|
|
|
training_steps: List[RLTrainingStep] = field(default_factory=list)
|
|
|
|
total_steps: int = 0
|
|
average_loss: float = 0.0
|
|
best_loss: float = float('inf')
|
|
|
|
profitable_actions: int = 0
|
|
total_actions: int = 0
|
|
profitability_rate: float = 0.0
|
|
session_value: float = 0.0
|
|
|
|
class RLTrainer:
|
|
"""RL trainer with comprehensive experience storage and replay"""
|
|
|
|
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
|
|
self.agent = agent.to(device)
|
|
self.device = device
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
|
|
self.experience_buffer = ProfitWeightedExperienceBuffer()
|
|
self.data_collector = get_training_data_collector()
|
|
|
|
self.training_sessions: List[RLTrainingSession] = []
|
|
self.current_session: Optional[RLTrainingSession] = None
|
|
|
|
self.gamma = 0.99
|
|
|
|
self.training_stats = {
|
|
'total_sessions': 0,
|
|
'total_steps': 0,
|
|
'total_experiences': 0,
|
|
'profitable_actions': 0,
|
|
'total_actions': 0,
|
|
'average_reward': 0.0
|
|
}
|
|
|
|
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
|
|
|
|
def add_experience(self, state: np.ndarray, action: int, reward: float,
|
|
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
|
|
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
|
|
"""Add experience to the buffer"""
|
|
try:
|
|
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
|
|
|
experience = RLExperience(
|
|
experience_id=experience_id,
|
|
timestamp=datetime.now(),
|
|
episode_id=market_context.get('episode_id', 'unknown'),
|
|
state=state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=done,
|
|
market_context=market_context,
|
|
cnn_predictions=cnn_predictions,
|
|
confidence_score=confidence_score
|
|
)
|
|
|
|
self.experience_buffer.add_experience(experience)
|
|
self.training_stats['total_experiences'] += 1
|
|
|
|
return experience_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding experience: {e}")
|
|
return None
|
|
|
|
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
|
|
"""Train on experiences with comprehensive data storage"""
|
|
try:
|
|
session = RLTrainingSession(
|
|
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
start_timestamp=datetime.now(),
|
|
training_mode='experience_replay'
|
|
)
|
|
self.current_session = session
|
|
|
|
self.agent.train()
|
|
total_loss = 0.0
|
|
|
|
for batch_idx in range(num_batches):
|
|
experiences = self.experience_buffer.sample_batch(batch_size, True)
|
|
|
|
if len(experiences) < batch_size:
|
|
continue
|
|
|
|
# Prepare batch tensors
|
|
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
|
|
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
|
|
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
|
|
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
|
|
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
|
|
|
|
# Forward pass
|
|
self.optimizer.zero_grad()
|
|
|
|
current_outputs = self.agent(states)
|
|
current_q_values = current_outputs['q_values']
|
|
|
|
# Calculate target Q-values
|
|
with torch.no_grad():
|
|
next_outputs = self.agent(next_states)
|
|
next_q_values = next_outputs['q_values']
|
|
max_next_q_values = torch.max(next_q_values, dim=1)[0]
|
|
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
|
|
|
|
# Calculate loss
|
|
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
|
|
|
|
# Backward pass
|
|
q_loss.backward()
|
|
|
|
# Store gradients
|
|
gradients = {}
|
|
gradient_norms = {}
|
|
for name, param in self.agent.named_parameters():
|
|
if param.grad is not None:
|
|
gradients[name] = param.grad.clone().detach()
|
|
gradient_norms[name] = param.grad.norm().item()
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
|
|
self.optimizer.step()
|
|
|
|
# Create training step record
|
|
step = RLTrainingStep(
|
|
step_id=f"{session.session_id}_step_{batch_idx}",
|
|
timestamp=datetime.now(),
|
|
batch_experiences=[exp.experience_id for exp in experiences],
|
|
total_loss=q_loss.item(),
|
|
q_loss=q_loss.item(),
|
|
policy_loss=0.0,
|
|
gradients=gradients,
|
|
gradient_norms=gradient_norms,
|
|
batch_size=len(experiences)
|
|
)
|
|
|
|
session.training_steps.append(step)
|
|
total_loss += q_loss.item()
|
|
|
|
# Finalize session
|
|
session.end_timestamp = datetime.now()
|
|
session.total_steps = num_batches
|
|
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
self._save_training_session(session)
|
|
|
|
self.training_stats['total_sessions'] += 1
|
|
self.training_stats['total_steps'] += session.total_steps
|
|
|
|
logger.info(f"RL training session completed: {session.session_id}")
|
|
logger.info(f"Average loss: {session.average_loss:.4f}")
|
|
|
|
return {
|
|
'status': 'success',
|
|
'session_id': session.session_id,
|
|
'average_loss': session.average_loss,
|
|
'total_steps': session.total_steps
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in RL training session: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
finally:
|
|
self.current_session = None
|
|
|
|
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
|
|
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
|
|
"""Train specifically on most profitable experiences"""
|
|
try:
|
|
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
|
|
|
|
filtered_experiences = [
|
|
exp for exp in profitable_experiences
|
|
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
|
|
]
|
|
|
|
if len(filtered_experiences) < batch_size:
|
|
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
|
|
|
|
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
|
|
|
|
num_batches = len(filtered_experiences) // batch_size
|
|
|
|
# Temporarily replace buffer sampling
|
|
original_sample_method = self.experience_buffer.sample_batch
|
|
|
|
def profitable_sample_batch(batch_size, prioritize_profitable=True):
|
|
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
|
|
|
|
self.experience_buffer.sample_batch = profitable_sample_batch
|
|
|
|
try:
|
|
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
|
|
results['training_mode'] = 'profitable_replay'
|
|
results['experiences_used'] = len(filtered_experiences)
|
|
return results
|
|
finally:
|
|
self.experience_buffer.sample_batch = original_sample_method
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training on profitable experiences: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
def _save_training_session(self, session: RLTrainingSession):
|
|
"""Save training session to disk"""
|
|
try:
|
|
session_dir = self.storage_dir / 'sessions'
|
|
session_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
session_file = session_dir / f"{session.session_id}.pkl"
|
|
with open(session_file, 'wb') as f:
|
|
pickle.dump(session, f)
|
|
|
|
metadata = {
|
|
'session_id': session.session_id,
|
|
'start_timestamp': session.start_timestamp.isoformat(),
|
|
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
|
'training_mode': session.training_mode,
|
|
'total_steps': session.total_steps,
|
|
'average_loss': session.average_loss
|
|
}
|
|
|
|
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
|
with open(metadata_file, 'w') as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving training session: {e}")
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive training statistics"""
|
|
stats = self.training_stats.copy()
|
|
|
|
if self.training_sessions:
|
|
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
|
|
stats['recent_sessions'] = [
|
|
{
|
|
'session_id': s.session_id,
|
|
'timestamp': s.start_timestamp.isoformat(),
|
|
'mode': s.training_mode,
|
|
'average_loss': s.average_loss
|
|
}
|
|
for s in recent_sessions
|
|
]
|
|
|
|
return stats
|
|
|
|
# Global instance
|
|
rl_trainer = None
|
|
|
|
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
|
|
"""Get global RL trainer instance"""
|
|
global rl_trainer
|
|
if rl_trainer is None:
|
|
if agent is None:
|
|
agent = RLTradingAgent()
|
|
rl_trainer = RLTrainer(agent)
|
|
return rl_trainer |