Files
gogo2/core/rl_training_pipeline.py
Dobromir Popov 12865fd3ef replay system
2025-07-20 12:37:02 +03:00

529 lines
20 KiB
Python

"""
RL Training Pipeline with Comprehensive Experience Storage and Replay
This module implements a robust RL training pipeline that:
1. Stores all training experiences with profitability metrics
2. Implements profit-weighted experience replay
3. Tracks gradient information for each training step
4. Enables retraining on most profitable trading sequences
5. Maintains comprehensive trading episode analysis
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
import json
import pickle
from collections import deque
import threading
import random
from .training_data_collector import get_training_data_collector
logger = logging.getLogger(__name__)
@dataclass
class RLExperience:
"""Single RL experience with complete state-action-reward information"""
experience_id: str
timestamp: datetime
episode_id: str
# Core RL components
state: np.ndarray
action: int # 0=SELL, 1=HOLD, 2=BUY
reward: float
next_state: np.ndarray
done: bool
# Extended state information
market_context: Dict[str, Any]
cnn_predictions: Optional[Dict[str, Any]] = None
confidence_score: float = 0.0
# Actual trading outcome
actual_profit: Optional[float] = None
actual_holding_time: Optional[timedelta] = None
optimal_action: Optional[int] = None
# Experience value for replay
experience_value: float = 0.0
profitability_score: float = 0.0
learning_priority: float = 0.0
# Training metadata
times_trained: int = 0
last_trained: Optional[datetime] = None
class ProfitWeightedExperienceBuffer:
"""Experience buffer with profit-weighted sampling for replay"""
def __init__(self, max_size: int = 100000):
self.max_size = max_size
self.experiences: Dict[str, RLExperience] = {}
self.experience_order: deque = deque(maxlen=max_size)
self.profitable_experiences: List[str] = []
self.total_experiences = 0
self.total_profitable = 0
def add_experience(self, experience: RLExperience):
"""Add experience to buffer"""
try:
self.experiences[experience.experience_id] = experience
self.experience_order.append(experience.experience_id)
if experience.actual_profit is not None and experience.actual_profit > 0:
self.profitable_experiences.append(experience.experience_id)
self.total_profitable += 1
# Remove oldest if buffer is full
if len(self.experiences) > self.max_size:
oldest_id = self.experience_order[0]
if oldest_id in self.experiences:
del self.experiences[oldest_id]
if oldest_id in self.profitable_experiences:
self.profitable_experiences.remove(oldest_id)
self.total_experiences += 1
except Exception as e:
logger.error(f"Error adding experience to buffer: {e}")
def sample_batch(self, batch_size: int, prioritize_profitable: bool = True) -> List[RLExperience]:
"""Sample batch with profit-weighted prioritization"""
try:
if len(self.experiences) < batch_size:
return list(self.experiences.values())
if prioritize_profitable and len(self.profitable_experiences) > batch_size // 2:
# Sample mix of profitable and all experiences
profitable_sample_size = min(batch_size // 2, len(self.profitable_experiences))
remaining_sample_size = batch_size - profitable_sample_size
profitable_ids = random.sample(self.profitable_experiences, profitable_sample_size)
all_ids = list(self.experiences.keys())
remaining_ids = random.sample(all_ids, remaining_sample_size)
sampled_ids = profitable_ids + remaining_ids
else:
# Random sampling from all experiences
all_ids = list(self.experiences.keys())
sampled_ids = random.sample(all_ids, batch_size)
sampled_experiences = [self.experiences[exp_id] for exp_id in sampled_ids]
# Update training counts
for experience in sampled_experiences:
experience.times_trained += 1
experience.last_trained = datetime.now()
return sampled_experiences
except Exception as e:
logger.error(f"Error sampling batch: {e}")
return list(self.experiences.values())[:batch_size]
def get_most_profitable_experiences(self, limit: int = 100) -> List[RLExperience]:
"""Get most profitable experiences for targeted training"""
try:
profitable_experiences = [
self.experiences[exp_id] for exp_id in self.profitable_experiences
if exp_id in self.experiences
]
profitable_experiences.sort(
key=lambda x: x.actual_profit if x.actual_profit else 0,
reverse=True
)
return profitable_experiences[:limit]
except Exception as e:
logger.error(f"Error getting profitable experiences: {e}")
return []
class RLTradingAgent(nn.Module):
"""RL Trading Agent with comprehensive state processing"""
def __init__(self, state_dim: int = 2000, action_dim: int = 3, hidden_dim: int = 512):
super(RLTradingAgent, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_dim = hidden_dim
# State processing network
self.state_processor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU()
)
# Q-value network
self.q_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim)
)
# Policy network
self.policy_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, action_dim),
nn.Softmax(dim=-1)
)
# Value network
self.value_network = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, state):
"""Forward pass through the agent"""
processed_state = self.state_processor(state)
q_values = self.q_network(processed_state)
policy_probs = self.policy_network(processed_state)
state_value = self.value_network(processed_state)
return {
'q_values': q_values,
'policy_probs': policy_probs,
'state_value': state_value,
'processed_state': processed_state
}
def select_action(self, state, epsilon: float = 0.1) -> Tuple[int, float]:
"""Select action using epsilon-greedy policy"""
self.eval()
with torch.no_grad():
if isinstance(state, np.ndarray):
state = torch.from_numpy(state).float().unsqueeze(0)
outputs = self.forward(state)
if random.random() < epsilon:
action = random.randint(0, self.action_dim - 1)
confidence = 0.33
else:
q_values = outputs['q_values']
action = torch.argmax(q_values, dim=1).item()
q_softmax = F.softmax(q_values, dim=1)
confidence = torch.max(q_softmax).item()
return action, confidence
@dataclass
class RLTrainingStep:
"""Single RL training step with backpropagation data"""
step_id: str
timestamp: datetime
batch_experiences: List[str]
# Training data
total_loss: float
q_loss: float
policy_loss: float
# Gradients
gradients: Dict[str, torch.Tensor]
gradient_norms: Dict[str, float]
# Metadata
learning_rate: float = 0.001
batch_size: int = 32
# Performance
batch_profitability: float = 0.0
correct_actions: int = 0
total_actions: int = 0
step_value: float = 0.0
@dataclass
class RLTrainingSession:
"""Complete RL training session"""
session_id: str
start_timestamp: datetime
end_timestamp: Optional[datetime] = None
training_mode: str = 'experience_replay'
symbol: str = ''
training_steps: List[RLTrainingStep] = field(default_factory=list)
total_steps: int = 0
average_loss: float = 0.0
best_loss: float = float('inf')
profitable_actions: int = 0
total_actions: int = 0
profitability_rate: float = 0.0
session_value: float = 0.0
class RLTrainer:
"""RL trainer with comprehensive experience storage and replay"""
def __init__(self, agent: RLTradingAgent, device: str = 'cuda', storage_dir: str = "rl_training_storage"):
self.agent = agent.to(device)
self.device = device
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001)
self.experience_buffer = ProfitWeightedExperienceBuffer()
self.data_collector = get_training_data_collector()
self.training_sessions: List[RLTrainingSession] = []
self.current_session: Optional[RLTrainingSession] = None
self.gamma = 0.99
self.training_stats = {
'total_sessions': 0,
'total_steps': 0,
'total_experiences': 0,
'profitable_actions': 0,
'total_actions': 0,
'average_reward': 0.0
}
logger.info(f"RL Trainer initialized with {sum(p.numel() for p in agent.parameters()):,} parameters")
def add_experience(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, market_context: Dict[str, Any],
cnn_predictions: Dict[str, Any] = None, confidence_score: float = 0.0) -> str:
"""Add experience to the buffer"""
try:
experience_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
experience = RLExperience(
experience_id=experience_id,
timestamp=datetime.now(),
episode_id=market_context.get('episode_id', 'unknown'),
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
market_context=market_context,
cnn_predictions=cnn_predictions,
confidence_score=confidence_score
)
self.experience_buffer.add_experience(experience)
self.training_stats['total_experiences'] += 1
return experience_id
except Exception as e:
logger.error(f"Error adding experience: {e}")
return None
def train_on_experiences(self, batch_size: int = 32, num_batches: int = 10) -> Dict[str, Any]:
"""Train on experiences with comprehensive data storage"""
try:
session = RLTrainingSession(
session_id=f"rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
start_timestamp=datetime.now(),
training_mode='experience_replay'
)
self.current_session = session
self.agent.train()
total_loss = 0.0
for batch_idx in range(num_batches):
experiences = self.experience_buffer.sample_batch(batch_size, True)
if len(experiences) < batch_size:
continue
# Prepare batch tensors
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.BoolTensor([exp.done for exp in experiences]).to(self.device)
# Forward pass
self.optimizer.zero_grad()
current_outputs = self.agent(states)
current_q_values = current_outputs['q_values']
# Calculate target Q-values
with torch.no_grad():
next_outputs = self.agent(next_states)
next_q_values = next_outputs['q_values']
max_next_q_values = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + (self.gamma * max_next_q_values * ~dones)
# Calculate loss
current_q_values_for_actions = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
q_loss = F.mse_loss(current_q_values_for_actions, target_q_values)
# Backward pass
q_loss.backward()
# Store gradients
gradients = {}
gradient_norms = {}
for name, param in self.agent.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone().detach()
gradient_norms[name] = param.grad.norm().item()
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), max_norm=1.0)
self.optimizer.step()
# Create training step record
step = RLTrainingStep(
step_id=f"{session.session_id}_step_{batch_idx}",
timestamp=datetime.now(),
batch_experiences=[exp.experience_id for exp in experiences],
total_loss=q_loss.item(),
q_loss=q_loss.item(),
policy_loss=0.0,
gradients=gradients,
gradient_norms=gradient_norms,
batch_size=len(experiences)
)
session.training_steps.append(step)
total_loss += q_loss.item()
# Finalize session
session.end_timestamp = datetime.now()
session.total_steps = num_batches
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
self._save_training_session(session)
self.training_stats['total_sessions'] += 1
self.training_stats['total_steps'] += session.total_steps
logger.info(f"RL training session completed: {session.session_id}")
logger.info(f"Average loss: {session.average_loss:.4f}")
return {
'status': 'success',
'session_id': session.session_id,
'average_loss': session.average_loss,
'total_steps': session.total_steps
}
except Exception as e:
logger.error(f"Error in RL training session: {e}")
return {'status': 'error', 'error': str(e)}
finally:
self.current_session = None
def train_on_profitable_experiences(self, min_profitability: float = 0.1,
max_experiences: int = 1000, batch_size: int = 32) -> Dict[str, Any]:
"""Train specifically on most profitable experiences"""
try:
profitable_experiences = self.experience_buffer.get_most_profitable_experiences(max_experiences)
filtered_experiences = [
exp for exp in profitable_experiences
if exp.actual_profit is not None and exp.actual_profit >= min_profitability
]
if len(filtered_experiences) < batch_size:
return {'status': 'insufficient_data', 'experiences_found': len(filtered_experiences)}
logger.info(f"Training on {len(filtered_experiences)} profitable experiences")
num_batches = len(filtered_experiences) // batch_size
# Temporarily replace buffer sampling
original_sample_method = self.experience_buffer.sample_batch
def profitable_sample_batch(batch_size, prioritize_profitable=True):
return random.sample(filtered_experiences, min(batch_size, len(filtered_experiences)))
self.experience_buffer.sample_batch = profitable_sample_batch
try:
results = self.train_on_experiences(batch_size=batch_size, num_batches=num_batches)
results['training_mode'] = 'profitable_replay'
results['experiences_used'] = len(filtered_experiences)
return results
finally:
self.experience_buffer.sample_batch = original_sample_method
except Exception as e:
logger.error(f"Error training on profitable experiences: {e}")
return {'status': 'error', 'error': str(e)}
def _save_training_session(self, session: RLTrainingSession):
"""Save training session to disk"""
try:
session_dir = self.storage_dir / 'sessions'
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{session.session_id}.pkl"
with open(session_file, 'wb') as f:
pickle.dump(session, f)
metadata = {
'session_id': session.session_id,
'start_timestamp': session.start_timestamp.isoformat(),
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
'training_mode': session.training_mode,
'total_steps': session.total_steps,
'average_loss': session.average_loss
}
metadata_file = session_dir / f"{session.session_id}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving training session: {e}")
def get_training_statistics(self) -> Dict[str, Any]:
"""Get comprehensive training statistics"""
stats = self.training_stats.copy()
if self.training_sessions:
recent_sessions = sorted(self.training_sessions, key=lambda x: x.start_timestamp, reverse=True)[:10]
stats['recent_sessions'] = [
{
'session_id': s.session_id,
'timestamp': s.start_timestamp.isoformat(),
'mode': s.training_mode,
'average_loss': s.average_loss
}
for s in recent_sessions
]
return stats
# Global instance
rl_trainer = None
def get_rl_trainer(agent: RLTradingAgent = None) -> RLTrainer:
"""Get global RL trainer instance"""
global rl_trainer
if rl_trainer is None:
if agent is None:
agent = RLTradingAgent()
rl_trainer = RLTrainer(agent)
return rl_trainer