new training process and changes to the models (wip)
This commit is contained in:
parent
a78906a888
commit
902593b5f3
@ -28,14 +28,14 @@ class DQNAgent:
|
||||
window_size: int,
|
||||
num_features: int,
|
||||
timeframes: List[str],
|
||||
learning_rate: float = 0.001,
|
||||
gamma: float = 0.99,
|
||||
learning_rate: float = 0.0005, # Reduced learning rate for more stability
|
||||
gamma: float = 0.97, # Slightly reduced discount factor
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
epsilon_decay: float = 0.995,
|
||||
memory_size: int = 10000,
|
||||
batch_size: int = 64,
|
||||
target_update: int = 10):
|
||||
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
|
||||
epsilon_decay: float = 0.9975, # Slower decay rate
|
||||
memory_size: int = 20000, # Increased memory size
|
||||
batch_size: int = 128, # Larger batch size
|
||||
target_update: int = 5): # More frequent target updates
|
||||
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
@ -70,23 +70,25 @@ class DQNAgent:
|
||||
).to(self.device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Initialize optimizer
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
||||
# Initialize optimizer with gradient clipping
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate, weight_decay=1e-5)
|
||||
|
||||
# Initialize memory
|
||||
# Initialize memories with different priorities
|
||||
self.memory = deque(maxlen=memory_size)
|
||||
|
||||
# Special memory for extrema samples to use for targeted learning
|
||||
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
|
||||
self.extrema_memory = deque(maxlen=memory_size // 4) # For extrema points
|
||||
self.positive_memory = deque(maxlen=memory_size // 4) # For positive rewards
|
||||
|
||||
# Training metrics
|
||||
self.update_count = 0
|
||||
self.losses = []
|
||||
self.avg_reward = 0
|
||||
self.no_improvement_count = 0
|
||||
self.best_reward = float('-inf')
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
||||
"""
|
||||
Store experience in memory
|
||||
Store experience in memory with prioritization
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
@ -97,28 +99,124 @@ class DQNAgent:
|
||||
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
||||
"""
|
||||
experience = (state, action, reward, next_state, done)
|
||||
|
||||
# Always add to main memory
|
||||
self.memory.append(experience)
|
||||
|
||||
# If this is an extrema sample, also add to specialized memory
|
||||
# Add to specialized memories if applicable
|
||||
if is_extrema:
|
||||
self.extrema_memory.append(experience)
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
# Store positive experiences separately for prioritized replay
|
||||
if reward > 0:
|
||||
self.positive_memory.append(experience)
|
||||
|
||||
def act(self, state: np.ndarray, explore=True) -> int:
|
||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
||||
if explore and random.random() < self.epsilon:
|
||||
return random.randrange(self.action_size)
|
||||
|
||||
with torch.no_grad():
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
action_probs, extrema_pred = self.policy_net(state)
|
||||
# Ensure state is normalized before inference
|
||||
state_tensor = self._normalize_state(state)
|
||||
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
||||
action_probs, extrema_pred = self.policy_net(state_tensor)
|
||||
return action_probs.argmax().item()
|
||||
|
||||
def replay(self, use_extrema=False) -> float:
|
||||
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Normalize the state data to prevent numerical issues"""
|
||||
# Handle NaN and infinite values
|
||||
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Check if state is 1D array (happens in some environments)
|
||||
if len(state.shape) == 1:
|
||||
# If 1D, we need to normalize the whole array
|
||||
normalized_state = state.copy()
|
||||
|
||||
# Convert any timestamp or non-numeric data to float
|
||||
for i in range(len(normalized_state)):
|
||||
# Check for timestamp-like objects
|
||||
if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
|
||||
# Convert timestamp to float (seconds since epoch)
|
||||
normalized_state[i] = float(normalized_state[i].timestamp())
|
||||
elif not isinstance(normalized_state[i], (int, float, np.number)):
|
||||
# Set non-numeric data to 0
|
||||
normalized_state[i] = 0.0
|
||||
|
||||
# Ensure all values are float
|
||||
normalized_state = normalized_state.astype(np.float32)
|
||||
|
||||
# Simple min-max normalization for 1D state
|
||||
state_min = np.min(normalized_state)
|
||||
state_max = np.max(normalized_state)
|
||||
if state_max > state_min:
|
||||
normalized_state = (normalized_state - state_min) / (state_max - state_min)
|
||||
return normalized_state
|
||||
|
||||
# Handle 2D arrays
|
||||
normalized_state = np.zeros_like(state, dtype=np.float32)
|
||||
|
||||
# Convert any timestamp or non-numeric data to float
|
||||
for i in range(state.shape[0]):
|
||||
for j in range(state.shape[1]):
|
||||
if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
|
||||
# Convert timestamp to float (seconds since epoch)
|
||||
normalized_state[i, j] = float(state[i, j].timestamp())
|
||||
elif isinstance(state[i, j], (int, float, np.number)):
|
||||
normalized_state[i, j] = state[i, j]
|
||||
else:
|
||||
# Set non-numeric data to 0
|
||||
normalized_state[i, j] = 0.0
|
||||
|
||||
# Loop through each timeframe's features in the combined state
|
||||
feature_count = state.shape[1] // len(self.timeframes)
|
||||
|
||||
for tf_idx in range(len(self.timeframes)):
|
||||
start_idx = tf_idx * feature_count
|
||||
end_idx = start_idx + feature_count
|
||||
|
||||
# Extract this timeframe's features
|
||||
tf_features = normalized_state[:, start_idx:end_idx]
|
||||
|
||||
# Normalize OHLCV data by the first close price in the window
|
||||
# This makes price movements relative rather than absolute
|
||||
price_idx = 3 # Assuming close price is at index 3
|
||||
if price_idx < tf_features.shape[1]:
|
||||
reference_price = np.mean(tf_features[:, price_idx])
|
||||
if reference_price != 0:
|
||||
# Normalize price-related columns (OHLC)
|
||||
for i in range(4): # First 4 columns are OHLC
|
||||
if i < tf_features.shape[1]:
|
||||
normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
|
||||
|
||||
# Normalize volume using mean and std
|
||||
vol_idx = 4 # Assuming volume is at index 4
|
||||
if vol_idx < tf_features.shape[1]:
|
||||
vol_mean = np.mean(tf_features[:, vol_idx])
|
||||
vol_std = np.std(tf_features[:, vol_idx])
|
||||
if vol_std > 0:
|
||||
normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
|
||||
else:
|
||||
normalized_state[:, start_idx + vol_idx] = 0
|
||||
|
||||
# Other features (technical indicators) - normalize with min-max scaling
|
||||
for i in range(5, feature_count):
|
||||
if i < tf_features.shape[1]:
|
||||
feature_min = np.min(tf_features[:, i])
|
||||
feature_max = np.max(tf_features[:, i])
|
||||
if feature_max > feature_min:
|
||||
normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
|
||||
else:
|
||||
normalized_state[:, start_idx + i] = 0
|
||||
|
||||
return normalized_state
|
||||
|
||||
def replay(self, use_prioritized=True) -> float:
|
||||
"""
|
||||
Train on a batch of experiences
|
||||
Train on a batch of experiences with prioritized sampling
|
||||
|
||||
Args:
|
||||
use_extrema: Whether to include extrema samples in training
|
||||
use_prioritized: Whether to use prioritized replay
|
||||
|
||||
Returns:
|
||||
float: Loss value
|
||||
@ -126,55 +224,67 @@ class DQNAgent:
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample batch - mix regular and extrema samples
|
||||
# Sample batch with prioritization
|
||||
batch = []
|
||||
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
|
||||
# Get some extrema samples
|
||||
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
|
||||
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
|
||||
|
||||
# Get regular samples for the rest
|
||||
regular_count = self.batch_size - extrema_count
|
||||
if use_prioritized and len(self.positive_memory) > 0 and len(self.extrema_memory) > 0:
|
||||
# Prioritized sampling from different memory types
|
||||
positive_count = min(self.batch_size // 4, len(self.positive_memory))
|
||||
extrema_count = min(self.batch_size // 4, len(self.extrema_memory))
|
||||
regular_count = self.batch_size - positive_count - extrema_count
|
||||
|
||||
positive_samples = random.sample(list(self.positive_memory), positive_count)
|
||||
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
|
||||
regular_samples = random.sample(list(self.memory), regular_count)
|
||||
|
||||
# Combine samples
|
||||
batch = extrema_samples + regular_samples
|
||||
batch = positive_samples + extrema_samples + regular_samples
|
||||
else:
|
||||
# Standard sampling
|
||||
batch = random.sample(self.memory, self.batch_size)
|
||||
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
|
||||
# Normalize states before training
|
||||
normalized_states = np.array([self._normalize_state(state) for state in states])
|
||||
normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
|
||||
|
||||
# Convert to tensors and move to device
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(actions).to(self.device)
|
||||
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
states_tensor = torch.FloatTensor(normalized_states).to(self.device)
|
||||
actions_tensor = torch.LongTensor(actions).to(self.device)
|
||||
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
|
||||
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, extrema_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
current_q_values, extrema_pred = self.policy_net(states_tensor)
|
||||
current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
|
||||
|
||||
# Get next Q values from target network
|
||||
# Get next Q values from target network (Double DQN approach)
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
# Get actions from policy network
|
||||
next_actions, _ = self.policy_net(next_states_tensor)
|
||||
next_actions = next_actions.max(1)[1].unsqueeze(1)
|
||||
|
||||
# Compute Q-learning loss
|
||||
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
# Get Q values from target network for those actions
|
||||
next_q_values, _ = self.target_net(next_states_tensor)
|
||||
next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
|
||||
|
||||
# If we have extrema labels (not in this implementation yet),
|
||||
# we could add an additional loss for extrema prediction
|
||||
# This would require labels for whether each state is near an extrema
|
||||
# Compute target Q values
|
||||
target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
|
||||
|
||||
# Total loss is just Q-learning loss for now
|
||||
loss = q_loss
|
||||
# Clamp target values to prevent extreme values
|
||||
target_q_values = torch.clamp(target_q_values, -100, 100)
|
||||
|
||||
# Compute Huber loss (more robust to outliers than MSE)
|
||||
loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Apply gradient clipping to prevent exploding gradients
|
||||
nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
@ -201,36 +311,76 @@ class DQNAgent:
|
||||
if len(states) == 0:
|
||||
return 0.0
|
||||
|
||||
# Normalize states
|
||||
normalized_states = np.array([self._normalize_state(state) for state in states])
|
||||
normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(actions).to(self.device)
|
||||
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
states_tensor = torch.FloatTensor(normalized_states).to(self.device)
|
||||
actions_tensor = torch.LongTensor(actions).to(self.device)
|
||||
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
|
||||
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Forward pass
|
||||
current_q_values, extrema_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
current_q_values, extrema_pred = self.policy_net(states_tensor)
|
||||
current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
|
||||
|
||||
# Get next Q values
|
||||
# Get next Q values (Double DQN approach)
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
next_actions, _ = self.policy_net(next_states_tensor)
|
||||
next_actions = next_actions.max(1)[1].unsqueeze(1)
|
||||
|
||||
# Higher weight for extrema training
|
||||
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
next_q_values, _ = self.target_net(next_states_tensor)
|
||||
next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
|
||||
|
||||
# Full loss is just Q-learning loss
|
||||
target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
|
||||
|
||||
# Clamp target values
|
||||
target_q_values = torch.clamp(target_q_values, -100, 100)
|
||||
|
||||
# Use Huber loss for extrema training
|
||||
q_loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# Full loss
|
||||
loss = q_loss
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
|
||||
"""Update learning metrics and perform learning rate adjustments if needed"""
|
||||
# Update average reward with exponential moving average
|
||||
if self.avg_reward == 0:
|
||||
self.avg_reward = episode_reward
|
||||
else:
|
||||
self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
|
||||
|
||||
# Check if we're making sufficient progress
|
||||
if episode_reward > (1 + best_reward_threshold) * self.best_reward:
|
||||
self.best_reward = episode_reward
|
||||
self.no_improvement_count = 0
|
||||
return True # Improved
|
||||
else:
|
||||
self.no_improvement_count += 1
|
||||
|
||||
# If no improvement for a while, adjust learning rate
|
||||
if self.no_improvement_count >= 10:
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
new_lr = current_lr * 0.5
|
||||
if new_lr >= 1e-6: # Don't reduce below minimum threshold
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
|
||||
self.no_improvement_count = 0
|
||||
|
||||
return False # No improvement
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
@ -246,9 +396,13 @@ class DQNAgent:
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict()
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward
|
||||
}
|
||||
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model and agent state"""
|
||||
@ -259,8 +413,19 @@ class DQNAgent:
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
state = torch.load(f"{path}_agent_state.pt")
|
||||
self.epsilon = state['epsilon']
|
||||
self.update_count = state['update_count']
|
||||
self.losses = state['losses']
|
||||
self.optimizer.load_state_dict(state['optimizer_state'])
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
@ -44,13 +44,46 @@ class PricePatternAttention(nn.Module):
|
||||
|
||||
return output, attn_weights
|
||||
|
||||
class AdaptiveNorm(nn.Module):
|
||||
"""
|
||||
Adaptive normalization layer that chooses between different normalization
|
||||
methods based on input dimensions
|
||||
"""
|
||||
def __init__(self, num_features):
|
||||
super(AdaptiveNorm, self).__init__()
|
||||
self.batch_norm = nn.BatchNorm1d(num_features, affine=True)
|
||||
self.group_norm = nn.GroupNorm(min(32, num_features), num_features)
|
||||
self.layer_norm = nn.LayerNorm([num_features, 1])
|
||||
|
||||
def forward(self, x):
|
||||
# Check input dimensions
|
||||
batch_size, channels, seq_len = x.size()
|
||||
|
||||
# Choose normalization method:
|
||||
# - Batch size > 1 and seq_len > 1: BatchNorm
|
||||
# - Batch size == 1 or seq_len == 1: GroupNorm
|
||||
# - Fallback for extreme cases: LayerNorm
|
||||
if batch_size > 1 and seq_len > 1:
|
||||
return self.batch_norm(x)
|
||||
elif seq_len > 1:
|
||||
return self.group_norm(x)
|
||||
else:
|
||||
# For 1D inputs (seq_len=1), we need to adjust the layer norm
|
||||
# to the actual input size
|
||||
if not hasattr(self, 'layer_norm_1d') or self.layer_norm_1d.normalized_shape[0] != channels:
|
||||
self.layer_norm_1d = nn.LayerNorm([channels, seq_len]).to(x.device)
|
||||
return self.layer_norm_1d(x)
|
||||
|
||||
class CNNModelPyTorch(nn.Module):
|
||||
"""
|
||||
CNN model for trading with multiple timeframes
|
||||
"""
|
||||
def __init__(self, window_size, num_features, output_size, timeframes):
|
||||
def __init__(self, window_size=20, num_features=5, output_size=3, timeframes=None):
|
||||
super(CNNModelPyTorch, self).__init__()
|
||||
|
||||
if timeframes is None:
|
||||
timeframes = [1]
|
||||
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.output_size = output_size
|
||||
@ -73,27 +106,28 @@ class CNNModelPyTorch(nn.Module):
|
||||
"""Create all model layers with current feature dimensions"""
|
||||
# Convolutional layers - use total_features as input channels
|
||||
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm1d(64)
|
||||
self.norm1 = AdaptiveNorm(64)
|
||||
self.dropout1 = nn.Dropout(0.2)
|
||||
|
||||
self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm1d(128)
|
||||
self.norm2 = AdaptiveNorm(128)
|
||||
self.dropout2 = nn.Dropout(0.3)
|
||||
|
||||
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm1d(256)
|
||||
self.norm3 = AdaptiveNorm(256)
|
||||
self.dropout3 = nn.Dropout(0.4)
|
||||
|
||||
# Add price pattern attention layer
|
||||
self.attention = PricePatternAttention(256)
|
||||
|
||||
# Extrema detection specialized convolutional layer
|
||||
self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2)
|
||||
self.extrema_bn = nn.BatchNorm1d(128)
|
||||
self.extrema_conv = nn.Conv1d(256, 128, kernel_size=3, padding=1) # Smaller kernel for small inputs
|
||||
self.extrema_norm = AdaptiveNorm(128)
|
||||
|
||||
# Calculate size after convolutions - adjusted for attention output
|
||||
conv_output_size = self.window_size * 256
|
||||
|
||||
# Fully connected layers
|
||||
self.fc1 = nn.Linear(conv_output_size, 512)
|
||||
# Fully connected layers - input size will be determined dynamically
|
||||
self.fc1 = None # Will be initialized in forward pass
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.dropout_fc = nn.Dropout(0.5)
|
||||
|
||||
# Advantage and Value streams (Dueling DQN architecture)
|
||||
self.fc3 = nn.Linear(256, self.output_size) # Advantage stream
|
||||
@ -131,46 +165,96 @@ class CNNModelPyTorch(nn.Module):
|
||||
# Ensure input is on the correct device
|
||||
x = x.to(self.device)
|
||||
|
||||
# Check and handle if input dimensions don't match model expectations
|
||||
batch_size, window_len, feature_dim = x.size()
|
||||
# Check input dimensions and reshape as needed
|
||||
if len(x.size()) == 2:
|
||||
# If input is [batch_size, features], reshape to [batch_size, features, 1]
|
||||
batch_size, feature_dim = x.size()
|
||||
|
||||
# Check and handle if input features don't match model expectations
|
||||
if feature_dim != self.total_features:
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
|
||||
self.rebuild_conv_layers(feature_dim)
|
||||
|
||||
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
|
||||
x = x.permute(0, 2, 1)
|
||||
# For 1D input, use a sequence length of 1
|
||||
seq_len = 1
|
||||
x = x.unsqueeze(2) # Reshape to [batch, features, 1]
|
||||
elif len(x.size()) == 3:
|
||||
# Standard case: [batch_size, window_size, features]
|
||||
batch_size, seq_len, feature_dim = x.size()
|
||||
|
||||
# Convolutional layers
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
# Check and handle if input dimensions don't match model expectations
|
||||
if feature_dim != self.total_features:
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
|
||||
self.rebuild_conv_layers(feature_dim)
|
||||
|
||||
# Reshape input: [batch, window_size, features] -> [batch, features, window_size]
|
||||
x = x.permute(0, 2, 1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input shape: {x.size()}, expected 2D or 3D tensor")
|
||||
|
||||
# Convolutional layers with dropout - safely handle small spatial dimensions
|
||||
try:
|
||||
x = self.dropout1(F.relu(self.norm1(self.conv1(x))))
|
||||
x = self.dropout2(F.relu(self.norm2(self.conv2(x))))
|
||||
x = self.dropout3(F.relu(self.norm3(self.conv3(x))))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in convolutional layers: {str(e)}")
|
||||
# Fallback for very small inputs: skip some convolutions
|
||||
if seq_len < 3:
|
||||
# Apply a simpler convolution for very small inputs
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
# Skip last conv if we get dimension errors
|
||||
try:
|
||||
x = F.relu(self.conv3(x))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Store conv features for extrema detection
|
||||
conv_features = x
|
||||
|
||||
# Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels]
|
||||
# Get the current shape after convolutions
|
||||
_, channels, conv_seq_len = x.size()
|
||||
|
||||
# Initialize fc1 if not created yet or if the shape has changed
|
||||
if self.fc1 is None:
|
||||
flattened_size = channels * conv_seq_len
|
||||
logger.info(f"Initializing fc1 with input size {flattened_size}")
|
||||
self.fc1 = nn.Linear(flattened_size, 512).to(self.device)
|
||||
|
||||
# Apply extrema detection safely
|
||||
try:
|
||||
extrema_features = F.relu(self.extrema_norm(self.extrema_conv(conv_features)))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in extrema detection: {str(e)}")
|
||||
extrema_features = conv_features # Fallback
|
||||
|
||||
# Handle attention for small sequence lengths
|
||||
if conv_seq_len > 1:
|
||||
# Reshape for attention: [batch, channels, seq_len] -> [batch, seq_len, channels]
|
||||
x_attention = x.permute(0, 2, 1)
|
||||
|
||||
# Apply attention
|
||||
try:
|
||||
attention_output, attention_weights = self.attention(x_attention)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in attention layer: {str(e)}")
|
||||
# Fallback: don't use attention
|
||||
|
||||
# We'll use attention directly without the residual connection
|
||||
# to avoid dimension mismatch issues
|
||||
attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size]
|
||||
# Flatten - get the actual shape for this batch
|
||||
flattened_size = channels * conv_seq_len
|
||||
x = x.view(batch_size, flattened_size)
|
||||
|
||||
# Apply extrema detection specialized layer
|
||||
extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features)))
|
||||
# Check if we need to recreate fc1 with the correct size
|
||||
if self.fc1.in_features != flattened_size:
|
||||
logger.info(f"Recreating fc1 layer to match input size {flattened_size}")
|
||||
self.fc1 = nn.Linear(flattened_size, 512).to(self.device)
|
||||
# Reinitialize optimizer after changing the model
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
||||
|
||||
# Use attention features directly instead of residual connection
|
||||
# to avoid dimension mismatches
|
||||
x = conv_features # Just use the convolutional features
|
||||
|
||||
# Flatten
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
# Fully connected layers
|
||||
# Fully connected layers with dropout
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.dropout_fc(F.relu(self.fc2(x)))
|
||||
|
||||
# Split into advantage and value streams
|
||||
advantage = self.fc3(x)
|
||||
|
@ -3,6 +3,10 @@ import gym
|
||||
from gym import spaces
|
||||
from typing import Dict, Tuple, List
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
@ -12,97 +16,284 @@ class TradingEnvironment(gym.Env):
|
||||
data: pd.DataFrame,
|
||||
initial_balance: float = 100.0,
|
||||
fee_rate: float = 0.0002,
|
||||
max_steps: int = 1000):
|
||||
max_steps: int = 1000,
|
||||
window_size: int = 20,
|
||||
risk_aversion: float = 0.2, # Controls how much to penalize volatility
|
||||
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
|
||||
reward_scaling: float = 10.0, # Scale factor for rewards
|
||||
episode_penalty: float = 0.1): # Penalty for active positions at end of episode
|
||||
super(TradingEnvironment, self).__init__()
|
||||
|
||||
self.data = data
|
||||
self.initial_balance = initial_balance
|
||||
self.fee_rate = fee_rate
|
||||
self.max_steps = max_steps
|
||||
self.window_size = window_size
|
||||
self.risk_aversion = risk_aversion
|
||||
self.price_scaling = price_scaling
|
||||
self.reward_scaling = reward_scaling
|
||||
self.episode_penalty = episode_penalty
|
||||
|
||||
# Action space: 0 (SELL), 1 (HOLD), 2 (BUY)
|
||||
# Preprocess data if needed
|
||||
self._preprocess_data()
|
||||
|
||||
# Action space: 0 (BUY), 1 (SELL), 2 (HOLD)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
|
||||
# Observation space: price data, technical indicators, and account state
|
||||
feature_dim = self.data.shape[1] + 3 # Adding position, equity, unrealized_pnl
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf,
|
||||
high=np.inf,
|
||||
shape=(data.shape[1],), # Number of features
|
||||
shape=(feature_dim,),
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def _preprocess_data(self):
|
||||
"""Preprocess data - normalize or standardize features"""
|
||||
# Store the original data for reference
|
||||
self.original_data = self.data.copy()
|
||||
|
||||
# Normalize price data based on the selected method
|
||||
if self.price_scaling == 'zscore':
|
||||
# For each feature, apply z-score normalization
|
||||
for col in self.data.columns:
|
||||
if col in ['open', 'high', 'low', 'close']:
|
||||
mean = self.data[col].mean()
|
||||
std = self.data[col].std()
|
||||
if std > 0:
|
||||
self.data[col] = (self.data[col] - mean) / std
|
||||
# Normalize volume separately
|
||||
elif col == 'volume':
|
||||
mean = self.data[col].mean()
|
||||
std = self.data[col].std()
|
||||
if std > 0:
|
||||
self.data[col] = (self.data[col] - mean) / std
|
||||
|
||||
elif self.price_scaling == 'minmax':
|
||||
# For each feature, apply min-max scaling
|
||||
for col in self.data.columns:
|
||||
min_val = self.data[col].min()
|
||||
max_val = self.data[col].max()
|
||||
if max_val > min_val:
|
||||
self.data[col] = (self.data[col] - min_val) / (max_val - min_val)
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
"""Reset the environment to initial state"""
|
||||
self.current_step = 0
|
||||
self.current_step = self.window_size
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0 # 0: no position, 1: long position
|
||||
self.position = 0 # 0: no position, 1: long position, -1: short position
|
||||
self.entry_price = 0
|
||||
self.entry_time = 0
|
||||
self.total_trades = 0
|
||||
self.winning_trades = 0
|
||||
self.losing_trades = 0
|
||||
self.total_pnl = 0
|
||||
self.balance_history = [self.initial_balance]
|
||||
self.equity_history = [self.initial_balance]
|
||||
self.max_balance = self.initial_balance
|
||||
self.max_drawdown = 0
|
||||
|
||||
# Trading performance metrics
|
||||
self.trade_durations = [] # Track how long trades are held
|
||||
self.returns = [] # Track returns of each trade
|
||||
|
||||
# For analyzing trade clustering
|
||||
self.last_action_time = 0
|
||||
self.actions_taken = []
|
||||
|
||||
return self._get_observation()
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
"""Get current observation state"""
|
||||
return self.data.iloc[self.current_step].values
|
||||
"""Get current observation state with account information"""
|
||||
# Get market data for the current step
|
||||
market_data = self.data.iloc[self.current_step].values
|
||||
|
||||
# Get current price
|
||||
current_price = self.original_data.iloc[self.current_step]['close']
|
||||
|
||||
# Calculate unrealized PnL
|
||||
unrealized_pnl = 0
|
||||
if self.position != 0:
|
||||
price_diff = current_price - self.entry_price
|
||||
unrealized_pnl = self.position * price_diff
|
||||
|
||||
# Calculate total equity (balance + unrealized PnL)
|
||||
equity = self.balance + unrealized_pnl
|
||||
|
||||
# Normalize account state
|
||||
normalized_position = self.position # -1, 0, or 1
|
||||
normalized_equity = equity / self.initial_balance - 1.0 # Percent change from initial
|
||||
normalized_unrealized_pnl = unrealized_pnl / self.initial_balance if self.initial_balance > 0 else 0
|
||||
|
||||
# Combine market data with account state
|
||||
account_state = np.array([normalized_position, normalized_equity, normalized_unrealized_pnl])
|
||||
observation = np.concatenate([market_data, account_state])
|
||||
|
||||
# Handle any NaN values
|
||||
observation = np.nan_to_num(observation, nan=0.0)
|
||||
|
||||
return observation
|
||||
|
||||
def _calculate_reward(self, action: int) -> float:
|
||||
"""Calculate reward based on action and outcome"""
|
||||
current_price = self.data.iloc[self.current_step]['close']
|
||||
"""
|
||||
Calculate reward based on action and outcome with improved risk-adjusted metrics
|
||||
|
||||
# If we have an open position
|
||||
if self.position != 0:
|
||||
# Calculate PnL
|
||||
pnl = self.position * (current_price - self.entry_price) / self.entry_price
|
||||
fees = self.fee_rate * 2 # Entry and exit fees
|
||||
Args:
|
||||
action: The action taken (0=BUY, 1=SELL, 2=HOLD)
|
||||
|
||||
# Close position
|
||||
if (action == 0 and self.position > 0) or (action == 2 and self.position < 0):
|
||||
net_pnl = pnl - fees
|
||||
self.total_pnl += net_pnl
|
||||
self.balance *= (1 + net_pnl)
|
||||
Returns:
|
||||
float: Calculated reward value
|
||||
"""
|
||||
# Get current price and next price
|
||||
current_price = self.original_data.iloc[self.current_step]['close']
|
||||
|
||||
# Default reward is slightly negative to discourage excessive trading
|
||||
reward = -0.0001
|
||||
pnl = 0.0
|
||||
|
||||
# Handle different actions based on current position
|
||||
if self.position == 0: # No position
|
||||
if action == 0: # BUY
|
||||
self.position = 1
|
||||
self.entry_price = current_price
|
||||
self.entry_time = self.current_step
|
||||
reward = -self.fee_rate # Small penalty for trading cost
|
||||
|
||||
elif action == 1: # SELL (start short position)
|
||||
self.position = -1
|
||||
self.entry_price = current_price
|
||||
self.entry_time = self.current_step
|
||||
reward = -self.fee_rate # Small penalty for trading cost
|
||||
|
||||
# else action == 2 (HOLD) - keep the small negative reward
|
||||
|
||||
elif self.position > 0: # Long position
|
||||
if action == 1: # SELL (close long)
|
||||
# Calculate profit/loss
|
||||
price_diff = current_price - self.entry_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
if pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
# Calculate trade duration
|
||||
trade_duration = self.current_step - self.entry_time
|
||||
self.trade_durations.append(trade_duration)
|
||||
|
||||
# Update returns list
|
||||
self.returns.append(pnl)
|
||||
|
||||
# Update balance and reset position
|
||||
self.balance *= (1 + pnl)
|
||||
self.balance_history.append(self.balance)
|
||||
self.max_balance = max(self.max_balance, self.balance)
|
||||
self.total_pnl += pnl
|
||||
|
||||
self.total_trades += 1
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
|
||||
# Reward based on PnL
|
||||
reward = net_pnl * 100 # Scale up for better learning
|
||||
|
||||
# Additional reward for win rate
|
||||
win_rate = self.winning_trades / max(1, self.total_trades)
|
||||
reward += win_rate * 0.1
|
||||
|
||||
# Reset position
|
||||
self.position = 0
|
||||
return reward
|
||||
|
||||
# Hold position
|
||||
return pnl * 0.1 # Small reward for holding profitable positions
|
||||
elif action == 0: # BUY (while already long)
|
||||
# Penalize trying to increase an already active position
|
||||
reward = -0.001
|
||||
|
||||
# No position
|
||||
if action == 1: # HOLD
|
||||
return 0
|
||||
# else action == 2 (HOLD) - calculate unrealized P&L for reward
|
||||
else:
|
||||
price_diff = current_price - self.entry_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Open new position
|
||||
if action in [0, 2]: # SELL or BUY
|
||||
self.position = -1 if action == 0 else 1
|
||||
self.entry_price = current_price
|
||||
return -self.fee_rate # Small penalty for trading
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
|
||||
return 0
|
||||
elif self.position < 0: # Short position
|
||||
if action == 0: # BUY (close short)
|
||||
# Calculate profit/loss
|
||||
price_diff = self.entry_price - current_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
if pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
# Calculate trade duration
|
||||
trade_duration = self.current_step - self.entry_time
|
||||
self.trade_durations.append(trade_duration)
|
||||
|
||||
# Update returns list
|
||||
self.returns.append(pnl)
|
||||
|
||||
# Update balance and reset position
|
||||
self.balance *= (1 + pnl)
|
||||
self.balance_history.append(self.balance)
|
||||
self.max_balance = max(self.max_balance, self.balance)
|
||||
self.total_pnl += pnl
|
||||
|
||||
# Reset position
|
||||
self.position = 0
|
||||
|
||||
elif action == 1: # SELL (while already short)
|
||||
# Penalize trying to increase an already active position
|
||||
reward = -0.001
|
||||
|
||||
# else action == 2 (HOLD) - calculate unrealized P&L for reward
|
||||
else:
|
||||
price_diff = self.entry_price - current_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
|
||||
# Record the action
|
||||
self.actions_taken.append(action)
|
||||
self.last_action_time = self.current_step
|
||||
|
||||
# Update equity history (balance + unrealized P&L)
|
||||
current_equity = self.balance
|
||||
if self.position != 0:
|
||||
# Calculate unrealized P&L
|
||||
if self.position > 0: # Long
|
||||
price_diff = current_price - self.entry_price
|
||||
unrealized_pnl = price_diff / self.entry_price * self.balance
|
||||
else: # Short
|
||||
price_diff = self.entry_price - current_price
|
||||
unrealized_pnl = price_diff / self.entry_price * self.balance
|
||||
|
||||
current_equity = self.balance + unrealized_pnl
|
||||
|
||||
self.equity_history.append(current_equity)
|
||||
|
||||
# Calculate current drawdown
|
||||
peak_equity = max(self.equity_history)
|
||||
current_drawdown = (peak_equity - current_equity) / peak_equity if peak_equity > 0 else 0
|
||||
self.max_drawdown = max(self.max_drawdown, current_drawdown)
|
||||
|
||||
# Apply risk aversion factor - penalize volatility
|
||||
if len(self.returns) > 1:
|
||||
returns_std = np.std(self.returns)
|
||||
reward -= returns_std * self.risk_aversion
|
||||
|
||||
return reward, pnl
|
||||
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||
"""Execute one step in the environment"""
|
||||
# Calculate reward
|
||||
reward = self._calculate_reward(action)
|
||||
# Calculate reward and update state
|
||||
reward, pnl = self._calculate_reward(action)
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
@ -110,26 +301,70 @@ class TradingEnvironment(gym.Env):
|
||||
# Check if episode is done
|
||||
done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1)
|
||||
|
||||
# Apply penalty if episode ends with open position
|
||||
if done and self.position != 0:
|
||||
reward -= self.episode_penalty
|
||||
|
||||
# Force close the position at the end if still open
|
||||
current_price = self.original_data.iloc[self.current_step]['close']
|
||||
if self.position > 0: # Long position
|
||||
price_diff = current_price - self.entry_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate
|
||||
else: # Short position
|
||||
price_diff = self.entry_price - current_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate
|
||||
|
||||
# Update balance
|
||||
self.balance *= (1 + pnl)
|
||||
self.total_pnl += pnl
|
||||
|
||||
# Track trade
|
||||
self.total_trades += 1
|
||||
if pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
# Reset position
|
||||
self.position = 0
|
||||
|
||||
# Get next observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Calculate max drawdown
|
||||
max_drawdown = 0
|
||||
if len(self.balance_history) > 1:
|
||||
peak = self.balance_history[0]
|
||||
for balance in self.balance_history:
|
||||
peak = max(peak, balance)
|
||||
drawdown = (peak - balance) / peak
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
# Calculate sharpe ratio and sortino ratio if possible
|
||||
sharpe_ratio = 0
|
||||
sortino_ratio = 0
|
||||
win_rate = self.winning_trades / max(1, self.total_trades)
|
||||
|
||||
if len(self.returns) > 1:
|
||||
mean_return = np.mean(self.returns)
|
||||
std_return = np.std(self.returns)
|
||||
if std_return > 0:
|
||||
sharpe_ratio = mean_return / std_return
|
||||
|
||||
# For sortino, we only consider downside deviation
|
||||
downside_returns = [r for r in self.returns if r < 0]
|
||||
if downside_returns:
|
||||
downside_deviation = np.std(downside_returns)
|
||||
if downside_deviation > 0:
|
||||
sortino_ratio = mean_return / downside_deviation
|
||||
|
||||
# Calculate average trade duration
|
||||
avg_trade_duration = np.mean(self.trade_durations) if self.trade_durations else 0
|
||||
|
||||
# Additional info
|
||||
info = {
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'total_trades': self.total_trades,
|
||||
'win_rate': self.winning_trades / max(1, self.total_trades),
|
||||
'win_rate': win_rate,
|
||||
'total_pnl': self.total_pnl,
|
||||
'max_drawdown': max_drawdown
|
||||
'max_drawdown': self.max_drawdown,
|
||||
'sharpe_ratio': sharpe_ratio,
|
||||
'sortino_ratio': sortino_ratio,
|
||||
'avg_trade_duration': avg_trade_duration,
|
||||
'pnl': pnl,
|
||||
'gain': (self.balance - self.initial_balance) / self.initial_balance
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
@ -143,20 +378,19 @@ class TradingEnvironment(gym.Env):
|
||||
print(f"Total Trades: {self.total_trades}")
|
||||
print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}")
|
||||
print(f"Total PnL: ${self.total_pnl:.2f}")
|
||||
print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}")
|
||||
print(f"Max Drawdown: {self.max_drawdown:.2%}")
|
||||
print(f"Sharpe Ratio: {self._calculate_sharpe_ratio():.4f}")
|
||||
print("-" * 50)
|
||||
|
||||
def _calculate_max_drawdown(self):
|
||||
"""Calculate maximum drawdown from balance history"""
|
||||
if len(self.balance_history) <= 1:
|
||||
def _calculate_sharpe_ratio(self):
|
||||
"""Calculate Sharpe ratio from returns"""
|
||||
if len(self.returns) < 2:
|
||||
return 0.0
|
||||
|
||||
peak = self.balance_history[0]
|
||||
max_drawdown = 0.0
|
||||
mean_return = np.mean(self.returns)
|
||||
std_return = np.std(self.returns)
|
||||
|
||||
for balance in self.balance_history:
|
||||
peak = max(peak, balance)
|
||||
drawdown = (peak - balance) / peak
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
if std_return == 0:
|
||||
return 0.0
|
||||
|
||||
return max_drawdown
|
||||
return mean_return / std_return
|
3255
realtime.py
3255
realtime.py
File diff suppressed because it is too large
Load Diff
3083
realtime_old.py
Normal file
3083
realtime_old.py
Normal file
File diff suppressed because it is too large
Load Diff
547
train_improved_rl.py
Normal file
547
train_improved_rl.py
Normal file
@ -0,0 +1,547 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Improved RL Trading with Enhanced Training and Monitoring
|
||||
|
||||
This script provides an improved version of the RL training process,
|
||||
implementing better normalization, reward structure, and model training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import torch
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
# Add project directory to path if needed
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
# Import our custom modules
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.utils.trading_env import TradingEnvironment
|
||||
from NN.utils.data_interface import DataInterface
|
||||
from realtime import BinanceHistoricalData, RealTimeChart
|
||||
|
||||
# Configure logging
|
||||
log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_filename),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger('improved_rl')
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training')
|
||||
parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train')
|
||||
parser.add_argument('--visualize', action='store_true', help='Visualize trades during training')
|
||||
parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||
args = parser.parse_args()
|
||||
|
||||
def create_training_environment(symbol, window_size=20):
|
||||
"""Create and prepare the training environment with data"""
|
||||
logger.info(f"Setting up training environment for {symbol}")
|
||||
|
||||
# Fetch historical data from multiple timeframes
|
||||
data_interface = DataInterface(symbol)
|
||||
|
||||
# Use Binance data provider for fetching data
|
||||
historical_data = BinanceHistoricalData()
|
||||
|
||||
# Fetch data for each timeframe
|
||||
df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000)
|
||||
df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000)
|
||||
df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500)
|
||||
|
||||
# Ensure all dataframes have index as timestamp type
|
||||
if df_1m is not None and not df_1m.empty:
|
||||
if 'timestamp' in df_1m.columns:
|
||||
df_1m = df_1m.set_index('timestamp')
|
||||
|
||||
if df_5m is not None and not df_5m.empty:
|
||||
if 'timestamp' in df_5m.columns:
|
||||
df_5m = df_5m.set_index('timestamp')
|
||||
|
||||
if df_15m is not None and not df_15m.empty:
|
||||
if 'timestamp' in df_15m.columns:
|
||||
df_15m = df_15m.set_index('timestamp')
|
||||
|
||||
# Preprocess data (add technical indicators)
|
||||
df_1m = preprocess_dataframe(df_1m)
|
||||
df_5m = preprocess_dataframe(df_5m)
|
||||
df_15m = preprocess_dataframe(df_15m)
|
||||
|
||||
# Create environment with all timeframes
|
||||
env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size)
|
||||
|
||||
return env, (df_1m, df_5m, df_15m)
|
||||
|
||||
def preprocess_dataframe(df):
|
||||
"""Add technical indicators and preprocess dataframe"""
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Drop any missing values
|
||||
df = df.dropna()
|
||||
|
||||
# Ensure it has OHLCV columns
|
||||
required_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
|
||||
if missing_columns:
|
||||
logger.warning(f"Missing required columns: {missing_columns}")
|
||||
for col in missing_columns:
|
||||
# Fill with close price for OHLC if missing
|
||||
if col in ['open', 'high', 'low'] and 'close' in df.columns:
|
||||
df[col] = df['close']
|
||||
# Fill with zeros for volume if missing
|
||||
elif col == 'volume':
|
||||
df[col] = 0
|
||||
|
||||
# Add simple technical indicators
|
||||
# 1. Simple Moving Averages
|
||||
df['sma_5'] = df['close'].rolling(window=5).mean()
|
||||
df['sma_10'] = df['close'].rolling(window=10).mean()
|
||||
|
||||
# 2. Relative Strength Index (RSI)
|
||||
delta = df['close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
||||
rs = gain / loss
|
||||
df['rsi'] = 100 - (100 / (1 + rs))
|
||||
|
||||
# 3. Bollinger Bands
|
||||
df['bb_middle'] = df['close'].rolling(window=20).mean()
|
||||
df['bb_std'] = df['close'].rolling(window=20).std()
|
||||
df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std']
|
||||
df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std']
|
||||
|
||||
# 4. MACD
|
||||
df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean()
|
||||
df['macd'] = df['ema_12'] - df['ema_26']
|
||||
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
||||
|
||||
# 5. Price rate of change
|
||||
df['roc'] = df['close'].pct_change(periods=10) * 100
|
||||
|
||||
# Fill any remaining NaN values with 0
|
||||
df = df.fillna(0)
|
||||
|
||||
return df
|
||||
|
||||
def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20):
|
||||
"""Create a custom environment that handles multiple timeframes"""
|
||||
|
||||
# Ensure we have complete data for all timeframes
|
||||
min_required_length = window_size + 100 # Add buffer for training
|
||||
|
||||
if (df_1m is None or len(df_1m) < min_required_length or
|
||||
df_5m is None or len(df_5m) < min_required_length or
|
||||
df_15m is None or len(df_15m) < min_required_length):
|
||||
raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.")
|
||||
|
||||
# Ensure we only use the last N valid data points
|
||||
df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy()
|
||||
df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy()
|
||||
df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy()
|
||||
|
||||
# Reset index to make sure we have continuous integers
|
||||
df_1m = df_1m.reset_index(drop=True)
|
||||
df_5m = df_5m.reset_index(drop=True)
|
||||
df_15m = df_15m.reset_index(drop=True)
|
||||
|
||||
# For simplicity, we'll use the 1m data as the base environment
|
||||
# The other timeframes will be incorporated through observation
|
||||
|
||||
env = TradingEnvironment(
|
||||
data=df_1m,
|
||||
initial_balance=100.0,
|
||||
fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges)
|
||||
max_steps=len(df_1m) - window_size - 50, # Leave some room at the end
|
||||
window_size=window_size,
|
||||
risk_aversion=0.2, # Moderately risk-averse
|
||||
price_scaling='zscore', # Use z-score normalization
|
||||
reward_scaling=10.0, # Scale rewards for better learning
|
||||
episode_penalty=0.2 # Penalty for holding positions at end of episode
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def initialize_agent(env, window_size=20, num_features=0, timeframes=None):
|
||||
"""Initialize the DQN agent with appropriate parameters"""
|
||||
if timeframes is None:
|
||||
timeframes = ['1m', '5m', '15m']
|
||||
|
||||
# Calculate input dimensions
|
||||
state_dim = env.observation_space.shape[0]
|
||||
action_dim = env.action_space.n
|
||||
|
||||
# If num_features wasn't provided, infer from environment
|
||||
if num_features == 0:
|
||||
# Calculate features per timeframe from state dimension and number of timeframes
|
||||
# Accounting for the 3 additional features (position, equity, unrealized_pnl)
|
||||
num_features = (state_dim - 3) // len(timeframes)
|
||||
|
||||
logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}")
|
||||
|
||||
agent = DQNAgent(
|
||||
state_size=state_dim,
|
||||
action_size=action_dim,
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
timeframes=timeframes,
|
||||
learning_rate=0.0005, # Start with a moderate learning rate
|
||||
gamma=0.97, # Slightly reduced discount factor for stable learning
|
||||
epsilon=1.0, # Start with full exploration
|
||||
epsilon_min=0.05, # Maintain some exploration even at the end
|
||||
epsilon_decay=0.9975, # Slower decay for more exploration
|
||||
memory_size=20000, # Larger replay buffer
|
||||
batch_size=128, # Larger batch size for more stable gradients
|
||||
target_update=5 # More frequent target network updates
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5):
|
||||
"""
|
||||
Train the DQN agent with improved training loop
|
||||
|
||||
Args:
|
||||
env: The trading environment
|
||||
agent: The DQN agent
|
||||
num_episodes: Number of episodes to train
|
||||
visualize: Whether to visualize trades during training
|
||||
chart: The visualization chart (if visualize=True)
|
||||
save_path: Path to save the model
|
||||
save_freq: How often to save checkpoints (in episodes)
|
||||
|
||||
Returns:
|
||||
tuple: (rewards, wins, losses, best_reward)
|
||||
"""
|
||||
logger.info(f"Starting training for {num_episodes} episodes")
|
||||
|
||||
# Initialize metrics tracking
|
||||
rewards = []
|
||||
win_rates = []
|
||||
total_train_time = 0
|
||||
best_reward = float('-inf')
|
||||
best_model_path = None
|
||||
|
||||
# Create directory for checkpoints if needed
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# For tracking improvement
|
||||
last_improved_episode = 0
|
||||
patience = 10 # Episodes to wait for improvement before early stopping
|
||||
|
||||
for episode in range(num_episodes):
|
||||
start_time = time.time()
|
||||
|
||||
# Reset environment and get initial state
|
||||
state = env.reset()
|
||||
done = False
|
||||
episode_reward = 0
|
||||
step = 0
|
||||
|
||||
# Action metrics for this episode
|
||||
actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions
|
||||
|
||||
while not done:
|
||||
# Select action
|
||||
action = agent.act(state)
|
||||
|
||||
# Execute action
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
# Store experience in replay buffer
|
||||
is_extrema = False # In a real implementation, detect extrema points
|
||||
agent.remember(state, action, reward, next_state, done, is_extrema)
|
||||
|
||||
# Learn from experience
|
||||
if len(agent.memory) >= agent.batch_size:
|
||||
use_prioritized = episode > 1 # Start using prioritized replay after first episode
|
||||
loss = agent.replay(use_prioritized=use_prioritized)
|
||||
|
||||
# Update state and metrics
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
actions_taken[action] += 1
|
||||
|
||||
# Every 100 steps, log progress
|
||||
if step % 100 == 0 or step < 10:
|
||||
action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
|
||||
current_price = info.get('current_price', 0)
|
||||
pnl = info.get('pnl', 0)
|
||||
balance = info.get('balance', 0)
|
||||
|
||||
logger.info(f"Episode {episode}, Step {step}: Action={action_str}, "
|
||||
f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}")
|
||||
|
||||
# Add trade to visualization if enabled
|
||||
if visualize and chart and action in [0, 1]: # BUY or SELL
|
||||
chart.add_trade(
|
||||
price=current_price,
|
||||
timestamp=datetime.now(),
|
||||
amount=0.1,
|
||||
pnl=pnl,
|
||||
action=action_str
|
||||
)
|
||||
|
||||
step += 1
|
||||
|
||||
# Episode finished - calculate metrics
|
||||
episode_time = time.time() - start_time
|
||||
total_train_time += episode_time
|
||||
|
||||
# Get environment info
|
||||
win_rate = env.winning_trades / max(1, env.total_trades)
|
||||
trades = env.total_trades
|
||||
balance = env.balance
|
||||
gain = (balance - env.initial_balance) / env.initial_balance
|
||||
max_drawdown = env.max_drawdown
|
||||
|
||||
# Record metrics
|
||||
rewards.append(episode_reward)
|
||||
win_rates.append(win_rate)
|
||||
|
||||
# Update agent's learning metrics
|
||||
improved = agent.update_learning_metrics(episode_reward)
|
||||
|
||||
# If this is best performance, save the model
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
if save_path:
|
||||
best_model_path = f"{save_path}_best"
|
||||
agent.save(best_model_path)
|
||||
logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})")
|
||||
last_improved_episode = episode
|
||||
|
||||
# Regular checkpoint saving
|
||||
if save_path and episode % save_freq == 0:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}")
|
||||
agent.save(checkpoint_path)
|
||||
|
||||
# Print episode summary
|
||||
actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()])
|
||||
logger.info(f"Episode {episode} completed in {episode_time:.2f}s")
|
||||
logger.info(f" Total reward: {episode_reward:.4f}")
|
||||
logger.info(f" Actions taken: {actions_summary}")
|
||||
logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}")
|
||||
logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}")
|
||||
logger.info(f" Max Drawdown: {max_drawdown:.2%}")
|
||||
|
||||
# Early stopping check
|
||||
if episode - last_improved_episode >= patience:
|
||||
logger.info(f"No improvement for {patience} episodes. Early stopping.")
|
||||
break
|
||||
|
||||
# Training complete
|
||||
avg_time_per_episode = total_train_time / max(1, len(rewards))
|
||||
logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)")
|
||||
|
||||
# Save final model
|
||||
if save_path:
|
||||
agent.save(f"{save_path}_final")
|
||||
logger.info(f"Final model saved to {save_path}_final")
|
||||
|
||||
# Return training metrics
|
||||
return rewards, win_rates, best_reward, best_model_path
|
||||
|
||||
def plot_training_results(rewards, win_rates, save_dir=None):
|
||||
"""Plot training metrics and save the figure"""
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Plot rewards
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(rewards, 'b-')
|
||||
plt.title('Training Rewards per Episode')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Total Reward')
|
||||
plt.grid(True)
|
||||
|
||||
# Plot win rates
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(win_rates, 'g-')
|
||||
plt.title('Win Rate per Episode')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate')
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure if directory provided
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
|
||||
|
||||
plt.close()
|
||||
|
||||
def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None):
|
||||
"""
|
||||
Evaluate a trained agent on the environment
|
||||
|
||||
Args:
|
||||
env: The trading environment
|
||||
agent: The trained DQN agent
|
||||
num_episodes: Number of evaluation episodes
|
||||
visualize: Whether to visualize trades
|
||||
chart: The visualization chart (if visualize=True)
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics
|
||||
"""
|
||||
logger.info(f"Evaluating agent over {num_episodes} episodes")
|
||||
|
||||
# Metrics to track
|
||||
total_rewards = []
|
||||
total_trades = []
|
||||
win_rates = []
|
||||
sharpe_ratios = []
|
||||
sortino_ratios = []
|
||||
max_drawdowns = []
|
||||
final_balances = []
|
||||
|
||||
for episode in range(num_episodes):
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
done = False
|
||||
episode_reward = 0
|
||||
|
||||
# Run episode without exploration
|
||||
while not done:
|
||||
action = agent.act(state, explore=False) # No exploration during evaluation
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
episode_reward += reward
|
||||
state = next_state
|
||||
|
||||
# Add trade to visualization if enabled
|
||||
if visualize and chart and action in [0, 1]: # BUY or SELL
|
||||
action_str = "BUY" if action == 0 else "SELL"
|
||||
current_price = info.get('current_price', 0)
|
||||
pnl = info.get('pnl', 0)
|
||||
|
||||
chart.add_trade(
|
||||
price=current_price,
|
||||
timestamp=datetime.now(),
|
||||
amount=0.1,
|
||||
pnl=pnl,
|
||||
action=action_str
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
total_rewards.append(episode_reward)
|
||||
total_trades.append(env.total_trades)
|
||||
win_rates.append(env.winning_trades / max(1, env.total_trades))
|
||||
sharpe_ratios.append(info.get('sharpe_ratio', 0))
|
||||
sortino_ratios.append(info.get('sortino_ratio', 0))
|
||||
max_drawdowns.append(env.max_drawdown)
|
||||
final_balances.append(env.balance)
|
||||
|
||||
logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, "
|
||||
f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}")
|
||||
|
||||
# Calculate average metrics
|
||||
avg_reward = np.mean(total_rewards)
|
||||
avg_trades = np.mean(total_trades)
|
||||
avg_win_rate = np.mean(win_rates)
|
||||
avg_sharpe = np.mean(sharpe_ratios)
|
||||
avg_sortino = np.mean(sortino_ratios)
|
||||
avg_max_drawdown = np.mean(max_drawdowns)
|
||||
avg_final_balance = np.mean(final_balances)
|
||||
|
||||
# Log evaluation summary
|
||||
logger.info("Evaluation completed:")
|
||||
logger.info(f" Average reward: {avg_reward:.4f}")
|
||||
logger.info(f" Average trades per episode: {avg_trades:.2f}")
|
||||
logger.info(f" Average win rate: {avg_win_rate:.2%}")
|
||||
logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}")
|
||||
logger.info(f" Average Sortino ratio: {avg_sortino:.4f}")
|
||||
logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}")
|
||||
logger.info(f" Average final balance: ${avg_final_balance:.2f}")
|
||||
|
||||
# Return evaluation metrics
|
||||
return {
|
||||
'avg_reward': avg_reward,
|
||||
'avg_trades': avg_trades,
|
||||
'avg_win_rate': avg_win_rate,
|
||||
'avg_sharpe': avg_sharpe,
|
||||
'avg_sortino': avg_sortino,
|
||||
'avg_max_drawdown': avg_max_drawdown,
|
||||
'avg_final_balance': avg_final_balance
|
||||
}
|
||||
|
||||
def main():
|
||||
"""Main function to run the improved RL training"""
|
||||
start_time = time.time()
|
||||
logger.info(f"Starting improved RL training for {args.symbol}")
|
||||
|
||||
# Create environment
|
||||
env, data_frames = create_training_environment(args.symbol)
|
||||
|
||||
# Initialize visualization if enabled
|
||||
chart = None
|
||||
if args.visualize:
|
||||
logger.info("Initializing visualization chart")
|
||||
chart = RealTimeChart(args.symbol)
|
||||
time.sleep(2) # Give time for chart to initialize
|
||||
|
||||
# Initialize agent
|
||||
agent = initialize_agent(env)
|
||||
|
||||
# Train agent
|
||||
rewards, win_rates, best_reward, best_model_path = train_agent(
|
||||
env=env,
|
||||
agent=agent,
|
||||
num_episodes=args.episodes,
|
||||
visualize=args.visualize,
|
||||
chart=chart,
|
||||
save_path=args.save_path
|
||||
)
|
||||
|
||||
# Plot training results
|
||||
plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots')
|
||||
plot_training_results(rewards, win_rates, save_dir=plot_dir)
|
||||
|
||||
# Evaluate best model
|
||||
logger.info("Evaluating best model")
|
||||
|
||||
# Load best model for evaluation
|
||||
if best_model_path:
|
||||
best_agent = initialize_agent(env)
|
||||
best_agent.load(best_model_path)
|
||||
|
||||
# Evaluate the best model
|
||||
eval_metrics = evaluate_agent(
|
||||
env=env,
|
||||
agent=best_agent,
|
||||
visualize=args.visualize,
|
||||
chart=chart
|
||||
)
|
||||
|
||||
# Log evaluation results
|
||||
logger.info("Best model evaluation complete:")
|
||||
for metric, value in eval_metrics.items():
|
||||
logger.info(f" {metric}: {value}")
|
||||
|
||||
# Total run time
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Total run time: {total_time:.2f} seconds")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user