added more predictions
This commit is contained in:
@ -99,6 +99,28 @@ class DQNAgent:
|
||||
}
|
||||
self.extrema_memory = [] # Special memory for storing extrema points
|
||||
|
||||
# Price prediction tracking
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
},
|
||||
'midterm': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
},
|
||||
'longterm': {
|
||||
'direction': 1, # Default to "sideways"
|
||||
'confidence': 0.0,
|
||||
'change': 0.0
|
||||
}
|
||||
}
|
||||
|
||||
# Store separate memory for price direction examples
|
||||
self.price_movement_memory = [] # For storing examples of clear price movements
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.avg_reward = 0.0
|
||||
@ -157,6 +179,36 @@ class DQNAgent:
|
||||
# Always add to main memory
|
||||
self.memory.append(experience)
|
||||
|
||||
# Try to extract price change to analyze the experience
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(state.shape) > 1: # 2D state [timeframes, features]
|
||||
current_price = state[-1, -1] # Last timeframe, last feature
|
||||
next_price = next_state[-1, -1]
|
||||
else: # 1D state
|
||||
current_price = state[-1] # Last feature
|
||||
next_price = next_state[-1]
|
||||
|
||||
# Calculate price change
|
||||
price_change = (next_price - current_price) / current_price
|
||||
|
||||
# Check if this is a significant price movement
|
||||
if abs(price_change) > 0.002: # Significant price change
|
||||
# Store in price movement memory
|
||||
self.price_movement_memory.append(experience)
|
||||
|
||||
# Log significant price movements
|
||||
direction = "UP" if price_change > 0 else "DOWN"
|
||||
logger.info(f"Stored significant {direction} price movement: {price_change:.4f}")
|
||||
|
||||
# For clear price movements, also duplicate in main memory to learn more
|
||||
if abs(price_change) > 0.005: # Very significant movement
|
||||
for _ in range(2): # Add 2 extra copies
|
||||
self.memory.append(experience)
|
||||
except Exception as e:
|
||||
# Skip price movement analysis if it fails
|
||||
pass
|
||||
|
||||
# Check if this is an extrema point based on our extrema detection head
|
||||
if hasattr(self, 'last_extrema_pred') and self.last_extrema_pred['class'] != 2:
|
||||
# Class 0 = bottom, 1 = top, 2 = neither
|
||||
@ -196,6 +248,9 @@ class DQNAgent:
|
||||
|
||||
if len(self.extrema_memory) > self.buffer_size // 4:
|
||||
self.extrema_memory = self.extrema_memory[-(self.buffer_size // 4):]
|
||||
|
||||
if len(self.price_movement_memory) > self.buffer_size // 4:
|
||||
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
||||
|
||||
def act(self, state: np.ndarray, explore=True) -> int:
|
||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
||||
@ -209,7 +264,7 @@ class DQNAgent:
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred = self.policy_net(state_tensor)
|
||||
action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
@ -221,17 +276,81 @@ class DQNAgent:
|
||||
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
||||
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
||||
|
||||
# Store extrema prediction for the environment to use
|
||||
# Process price predictions
|
||||
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
||||
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
||||
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
||||
price_values = price_predictions['values']
|
||||
|
||||
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
||||
immediate_direction = price_immediate.argmax(dim=1).item()
|
||||
midterm_direction = price_midterm.argmax(dim=1).item()
|
||||
longterm_direction = price_longterm.argmax(dim=1).item()
|
||||
|
||||
# Get confidence levels
|
||||
immediate_conf = price_immediate[0, immediate_direction].item()
|
||||
midterm_conf = price_midterm[0, midterm_direction].item()
|
||||
longterm_conf = price_longterm[0, longterm_direction].item()
|
||||
|
||||
# Get predicted price change percentages
|
||||
price_changes = price_values[0].tolist()
|
||||
|
||||
# Log significant price movement predictions
|
||||
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
||||
directions = ["DOWN", "SIDEWAYS", "UP"]
|
||||
|
||||
for i, (direction, conf) in enumerate([
|
||||
(immediate_direction, immediate_conf),
|
||||
(midterm_direction, midterm_conf),
|
||||
(longterm_direction, longterm_conf)
|
||||
]):
|
||||
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
||||
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
||||
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
||||
|
||||
# Store predictions for environment to use
|
||||
self.last_extrema_pred = {
|
||||
'class': extrema_class,
|
||||
'confidence': extrema_confidence,
|
||||
'raw': extrema_pred.cpu().numpy()
|
||||
}
|
||||
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': immediate_direction,
|
||||
'confidence': immediate_conf,
|
||||
'change': price_changes[0]
|
||||
},
|
||||
'midterm': {
|
||||
'direction': midterm_direction,
|
||||
'confidence': midterm_conf,
|
||||
'change': price_changes[1]
|
||||
},
|
||||
'longterm': {
|
||||
'direction': longterm_direction,
|
||||
'confidence': longterm_conf,
|
||||
'change': price_changes[2]
|
||||
}
|
||||
}
|
||||
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Adjust action based on extrema prediction (with some probability)
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
# Also consider extrema detection for action adjustment
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
@ -307,53 +426,102 @@ class DQNAgent:
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
|
||||
# Occasionally train specifically on extrema points, if we have enough
|
||||
if hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
if random.random() < 0.3: # 30% chance to do extra extrema training
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
# Training focus selector - randomly focus on one of the specialized training types
|
||||
training_focus = random.random()
|
||||
|
||||
# Occasionally train specifically on extrema points
|
||||
if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
|
||||
# Occasionally train specifically on price movement data
|
||||
elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2:
|
||||
# Sample from price movement memory
|
||||
price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory))
|
||||
price_batch = random.sample(self.price_movement_memory, price_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch])
|
||||
price_actions = np.array([x[1] for x in price_batch])
|
||||
price_rewards = np.array([x[2] for x in price_batch])
|
||||
price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch])
|
||||
price_dones = np.array([x[4] for x in price_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
price_states_tensor = torch.FloatTensor(price_states).to(self.device)
|
||||
price_actions_tensor = torch.LongTensor(price_actions).to(self.device)
|
||||
price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device)
|
||||
price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device)
|
||||
price_dones_tensor = torch.FloatTensor(price_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on price movements (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
# Train on price movement data
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
else:
|
||||
price_loss = self._replay_standard(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on price movement data: loss={price_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + price_loss) / 2
|
||||
|
||||
# Store and return loss
|
||||
self.losses.append(loss)
|
||||
@ -365,12 +533,12 @@ class DQNAgent:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -389,13 +557,10 @@ class DQNAgent:
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Create extrema labels from price movements (crude approximation)
|
||||
# If the next state price is higher than current, we might be in an uptrend (not a bottom)
|
||||
# If the next state price is lower than current, we might be in a downtrend (not a top)
|
||||
# This is a simplified approximation; in real scenarios we'd want to use actual extrema detection
|
||||
# Initialize combined loss with Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
# Assuming price is in the last feature
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(states.shape) == 3: # [batch, seq, features]
|
||||
@ -405,43 +570,99 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes
|
||||
price_changes = (next_prices - current_prices) / current_prices
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create crude extrema labels:
|
||||
# 0 = bottom: Large negative price change followed by positive change
|
||||
# 1 = top: Large positive price change followed by negative change
|
||||
# 2 = neither: Small or inconsistent changes
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Classify based on price change magnitude
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (price_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (price_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Calculate extrema prediction loss (auxiliary task)
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Combined loss (primary + auxiliary with lower weight)
|
||||
# Typically auxiliary tasks should have lower weight to not dominate the primary task
|
||||
loss = q_loss + 0.3 * extrema_loss
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# Log separate loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time to avoid flood
|
||||
logger.info(f"Training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
|
||||
else:
|
||||
# Fall back to just Q-value loss if extrema predictions aren't available
|
||||
loss = q_loss
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Training losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}, "
|
||||
f"Imm-loss={immediate_loss.item():.4f}, "
|
||||
f"Mid-loss={midterm_loss.item():.4f}, "
|
||||
f"Long-loss={longterm_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
@ -475,12 +696,12 @@ class DQNAgent:
|
||||
# Forward pass with amp autocasting
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -499,7 +720,9 @@ class DQNAgent:
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Create extrema labels from price movements (crude approximation)
|
||||
# Initialize loss with q_loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
@ -510,42 +733,96 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes
|
||||
price_changes = (next_prices - current_prices) / current_prices
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create crude extrema labels:
|
||||
# 0 = bottom: Large negative price change followed by positive change
|
||||
# 1 = top: Large positive price change followed by negative change
|
||||
# 2 = neither: Small or inconsistent changes
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Classify based on price change magnitude
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (price_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (price_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Calculate extrema prediction loss (auxiliary task)
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Combined loss (primary + auxiliary with lower weight)
|
||||
loss = q_loss + 0.3 * extrema_loss
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# Log separate loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time to avoid flood
|
||||
logger.info(f"Mixed precision training losses: Q-loss={q_loss.item():.4f}, Extrema-loss={extrema_loss.item():.4f}")
|
||||
else:
|
||||
# Fall back to just Q-value loss
|
||||
loss = q_loss
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass with scaled gradients
|
||||
|
Reference in New Issue
Block a user