new overhaul
This commit is contained in:
566
training/enhanced_cnn_trainer.py
Normal file
566
training/enhanced_cnn_trainer.py
Normal file
@ -0,0 +1,566 @@
|
||||
"""
|
||||
Enhanced CNN Trainer with Perfect Move Learning
|
||||
|
||||
This trainer implements:
|
||||
1. Training on marked perfect moves with known outcomes
|
||||
2. Multi-timeframe CNN model training with confidence scoring
|
||||
3. Backpropagation on optimal moves when future outcomes are known
|
||||
4. Progressive learning from real trading experience
|
||||
5. Symbol-specific and timeframe-specific model fine-tuning
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
||||
from models import CNNModelInterface
|
||||
import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PerfectMoveDataset(Dataset):
|
||||
"""Dataset for training on perfect moves with known outcomes"""
|
||||
|
||||
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
||||
"""
|
||||
Initialize dataset from perfect moves
|
||||
|
||||
Args:
|
||||
perfect_moves: List of perfect moves with known outcomes
|
||||
data_provider: Data provider to fetch additional context
|
||||
"""
|
||||
self.perfect_moves = perfect_moves
|
||||
self.data_provider = data_provider
|
||||
self.samples = []
|
||||
self._prepare_samples()
|
||||
|
||||
def _prepare_samples(self):
|
||||
"""Prepare training samples from perfect moves"""
|
||||
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
||||
|
||||
for move in self.perfect_moves:
|
||||
try:
|
||||
# Get feature matrix at the time of the decision
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=move.symbol,
|
||||
timeframes=[move.timeframe],
|
||||
window_size=20,
|
||||
end_time=move.timestamp
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Convert optimal action to label
|
||||
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
label = action_to_label.get(move.optimal_action, 1)
|
||||
|
||||
# Create confidence target (what confidence should have been)
|
||||
confidence_target = move.confidence_should_have_been
|
||||
|
||||
sample = {
|
||||
'features': feature_matrix,
|
||||
'action_label': label,
|
||||
'confidence_target': confidence_target,
|
||||
'symbol': move.symbol,
|
||||
'timeframe': move.timeframe,
|
||||
'outcome': move.actual_outcome,
|
||||
'timestamp': move.timestamp
|
||||
}
|
||||
self.samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error preparing sample for perfect move: {e}")
|
||||
|
||||
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx]
|
||||
|
||||
# Convert to tensors
|
||||
features = torch.FloatTensor(sample['features'])
|
||||
action_label = torch.LongTensor([sample['action_label']])
|
||||
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'action_label': action_label,
|
||||
'confidence_target': confidence_target,
|
||||
'metadata': {
|
||||
'symbol': sample['symbol'],
|
||||
'timeframe': sample['timeframe'],
|
||||
'outcome': sample['outcome'],
|
||||
'timestamp': sample['timestamp']
|
||||
}
|
||||
}
|
||||
|
||||
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
||||
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
CNNModelInterface.__init__(self, config)
|
||||
|
||||
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
||||
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
||||
self.window_size = config.get('window_size', 20)
|
||||
|
||||
# Build the neural network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Training components
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
self.confidence_criterion = nn.MSELoss()
|
||||
|
||||
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the CNN architecture"""
|
||||
# Convolutional feature extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Global average pooling
|
||||
nn.AdaptiveAvgPool1d(1)
|
||||
)
|
||||
|
||||
# Timeframe-specific heads
|
||||
self.timeframe_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.timeframe_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3)
|
||||
)
|
||||
|
||||
# Action prediction heads (one per timeframe)
|
||||
self.action_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
||||
|
||||
# Confidence prediction heads (one per timeframe)
|
||||
self.confidence_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.confidence_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid() # Output between 0 and 1
|
||||
)
|
||||
|
||||
def forward(self, x, timeframe: str = None):
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, window_size, features]
|
||||
timeframe: Specific timeframe to predict for
|
||||
|
||||
Returns:
|
||||
action_probs: Action probabilities
|
||||
confidence: Confidence score
|
||||
"""
|
||||
# Reshape for conv1d: [batch, features, sequence]
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# Extract features
|
||||
features = self.conv_layers(x) # [batch, 256, 1]
|
||||
features = features.squeeze(-1) # [batch, 256]
|
||||
|
||||
if timeframe and timeframe in self.timeframe_heads:
|
||||
# Timeframe-specific prediction
|
||||
tf_features = self.timeframe_heads[timeframe](features)
|
||||
action_logits = self.action_heads[timeframe](tf_features)
|
||||
confidence = self.confidence_heads[timeframe](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
return action_probs, confidence.squeeze(-1)
|
||||
else:
|
||||
# Multi-timeframe prediction (average across timeframes)
|
||||
all_action_probs = []
|
||||
all_confidences = []
|
||||
|
||||
for tf in self.timeframes:
|
||||
tf_features = self.timeframe_heads[tf](features)
|
||||
action_logits = self.action_heads[tf](tf_features)
|
||||
confidence = self.confidence_heads[tf](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
all_action_probs.append(action_probs)
|
||||
all_confidences.append(confidence.squeeze(-1))
|
||||
|
||||
# Average predictions across timeframes
|
||||
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
||||
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
||||
|
||||
return avg_action_probs, avg_confidence
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
||||
"""Predict for specific timeframe"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x, timeframe)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
# Rough estimate for CPU
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
||||
|
||||
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train the model (placeholder for interface compatibility)"""
|
||||
return {}
|
||||
|
||||
class EnhancedCNNTrainer:
|
||||
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced trainer"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model
|
||||
self.model = EnhancedCNNModel(self.config.cnn)
|
||||
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'confidence_accuracy': []
|
||||
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
|
||||
|
||||
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||
"""Train the model on perfect moves from the orchestrator"""
|
||||
if not self.orchestrator:
|
||||
raise ValueError("Orchestrator required for perfect move training")
|
||||
|
||||
# Get perfect moves from orchestrator
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) < min_samples:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
||||
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
||||
|
||||
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Create dataset
|
||||
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
||||
|
||||
# Split into train/validation
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# Training phase
|
||||
train_loss, train_acc = self._train_epoch(train_loader)
|
||||
|
||||
# Validation phase
|
||||
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
||||
|
||||
# Update history
|
||||
self.training_history['train_loss'].append(train_loss)
|
||||
self.training_history['val_loss'].append(val_loss)
|
||||
self.training_history['train_accuracy'].append(train_acc)
|
||||
self.training_history['val_accuracy'].append(val_acc)
|
||||
self.training_history['confidence_accuracy'].append(conf_acc)
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
||||
f"Conf Acc: {conf_acc:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
self._save_model('best_model.pt')
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= self.patience:
|
||||
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# Save final model
|
||||
self._save_model('final_model.pt')
|
||||
|
||||
# Generate training report
|
||||
return self._generate_training_report()
|
||||
|
||||
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
||||
"""Train for one epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
for batch in train_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Zero gradients
|
||||
self.model.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
|
||||
# Combined loss
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss_batch.backward()
|
||||
self.model.optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
||||
"""Validate for one epoch"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
confidence_errors = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
# Track confidence accuracy
|
||||
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
||||
confidence_errors.extend(conf_errors.cpu().numpy())
|
||||
|
||||
avg_loss = total_loss / len(val_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
||||
|
||||
return avg_loss, accuracy, confidence_accuracy
|
||||
|
||||
def _save_model(self, filename: str):
|
||||
"""Save the model"""
|
||||
save_path = self.save_dir / filename
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||
'config': self.config.cnn,
|
||||
'training_history': self.training_history
|
||||
}, save_path)
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
def load_model(self, filename: str) -> bool:
|
||||
"""Load a saved model"""
|
||||
load_path = self.save_dir / filename
|
||||
if not load_path.exists():
|
||||
logger.error(f"Model file not found: {load_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(load_path, map_location=self.model.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.training_history = checkpoint.get('training_history', {})
|
||||
logger.info(f"Model loaded from {load_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def _generate_training_report(self) -> Dict[str, Any]:
|
||||
"""Generate comprehensive training report"""
|
||||
if not self.training_history['train_loss']:
|
||||
return {'error': 'no_training_data'}
|
||||
|
||||
# Calculate final metrics
|
||||
final_train_loss = self.training_history['train_loss'][-1]
|
||||
final_val_loss = self.training_history['val_loss'][-1]
|
||||
final_train_acc = self.training_history['train_accuracy'][-1]
|
||||
final_val_acc = self.training_history['val_accuracy'][-1]
|
||||
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
||||
|
||||
# Best metrics
|
||||
best_val_loss = min(self.training_history['val_loss'])
|
||||
best_val_acc = max(self.training_history['val_accuracy'])
|
||||
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
||||
|
||||
report = {
|
||||
'training_completed': True,
|
||||
'epochs_trained': len(self.training_history['train_loss']),
|
||||
'final_metrics': {
|
||||
'train_loss': final_train_loss,
|
||||
'val_loss': final_val_loss,
|
||||
'train_accuracy': final_train_acc,
|
||||
'val_accuracy': final_val_acc,
|
||||
'confidence_accuracy': final_conf_acc
|
||||
},
|
||||
'best_metrics': {
|
||||
'val_loss': best_val_loss,
|
||||
'val_accuracy': best_val_acc,
|
||||
'confidence_accuracy': best_conf_acc
|
||||
},
|
||||
'model_info': {
|
||||
'timeframes': self.model.timeframes,
|
||||
'memory_usage_mb': self.model.get_memory_usage(),
|
||||
'device': str(self.model.device)
|
||||
}
|
||||
}
|
||||
|
||||
# Generate plots
|
||||
self._plot_training_history()
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
||||
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
||||
|
||||
return report
|
||||
|
||||
def _plot_training_history(self):
|
||||
"""Plot training history"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
fig.suptitle('Enhanced CNN Training History')
|
||||
|
||||
# Loss plot
|
||||
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
||||
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
||||
axes[0, 0].set_title('Loss')
|
||||
axes[0, 0].set_xlabel('Epoch')
|
||||
axes[0, 0].set_ylabel('Loss')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Accuracy plot
|
||||
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
||||
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
||||
axes[0, 1].set_title('Action Accuracy')
|
||||
axes[0, 1].set_xlabel('Epoch')
|
||||
axes[0, 1].set_ylabel('Accuracy')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Confidence accuracy plot
|
||||
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
||||
axes[1, 0].set_xlabel('Epoch')
|
||||
axes[1, 0].set_ylabel('Accuracy')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Learning curves comparison
|
||||
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
||||
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 1].set_title('Model Performance Overview')
|
||||
axes[1, 1].set_xlabel('Epoch')
|
||||
axes[1, 1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
||||
|
||||
def get_model(self) -> EnhancedCNNModel:
|
||||
"""Get the trained model"""
|
||||
return self.model
|
625
training/enhanced_rl_trainer.py
Normal file
625
training/enhanced_rl_trainer.py
Normal file
@ -0,0 +1,625 @@
|
||||
"""
|
||||
Enhanced RL Trainer with Market Environment Adaptation
|
||||
|
||||
This trainer implements:
|
||||
1. Continuous learning from orchestrator action evaluations
|
||||
2. Environment adaptation based on market regime changes
|
||||
3. Multi-symbol coordinated RL training
|
||||
4. Experience replay with prioritized sampling
|
||||
5. Dynamic reward shaping based on market conditions
|
||||
"""
|
||||
|
||||
import asyncioimport asyncioimport loggingimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom collections import deque, namedtupleimport randomfrom datetime import datetime, timedeltafrom typing import Dict, List, Optional, Tuple, Anyimport matplotlib.pyplot as pltfrom pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from models import RLAgentInterface
|
||||
import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Experience tuple for replay buffer
|
||||
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
|
||||
|
||||
class PrioritizedReplayBuffer:
|
||||
"""Prioritized experience replay buffer for RL training"""
|
||||
|
||||
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
|
||||
"""
|
||||
Initialize prioritized replay buffer
|
||||
|
||||
Args:
|
||||
capacity: Maximum number of experiences to store
|
||||
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.alpha = alpha
|
||||
self.buffer = []
|
||||
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
|
||||
def add(self, experience: Experience):
|
||||
"""Add experience to buffer with priority"""
|
||||
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
|
||||
|
||||
if self.size < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.size += 1
|
||||
else:
|
||||
self.buffer[self.position] = experience
|
||||
|
||||
self.priorities[self.position] = max_priority
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
|
||||
"""Sample batch with prioritized sampling"""
|
||||
if self.size == 0:
|
||||
return [], np.array([]), np.array([])
|
||||
|
||||
# Calculate sampling probabilities
|
||||
priorities = self.priorities[:self.size] ** self.alpha
|
||||
probabilities = priorities / priorities.sum()
|
||||
|
||||
# Sample indices
|
||||
indices = np.random.choice(self.size, batch_size, p=probabilities)
|
||||
experiences = [self.buffer[i] for i in indices]
|
||||
|
||||
# Calculate importance sampling weights
|
||||
weights = (self.size * probabilities[indices]) ** (-beta)
|
||||
weights = weights / weights.max() # Normalize
|
||||
|
||||
return experiences, indices, weights
|
||||
|
||||
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
|
||||
"""Update priorities for sampled experiences"""
|
||||
for idx, priority in zip(indices, priorities):
|
||||
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
"""Enhanced DQN agent with market environment adaptation"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
RLAgentInterface.__init__(self, config)
|
||||
|
||||
# Network architecture
|
||||
self.state_size = config.get('state_size', 100)
|
||||
self.action_space = config.get('action_space', 3)
|
||||
self.hidden_size = config.get('hidden_size', 256)
|
||||
|
||||
# Build networks
|
||||
self._build_networks()
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = config.get('learning_rate', 0.0001)
|
||||
self.gamma = config.get('gamma', 0.99)
|
||||
self.epsilon = config.get('epsilon', 1.0)
|
||||
self.epsilon_decay = config.get('epsilon_decay', 0.995)
|
||||
self.epsilon_min = config.get('epsilon_min', 0.01)
|
||||
self.target_update_freq = config.get('target_update_freq', 1000)
|
||||
|
||||
# Initialize device and optimizer
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Experience replay
|
||||
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
|
||||
self.batch_size = config.get('batch_size', 64)
|
||||
|
||||
# Market adaptation
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Training statistics
|
||||
self.training_steps = 0
|
||||
self.losses = []
|
||||
self.rewards = []
|
||||
self.epsilon_history = []
|
||||
|
||||
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
|
||||
|
||||
def _build_networks(self):
|
||||
"""Build main and target networks"""
|
||||
# Main network
|
||||
self.main_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# Dueling network heads
|
||||
self.value_head = nn.Linear(128, 1)
|
||||
self.advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Target network (copy of main network)
|
||||
self.target_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
self.target_value_head = nn.Linear(128, 1)
|
||||
self.target_advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Initialize target network with same weights
|
||||
self._update_target_network()
|
||||
|
||||
def forward(self, state, target: bool = False):
|
||||
"""Forward pass through the network"""
|
||||
if target:
|
||||
features = self.target_network(state)
|
||||
value = self.target_value_head(features)
|
||||
advantage = self.target_advantage_head(features)
|
||||
else:
|
||||
features = self.main_network(state)
|
||||
value = self.value_head(features)
|
||||
advantage = self.advantage_head(features)
|
||||
|
||||
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
|
||||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||
|
||||
return q_values
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
return random.randint(0, self.action_space - 1)
|
||||
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
return q_values.argmax().item()
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in replay buffer"""
|
||||
# Calculate TD error for priority
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
|
||||
|
||||
current_q = self.forward(state_tensor)[0, action]
|
||||
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
|
||||
target_q = reward + (self.gamma * next_q * (1 - done))
|
||||
|
||||
td_error = abs(current_q.item() - target_q.item())
|
||||
|
||||
experience = Experience(state, action, reward, next_state, done, td_error)
|
||||
self.replay_buffer.add(experience)
|
||||
|
||||
def replay(self) -> Optional[float]:
|
||||
"""Train the network on a batch of experiences"""
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return None
|
||||
|
||||
# Sample batch
|
||||
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
|
||||
|
||||
if not experiences:
|
||||
return None
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
|
||||
weights_tensor = torch.FloatTensor(weights).to(self.device)
|
||||
|
||||
# Current Q-values
|
||||
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Target Q-values (Double DQN)
|
||||
with torch.no_grad():
|
||||
# Use main network to select actions
|
||||
next_actions = self.forward(next_states).argmax(1)
|
||||
# Use target network to evaluate actions
|
||||
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
|
||||
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
|
||||
|
||||
# Calculate weighted loss
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights_tensor * (td_errors ** 2)).mean()
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Update priorities
|
||||
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
|
||||
self.replay_buffer.update_priorities(indices, new_priorities)
|
||||
|
||||
# Update target network
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self._update_target_network()
|
||||
|
||||
# Decay epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
# Track statistics
|
||||
self.losses.append(loss.item())
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def _update_target_network(self):
|
||||
"""Update target network with main network weights"""
|
||||
self.target_network.load_state_dict(self.main_network.state_dict())
|
||||
self.target_value_head.load_state_dict(self.value_head.state_dict())
|
||||
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]: """Predict action probabilities and confidence (required by ModelInterface)""" action, confidence = self.act_with_confidence(features) # Convert action to probabilities action_probs = np.zeros(self.action_space) action_probs[action] = 1.0 return action_probs, confidence def get_memory_usage(self) -> int: """Get memory usage in MB""" if torch.cuda.is_available(): return torch.cuda.memory_allocated(self.device) // (1024 * 1024) else: param_count = sum(p.numel() for p in self.parameters()) buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with continuous learning from market feedback"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced RL trainer"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Create RL agents for each symbol
|
||||
self.agents = {}
|
||||
for symbol in self.config.symbols:
|
||||
agent_config = self.config.rl.copy()
|
||||
agent_config['name'] = f'RL_{symbol}'
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
|
||||
# Training parameters
|
||||
self.training_interval = 3600 # Train every hour
|
||||
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
|
||||
self.min_experiences = 100 # Minimum experiences before training
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = {symbol: [] for symbol in self.config.symbols}
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.config.symbols},
|
||||
'losses': {symbol: [] for symbol in self.config.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
|
||||
}
|
||||
|
||||
# Create save directory models_path = self.config.rl.get('model_dir', "models/enhanced_rl") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
|
||||
|
||||
async def continuous_learning_loop(self):
|
||||
"""Main continuous learning loop"""
|
||||
logger.info("Starting continuous RL learning loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Train agents with recent experiences
|
||||
await self._train_all_agents()
|
||||
|
||||
# Evaluate recent actions
|
||||
if self.orchestrator:
|
||||
await self.orchestrator.evaluate_actions_with_rl()
|
||||
|
||||
# Adapt to market regime changes
|
||||
await self._adapt_to_market_changes()
|
||||
|
||||
# Update performance metrics
|
||||
self._update_performance_metrics()
|
||||
|
||||
# Save models periodically
|
||||
if self.training_metrics['total_episodes'] % 100 == 0:
|
||||
self._save_all_models()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(self.training_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous learning loop: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute on error
|
||||
|
||||
async def _train_all_agents(self):
|
||||
"""Train all RL agents with their experiences"""
|
||||
for symbol, agent in self.agents.items():
|
||||
try:
|
||||
if len(agent.replay_buffer) >= self.min_experiences:
|
||||
# Train for multiple steps
|
||||
losses = []
|
||||
for _ in range(10): # Train 10 steps per cycle
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
losses.append(loss)
|
||||
|
||||
if losses:
|
||||
avg_loss = np.mean(losses)
|
||||
self.training_metrics['losses'][symbol].append(avg_loss)
|
||||
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
|
||||
|
||||
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training {symbol} agent: {e}")
|
||||
|
||||
async def _adapt_to_market_changes(self):
|
||||
"""Adapt agents to market regime changes"""
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
for symbol in self.config.symbols:
|
||||
try:
|
||||
# Get recent market states
|
||||
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||
|
||||
if len(recent_states) < 5:
|
||||
continue
|
||||
|
||||
# Analyze regime stability
|
||||
regimes = [state.market_regime for state in recent_states]
|
||||
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
|
||||
|
||||
# Adjust learning parameters based on stability
|
||||
agent = self.agents[symbol]
|
||||
if regime_stability < 0.3: # Stable regime
|
||||
agent.epsilon *= 0.99 # Faster epsilon decay
|
||||
elif regime_stability > 0.7: # Unstable regime
|
||||
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
|
||||
|
||||
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adapting {symbol} to market changes: {e}")
|
||||
|
||||
def add_trading_experience(self, symbol: str, action: TradingAction,
|
||||
initial_state: MarketState, final_state: MarketState,
|
||||
reward: float):
|
||||
"""Add trading experience to the appropriate agent"""
|
||||
if symbol not in self.agents:
|
||||
logger.warning(f"No agent for symbol {symbol}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert market states to RL state vectors
|
||||
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||
final_rl_state = self._market_state_to_rl_state(final_state)
|
||||
|
||||
# Convert action to RL action index
|
||||
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
action_idx = action_mapping.get(action.action, 1)
|
||||
|
||||
# Store experience
|
||||
agent = self.agents[symbol]
|
||||
agent.remember(
|
||||
state=initial_rl_state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=final_rl_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Track reward
|
||||
self.training_metrics['total_rewards'][symbol].append(reward)
|
||||
|
||||
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to RL state vector"""
|
||||
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
|
||||
return self.orchestrator._market_state_to_rl_state(market_state)
|
||||
|
||||
# Fallback implementation
|
||||
state_components = [
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
market_state.trend_strength
|
||||
]
|
||||
|
||||
# Add price features
|
||||
for timeframe in sorted(market_state.prices.keys()):
|
||||
state_components.append(market_state.prices[timeframe])
|
||||
|
||||
# Pad or truncate to expected state size
|
||||
expected_size = self.config.rl.get('state_size', 100)
|
||||
if len(state_components) < expected_size:
|
||||
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||
else:
|
||||
state_components = state_components[:expected_size]
|
||||
|
||||
return np.array(state_components, dtype=np.float32)
|
||||
|
||||
def _update_performance_metrics(self):
|
||||
"""Update performance tracking metrics"""
|
||||
self.training_metrics['total_episodes'] += 1
|
||||
|
||||
# Calculate recent performance for each agent
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
|
||||
if recent_rewards:
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
self.performance_history[symbol].append({
|
||||
'timestamp': datetime.now(),
|
||||
'avg_reward': avg_reward,
|
||||
'epsilon': agent.epsilon,
|
||||
'experiences': len(agent.replay_buffer)
|
||||
})
|
||||
|
||||
def _save_all_models(self):
|
||||
"""Save all RL models"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
torch.save({
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': agent.optimizer.state_dict(),
|
||||
'config': self.config.rl,
|
||||
'training_metrics': self.training_metrics,
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps
|
||||
}, filepath)
|
||||
|
||||
logger.info(f"Saved {symbol} RL agent to {filepath}")
|
||||
|
||||
def load_models(self, timestamp: str = None):
|
||||
"""Load RL models from files"""
|
||||
if timestamp is None:
|
||||
# Find most recent models
|
||||
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
|
||||
if not model_files:
|
||||
logger.warning("No saved RL models found")
|
||||
return False
|
||||
|
||||
# Group by timestamp and get most recent
|
||||
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
|
||||
timestamp = max(timestamps)
|
||||
|
||||
loaded_count = 0
|
||||
for symbol in self.config.symbols:
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
if filepath.exists():
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
|
||||
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
|
||||
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
|
||||
|
||||
logger.info(f"Loaded {symbol} RL agent from {filepath}")
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {symbol} RL agent: {e}")
|
||||
|
||||
return loaded_count > 0
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
"""Generate performance report for all agents"""
|
||||
report = {
|
||||
'total_episodes': self.training_metrics['total_episodes'],
|
||||
'agents': {}
|
||||
}
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
|
||||
recent_losses = self.training_metrics['losses'][symbol][-10:]
|
||||
|
||||
agent_report = {
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps,
|
||||
'experiences_stored': len(agent.replay_buffer),
|
||||
'memory_usage_mb': agent.get_memory_usage(),
|
||||
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
|
||||
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
|
||||
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
|
||||
}
|
||||
|
||||
report['agents'][symbol] = agent_report
|
||||
|
||||
return report
|
||||
|
||||
def plot_training_metrics(self):
|
||||
"""Plot training metrics for all agents"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||
fig.suptitle('Enhanced RL Training Metrics')
|
||||
|
||||
symbols = list(self.agents.keys())
|
||||
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
|
||||
|
||||
# Rewards plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
rewards = self.training_metrics['total_rewards'][symbol]
|
||||
if rewards:
|
||||
# Moving average of rewards
|
||||
window = min(100, len(rewards))
|
||||
if len(rewards) >= window:
|
||||
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
||||
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 0].set_title('Average Rewards (Moving Average)')
|
||||
axes[0, 0].set_xlabel('Episodes')
|
||||
axes[0, 0].set_ylabel('Reward')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Losses plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
losses = self.training_metrics['losses'][symbol]
|
||||
if losses:
|
||||
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 1].set_title('Training Losses')
|
||||
axes[0, 1].set_xlabel('Training Steps')
|
||||
axes[0, 1].set_ylabel('Loss')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Epsilon values
|
||||
for i, symbol in enumerate(symbols):
|
||||
epsilon_values = self.training_metrics['epsilon_values'][symbol]
|
||||
if epsilon_values:
|
||||
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[1, 0].set_title('Exploration Rate (Epsilon)')
|
||||
axes[1, 0].set_xlabel('Training Steps')
|
||||
axes[1, 0].set_ylabel('Epsilon')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Experience buffer sizes
|
||||
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
|
||||
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
|
||||
axes[1, 1].set_title('Experience Buffer Sizes')
|
||||
axes[1, 1].set_ylabel('Number of Experiences')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
|
||||
|
||||
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
|
||||
"""Get all RL agents"""
|
||||
return self.agents
|
Reference in New Issue
Block a user