predict price direction

This commit is contained in:
Dobromir Popov
2025-07-27 23:20:47 +03:00
parent dfa18035f1
commit 39267697f3
4 changed files with 572 additions and 101 deletions

View File

@ -4,7 +4,7 @@ import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
from typing import Tuple, List, Dict, Any
import os
import sys
import logging
@ -84,8 +84,8 @@ class DQNNetwork(nn.Module):
nn.Linear(512, 4) # trending, ranging, volatile, mixed
)
# Price prediction head
self.price_head = nn.Sequential(
# Price direction prediction head - outputs direction and confidence
self.price_direction_head = nn.Sequential(
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.ReLU(inplace=True),
@ -93,9 +93,14 @@ class DQNNetwork(nn.Module):
nn.Linear(1024, 512),
nn.LayerNorm(512),
nn.ReLU(inplace=True),
nn.Linear(512, 3) # short, medium, long term price direction
nn.Linear(512, 2) # [direction, confidence]
)
# Direction activation (tanh for -1 to 1)
self.direction_activation = nn.Tanh()
# Confidence activation (sigmoid for 0 to 1)
self.confidence_activation = nn.Sigmoid()
# Volatility prediction head
self.volatility_head = nn.Sequential(
nn.Linear(2048, 1024),
@ -105,7 +110,7 @@ class DQNNetwork(nn.Module):
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(inplace=True),
nn.Linear(256, 1) # predicted volatility
nn.Linear(256, 4) # predicted volatility for 4 timeframes
)
# Main Q-value head (dueling architecture)
@ -162,7 +167,13 @@ class DQNNetwork(nn.Module):
# Multiple prediction heads
regime_pred = self.regime_head(features)
price_pred = self.price_head(features)
price_direction_raw = self.price_direction_head(features)
# Apply separate activations to direction and confidence
direction = self.direction_activation(price_direction_raw[:, 0:1]) # -1 to 1
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
volatility_pred = self.volatility_head(features)
# Dueling Q-network
@ -172,7 +183,7 @@ class DQNNetwork(nn.Module):
# Combine value and advantage for Q-values
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values, regime_pred, price_pred, volatility_pred, features
return q_values, regime_pred, price_direction_pred, volatility_pred, features
def act(self, state, explore=True):
"""
@ -196,7 +207,11 @@ class DQNNetwork(nn.Module):
state = state.unsqueeze(0)
with torch.no_grad():
q_values, regime_pred, price_pred, volatility_pred, features = self.forward(state)
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
@ -332,23 +347,10 @@ class DQNAgent:
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
# Price direction tracking - stores direction and confidence
self.last_price_direction = {
'direction': 0.0, # Single value between -1 and 1
'confidence': 0.0 # Single value between 0 and 1
}
# Store separate memory for price direction examples
@ -521,25 +523,6 @@ class DQNAgent:
logger.error(f"Error saving DQN checkpoint: {e}")
return False
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
@ -811,6 +794,92 @@ class DQNAgent:
logger.error(f"Error in act_with_confidence: {e}")
# Return default action with low confidence
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
def process_price_direction_predictions(self, price_direction_pred: torch.Tensor) -> Dict[str, float]:
"""
Process price direction predictions and convert to standardized format
Args:
price_direction_pred: Tensor of shape (batch_size, 2) containing [direction, confidence]
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
try:
if price_direction_pred is None or price_direction_pred.numel() == 0:
return self.last_price_direction
# Extract direction and confidence values
direction_value = float(price_direction_pred[0, 0].item()) # -1 to 1
confidence_value = float(price_direction_pred[0, 1].item()) # 0 to 1
# Update last price direction
self.last_price_direction = {
'direction': direction_value,
'confidence': confidence_value
}
return self.last_price_direction
except Exception as e:
logger.error(f"Error processing price direction predictions: {e}")
return self.last_price_direction
def get_price_direction_vector(self) -> Dict[str, float]:
"""
Get the current price direction and confidence
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
return self.last_price_direction
def get_price_direction_summary(self) -> Dict[str, Any]:
"""
Get a summary of price direction prediction
Returns:
Dict containing direction and confidence information
"""
try:
direction_value = self.last_price_direction['direction']
confidence_value = self.last_price_direction['confidence']
# Convert to discrete direction
if direction_value > 0.1:
direction_label = "UP"
discrete_direction = 1
elif direction_value < -0.1:
direction_label = "DOWN"
discrete_direction = -1
else:
direction_label = "SIDEWAYS"
discrete_direction = 0
return {
'direction_value': float(direction_value),
'confidence_value': float(confidence_value),
'direction_label': direction_label,
'discrete_direction': discrete_direction,
'strength': abs(float(direction_value)),
'weighted_strength': abs(float(direction_value)) * float(confidence_value)
}
except Exception as e:
logger.error(f"Error calculating price direction summary: {e}")
return {
'direction_value': 0.0,
'confidence_value': 0.0,
'direction_label': "SIDEWAYS",
'discrete_direction': 0,
'strength': 0.0,
'weighted_strength': 0.0
}
except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")
# Return default action with low confidence
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
"""
@ -1032,11 +1101,8 @@ class DQNAgent:
logger.error(f"Error converting experiences to tensors: {e}")
return 0.0
# Choose training method based on precision mode
if self.use_mixed_precision:
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
else:
loss = self._replay_standard(states, actions, rewards, next_states, dones)
# Always use standard training to fix gradient issues
loss = self._replay_standard(states, actions, rewards, next_states, dones)
# Update epsilon
if self.epsilon > self.epsilon_min:
@ -1208,9 +1274,33 @@ class DQNAgent:
q_loss = self.criterion(current_q_values, target_q_values.detach())
# Use only Q-loss for now to ensure clean gradients
# Calculate auxiliary losses and add to Q-loss
total_loss = q_loss
# Add auxiliary losses if available
try:
# Get additional predictions from forward pass
if isinstance(q_values_output, tuple) and len(q_values_output) >= 5:
current_regime_pred = q_values_output[1]
current_price_pred = q_values_output[2]
current_volatility_pred = q_values_output[3]
current_extrema_pred = current_regime_pred # Use regime as extrema proxy for now
# Price direction loss
if current_price_pred is not None and current_price_pred.shape[0] > 0:
price_direction_loss = self._calculate_price_direction_loss(current_price_pred, rewards, actions)
if price_direction_loss is not None:
total_loss = total_loss + 0.2 * price_direction_loss
# Extrema loss
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
extrema_loss = self._calculate_extrema_loss(current_extrema_pred, rewards, actions)
if extrema_loss is not None:
total_loss = total_loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not add auxiliary loss in standard training: {e}")
# Reset gradients
self.optimizer.zero_grad()
@ -1309,13 +1399,17 @@ class DQNAgent:
# Add auxiliary losses if available
try:
# Price direction loss
if current_price_pred is not None and current_price_pred.shape[0] > 0:
price_direction_loss = self._calculate_price_direction_loss(current_price_pred, rewards, actions)
if price_direction_loss is not None:
loss = loss + 0.2 * price_direction_loss
# Extrema loss
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
# Simple extrema targets
with torch.no_grad():
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
loss = loss + 0.1 * extrema_loss
extrema_loss = self._calculate_extrema_loss(current_extrema_pred, rewards, actions)
if extrema_loss is not None:
loss = loss + 0.1 * extrema_loss
except Exception as e:
logger.debug(f"Could not add auxiliary loss in mixed precision: {e}")
@ -1649,6 +1743,95 @@ class DQNAgent:
'exit_threshold': self.exit_confidence_threshold
}
def _calculate_price_direction_loss(self, price_direction_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate loss for price direction predictions
Args:
price_direction_pred: Tensor of shape [batch, 2] containing [direction, confidence]
rewards: Tensor of shape [batch] containing rewards
actions: Tensor of shape [batch] containing actions
Returns:
Price direction loss tensor
"""
try:
if price_direction_pred.size(1) != 2:
return None
batch_size = price_direction_pred.size(0)
# Extract direction and confidence predictions
direction_pred = price_direction_pred[:, 0] # -1 to 1
confidence_pred = price_direction_pred[:, 1] # 0 to 1
# Create targets based on rewards and actions
with torch.no_grad():
# Direction targets: 1 if reward > 0 and action is BUY, -1 if reward > 0 and action is SELL, 0 otherwise
direction_targets = torch.zeros(batch_size, device=price_direction_pred.device)
for i in range(batch_size):
if rewards[i] > 0.01: # Positive reward threshold
if actions[i] == 0: # BUY action
direction_targets[i] = 1.0 # UP
elif actions[i] == 1: # SELL action
direction_targets[i] = -1.0 # DOWN
# else: targets remain 0 (sideways)
# Confidence targets: based on reward magnitude (higher reward = higher confidence)
confidence_targets = torch.abs(rewards).clamp(0, 1)
# Calculate losses for each component
direction_loss = F.mse_loss(direction_pred, direction_targets)
confidence_loss = F.mse_loss(confidence_pred, confidence_targets)
# Combined loss (direction is more important than confidence)
total_loss = direction_loss + 0.3 * confidence_loss
return total_loss
except Exception as e:
logger.debug(f"Error calculating price direction loss: {e}")
return None
def _calculate_extrema_loss(self, extrema_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate loss for extrema predictions
Args:
extrema_pred: Extrema predictions
rewards: Tensor containing rewards
actions: Tensor containing actions
Returns:
Extrema loss tensor
"""
try:
batch_size = extrema_pred.size(0)
# Create targets based on reward patterns
with torch.no_grad():
extrema_targets = torch.ones(batch_size, dtype=torch.long, device=extrema_pred.device) * 2 # Default to "neither"
for i in range(batch_size):
# High positive reward suggests we're at a good entry point (potential bottom for BUY, top for SELL)
if rewards[i] > 0.05:
if actions[i] == 0: # BUY action
extrema_targets[i] = 0 # Bottom
elif actions[i] == 1: # SELL action
extrema_targets[i] = 1 # Top
# Calculate cross-entropy loss
if extrema_pred.size(1) >= 3:
extrema_loss = F.cross_entropy(extrema_pred[:, :3], extrema_targets)
else:
extrema_loss = F.cross_entropy(extrema_pred, extrema_targets)
return extrema_loss
except Exception as e:
logger.debug(f"Error calculating extrema loss: {e}")
return None
def get_enhanced_training_stats(self):
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
return {