added more predictions
This commit is contained in:
@ -99,6 +99,28 @@ class DQNAgent:
|
||||
}
|
||||
self.extrema_memory = [] # Special memory for storing extrema points
|
||||
|
||||
# Price prediction tracking
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
},
|
||||
'midterm': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
},
|
||||
'longterm': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
}
|
||||
}
|
||||
|
||||
# Store separate memory for price direction examples
|
||||
self.price_movement_memory = [] # For storing examples of clear price movements
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.avg_reward = 0.0
|
||||
@ -157,6 +179,36 @@ class DQNAgent:
|
||||
# Always add to main memory
|
||||
self.memory.append(experience)
|
||||
|
||||
# Try to extract price change to analyze the experience
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(state.shape) > 1: # 2D state [timeframes, features]
|
||||
current_price = state[-1, -1] # Last timeframe, last feature
|
||||
next_price = next_state[-1, -1]
|
||||
else: # 1D state
|
||||
current_price = state[-1] # Last feature
|
||||
next_price = next_state[-1]
|
||||
|
||||
# Calculate price change
|
||||
price_change = (next_price - current_price) / current_price
|
||||
|
||||
# Check if this is a significant price movement
|
||||
if abs(price_change) > 0.002: # Significant price change
|
||||
# Store in price movement memory
|
||||
self.price_movement_memory.append(experience)
|
||||
|
||||
# Log significant price movements
|
||||
direction = "UP" if price_change > 0 else "DOWN"
|
||||
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
|
||||
|
||||
# For clear price movements, also duplicate in main memory to learn more
|
||||
if abs(price_change) > 0.005: # Very significant movement
|
||||
for _ in range(2): # Add 2 extra copies
|
||||
self.memory.append(experience)
|
||||
except Exception as e:
|
||||
# Skip price movement analysis if it fails
|
||||
pass
|
||||
|
||||
# Check if this is an extrema point based on our extrema detection head
|
||||
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
|
||||
# Class 0 = bottom, 1 = top, 2 = neither
|
||||
@ -196,6 +248,9 @@ class DQNAgent:
|
||||
|
||||
if len(self.extrema_memory) > self.buffer_size // 4:
|
||||
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
|
||||
|
||||
if len(self.price_movement_memory) > self.buffer_size // 4:
|
||||
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
||||
|
||||
def act(self, state: np.ndarray, explore=True) -> int:
|
||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
||||
@ -209,7 +264,7 @@ class DQNAgent:
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred = self.policy_net(state_tensor)
|
||||
action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
@ -221,17 +276,81 @@ class DQNAgent:
|
||||
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
||||
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
||||
|
||||
# Store extrema prediction for the environment to use
|
||||
# Process price predictions
|
||||
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
||||
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
||||
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
||||
price_values = price_predictions['values']
|
||||
|
||||
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
||||
immediate_direction = price_immediate.argmax(dim=1).item()
|
||||
midterm_direction = price_midterm.argmax(dim=1).item()
|
||||
longterm_direction = price_longterm.argmax(dim=1).item()
|
||||
|
||||
# Get confidence levels
|
||||
immediate_conf = price_immediate[0, immediate_direction].item()
|
||||
midterm_conf = price_midterm[0, midterm_direction].item()
|
||||
longterm_conf = price_longterm[0, longterm_direction].item()
|
||||
|
||||
# Get predicted price change percentages
|
||||
price_changes = price_values[0].tolist()
|
||||
|
||||
# Log significant price movement predictions
|
||||
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
||||
directions = ["DOWN", "SIDEWAYS", "UP"]
|
||||
|
||||
for i, (direction, conf) in enumerate([
|
||||
(immediate_direction, immediate_conf),
|
||||
(midterm_direction, midterm_conf),
|
||||
(longterm_direction, longterm_conf)
|
||||
]):
|
||||
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
||||
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
||||
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
||||
|
||||
# Store predictions for environment to use
|
||||
self.last_extrema_pred = {
|
||||
'class': extrema_class,
|
||||
'confidence': extrema_confidence,
|
||||
'raw': extrema_pred.cpu().numpy()
|
||||
}
|
||||
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': immediate_direction,
|
||||
'confidence': immediate_conf,
|
||||
'change': price_changes[0]
|
||||
},
|
||||
'midterm': {
|
||||
'direction': midterm_direction,
|
||||
'confidence': midterm_conf,
|
||||
'change': price_changes[1]
|
||||
},
|
||||
'longterm': {
|
||||
'direction': longterm_direction,
|
||||
'confidence': longterm_conf,
|
||||
'change': price_changes[2]
|
||||
}
|
||||
}
|
||||
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Adjust action based on extrema prediction (with some probability)
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
# Also consider extrema detection for action adjustment
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
@ -307,53 +426,102 @@ class DQNAgent:
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
|
||||
# Occasionally train specifically on extrema points, if we have enough
|
||||
if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
if random.random() < 0.3: # 30% chance to do extra extrema training
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
# Training focus selector - randomly focus on one of the specialized training types
|
||||
training_focus = random.random()
|
||||
|
||||
# Occasionally train specifically on extrema points
|
||||
if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
|
||||
# Occasionally train specifically on price movement data
|
||||
elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2:
|
||||
# Sample from price movement memory
|
||||
price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory))
|
||||
price_batch = random.sample(self.price_movement_memory, price_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch])
|
||||
price_actions = np.array([x[1] for x in price_batch])
|
||||
price_rewards = np.array([x[2] for x in price_batch])
|
||||
price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch])
|
||||
price_dones = np.array([x[4] for x in price_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
price_states_tensor = torch.FloatTensor(price_states).to(self.device)
|
||||
price_actions_tensor = torch.LongTensor(price_actions).to(self.device)
|
||||
price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device)
|
||||
price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device)
|
||||
price_dones_tensor = torch.FloatTensor(price_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on price movements (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
# Train on price movement data
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
else:
|
||||
price_loss = self._replay_standard(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on price movement data: loss={price_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + price_loss) / 2
|
||||
|
||||
# Store and return loss
|
||||
self.losses.append(loss)
|
||||
@ -365,12 +533,12 @@ class DQNAgent:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -389,13 +557,10 @@ class DQNAgent:
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Create extrema labels from price movements (crude approximation)
|
||||
# If the next state price is higher than current, we might be in an uptrend (not a bottom)
|
||||
# If the next state price is lower than current, we might be in a downtrend (not a top)
|
||||
# This is a simplified approximation; in real scenarios we'd want to use actual extrema detection
|
||||
# Initialize combined loss with Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
# Assuming price is in the last feature
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(states.shape) == 3: # [batch, seq, features]
|
||||
@ -405,43 +570,99 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes
|
||||
price_changes = (next_prices - current_prices) / current_prices
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create crude extrema labels:
|
||||
# 0 = bottom: Large negative price change followed by positive change
|
||||
# 1 = top: Large positive price change followed by negative change
|
||||
# 2 = neither: Small or inconsistent changes
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Classify based on price change magnitude
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (price_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (price_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Calculate extrema prediction loss (auxiliary task)
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Combined loss (primary + auxiliary with lower weight)
|
||||
# Typically auxiliary tasks should have lower weight to not dominate the primary task
|
||||
loss = q_loss + 0.3 * extrema_loss
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# Log separate loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time to avoid flood
|
||||
logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
|
||||
else:
|
||||
# Fall back to just Q-value loss if extrema predictions aren't available
|
||||
loss = q_loss
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Training losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}, "
|
||||
f"Imm-loss={immediate_loss.item():.4f}, "
|
||||
f"Mid-loss={midterm_loss.item():.4f}, "
|
||||
f"Long-loss={longterm_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
@ -475,12 +696,12 @@ class DQNAgent:
|
||||
# Forward pass with amp autocasting
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -499,7 +720,9 @@ class DQNAgent:
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Create extrema labels from price movements (crude approximation)
|
||||
# Initialize loss with q_loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
@ -510,42 +733,96 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes
|
||||
price_changes = (next_prices - current_prices) / current_prices
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create crude extrema labels:
|
||||
# 0 = bottom: Large negative price change followed by positive change
|
||||
# 1 = top: Large positive price change followed by negative change
|
||||
# 2 = neither: Small or inconsistent changes
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Classify based on price change magnitude
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (price_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (price_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Calculate extrema prediction loss (auxiliary task)
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Combined loss (primary + auxiliary with lower weight)
|
||||
loss = q_loss + 0.3 * extrema_loss
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# Log separate loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time to avoid flood
|
||||
logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
|
||||
else:
|
||||
# Fall back to just Q-value loss
|
||||
loss = q_loss
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass with scaled gradients
|
||||
|
@ -125,6 +125,14 @@ class SimpleCNN(nn.Module):
|
||||
|
||||
# Extrema detection head
|
||||
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
|
||||
# Price prediction heads for different timeframes
|
||||
self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m)
|
||||
self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h)
|
||||
self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d)
|
||||
|
||||
# Regression heads for exact price prediction
|
||||
self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
@ -140,7 +148,7 @@ class SimpleCNN(nn.Module):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the network
|
||||
Returns both action values and extrema predictions
|
||||
Returns action values, extrema predictions, and price movement predictions for multiple timeframes
|
||||
"""
|
||||
# Handle different input shapes
|
||||
if len(x.shape) == 2: # [batch_size, features]
|
||||
@ -173,7 +181,50 @@ class SimpleCNN(nn.Module):
|
||||
# Extrema predictions
|
||||
extrema_pred = self.extrema_head(fc_out)
|
||||
|
||||
return action_values, extrema_pred
|
||||
# Price movement predictions for different timeframes
|
||||
price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m
|
||||
price_midterm = self.price_pred_midterm(fc_out) # 1h
|
||||
price_longterm = self.price_pred_longterm(fc_out) # 1d
|
||||
|
||||
# Regression values for exact price predictions (percentage changes)
|
||||
price_values = self.price_pred_value(fc_out)
|
||||
|
||||
# Return all predictions in a structured dictionary
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
'longterm': price_longterm,
|
||||
'values': price_values
|
||||
}
|
||||
|
||||
return action_values, extrema_pred, price_predictions
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'state_dict': self.state_dict(),
|
||||
'input_shape': self.input_shape,
|
||||
'n_actions': self.n_actions,
|
||||
'feature_dim': self.feature_dim
|
||||
}, f"{path}.pt")
|
||||
logger.info(f"Model saved to {path}.pt")
|
||||
|
||||
def load(self, path):
|
||||
"""Load model weights and architecture"""
|
||||
try:
|
||||
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
|
||||
self.input_shape = checkpoint['input_shape']
|
||||
self.n_actions = checkpoint['n_actions']
|
||||
self.feature_dim = checkpoint['feature_dim']
|
||||
self._build_network()
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.to(self.device)
|
||||
logger.info(f"Model loaded from {path}.pt")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
|
||||
class CNNModelPyTorch(nn.Module):
|
||||
"""
|
||||
|
552
NN/train_rl.py
552
NN/train_rl.py
@ -9,6 +9,9 @@ import sys
|
||||
import pandas as pd
|
||||
import gym
|
||||
import json
|
||||
import random
|
||||
import torch.nn as nn
|
||||
import contextlib
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@ -49,19 +52,36 @@ class RLTradingEnvironment(gym.Env):
|
||||
Reinforcement Learning environment for trading with technical indicators
|
||||
from multiple timeframes
|
||||
"""
|
||||
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
||||
def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
||||
super().__init__()
|
||||
|
||||
# Initialize attributes before parent class
|
||||
self.window_size = window_size
|
||||
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
||||
self.num_timeframes = 3 # 1m, 5m, 15m
|
||||
|
||||
# Count available timeframes
|
||||
self.num_timeframes = 1 # Always have 1m
|
||||
if features_1h is not None:
|
||||
self.num_timeframes += 1
|
||||
if features_1d is not None:
|
||||
self.num_timeframes += 1
|
||||
|
||||
self.feature_dim = self.num_features * self.num_timeframes
|
||||
|
||||
# Store features from different timeframes
|
||||
self.features_1m = features_1m
|
||||
self.features_5m = features_5m
|
||||
self.features_15m = features_15m
|
||||
self.features_1h = features_1h
|
||||
self.features_1d = features_1d
|
||||
|
||||
# Create synthetic 1s data from 1m (for demo purposes)
|
||||
self.features_1s = self._create_synthetic_1s_data(features_1m)
|
||||
|
||||
# If higher timeframes are missing, create synthetic data
|
||||
if self.features_1h is None:
|
||||
self.features_1h = self._create_synthetic_hourly_data(features_1m)
|
||||
|
||||
if self.features_1d is None:
|
||||
self.features_1d = self._create_synthetic_daily_data(features_1h)
|
||||
|
||||
# Trading parameters
|
||||
self.initial_balance = 1.0
|
||||
@ -83,6 +103,45 @@ class RLTradingEnvironment(gym.Env):
|
||||
# Callback for visualization or external monitoring
|
||||
self.action_callback = None
|
||||
|
||||
def _create_synthetic_1s_data(self, features_1m):
|
||||
"""Create synthetic 1-second data from 1-minute data"""
|
||||
# Simple approach: duplicate each 1m candle for 60 seconds with some noise
|
||||
num_samples = features_1m.shape[0]
|
||||
synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
for j in range(60):
|
||||
idx = i * 60 + j
|
||||
if idx < synthetic_1s.shape[0]:
|
||||
# Copy the 1m data with small random noise
|
||||
synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1]))
|
||||
|
||||
return synthetic_1s
|
||||
|
||||
def _create_synthetic_hourly_data(self, features_1m):
|
||||
"""Create synthetic hourly data from minute data"""
|
||||
# Group by hour, taking every 60th candle
|
||||
num_samples = features_1m.shape[0] // 60
|
||||
synthetic_1h = np.zeros((num_samples, features_1m.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
if i * 60 < features_1m.shape[0]:
|
||||
synthetic_1h[i] = features_1m[i * 60]
|
||||
|
||||
return synthetic_1h
|
||||
|
||||
def _create_synthetic_daily_data(self, features_1h):
|
||||
"""Create synthetic daily data from hourly data"""
|
||||
# Group by day, taking every 24th candle
|
||||
num_samples = features_1h.shape[0] // 24
|
||||
synthetic_1d = np.zeros((num_samples, features_1h.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
if i * 24 < features_1h.shape[0]:
|
||||
synthetic_1d[i] = features_1h[i * 24]
|
||||
|
||||
return synthetic_1d
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
self.balance = self.initial_balance
|
||||
@ -104,35 +163,43 @@ class RLTradingEnvironment(gym.Env):
|
||||
Combine features from multiple timeframes, reshaped for the CNN.
|
||||
"""
|
||||
# Calculate indices for each timeframe
|
||||
idx_1m = self.current_step
|
||||
idx_5m = idx_1m // 5
|
||||
idx_15m = idx_1m // 15
|
||||
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
|
||||
idx_1h = idx_1m // 60 # 60 minutes in an hour
|
||||
idx_1d = idx_1h // 24 # 24 hours in a day
|
||||
|
||||
# Cap indices to prevent out of bounds
|
||||
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
|
||||
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
|
||||
|
||||
# Extract feature windows from each timeframe
|
||||
window_1m = self.features_1m[idx_1m - self.window_size:idx_1m]
|
||||
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
|
||||
|
||||
# Handle 5m timeframe
|
||||
start_5m = max(0, idx_5m - self.window_size)
|
||||
window_5m = self.features_5m[start_5m:idx_5m]
|
||||
# Handle hourly timeframe
|
||||
start_1h = max(0, idx_1h - self.window_size)
|
||||
window_1h = self.features_1h[start_1h:idx_1h]
|
||||
|
||||
# Handle 15m timeframe
|
||||
start_15m = max(0, idx_15m - self.window_size)
|
||||
window_15m = self.features_15m[start_15m:idx_15m]
|
||||
# Handle daily timeframe
|
||||
start_1d = max(0, idx_1d - self.window_size)
|
||||
window_1d = self.features_1d[start_1d:idx_1d]
|
||||
|
||||
# Pad if needed (for 5m and 15m)
|
||||
if len(window_5m) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1]))
|
||||
window_5m = np.vstack([padding, window_5m])
|
||||
# Pad if needed (for higher timeframes)
|
||||
if len(window_1m) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
|
||||
window_1m = np.vstack([padding, window_1m])
|
||||
|
||||
if len(window_15m) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1]))
|
||||
window_15m = np.vstack([padding, window_15m])
|
||||
if len(window_1h) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
|
||||
window_1h = np.vstack([padding, window_1h])
|
||||
|
||||
if len(window_1d) < self.window_size:
|
||||
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
|
||||
window_1d = np.vstack([padding, window_1d])
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = np.hstack([
|
||||
window_1m.reshape(self.window_size, -1),
|
||||
window_5m.reshape(self.window_size, -1),
|
||||
window_15m.reshape(self.window_size, -1)
|
||||
window_1h.reshape(self.window_size, -1),
|
||||
window_1d.reshape(self.window_size, -1)
|
||||
])
|
||||
|
||||
# Convert to float32 and handle any NaN values
|
||||
@ -152,7 +219,14 @@ class RLTradingEnvironment(gym.Env):
|
||||
"""
|
||||
# Get current and next price
|
||||
current_price = self.features_1m[self.current_step, -1] # Close price is last column
|
||||
next_price = self.features_1m[self.current_step + 1, -1]
|
||||
|
||||
# Check if we're at the end of the data
|
||||
if self.current_step + 1 >= len(self.features_1m):
|
||||
next_price = current_price # Use current price if at the end
|
||||
done = True
|
||||
else:
|
||||
next_price = self.features_1m[self.current_step + 1, -1]
|
||||
done = False
|
||||
|
||||
# Handle zero or negative prices
|
||||
if current_price <= 0:
|
||||
@ -164,7 +238,6 @@ class RLTradingEnvironment(gym.Env):
|
||||
|
||||
# Default reward is slightly negative to discourage inaction
|
||||
reward = -0.0001
|
||||
done = False
|
||||
profit_pct = None # Initialize profit_pct variable
|
||||
|
||||
# Check if enough time has passed since last trade
|
||||
@ -225,7 +298,7 @@ class RLTradingEnvironment(gym.Env):
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Check if done
|
||||
# Check if done (reached end of data)
|
||||
if self.current_step >= len(self.features_1m) - 1:
|
||||
done = True
|
||||
|
||||
@ -251,6 +324,26 @@ class RLTradingEnvironment(gym.Env):
|
||||
gain = (total_value - self.initial_balance) / self.initial_balance
|
||||
self.win_rate = self.wins / max(1, self.trades)
|
||||
|
||||
# Check if we have prediction data for future timeframes
|
||||
future_price_1h = None
|
||||
future_price_1d = None
|
||||
|
||||
# Get hourly index
|
||||
idx_1h = self.current_step // 60
|
||||
if idx_1h + 1 < len(self.features_1h):
|
||||
hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column
|
||||
current_1h_price = self.features_1h[idx_1h, hourly_close_idx]
|
||||
next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx]
|
||||
future_price_1h = (next_1h_price - current_1h_price) / current_1h_price
|
||||
|
||||
# Get daily index
|
||||
idx_1d = idx_1h // 24
|
||||
if idx_1d + 1 < len(self.features_1d):
|
||||
daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column
|
||||
current_1d_price = self.features_1d[idx_1d, daily_close_idx]
|
||||
next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx]
|
||||
future_price_1d = (next_1d_price - current_1d_price) / current_1d_price
|
||||
|
||||
info = {
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
@ -260,7 +353,9 @@ class RLTradingEnvironment(gym.Env):
|
||||
'win_rate': self.win_rate,
|
||||
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
||||
'current_price': current_price,
|
||||
'next_price': next_price
|
||||
'next_price': next_price,
|
||||
'future_price_1h': future_price_1h, # Actual future hourly price change
|
||||
'future_price_1d': future_price_1d # Actual future daily price change
|
||||
}
|
||||
|
||||
# Call the callback if it exists
|
||||
@ -279,7 +374,8 @@ class RLTradingEnvironment(gym.Env):
|
||||
self.action_callback = callback
|
||||
|
||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
||||
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
|
||||
action_callback=None, episode_callback=None, symbol="BTC/USDT",
|
||||
pretrain_price_prediction_enabled=True, pretrain_epochs=10):
|
||||
"""
|
||||
Train a reinforcement learning agent for trading
|
||||
|
||||
@ -291,6 +387,8 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
action_callback: Callback function for monitoring actions
|
||||
episode_callback: Callback function for monitoring episodes
|
||||
symbol: Trading symbol to use
|
||||
pretrain_price_prediction_enabled: Whether to pre-train price prediction
|
||||
pretrain_epochs: Number of epochs for pre-training
|
||||
|
||||
Returns:
|
||||
tuple: (trained agent, environment)
|
||||
@ -396,6 +494,30 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
else:
|
||||
logger.info("No existing model found. Starting with a new model.")
|
||||
|
||||
# Pre-train price prediction if enabled and we have a new model
|
||||
if pretrain_price_prediction_enabled:
|
||||
if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y':
|
||||
logger.info("Pre-training price prediction capability...")
|
||||
# Attempt to load hourly and daily data for pre-training
|
||||
try:
|
||||
data_interface.add_timeframe('1h')
|
||||
data_interface.add_timeframe('1d')
|
||||
|
||||
# Run pre-training
|
||||
agent = pretrain_price_prediction(
|
||||
agent=agent,
|
||||
data_interface=data_interface,
|
||||
n_epochs=pretrain_epochs,
|
||||
batch_size=128
|
||||
)
|
||||
|
||||
# Save the pre-trained model
|
||||
agent.save(f"{save_path}_pretrained")
|
||||
logger.info("Pre-trained model saved.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during pre-training: {e}")
|
||||
logger.warning("Continuing with RL training without pre-training.")
|
||||
|
||||
# Create TensorBoard writer
|
||||
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
|
||||
|
||||
@ -499,5 +621,379 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
|
||||
return agent, env
|
||||
|
||||
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
||||
"""
|
||||
Generate labeled training data for price prediction at different timeframes
|
||||
|
||||
Args:
|
||||
data_1m: DataFrame with 1-minute data
|
||||
data_1h: DataFrame with 1-hour data
|
||||
data_1d: DataFrame with 1-day data
|
||||
window_size: Size of the input window
|
||||
|
||||
Returns:
|
||||
tuple: (X, y_immediate, y_midterm, y_longterm, y_values)
|
||||
- X: input features (window sequences)
|
||||
- y_immediate: immediate direction labels (0=down, 1=sideways, 2=up)
|
||||
- y_midterm: mid-term direction labels
|
||||
- y_longterm: long-term direction labels
|
||||
- y_values: actual percentage changes for each timeframe
|
||||
"""
|
||||
logger.info("Generating price prediction training data from historical prices")
|
||||
|
||||
# Prepare data structures
|
||||
X = []
|
||||
y_immediate = [] # 1m
|
||||
y_midterm = [] # 1h
|
||||
y_longterm = [] # 1d
|
||||
y_values = [] # Actual percentage changes
|
||||
|
||||
# Calculate future returns for labeling
|
||||
data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle
|
||||
data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles
|
||||
|
||||
# Add indices to align data
|
||||
data_1m['index'] = range(len(data_1m))
|
||||
data_1h['index'] = range(len(data_1h))
|
||||
data_1d['index'] = range(len(data_1d))
|
||||
|
||||
# Define thresholds for direction labels
|
||||
immediate_threshold = 0.0005
|
||||
midterm_threshold = 0.001
|
||||
longterm_threshold = 0.002
|
||||
|
||||
# Loop through 1m data to create training samples
|
||||
max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels
|
||||
sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size))
|
||||
|
||||
for idx in sample_indices:
|
||||
# Get window of 1m data
|
||||
window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore')
|
||||
|
||||
# Skip if window contains NaN
|
||||
if window_1m.isnull().values.any():
|
||||
continue
|
||||
|
||||
# Get future returns for labeling
|
||||
future_return_1m = data_1m.iloc[idx]['future_return_1m']
|
||||
future_return_10m = data_1m.iloc[idx]['future_return_10m']
|
||||
|
||||
# Find corresponding row in 1h data (closest timestamp)
|
||||
current_timestamp = data_1m.iloc[idx]['timestamp']
|
||||
|
||||
# Find 1h candle for mid-term prediction
|
||||
if 'timestamp' in data_1h.columns:
|
||||
# Find closest 1h candle
|
||||
closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp)
|
||||
if closest_1h_idx >= len(data_1h):
|
||||
closest_1h_idx = len(data_1h) - 1
|
||||
|
||||
# Get future 1h return (next candle)
|
||||
if closest_1h_idx < len(data_1h) - 1:
|
||||
future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close']
|
||||
else:
|
||||
future_return_1h = 0
|
||||
else:
|
||||
future_return_1h = future_return_10m # Fallback
|
||||
|
||||
# Find 1d candle for long-term prediction
|
||||
if 'timestamp' in data_1d.columns:
|
||||
# Find closest 1d candle
|
||||
closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp)
|
||||
if closest_1d_idx >= len(data_1d):
|
||||
closest_1d_idx = len(data_1d) - 1
|
||||
|
||||
# Get future 1d return (next candle)
|
||||
if closest_1d_idx < len(data_1d) - 1:
|
||||
future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close']
|
||||
else:
|
||||
future_return_1d = 0
|
||||
else:
|
||||
future_return_1d = future_return_1h * 2 # Fallback
|
||||
|
||||
# Create direction labels
|
||||
# 0=down, 1=sideways, 2=up
|
||||
|
||||
# Immediate (1m)
|
||||
if future_return_1m > immediate_threshold:
|
||||
immediate_label = 2 # UP
|
||||
elif future_return_1m < -immediate_threshold:
|
||||
immediate_label = 0 # DOWN
|
||||
else:
|
||||
immediate_label = 1 # SIDEWAYS
|
||||
|
||||
# Mid-term (1h)
|
||||
if future_return_1h > midterm_threshold:
|
||||
midterm_label = 2 # UP
|
||||
elif future_return_1h < -midterm_threshold:
|
||||
midterm_label = 0 # DOWN
|
||||
else:
|
||||
midterm_label = 1 # SIDEWAYS
|
||||
|
||||
# Long-term (1d)
|
||||
if future_return_1d > longterm_threshold:
|
||||
longterm_label = 2 # UP
|
||||
elif future_return_1d < -longterm_threshold:
|
||||
longterm_label = 0 # DOWN
|
||||
else:
|
||||
longterm_label = 1 # SIDEWAYS
|
||||
|
||||
# Store data
|
||||
X.append(window_1m.values)
|
||||
y_immediate.append(immediate_label)
|
||||
y_midterm.append(midterm_label)
|
||||
y_longterm.append(longterm_label)
|
||||
y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y_immediate = np.array(y_immediate)
|
||||
y_midterm = np.array(y_midterm)
|
||||
y_longterm = np.array(y_longterm)
|
||||
y_values = np.array(y_values)
|
||||
|
||||
logger.info(f"Generated {len(X)} price prediction training samples")
|
||||
|
||||
# Log class distribution
|
||||
for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]:
|
||||
down = (y == 0).sum()
|
||||
sideways = (y == 1).sum()
|
||||
up = (y == 2).sum()
|
||||
logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), "
|
||||
f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), "
|
||||
f"UP={up} ({up/len(y)*100:.1f}%)")
|
||||
|
||||
return X, y_immediate, y_midterm, y_longterm, y_values
|
||||
|
||||
def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128):
|
||||
"""
|
||||
Pre-train the agent's price prediction capability on historical data
|
||||
|
||||
Args:
|
||||
agent: DQNAgent instance to train
|
||||
data_interface: DataInterface instance for accessing data
|
||||
n_epochs: Number of epochs for training
|
||||
batch_size: Batch size for training
|
||||
|
||||
Returns:
|
||||
The agent with pre-trained price prediction capabilities
|
||||
"""
|
||||
logger.info("Starting supervised pre-training of price prediction")
|
||||
|
||||
try:
|
||||
# Load data for all required timeframes
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
# Check if data is available
|
||||
if data_1m is None:
|
||||
logger.warning("1m data not available for pre-training")
|
||||
return agent
|
||||
|
||||
if data_1h is None:
|
||||
logger.warning("1h data not available, using synthesized data")
|
||||
# Create synthetic 1h data from 1m data
|
||||
data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record
|
||||
|
||||
if data_1d is None:
|
||||
logger.warning("1d data not available, using synthesized data")
|
||||
# Create synthetic 1d data from 1h data
|
||||
data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record
|
||||
|
||||
# Add technical indicators to all data
|
||||
data_1m = data_interface.add_technical_indicators(data_1m)
|
||||
data_1h = data_interface.add_technical_indicators(data_1h)
|
||||
data_1d = data_interface.add_technical_indicators(data_1d)
|
||||
|
||||
# Generate labeled training data
|
||||
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
||||
data_1m, data_1h, data_1d, window_size=20
|
||||
)
|
||||
|
||||
# Split data into training and validation sets
|
||||
from sklearn.model_selection import train_test_split
|
||||
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
||||
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Convert to torch tensors
|
||||
X_train_tensor = torch.FloatTensor(X_train).to(agent.device)
|
||||
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device)
|
||||
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device)
|
||||
y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device)
|
||||
y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device)
|
||||
|
||||
X_val_tensor = torch.FloatTensor(X_val).to(agent.device)
|
||||
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device)
|
||||
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device)
|
||||
y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device)
|
||||
y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device)
|
||||
|
||||
# Calculate class weights for imbalanced data
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
# Function to calculate class weights
|
||||
def get_class_weights(labels):
|
||||
counts = np.bincount(labels)
|
||||
if len(counts) < 3: # Ensure we have 3 classes
|
||||
counts = np.append(counts, [0] * (3 - len(counts)))
|
||||
weights = 1.0 / np.array(counts)
|
||||
weights = weights / np.sum(weights) # Normalize
|
||||
return weights
|
||||
|
||||
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device)
|
||||
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device)
|
||||
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device)
|
||||
|
||||
# Create DataLoader for batch training
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
|
||||
train_dataset = TensorDataset(
|
||||
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
||||
y_long_train_tensor, y_val_train_tensor
|
||||
)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Set up loss functions with class weights
|
||||
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
||||
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
||||
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
||||
value_criterion = nn.MSELoss()
|
||||
|
||||
# Set up optimizer (separate from agent's optimizer)
|
||||
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
||||
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
||||
)
|
||||
|
||||
# Set model to training mode
|
||||
agent.policy_net.train()
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience = 5
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
# Training phase
|
||||
train_loss = 0.0
|
||||
imm_correct, mid_correct, long_correct = 0, 0, 0
|
||||
total = 0
|
||||
|
||||
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
||||
# Zero gradients
|
||||
pretrain_optimizer.zero_grad()
|
||||
|
||||
# Forward pass - we only need the price predictions
|
||||
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
||||
_, _, price_preds = agent.policy_net(X_batch)
|
||||
|
||||
# Calculate losses for each prediction head
|
||||
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
||||
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
||||
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
||||
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
||||
|
||||
# Combined loss (weighted by importance)
|
||||
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
if agent.use_mixed_precision:
|
||||
agent.scaler.scale(total_loss).backward()
|
||||
agent.scaler.unscale_(pretrain_optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
agent.scaler.step(pretrain_optimizer)
|
||||
agent.scaler.update()
|
||||
else:
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
pretrain_optimizer.step()
|
||||
|
||||
# Accumulate metrics
|
||||
train_loss += total_loss.item()
|
||||
total += X_batch.size(0)
|
||||
|
||||
# Calculate accuracy
|
||||
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
||||
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
||||
_, long_pred = torch.max(price_preds['longterm'], 1)
|
||||
|
||||
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
||||
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
||||
long_correct += (long_pred == y_long_batch).sum().item()
|
||||
|
||||
# Calculate epoch metrics
|
||||
train_loss /= len(train_loader)
|
||||
imm_acc = imm_correct / total
|
||||
mid_acc = mid_correct / total
|
||||
long_acc = long_correct / total
|
||||
|
||||
# Validation phase
|
||||
agent.policy_net.eval()
|
||||
val_loss = 0.0
|
||||
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass on validation data
|
||||
_, _, val_price_preds = agent.policy_net(X_val_tensor)
|
||||
|
||||
# Calculate validation losses
|
||||
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
||||
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
||||
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
||||
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
||||
|
||||
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
||||
val_loss = val_total_loss.item()
|
||||
|
||||
# Calculate validation accuracy
|
||||
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
||||
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
||||
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
||||
|
||||
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
||||
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
||||
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
||||
|
||||
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
||||
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
||||
long_val_acc = long_val_correct / len(X_val_tensor)
|
||||
|
||||
# Learning rate scheduling
|
||||
pretrain_scheduler.step(val_loss)
|
||||
|
||||
# Early stopping check
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# Copy policy_net weights to target_net
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
||||
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
||||
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
||||
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
||||
|
||||
# Set model back to training mode for next epoch
|
||||
agent.policy_net.train()
|
||||
|
||||
logger.info("Price prediction pre-training complete")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return agent
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_rl()
|
Reference in New Issue
Block a user