added more predictions

This commit is contained in:
Dobromir Popov
2025-04-02 14:20:39 +03:00
parent 70eb7bba9b
commit 7dda00b64a
3 changed files with 967 additions and 143 deletions

View File

@ -99,6 +99,28 @@ class DQNAgent:
} }
self.extrema_memory = [] # Special memory for storing extrema points self.extrema_memory = [] # Special memory for storing extrema points
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking # Performance tracking
self.losses = [] self.losses = []
self.avg_reward = 0.0 self.avg_reward = 0.0
@ -157,6 +179,36 @@ class DQNAgent:
# Always add to main memory # Always add to main memory
self.memory.append(experience) self.memory.append(experience)
# Try to extract price change to analyze the experience
try:
# Extract price feature from sequence data (if available)
if len(state.shape) > 1: # 2D state [timeframes, features]
current_price = state[-1, -1] # Last timeframe, last feature
next_price = next_state[-1, -1]
else: # 1D state
current_price = state[-1] # Last feature
next_price = next_state[-1]
# Calculate price change
price_change = (next_price - current_price) / current_price
# Check if this is a significant price movement
if abs(price_change) > 0.002: # Significant price change
# Store in price movement memory
self.price_movement_memory.append(experience)
# Log significant price movements
direction = "UP" if price_change > 0 else "DOWN"
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
# For clear price movements, also duplicate in main memory to learn more
if abs(price_change) > 0.005: # Very significant movement
for _ in range(2): # Add 2 extra copies
self.memory.append(experience)
except Exception as e:
# Skip price movement analysis if it fails
pass
# Check if this is an extrema point based on our extrema detection head # Check if this is an extrema point based on our extrema detection head
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2: if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
# Class 0 = bottom, 1 = top, 2 = neither # Class 0 = bottom, 1 = top, 2 = neither
@ -196,6 +248,9 @@ class DQNAgent:
if len(self.extrema_memory) > self.buffer_size // 4: if len(self.extrema_memory) > self.buffer_size // 4:
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):] self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
if len(self.price_movement_memory) > self.buffer_size // 4:
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
def act(self, state: np.ndarray, explore=True) -> int: def act(self, state: np.ndarray, explore=True) -> int:
"""Choose action using epsilon-greedy policy with explore flag""" """Choose action using epsilon-greedy policy with explore flag"""
@ -209,7 +264,7 @@ class DQNAgent:
# Get predictions using the policy network # Get predictions using the policy network
self.policy_net.eval() # Set to evaluation mode for inference self.policy_net.eval() # Set to evaluation mode for inference
action_probs, extrema_pred = self.policy_net(state_tensor) action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor)
self.policy_net.train() # Back to training mode self.policy_net.train() # Back to training mode
# Get the predicted extrema class (0=bottom, 1=top, 2=neither) # Get the predicted extrema class (0=bottom, 1=top, 2=neither)
@ -221,17 +276,81 @@ class DQNAgent:
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER" extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}") logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
# Store extrema prediction for the environment to use # Process price predictions
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
price_values = price_predictions['values']
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
immediate_direction = price_immediate.argmax(dim=1).item()
midterm_direction = price_midterm.argmax(dim=1).item()
longterm_direction = price_longterm.argmax(dim=1).item()
# Get confidence levels
immediate_conf = price_immediate[0, immediate_direction].item()
midterm_conf = price_midterm[0, midterm_direction].item()
longterm_conf = price_longterm[0, longterm_direction].item()
# Get predicted price change percentages
price_changes = price_values[0].tolist()
# Log significant price movement predictions
timeframes = ["1s/1m", "1h", "1d", "1w"]
directions = ["DOWN", "SIDEWAYS", "UP"]
for i, (direction, conf) in enumerate([
(immediate_direction, immediate_conf),
(midterm_direction, midterm_conf),
(longterm_direction, longterm_conf)
]):
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
# Store predictions for environment to use
self.last_extrema_pred = { self.last_extrema_pred = {
'class': extrema_class, 'class': extrema_class,
'confidence': extrema_confidence, 'confidence': extrema_confidence,
'raw': extrema_pred.cpu().numpy() 'raw': extrema_pred.cpu().numpy()
} }
self.last_price_pred = {
'immediate': {
'direction': immediate_direction,
'confidence': immediate_conf,
'change': price_changes[0]
},
'midterm': {
'direction': midterm_direction,
'confidence': midterm_conf,
'change': price_changes[1]
},
'longterm': {
'direction': longterm_direction,
'confidence': longterm_conf,
'change': price_changes[2]
}
}
# Get the action with highest Q-value # Get the action with highest Q-value
action = action_probs.argmax().item() action = action_probs.argmax().item()
# Adjust action based on extrema prediction (with some probability) # Adjust action based on extrema and price predictions
# Prioritize short-term movement for trading decisions
if immediate_conf > 0.8: # Only adjust for strong signals
if immediate_direction == 2: # UP prediction
# Bias toward BUY for strong up predictions
if action != 0 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
action = 0 # BUY
elif immediate_direction == 0: # DOWN prediction
# Bias toward SELL for strong down predictions
if action != 1 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
action = 1 # SELL
# Also consider extrema detection for action adjustment
if extrema_confidence > 0.8: # Only adjust for strong signals if extrema_confidence > 0.8: # Only adjust for strong signals
if extrema_class == 0: # Bottom detected if extrema_class == 0: # Bottom detected
# Bias toward BUY at bottoms # Bias toward BUY at bottoms
@ -307,53 +426,102 @@ class DQNAgent:
next_states_tensor, dones_tensor next_states_tensor, dones_tensor
) )
# Occasionally train specifically on extrema points, if we have enough # Training focus selector - randomly focus on one of the specialized training types
if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2: training_focus = random.random()
if random.random() < 0.3: # 30% chance to do extra extrema training
# Sample from extrema memory # Occasionally train specifically on extrema points
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory)) if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size) # Sample from extrema memory
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
# Extract batches with proper tensor conversion
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
extrema_actions = np.array([x[1] for x in extrema_batch])
extrema_rewards = np.array([x[2] for x in extrema_batch])
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
# Convert to torch tensors and move to device
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
# Additional training step focused on extrema points (with smaller learning rate)
original_lr = self.optimizer.param_groups[0]['lr']
# Temporarily reduce learning rate for fine-tuning on extrema
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr * 0.5
# Extract batches with proper tensor conversion # Train on extrema
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch]) if self.use_mixed_precision:
extrema_actions = np.array([x[1] for x in extrema_batch]) extrema_loss = self._replay_mixed_precision(
extrema_rewards = np.array([x[2] for x in extrema_batch]) extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch]) extrema_next_states_tensor, extrema_dones_tensor
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32) )
else:
extrema_loss = self._replay_standard(
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
extrema_next_states_tensor, extrema_dones_tensor
)
# Restore original learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
# Average the loss
loss = (loss + extrema_loss) / 2
# Occasionally train specifically on price movement data
elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2:
# Sample from price movement memory
price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory))
price_batch = random.sample(self.price_movement_memory, price_batch_size)
# Extract batches with proper tensor conversion
price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch])
price_actions = np.array([x[1] for x in price_batch])
price_rewards = np.array([x[2] for x in price_batch])
price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch])
price_dones = np.array([x[4] for x in price_batch], dtype=np.float32)
# Convert to torch tensors and move to device
price_states_tensor = torch.FloatTensor(price_states).to(self.device)
price_actions_tensor = torch.LongTensor(price_actions).to(self.device)
price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device)
price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device)
price_dones_tensor = torch.FloatTensor(price_dones).to(self.device)
# Additional training step focused on price movements (with smaller learning rate)
original_lr = self.optimizer.param_groups[0]['lr']
# Temporarily reduce learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr * 0.5
# Convert to torch tensors and move to device # Train on price movement data
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device) if self.use_mixed_precision:
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device) price_loss = self._replay_mixed_precision(
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device) price_states_tensor, price_actions_tensor, price_rewards_tensor,
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device) price_next_states_tensor, price_dones_tensor
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device) )
else:
# Additional training step focused on extrema points (with smaller learning rate) price_loss = self._replay_standard(
original_lr = self.optimizer.param_groups[0]['lr'] price_states_tensor, price_actions_tensor, price_rewards_tensor,
# Temporarily reduce learning rate for fine-tuning on extrema price_next_states_tensor, price_dones_tensor
for param_group in self.optimizer.param_groups: )
param_group['lr'] = original_lr * 0.5
# Restore original learning rate
# Train on extrema for param_group in self.optimizer.param_groups:
if self.use_mixed_precision: param_group['lr'] = original_lr
extrema_loss = self._replay_mixed_precision(
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor, logger.info(f"Extra training on price movement data: loss={price_loss:.4f}")
extrema_next_states_tensor, extrema_dones_tensor
) # Average the loss
else: loss = (loss + price_loss) / 2
extrema_loss = self._replay_standard(
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
extrema_next_states_tensor, extrema_dones_tensor
)
# Restore original learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = original_lr
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
# Average the loss
loss = (loss + extrema_loss) / 2
# Store and return loss # Store and return loss
self.losses.append(loss) self.losses.append(loss)
@ -365,12 +533,12 @@ class DQNAgent:
self.optimizer.zero_grad() self.optimizer.zero_grad()
# Get current Q values and extrema predictions # Get current Q values and extrema predictions
current_q_values, current_extrema_pred = self.policy_net(states) current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network # Get next Q values from target network
with torch.no_grad(): with torch.no_grad():
next_q_values, next_extrema_pred = self.target_net(next_states) next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0] next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it # Check for dimension mismatch and fix it
@ -389,13 +557,10 @@ class DQNAgent:
# Compute Q-value loss (primary task) # Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values) q_loss = nn.MSELoss()(current_q_values, target_q_values)
# Create extrema labels from price movements (crude approximation) # Initialize combined loss with Q-value loss
# If the next state price is higher than current, we might be in an uptrend (not a bottom) loss = q_loss
# If the next state price is lower than current, we might be in a downtrend (not a top)
# This is a simplified approximation; in real scenarios we'd want to use actual extrema detection
# Try to extract price from current and next states # Try to extract price from current and next states
# Assuming price is in the last feature
try: try:
# Extract price feature from sequence data (if available) # Extract price feature from sequence data (if available)
if len(states.shape) == 3: # [batch, seq, features] if len(states.shape) == 3: # [batch, seq, features]
@ -405,43 +570,99 @@ class DQNAgent:
current_prices = states[:, -1] # Last feature current_prices = states[:, -1] # Last feature
next_prices = next_states[:, -1] next_prices = next_states[:, -1]
# Compute price changes # Compute price changes for different timeframes
price_changes = (next_prices - current_prices) / current_prices immediate_changes = (next_prices - current_prices) / current_prices
# Create crude extrema labels: # Create price direction labels - simplified for training
# 0 = bottom: Large negative price change followed by positive change # 0 = down, 1 = sideways, 2 = up
# 1 = top: Large positive price change followed by negative change immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
# 2 = neither: Small or inconsistent changes midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
# Classify based on price change magnitude # Immediate term direction (1s, 1m)
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither immediate_up = (immediate_changes > 0.0005)
immediate_down = (immediate_changes < -0.0005)
immediate_labels[immediate_up] = 2 # Up
immediate_labels[immediate_down] = 0 # Down
# Identify potential bottoms (significant negative change) # For mid and long term, we can only approximate during training
bottoms = (price_changes < -0.003) # In a real system, we'd need historical data to validate these
extrema_labels[bottoms] = 0 # Here we'll use the immediate term with increasing thresholds as approximation
# Identify potential tops (significant positive change) # Mid-term (1h) - use slightly higher threshold
tops = (price_changes > 0.003) midterm_up = (immediate_changes > 0.001)
extrema_labels[tops] = 1 midterm_down = (immediate_changes < -0.001)
midterm_labels[midterm_up] = 2 # Up
midterm_labels[midterm_down] = 0 # Down
# Calculate extrema prediction loss (auxiliary task) # Long-term (1d) - use even higher threshold
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: longterm_up = (immediate_changes > 0.002)
current_extrema_pred = current_extrema_pred[:min_size] longterm_down = (immediate_changes < -0.002)
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) longterm_labels[longterm_up] = 2 # Up
longterm_labels[longterm_down] = 0 # Down
# Generate target values for price change regression
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
price_value_targets = torch.zeros((min_size, 4), device=self.device)
price_value_targets[:, 0] = immediate_changes
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
# Calculate loss for price direction prediction (classification)
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
# Slice predictions to match the adjusted batch size
immediate_pred = current_price_pred['immediate'][:min_size]
midterm_pred = current_price_pred['midterm'][:min_size]
longterm_pred = current_price_pred['longterm'][:min_size]
price_values_pred = current_price_pred['values'][:min_size]
# Combined loss (primary + auxiliary with lower weight) # Compute losses for each task
# Typically auxiliary tasks should have lower weight to not dominate the primary task immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
loss = q_loss + 0.3 * extrema_loss midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
# Log separate loss components occasionally # MSE loss for price value regression
if random.random() < 0.01: # Log 1% of the time to avoid flood price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
else: # Combine all price prediction losses
# Fall back to just Q-value loss if extrema predictions aren't available price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
loss = q_loss
# Create extrema labels (same as before)
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
# Identify potential bottoms (significant negative change)
bottoms = (immediate_changes < -0.003)
extrema_labels[bottoms] = 0
# Identify potential tops (significant positive change)
tops = (immediate_changes > 0.003)
extrema_labels[tops] = 1
# Calculate extrema prediction loss
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
current_extrema_pred = current_extrema_pred[:min_size]
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
# Combined loss with all components
# Primary task: Q-value learning (RL objective)
# Secondary tasks: extrema detection and price prediction (supervised objectives)
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
# Log loss components occasionally
if random.random() < 0.01: # Log 1% of the time
logger.info(
f"Training losses: Q-loss={q_loss.item():.4f}, "
f"Extrema-loss={extrema_loss.item():.4f}, "
f"Price-loss={price_loss.item():.4f}, "
f"Imm-loss={immediate_loss.item():.4f}, "
f"Mid-loss={midterm_loss.item():.4f}, "
f"Long-loss={longterm_loss.item():.4f}"
)
except Exception as e: except Exception as e:
# Fallback if price extraction fails # Fallback if price extraction fails
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
# Just use Q-value loss
loss = q_loss loss = q_loss
# Backward pass and optimize # Backward pass and optimize
@ -475,12 +696,12 @@ class DQNAgent:
# Forward pass with amp autocasting # Forward pass with amp autocasting
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
# Get current Q values and extrema predictions # Get current Q values and extrema predictions
current_q_values, current_extrema_pred = self.policy_net(states) current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values from target network # Get next Q values from target network
with torch.no_grad(): with torch.no_grad():
next_q_values, next_extrema_pred = self.target_net(next_states) next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0] next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch and fix it # Check for dimension mismatch and fix it
@ -499,7 +720,9 @@ class DQNAgent:
# Compute Q-value loss (primary task) # Compute Q-value loss (primary task)
q_loss = nn.MSELoss()(current_q_values, target_q_values) q_loss = nn.MSELoss()(current_q_values, target_q_values)
# Create extrema labels from price movements (crude approximation) # Initialize loss with q_loss
loss = q_loss
# Try to extract price from current and next states # Try to extract price from current and next states
try: try:
# Extract price feature from sequence data (if available) # Extract price feature from sequence data (if available)
@ -510,42 +733,96 @@ class DQNAgent:
current_prices = states[:, -1] # Last feature current_prices = states[:, -1] # Last feature
next_prices = next_states[:, -1] next_prices = next_states[:, -1]
# Compute price changes # Compute price changes for different timeframes
price_changes = (next_prices - current_prices) / current_prices immediate_changes = (next_prices - current_prices) / current_prices
# Create crude extrema labels: # Create price direction labels - simplified for training
# 0 = bottom: Large negative price change followed by positive change # 0 = down, 1 = sideways, 2 = up
# 1 = top: Large positive price change followed by negative change immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
# 2 = neither: Small or inconsistent changes midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
# Classify based on price change magnitude # Immediate term direction (1s, 1m)
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither immediate_up = (immediate_changes > 0.0005)
immediate_down = (immediate_changes < -0.0005)
immediate_labels[immediate_up] = 2 # Up
immediate_labels[immediate_down] = 0 # Down
# Identify potential bottoms (significant negative change) # For mid and long term, we can only approximate during training
bottoms = (price_changes < -0.003) # In a real system, we'd need historical data to validate these
extrema_labels[bottoms] = 0 # Here we'll use the immediate term with increasing thresholds as approximation
# Identify potential tops (significant positive change) # Mid-term (1h) - use slightly higher threshold
tops = (price_changes > 0.003) midterm_up = (immediate_changes > 0.001)
extrema_labels[tops] = 1 midterm_down = (immediate_changes < -0.001)
midterm_labels[midterm_up] = 2 # Up
midterm_labels[midterm_down] = 0 # Down
# Calculate extrema prediction loss (auxiliary task) # Long-term (1d) - use even higher threshold
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size: longterm_up = (immediate_changes > 0.002)
current_extrema_pred = current_extrema_pred[:min_size] longterm_down = (immediate_changes < -0.002)
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels) longterm_labels[longterm_up] = 2 # Up
longterm_labels[longterm_down] = 0 # Down
# Generate target values for price change regression
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
price_value_targets = torch.zeros((min_size, 4), device=self.device)
price_value_targets[:, 0] = immediate_changes
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
# Calculate loss for price direction prediction (classification)
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
# Slice predictions to match the adjusted batch size
immediate_pred = current_price_pred['immediate'][:min_size]
midterm_pred = current_price_pred['midterm'][:min_size]
longterm_pred = current_price_pred['longterm'][:min_size]
price_values_pred = current_price_pred['values'][:min_size]
# Combined loss (primary + auxiliary with lower weight) # Compute losses for each task
loss = q_loss + 0.3 * extrema_loss immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
# Log separate loss components occasionally # MSE loss for price value regression
if random.random() < 0.01: # Log 1% of the time to avoid flood price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
else: # Combine all price prediction losses
# Fall back to just Q-value loss price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
loss = q_loss
# Create extrema labels (same as before)
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
# Identify potential bottoms (significant negative change)
bottoms = (immediate_changes < -0.003)
extrema_labels[bottoms] = 0
# Identify potential tops (significant positive change)
tops = (immediate_changes > 0.003)
extrema_labels[tops] = 1
# Calculate extrema prediction loss
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
current_extrema_pred = current_extrema_pred[:min_size]
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
# Combined loss with all components
# Primary task: Q-value learning (RL objective)
# Secondary tasks: extrema detection and price prediction (supervised objectives)
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
# Log loss components occasionally
if random.random() < 0.01: # Log 1% of the time
logger.info(
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
f"Extrema-loss={extrema_loss.item():.4f}, "
f"Price-loss={price_loss.item():.4f}"
)
except Exception as e: except Exception as e:
# Fallback if price extraction fails # Fallback if price extraction fails
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.") logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
# Just use Q-value loss
loss = q_loss loss = q_loss
# Backward pass with scaled gradients # Backward pass with scaled gradients

View File

@ -125,6 +125,14 @@ class SimpleCNN(nn.Module):
# Extrema detection head # Extrema detection head
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
# Price prediction heads for different timeframes
self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m)
self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h)
self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d)
# Regression heads for exact price prediction
self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d)
def _check_rebuild_network(self, features): def _check_rebuild_network(self, features):
"""Check if network needs to be rebuilt for different feature dimensions""" """Check if network needs to be rebuilt for different feature dimensions"""
@ -140,7 +148,7 @@ class SimpleCNN(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass through the network Forward pass through the network
Returns both action values and extrema predictions Returns action values, extrema predictions, and price movement predictions for multiple timeframes
""" """
# Handle different input shapes # Handle different input shapes
if len(x.shape) == 2: # [batch_size, features] if len(x.shape) == 2: # [batch_size, features]
@ -173,7 +181,50 @@ class SimpleCNN(nn.Module):
# Extrema predictions # Extrema predictions
extrema_pred = self.extrema_head(fc_out) extrema_pred = self.extrema_head(fc_out)
return action_values, extrema_pred # Price movement predictions for different timeframes
price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m
price_midterm = self.price_pred_midterm(fc_out) # 1h
price_longterm = self.price_pred_longterm(fc_out) # 1d
# Regression values for exact price predictions (percentage changes)
price_values = self.price_pred_value(fc_out)
# Return all predictions in a structured dictionary
price_predictions = {
'immediate': price_immediate,
'midterm': price_midterm,
'longterm': price_longterm,
'values': price_values
}
return action_values, extrema_pred, price_predictions
def save(self, path):
"""Save model weights and architecture"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'state_dict': self.state_dict(),
'input_shape': self.input_shape,
'n_actions': self.n_actions,
'feature_dim': self.feature_dim
}, f"{path}.pt")
logger.info(f"Model saved to {path}.pt")
def load(self, path):
"""Load model weights and architecture"""
try:
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
self.input_shape = checkpoint['input_shape']
self.n_actions = checkpoint['n_actions']
self.feature_dim = checkpoint['feature_dim']
self._build_network()
self.load_state_dict(checkpoint['state_dict'])
self.to(self.device)
logger.info(f"Model loaded from {path}.pt")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
class CNNModelPyTorch(nn.Module): class CNNModelPyTorch(nn.Module):
""" """

View File

@ -9,6 +9,9 @@ import sys
import pandas as pd import pandas as pd
import gym import gym
import json import json
import random
import torch.nn as nn
import contextlib
# Add parent directory to path # Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -49,19 +52,36 @@ class RLTradingEnvironment(gym.Env):
Reinforcement Learning environment for trading with technical indicators Reinforcement Learning environment for trading with technical indicators
from multiple timeframes from multiple timeframes
""" """
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15): def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__() super().__init__()
# Initialize attributes before parent class # Initialize attributes before parent class
self.window_size = window_size self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price self.num_features = features_1m.shape[1] - 1 # Exclude close price
self.num_timeframes = 3 # 1m, 5m, 15m
# Count available timeframes
self.num_timeframes = 1 # Always have 1m
if features_1h is not None:
self.num_timeframes += 1
if features_1d is not None:
self.num_timeframes += 1
self.feature_dim = self.num_features * self.num_timeframes self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes # Store features from different timeframes
self.features_1m = features_1m self.features_1m = features_1m
self.features_5m = features_5m self.features_1h = features_1h
self.features_15m = features_15m self.features_1d = features_1d
# Create synthetic 1s data from 1m (for demo purposes)
self.features_1s = self._create_synthetic_1s_data(features_1m)
# If higher timeframes are missing, create synthetic data
if self.features_1h is None:
self.features_1h = self._create_synthetic_hourly_data(features_1m)
if self.features_1d is None:
self.features_1d = self._create_synthetic_daily_data(features_1h)
# Trading parameters # Trading parameters
self.initial_balance = 1.0 self.initial_balance = 1.0
@ -83,6 +103,45 @@ class RLTradingEnvironment(gym.Env):
# Callback for visualization or external monitoring # Callback for visualization or external monitoring
self.action_callback = None self.action_callback = None
def _create_synthetic_1s_data(self, features_1m):
"""Create synthetic 1-second data from 1-minute data"""
# Simple approach: duplicate each 1m candle for 60 seconds with some noise
num_samples = features_1m.shape[0]
synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1]))
for i in range(num_samples):
for j in range(60):
idx = i * 60 + j
if idx < synthetic_1s.shape[0]:
# Copy the 1m data with small random noise
synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1]))
return synthetic_1s
def _create_synthetic_hourly_data(self, features_1m):
"""Create synthetic hourly data from minute data"""
# Group by hour, taking every 60th candle
num_samples = features_1m.shape[0] // 60
synthetic_1h = np.zeros((num_samples, features_1m.shape[1]))
for i in range(num_samples):
if i * 60 < features_1m.shape[0]:
synthetic_1h[i] = features_1m[i * 60]
return synthetic_1h
def _create_synthetic_daily_data(self, features_1h):
"""Create synthetic daily data from hourly data"""
# Group by day, taking every 24th candle
num_samples = features_1h.shape[0] // 24
synthetic_1d = np.zeros((num_samples, features_1h.shape[1]))
for i in range(num_samples):
if i * 24 < features_1h.shape[0]:
synthetic_1d[i] = features_1h[i * 24]
return synthetic_1d
def reset(self): def reset(self):
"""Reset the environment to initial state""" """Reset the environment to initial state"""
self.balance = self.initial_balance self.balance = self.initial_balance
@ -104,35 +163,43 @@ class RLTradingEnvironment(gym.Env):
Combine features from multiple timeframes, reshaped for the CNN. Combine features from multiple timeframes, reshaped for the CNN.
""" """
# Calculate indices for each timeframe # Calculate indices for each timeframe
idx_1m = self.current_step idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
idx_5m = idx_1m // 5 idx_1h = idx_1m // 60 # 60 minutes in an hour
idx_15m = idx_1m // 15 idx_1d = idx_1h // 24 # 24 hours in a day
# Cap indices to prevent out of bounds
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
# Extract feature windows from each timeframe # Extract feature windows from each timeframe
window_1m = self.features_1m[idx_1m - self.window_size:idx_1m] window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
# Handle 5m timeframe # Handle hourly timeframe
start_5m = max(0, idx_5m - self.window_size) start_1h = max(0, idx_1h - self.window_size)
window_5m = self.features_5m[start_5m:idx_5m] window_1h = self.features_1h[start_1h:idx_1h]
# Handle 15m timeframe # Handle daily timeframe
start_15m = max(0, idx_15m - self.window_size) start_1d = max(0, idx_1d - self.window_size)
window_15m = self.features_15m[start_15m:idx_15m] window_1d = self.features_1d[start_1d:idx_1d]
# Pad if needed (for 5m and 15m) # Pad if needed (for higher timeframes)
if len(window_5m) < self.window_size: if len(window_1m) < self.window_size:
padding = np.zeros((self.window_size - len(window_5m), window_5m.shape[1])) padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
window_5m = np.vstack([padding, window_5m]) window_1m = np.vstack([padding, window_1m])
if len(window_15m) < self.window_size: if len(window_1h) < self.window_size:
padding = np.zeros((self.window_size - len(window_15m), window_15m.shape[1])) padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
window_15m = np.vstack([padding, window_15m]) window_1h = np.vstack([padding, window_1h])
if len(window_1d) < self.window_size:
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
window_1d = np.vstack([padding, window_1d])
# Combine features from all timeframes # Combine features from all timeframes
combined_features = np.hstack([ combined_features = np.hstack([
window_1m.reshape(self.window_size, -1), window_1m.reshape(self.window_size, -1),
window_5m.reshape(self.window_size, -1), window_1h.reshape(self.window_size, -1),
window_15m.reshape(self.window_size, -1) window_1d.reshape(self.window_size, -1)
]) ])
# Convert to float32 and handle any NaN values # Convert to float32 and handle any NaN values
@ -152,7 +219,14 @@ class RLTradingEnvironment(gym.Env):
""" """
# Get current and next price # Get current and next price
current_price = self.features_1m[self.current_step, -1] # Close price is last column current_price = self.features_1m[self.current_step, -1] # Close price is last column
next_price = self.features_1m[self.current_step + 1, -1]
# Check if we're at the end of the data
if self.current_step + 1 >= len(self.features_1m):
next_price = current_price # Use current price if at the end
done = True
else:
next_price = self.features_1m[self.current_step + 1, -1]
done = False
# Handle zero or negative prices # Handle zero or negative prices
if current_price <= 0: if current_price <= 0:
@ -164,7 +238,6 @@ class RLTradingEnvironment(gym.Env):
# Default reward is slightly negative to discourage inaction # Default reward is slightly negative to discourage inaction
reward = -0.0001 reward = -0.0001
done = False
profit_pct = None # Initialize profit_pct variable profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade # Check if enough time has passed since last trade
@ -225,7 +298,7 @@ class RLTradingEnvironment(gym.Env):
# Move to next step # Move to next step
self.current_step += 1 self.current_step += 1
# Check if done # Check if done (reached end of data)
if self.current_step >= len(self.features_1m) - 1: if self.current_step >= len(self.features_1m) - 1:
done = True done = True
@ -251,6 +324,26 @@ class RLTradingEnvironment(gym.Env):
gain = (total_value - self.initial_balance) / self.initial_balance gain = (total_value - self.initial_balance) / self.initial_balance
self.win_rate = self.wins / max(1, self.trades) self.win_rate = self.wins / max(1, self.trades)
# Check if we have prediction data for future timeframes
future_price_1h = None
future_price_1d = None
# Get hourly index
idx_1h = self.current_step // 60
if idx_1h + 1 < len(self.features_1h):
hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column
current_1h_price = self.features_1h[idx_1h, hourly_close_idx]
next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx]
future_price_1h = (next_1h_price - current_1h_price) / current_1h_price
# Get daily index
idx_1d = idx_1h // 24
if idx_1d + 1 < len(self.features_1d):
daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column
current_1d_price = self.features_1d[idx_1d, daily_close_idx]
next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx]
future_price_1d = (next_1d_price - current_1d_price) / current_1d_price
info = { info = {
'balance': self.balance, 'balance': self.balance,
'position': self.position, 'position': self.position,
@ -260,7 +353,9 @@ class RLTradingEnvironment(gym.Env):
'win_rate': self.win_rate, 'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None, 'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price, 'current_price': current_price,
'next_price': next_price 'next_price': next_price,
'future_price_1h': future_price_1h, # Actual future hourly price change
'future_price_1d': future_price_1d # Actual future daily price change
} }
# Call the callback if it exists # Call the callback if it exists
@ -279,7 +374,8 @@ class RLTradingEnvironment(gym.Env):
self.action_callback = callback self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent", def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT"): action_callback=None, episode_callback=None, symbol="BTC/USDT",
pretrain_price_prediction_enabled=True, pretrain_epochs=10):
""" """
Train a reinforcement learning agent for trading Train a reinforcement learning agent for trading
@ -291,6 +387,8 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
action_callback: Callback function for monitoring actions action_callback: Callback function for monitoring actions
episode_callback: Callback function for monitoring episodes episode_callback: Callback function for monitoring episodes
symbol: Trading symbol to use symbol: Trading symbol to use
pretrain_price_prediction_enabled: Whether to pre-train price prediction
pretrain_epochs: Number of epochs for pre-training
Returns: Returns:
tuple: (trained agent, environment) tuple: (trained agent, environment)
@ -396,6 +494,30 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
else: else:
logger.info("No existing model found. Starting with a new model.") logger.info("No existing model found. Starting with a new model.")
# Pre-train price prediction if enabled and we have a new model
if pretrain_price_prediction_enabled:
if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y':
logger.info("Pre-training price prediction capability...")
# Attempt to load hourly and daily data for pre-training
try:
data_interface.add_timeframe('1h')
data_interface.add_timeframe('1d')
# Run pre-training
agent = pretrain_price_prediction(
agent=agent,
data_interface=data_interface,
n_epochs=pretrain_epochs,
batch_size=128
)
# Save the pre-trained model
agent.save(f"{save_path}_pretrained")
logger.info("Pre-trained model saved.")
except Exception as e:
logger.error(f"Error during pre-training: {e}")
logger.warning("Continuing with RL training without pre-training.")
# Create TensorBoard writer # Create TensorBoard writer
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}') writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
@ -499,5 +621,379 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
return agent, env return agent, env
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
"""
Generate labeled training data for price prediction at different timeframes
Args:
data_1m: DataFrame with 1-minute data
data_1h: DataFrame with 1-hour data
data_1d: DataFrame with 1-day data
window_size: Size of the input window
Returns:
tuple: (X, y_immediate, y_midterm, y_longterm, y_values)
- X: input features (window sequences)
- y_immediate: immediate direction labels (0=down, 1=sideways, 2=up)
- y_midterm: mid-term direction labels
- y_longterm: long-term direction labels
- y_values: actual percentage changes for each timeframe
"""
logger.info("Generating price prediction training data from historical prices")
# Prepare data structures
X = []
y_immediate = [] # 1m
y_midterm = [] # 1h
y_longterm = [] # 1d
y_values = [] # Actual percentage changes
# Calculate future returns for labeling
data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle
data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles
# Add indices to align data
data_1m['index'] = range(len(data_1m))
data_1h['index'] = range(len(data_1h))
data_1d['index'] = range(len(data_1d))
# Define thresholds for direction labels
immediate_threshold = 0.0005
midterm_threshold = 0.001
longterm_threshold = 0.002
# Loop through 1m data to create training samples
max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels
sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size))
for idx in sample_indices:
# Get window of 1m data
window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore')
# Skip if window contains NaN
if window_1m.isnull().values.any():
continue
# Get future returns for labeling
future_return_1m = data_1m.iloc[idx]['future_return_1m']
future_return_10m = data_1m.iloc[idx]['future_return_10m']
# Find corresponding row in 1h data (closest timestamp)
current_timestamp = data_1m.iloc[idx]['timestamp']
# Find 1h candle for mid-term prediction
if 'timestamp' in data_1h.columns:
# Find closest 1h candle
closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp)
if closest_1h_idx >= len(data_1h):
closest_1h_idx = len(data_1h) - 1
# Get future 1h return (next candle)
if closest_1h_idx < len(data_1h) - 1:
future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close']
else:
future_return_1h = 0
else:
future_return_1h = future_return_10m # Fallback
# Find 1d candle for long-term prediction
if 'timestamp' in data_1d.columns:
# Find closest 1d candle
closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp)
if closest_1d_idx >= len(data_1d):
closest_1d_idx = len(data_1d) - 1
# Get future 1d return (next candle)
if closest_1d_idx < len(data_1d) - 1:
future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close']
else:
future_return_1d = 0
else:
future_return_1d = future_return_1h * 2 # Fallback
# Create direction labels
# 0=down, 1=sideways, 2=up
# Immediate (1m)
if future_return_1m > immediate_threshold:
immediate_label = 2 # UP
elif future_return_1m < -immediate_threshold:
immediate_label = 0 # DOWN
else:
immediate_label = 1 # SIDEWAYS
# Mid-term (1h)
if future_return_1h > midterm_threshold:
midterm_label = 2 # UP
elif future_return_1h < -midterm_threshold:
midterm_label = 0 # DOWN
else:
midterm_label = 1 # SIDEWAYS
# Long-term (1d)
if future_return_1d > longterm_threshold:
longterm_label = 2 # UP
elif future_return_1d < -longterm_threshold:
longterm_label = 0 # DOWN
else:
longterm_label = 1 # SIDEWAYS
# Store data
X.append(window_1m.values)
y_immediate.append(immediate_label)
y_midterm.append(midterm_label)
y_longterm.append(longterm_label)
y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate
# Convert to numpy arrays
X = np.array(X)
y_immediate = np.array(y_immediate)
y_midterm = np.array(y_midterm)
y_longterm = np.array(y_longterm)
y_values = np.array(y_values)
logger.info(f"Generated {len(X)} price prediction training samples")
# Log class distribution
for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]:
down = (y == 0).sum()
sideways = (y == 1).sum()
up = (y == 2).sum()
logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), "
f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), "
f"UP={up} ({up/len(y)*100:.1f}%)")
return X, y_immediate, y_midterm, y_longterm, y_values
def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128):
"""
Pre-train the agent's price prediction capability on historical data
Args:
agent: DQNAgent instance to train
data_interface: DataInterface instance for accessing data
n_epochs: Number of epochs for training
batch_size: Batch size for training
Returns:
The agent with pre-trained price prediction capabilities
"""
logger.info("Starting supervised pre-training of price prediction")
try:
# Load data for all required timeframes
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
# Check if data is available
if data_1m is None:
logger.warning("1m data not available for pre-training")
return agent
if data_1h is None:
logger.warning("1h data not available, using synthesized data")
# Create synthetic 1h data from 1m data
data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record
if data_1d is None:
logger.warning("1d data not available, using synthesized data")
# Create synthetic 1d data from 1h data
data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record
# Add technical indicators to all data
data_1m = data_interface.add_technical_indicators(data_1m)
data_1h = data_interface.add_technical_indicators(data_1h)
data_1d = data_interface.add_technical_indicators(data_1d)
# Generate labeled training data
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
data_1m, data_1h, data_1d, window_size=20
)
# Split data into training and validation sets
from sklearn.model_selection import train_test_split
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
)
# Convert to torch tensors
X_train_tensor = torch.FloatTensor(X_train).to(agent.device)
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device)
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device)
y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device)
y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device)
X_val_tensor = torch.FloatTensor(X_val).to(agent.device)
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device)
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device)
y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device)
y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device)
# Calculate class weights for imbalanced data
from torch.nn.functional import one_hot
# Function to calculate class weights
def get_class_weights(labels):
counts = np.bincount(labels)
if len(counts) < 3: # Ensure we have 3 classes
counts = np.append(counts, [0] * (3 - len(counts)))
weights = 1.0 / np.array(counts)
weights = weights / np.sum(weights) # Normalize
return weights
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device)
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device)
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device)
# Create DataLoader for batch training
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
y_long_train_tensor, y_val_train_tensor
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Set up loss functions with class weights
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
value_criterion = nn.MSELoss()
# Set up optimizer (separate from agent's optimizer)
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Set model to training mode
agent.policy_net.train()
# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0
for epoch in range(n_epochs):
# Training phase
train_loss = 0.0
imm_correct, mid_correct, long_correct = 0, 0, 0
total = 0
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
# Zero gradients
pretrain_optimizer.zero_grad()
# Forward pass - we only need the price predictions
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
_, _, price_preds = agent.policy_net(X_batch)
# Calculate losses for each prediction head
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
value_loss = value_criterion(price_preds['values'], y_val_batch)
# Combined loss (weighted by importance)
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
# Backward pass and optimize
if agent.use_mixed_precision:
agent.scaler.scale(total_loss).backward()
agent.scaler.unscale_(pretrain_optimizer)
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
agent.scaler.step(pretrain_optimizer)
agent.scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
pretrain_optimizer.step()
# Accumulate metrics
train_loss += total_loss.item()
total += X_batch.size(0)
# Calculate accuracy
_, imm_pred = torch.max(price_preds['immediate'], 1)
_, mid_pred = torch.max(price_preds['midterm'], 1)
_, long_pred = torch.max(price_preds['longterm'], 1)
imm_correct += (imm_pred == y_imm_batch).sum().item()
mid_correct += (mid_pred == y_mid_batch).sum().item()
long_correct += (long_pred == y_long_batch).sum().item()
# Calculate epoch metrics
train_loss /= len(train_loader)
imm_acc = imm_correct / total
mid_acc = mid_correct / total
long_acc = long_correct / total
# Validation phase
agent.policy_net.eval()
val_loss = 0.0
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
with torch.no_grad():
# Forward pass on validation data
_, _, val_price_preds = agent.policy_net(X_val_tensor)
# Calculate validation losses
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
val_loss = val_total_loss.item()
# Calculate validation accuracy
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
imm_val_acc = imm_val_correct / len(X_val_tensor)
mid_val_acc = mid_val_correct / len(X_val_tensor)
long_val_acc = long_val_correct / len(X_val_tensor)
# Learning rate scheduling
pretrain_scheduler.step(val_loss)
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Copy policy_net weights to target_net
agent.target_net.load_state_dict(agent.policy_net.state_dict())
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Log progress
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
# Set model back to training mode for next epoch
agent.policy_net.train()
logger.info("Price prediction pre-training complete")
return agent
except Exception as e:
logger.error(f"Error during price prediction pre-training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return agent
if __name__ == "__main__": if __name__ == "__main__":
train_rl() train_rl()