575 lines
23 KiB
Python
575 lines
23 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import os
|
|
import logging
|
|
import torch.nn.functional as F
|
|
from typing import List, Tuple, Dict, Any, Optional, Union
|
|
|
|
# Configure logger
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""
|
|
Residual block with pre-activation (BatchNorm -> ReLU -> Conv)
|
|
"""
|
|
def __init__(self, in_channels, out_channels, stride=1):
|
|
super(ResidualBlock, self).__init__()
|
|
self.bn1 = nn.BatchNorm1d(in_channels)
|
|
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
self.bn2 = nn.BatchNorm1d(out_channels)
|
|
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
# Shortcut connection to match dimensions
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = F.relu(self.bn1(x))
|
|
shortcut = self.shortcut(out)
|
|
out = self.conv1(out)
|
|
out = self.conv2(F.relu(self.bn2(out)))
|
|
out += shortcut
|
|
return out
|
|
|
|
class SelfAttention(nn.Module):
|
|
"""
|
|
Self-attention mechanism for sequential data
|
|
"""
|
|
def __init__(self, dim):
|
|
super(SelfAttention, self).__init__()
|
|
self.query = nn.Linear(dim, dim)
|
|
self.key = nn.Linear(dim, dim)
|
|
self.value = nn.Linear(dim, dim)
|
|
self.scale = torch.sqrt(torch.tensor(dim, dtype=torch.float32))
|
|
|
|
def forward(self, x):
|
|
# x shape: [batch_size, seq_len, dim]
|
|
batch_size, seq_len, dim = x.size()
|
|
|
|
q = self.query(x) # [batch_size, seq_len, dim]
|
|
k = self.key(x) # [batch_size, seq_len, dim]
|
|
v = self.value(x) # [batch_size, seq_len, dim]
|
|
|
|
# Calculate attention scores
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [batch_size, seq_len, seq_len]
|
|
|
|
# Apply softmax to get attention weights
|
|
attention = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
|
|
|
|
# Apply attention to values
|
|
out = torch.matmul(attention, v) # [batch_size, seq_len, dim]
|
|
|
|
return out, attention
|
|
|
|
class EnhancedCNN(nn.Module):
|
|
"""
|
|
Enhanced CNN model with residual connections and attention mechanisms
|
|
for improved trading decision making
|
|
"""
|
|
def __init__(self, input_shape, n_actions, confidence_threshold=0.5):
|
|
super(EnhancedCNN, self).__init__()
|
|
|
|
# Store dimensions
|
|
self.input_shape = input_shape
|
|
self.n_actions = n_actions
|
|
self.confidence_threshold = confidence_threshold
|
|
|
|
# Calculate input dimensions
|
|
if isinstance(input_shape, (list, tuple)):
|
|
if len(input_shape) == 3: # [channels, height, width]
|
|
self.channels, self.height, self.width = input_shape
|
|
self.feature_dim = self.height * self.width
|
|
elif len(input_shape) == 2: # [timeframes, features]
|
|
self.channels = input_shape[0]
|
|
self.features = input_shape[1]
|
|
self.feature_dim = self.features * self.channels
|
|
elif len(input_shape) == 1: # [features]
|
|
self.channels = 1
|
|
self.features = input_shape[0]
|
|
self.feature_dim = self.features
|
|
else:
|
|
raise ValueError(f"Unsupported input shape: {input_shape}")
|
|
else: # single integer
|
|
self.channels = 1
|
|
self.features = input_shape
|
|
self.feature_dim = input_shape
|
|
|
|
# Build network
|
|
self._build_network()
|
|
|
|
# Initialize device
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.to(self.device)
|
|
|
|
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
|
|
|
def _build_network(self):
|
|
"""Build the MASSIVELY enhanced neural network for 4GB VRAM budget"""
|
|
|
|
# MASSIVELY SCALED ARCHITECTURE for 4GB VRAM (up to ~50M parameters)
|
|
if self.channels > 1:
|
|
# Massive convolutional backbone with deeper residual blocks
|
|
self.conv_layers = nn.Sequential(
|
|
# Initial large conv block
|
|
nn.Conv1d(self.channels, 256, kernel_size=7, padding=3), # Much wider initial layer
|
|
nn.BatchNorm1d(256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
|
|
# First residual stage - 256 channels
|
|
ResidualBlock(256, 512),
|
|
ResidualBlock(512, 512),
|
|
ResidualBlock(512, 512),
|
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
|
nn.Dropout(0.2),
|
|
|
|
# Second residual stage - 512 channels
|
|
ResidualBlock(512, 1024),
|
|
ResidualBlock(1024, 1024),
|
|
ResidualBlock(1024, 1024),
|
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
|
nn.Dropout(0.25),
|
|
|
|
# Third residual stage - 1024 channels
|
|
ResidualBlock(1024, 1536),
|
|
ResidualBlock(1536, 1536),
|
|
ResidualBlock(1536, 1536),
|
|
nn.MaxPool1d(kernel_size=2, stride=2),
|
|
nn.Dropout(0.3),
|
|
|
|
# Fourth residual stage - 1536 channels (MASSIVE)
|
|
ResidualBlock(1536, 2048),
|
|
ResidualBlock(2048, 2048),
|
|
ResidualBlock(2048, 2048),
|
|
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
|
)
|
|
# Massive feature dimension after conv layers
|
|
self.conv_features = 2048
|
|
else:
|
|
# For 1D vectors, use massive dense preprocessing
|
|
self.conv_layers = None
|
|
self.conv_features = 0
|
|
|
|
# MASSIVE fully connected feature extraction layers
|
|
if self.conv_layers is None:
|
|
# For 1D inputs - massive feature extraction
|
|
self.fc1 = nn.Linear(self.feature_dim, 2048)
|
|
self.features_dim = 2048
|
|
else:
|
|
# For data processed by massive conv layers
|
|
self.fc1 = nn.Linear(self.conv_features, 2048)
|
|
self.features_dim = 2048
|
|
|
|
# MASSIVE common feature extraction with multiple attention layers
|
|
self.fc_layers = nn.Sequential(
|
|
self.fc1,
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(2048, 2048), # Keep massive width
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(2048, 1536), # Still very wide
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(1536, 1024), # Large hidden layer
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(1024, 768), # Final feature representation
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Multiple attention mechanisms for different aspects
|
|
self.price_attention = SelfAttention(768)
|
|
self.volume_attention = SelfAttention(768)
|
|
self.trend_attention = SelfAttention(768)
|
|
self.volatility_attention = SelfAttention(768)
|
|
|
|
# Attention fusion layer
|
|
self.attention_fusion = nn.Sequential(
|
|
nn.Linear(768 * 4, 1024), # Combine all attention outputs
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(1024, 768)
|
|
)
|
|
|
|
# MASSIVE dueling architecture with deeper networks
|
|
self.advantage_stream = nn.Sequential(
|
|
nn.Linear(768, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, self.n_actions)
|
|
)
|
|
|
|
self.value_stream = nn.Sequential(
|
|
nn.Linear(768, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1)
|
|
)
|
|
|
|
# MASSIVE extrema detection head with ensemble predictions
|
|
self.extrema_head = nn.Sequential(
|
|
nn.Linear(768, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
|
)
|
|
|
|
# MASSIVE multi-timeframe price prediction heads
|
|
self.price_pred_immediate = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 3) # Up, Down, Sideways
|
|
)
|
|
|
|
self.price_pred_midterm = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 3) # Up, Down, Sideways
|
|
)
|
|
|
|
self.price_pred_longterm = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 3) # Up, Down, Sideways
|
|
)
|
|
|
|
# MASSIVE value prediction with ensemble approaches
|
|
self.price_pred_value = nn.Sequential(
|
|
nn.Linear(768, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 8) # More granular % change predictions for different timeframes
|
|
)
|
|
|
|
# Additional specialized prediction heads for better accuracy
|
|
# Volatility prediction head
|
|
self.volatility_head = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 5) # Very low, low, medium, high, very high volatility
|
|
)
|
|
|
|
# Support/Resistance level detection head
|
|
self.support_resistance_head = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 6) # Strong support, weak support, neutral, weak resistance, strong resistance, breakout
|
|
)
|
|
|
|
# Market regime classification head
|
|
self.market_regime_head = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 7) # Bull trend, bear trend, sideways, volatile up, volatile down, accumulation, distribution
|
|
)
|
|
|
|
# Risk assessment head
|
|
self.risk_head = nn.Sequential(
|
|
nn.Linear(768, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
|
)
|
|
|
|
def _check_rebuild_network(self, features):
|
|
"""Check if network needs to be rebuilt for different feature dimensions"""
|
|
if features != self.feature_dim:
|
|
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
|
|
self.feature_dim = features
|
|
self._build_network()
|
|
# Move to device after rebuilding
|
|
self.to(self.device)
|
|
return True
|
|
return False
|
|
|
|
def forward(self, x):
|
|
"""Forward pass through the MASSIVE network"""
|
|
batch_size = x.size(0)
|
|
|
|
# Process different input shapes
|
|
if len(x.shape) > 2:
|
|
# Handle 3D input [batch, timeframes, features]
|
|
if self.conv_layers is not None:
|
|
# Reshape for 1D convolution:
|
|
# [batch, timeframes, features] -> [batch, timeframes, features*1]
|
|
if len(x.shape) == 3:
|
|
x = x.permute(0, 1, 2) # Ensure shape is [batch, timeframes, features]
|
|
x_reshaped = x.permute(0, 1, 2) # [batch, timeframes, features]
|
|
|
|
# Check if the feature dimension has changed and rebuild if necessary
|
|
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
|
|
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
|
self._check_rebuild_network(total_features)
|
|
|
|
# Apply massive convolutions
|
|
x_conv = self.conv_layers(x_reshaped)
|
|
# Flatten: [batch, channels, 1] -> [batch, channels]
|
|
x_flat = x_conv.view(batch_size, -1)
|
|
else:
|
|
# If no conv layers, just flatten
|
|
x_flat = x.view(batch_size, -1)
|
|
else:
|
|
# For 2D input [batch, features]
|
|
x_flat = x
|
|
|
|
# Check if dimensions have changed
|
|
if x_flat.size(1) != self.feature_dim:
|
|
self._check_rebuild_network(x_flat.size(1))
|
|
|
|
# Apply MASSIVE FC layers to get base features
|
|
features = self.fc_layers(x_flat) # [batch, 768]
|
|
|
|
# Apply multiple specialized attention mechanisms
|
|
features_3d = features.unsqueeze(1) # [batch, 1, 768]
|
|
|
|
# Get attention-refined features for different aspects
|
|
price_features, _ = self.price_attention(features_3d)
|
|
price_features = price_features.squeeze(1) # [batch, 768]
|
|
|
|
volume_features, _ = self.volume_attention(features_3d)
|
|
volume_features = volume_features.squeeze(1) # [batch, 768]
|
|
|
|
trend_features, _ = self.trend_attention(features_3d)
|
|
trend_features = trend_features.squeeze(1) # [batch, 768]
|
|
|
|
volatility_features, _ = self.volatility_attention(features_3d)
|
|
volatility_features = volatility_features.squeeze(1) # [batch, 768]
|
|
|
|
# Fuse all attention outputs
|
|
combined_attention = torch.cat([
|
|
price_features, volume_features,
|
|
trend_features, volatility_features
|
|
], dim=1) # [batch, 768*4]
|
|
|
|
# Apply attention fusion to get final refined features
|
|
features_refined = self.attention_fusion(combined_attention) # [batch, 768]
|
|
|
|
# Calculate advantage and value (Dueling DQN architecture)
|
|
advantage = self.advantage_stream(features_refined)
|
|
value = self.value_stream(features_refined)
|
|
|
|
# Combine for Q-values (Dueling architecture)
|
|
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
|
|
|
# Get massive ensemble of predictions
|
|
|
|
# Extrema predictions (bottom/top/neither detection)
|
|
extrema_pred = self.extrema_head(features_refined)
|
|
|
|
# Multi-timeframe price movement predictions
|
|
price_immediate = self.price_pred_immediate(features_refined)
|
|
price_midterm = self.price_pred_midterm(features_refined)
|
|
price_longterm = self.price_pred_longterm(features_refined)
|
|
price_values = self.price_pred_value(features_refined)
|
|
|
|
# Additional specialized predictions for enhanced accuracy
|
|
volatility_pred = self.volatility_head(features_refined)
|
|
support_resistance_pred = self.support_resistance_head(features_refined)
|
|
market_regime_pred = self.market_regime_head(features_refined)
|
|
risk_pred = self.risk_head(features_refined)
|
|
|
|
# Package all price predictions
|
|
price_predictions = {
|
|
'immediate': price_immediate,
|
|
'midterm': price_midterm,
|
|
'longterm': price_longterm,
|
|
'values': price_values
|
|
}
|
|
|
|
# Package additional predictions for enhanced decision making
|
|
advanced_predictions = {
|
|
'volatility': volatility_pred,
|
|
'support_resistance': support_resistance_pred,
|
|
'market_regime': market_regime_pred,
|
|
'risk_assessment': risk_pred
|
|
}
|
|
|
|
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
|
|
|
def act(self, state, explore=True):
|
|
"""Enhanced action selection with massive model predictions"""
|
|
if explore and np.random.random() < 0.1: # 10% random exploration
|
|
return np.random.choice(self.n_actions)
|
|
|
|
self.eval()
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
|
|
|
# Apply softmax to get action probabilities
|
|
action_probs = torch.softmax(q_values, dim=1)
|
|
action = torch.argmax(action_probs, dim=1).item()
|
|
|
|
# Log advanced predictions for better decision making
|
|
if hasattr(self, '_log_predictions') and self._log_predictions:
|
|
# Log volatility prediction
|
|
volatility = torch.softmax(advanced_predictions['volatility'], dim=1)
|
|
volatility_class = torch.argmax(volatility, dim=1).item()
|
|
volatility_labels = ['Very Low', 'Low', 'Medium', 'High', 'Very High']
|
|
|
|
# Log support/resistance prediction
|
|
sr = torch.softmax(advanced_predictions['support_resistance'], dim=1)
|
|
sr_class = torch.argmax(sr, dim=1).item()
|
|
sr_labels = ['Strong Support', 'Weak Support', 'Neutral', 'Weak Resistance', 'Strong Resistance', 'Breakout']
|
|
|
|
# Log market regime prediction
|
|
regime = torch.softmax(advanced_predictions['market_regime'], dim=1)
|
|
regime_class = torch.argmax(regime, dim=1).item()
|
|
regime_labels = ['Bull Trend', 'Bear Trend', 'Sideways', 'Volatile Up', 'Volatile Down', 'Accumulation', 'Distribution']
|
|
|
|
# Log risk assessment
|
|
risk = torch.softmax(advanced_predictions['risk_assessment'], dim=1)
|
|
risk_class = torch.argmax(risk, dim=1).item()
|
|
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
|
|
|
|
logger.info(f"MASSIVE Model Predictions:")
|
|
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
|
|
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
|
|
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")
|
|
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[0, risk_class]:.3f})")
|
|
|
|
return action
|
|
|
|
def save(self, path):
|
|
"""Save model weights and architecture"""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
torch.save({
|
|
'state_dict': self.state_dict(),
|
|
'input_shape': self.input_shape,
|
|
'n_actions': self.n_actions,
|
|
'feature_dim': self.feature_dim,
|
|
'confidence_threshold': self.confidence_threshold
|
|
}, f"{path}.pt")
|
|
logger.info(f"Enhanced CNN model saved to {path}.pt")
|
|
|
|
def load(self, path):
|
|
"""Load model weights and architecture"""
|
|
try:
|
|
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
|
|
self.input_shape = checkpoint['input_shape']
|
|
self.n_actions = checkpoint['n_actions']
|
|
self.feature_dim = checkpoint['feature_dim']
|
|
if 'confidence_threshold' in checkpoint:
|
|
self.confidence_threshold = checkpoint['confidence_threshold']
|
|
self._build_network()
|
|
self.load_state_dict(checkpoint['state_dict'])
|
|
self.to(self.device)
|
|
logger.info(f"Enhanced CNN model loaded from {path}.pt")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {str(e)}")
|
|
return False
|
|
|
|
# Additional utility for example sifting
|
|
class ExampleSiftingDataset:
|
|
"""
|
|
Dataset that selectively keeps high-quality examples for training
|
|
to improve model performance
|
|
"""
|
|
def __init__(self, max_examples=50000):
|
|
self.examples = []
|
|
self.labels = []
|
|
self.rewards = []
|
|
self.max_examples = max_examples
|
|
self.min_reward_threshold = -0.05 # Minimum reward to keep an example
|
|
|
|
def add_example(self, state, action, reward, next_state, done):
|
|
"""Add a new training example with reward-based filtering"""
|
|
# Only keep examples with rewards above the threshold
|
|
if reward > self.min_reward_threshold:
|
|
self.examples.append((state, action, reward, next_state, done))
|
|
self.rewards.append(reward)
|
|
|
|
# Sort by reward and keep only the top examples
|
|
if len(self.examples) > self.max_examples:
|
|
# Sort by reward (highest first)
|
|
sorted_indices = np.argsort(self.rewards)[::-1]
|
|
# Keep top examples
|
|
self.examples = [self.examples[i] for i in sorted_indices[:self.max_examples]]
|
|
self.rewards = [self.rewards[i] for i in sorted_indices[:self.max_examples]]
|
|
|
|
# Update the minimum reward threshold to be the minimum in our kept examples
|
|
self.min_reward_threshold = min(self.rewards)
|
|
|
|
def get_batch(self, batch_size):
|
|
"""Get a batch of examples, prioritizing better examples"""
|
|
if not self.examples:
|
|
return None
|
|
|
|
# Calculate selection probabilities based on rewards
|
|
rewards = np.array(self.rewards)
|
|
# Shift rewards to be positive for probability calculation
|
|
min_reward = min(rewards)
|
|
shifted_rewards = rewards - min_reward + 0.1 # Add small constant
|
|
probs = shifted_rewards / shifted_rewards.sum()
|
|
|
|
# Sample batch indices with reward-based probabilities
|
|
indices = np.random.choice(
|
|
len(self.examples),
|
|
size=min(batch_size, len(self.examples)),
|
|
p=probs,
|
|
replace=False
|
|
)
|
|
|
|
# Create batch
|
|
batch = [self.examples[i] for i in indices]
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
return {
|
|
'states': np.array(states),
|
|
'actions': np.array(actions),
|
|
'rewards': np.array(rewards),
|
|
'next_states': np.array(next_states),
|
|
'dones': np.array(dones)
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.examples) |