fix model mappings,dash updates, trading
This commit is contained in:
@ -111,6 +111,9 @@ class SpatialAttentionBlock(nn.Module):
|
||||
# Avoid in-place operation by creating new tensor
|
||||
return torch.mul(x, attention)
|
||||
|
||||
#Todo:
|
||||
#1. Add pivot points array as input
|
||||
#2. change output to be next pivot point (we'll need to adjust training as well)
|
||||
class EnhancedCNNModel(nn.Module):
|
||||
"""
|
||||
Much larger and more sophisticated CNN architecture for trading
|
||||
@ -125,7 +128,7 @@ class EnhancedCNNModel(nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 2, # BUY/SELL for 2-action system
|
||||
output_size: int = 3, # BUY/SELL/HOLD for 3-action system
|
||||
base_channels: int = 256, # Increased from 128 to 256
|
||||
num_blocks: int = 12, # Increased from 6 to 12
|
||||
num_attention_heads: int = 16, # Increased from 8 to 16
|
||||
@ -479,9 +482,13 @@ class EnhancedCNNModel(nn.Module):
|
||||
action = int(np.argmax(probs))
|
||||
action_confidence = float(probs[action])
|
||||
|
||||
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
action_name = action_names[action] if action < len(action_names) else 'HOLD'
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'action_name': 'BUY' if action == 0 else 'SELL',
|
||||
'action_name': action_name,
|
||||
'confidence': float(confidence),
|
||||
'action_confidence': action_confidence,
|
||||
'probabilities': probs.tolist(),
|
||||
@ -965,21 +972,21 @@ class CNNModel:
|
||||
if len(trend_data) > 1:
|
||||
trend = (trend_data[-1] - trend_data[0]) / trend_data[0] if trend_data[0] != 0 else 0
|
||||
|
||||
# Map trend to action
|
||||
# Map trend to action - FIXED ACTION MAPPING: 0=BUY, 1=SELL
|
||||
if trend > 0.001: # Upward trend > 0.1%
|
||||
action = 1 # BUY
|
||||
action = 0 # BUY (action 0)
|
||||
confidence = min(0.9, 0.5 + abs(trend) * 10)
|
||||
elif trend < -0.001: # Downward trend < -0.1%
|
||||
action = 0 # SELL
|
||||
action = 1 # SELL (action 1)
|
||||
confidence = min(0.9, 0.5 + abs(trend) * 10)
|
||||
else:
|
||||
action = 0 # Default to SELL for unclear trend
|
||||
action = 2 # Default to HOLD for unclear trend
|
||||
confidence = 0.3
|
||||
else:
|
||||
action = 0
|
||||
action = 2 # HOLD for unknown trend
|
||||
confidence = 0.3
|
||||
else:
|
||||
action = 0
|
||||
action = 2 # HOLD for insufficient data
|
||||
confidence = 0.3
|
||||
|
||||
# Create probabilities
|
||||
@ -1000,7 +1007,7 @@ class CNNModel:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback prediction: {e}")
|
||||
# Final fallback - conservative prediction
|
||||
pred_class = np.array([0]) # SELL
|
||||
pred_class = np.array([2]) # HOLD (safe default)
|
||||
proba = np.ones(self.output_size) / self.output_size # Equal probabilities
|
||||
pred_proba = np.array([proba])
|
||||
return pred_class, pred_proba
|
||||
|
@ -578,7 +578,7 @@ class DQNAgent:
|
||||
market_context: Additional market context for decision making
|
||||
|
||||
Returns:
|
||||
int: Action (0=SELL, 1=BUY) or None if should hold position
|
||||
int: Action (0=BUY, 1=SELL, 2=HOLD) or None if should hold position
|
||||
"""
|
||||
|
||||
# Convert state to tensor
|
||||
@ -602,8 +602,9 @@ class DQNAgent:
|
||||
if q_values.dim() == 1:
|
||||
q_values = q_values.unsqueeze(0)
|
||||
|
||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
|
||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||
|
||||
# Determine action based on current position and confidence thresholds
|
||||
action = self._determine_action_with_position_management(
|
||||
@ -669,68 +670,68 @@ class DQNAgent:
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
return np.random.choice([0, 1])
|
||||
|
||||
# Get the dominant signal
|
||||
dominant_action = 0 if sell_conf > buy_conf else 1
|
||||
dominant_confidence = max(sell_conf, buy_conf)
|
||||
# Get the dominant signal - FIXED ACTION MAPPING: 0=BUY, 1=SELL
|
||||
dominant_action = 0 if buy_conf > sell_conf else 1
|
||||
dominant_confidence = max(buy_conf, sell_conf)
|
||||
|
||||
# Decision logic based on current position
|
||||
if self.current_position == 0: # No position - need high confidence to enter
|
||||
if dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Strong enough signal to enter position
|
||||
if dominant_action == 1: # BUY signal
|
||||
if dominant_action == 0: # BUY signal (action 0)
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 1
|
||||
else: # SELL signal
|
||||
return 0 # Return BUY action (0)
|
||||
else: # SELL signal (action 1)
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 0
|
||||
return 1 # Return SELL action (1)
|
||||
else:
|
||||
# Not confident enough to enter position
|
||||
return None
|
||||
|
||||
elif self.current_position > 0: # Long position
|
||||
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# SELL signal with enough confidence to close long position
|
||||
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# SELL signal (action 1) with enough confidence to close long position
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 0
|
||||
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
return 1 # Return SELL action (1)
|
||||
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong SELL signal - close long and enter short
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 0
|
||||
return 1 # Return SELL action (1)
|
||||
else:
|
||||
# Hold the long position
|
||||
return None
|
||||
|
||||
elif self.current_position < 0: # Short position
|
||||
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# BUY signal with enough confidence to close short position
|
||||
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# BUY signal (action 0) with enough confidence to close short position
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 1
|
||||
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
return 0 # Return BUY action (0)
|
||||
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong BUY signal - close short and enter long
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 1
|
||||
return 0 # Return BUY action (0)
|
||||
else:
|
||||
# Hold the short position
|
||||
return None
|
||||
@ -792,246 +793,157 @@ class DQNAgent:
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
|
||||
# Sanitize and stack states and next_states
|
||||
sanitized_states = []
|
||||
sanitized_next_states = []
|
||||
sanitized_experiences = []
|
||||
# Validate experiences before processing
|
||||
if not experiences or len(experiences) == 0:
|
||||
logger.warning("No experiences provided for training")
|
||||
return 0.0
|
||||
|
||||
for i, e in enumerate(experiences):
|
||||
try:
|
||||
# Extract experience components
|
||||
state, action, reward, next_state, done = e
|
||||
|
||||
# Sanitize state - convert any dict/object to float arrays
|
||||
state = self._sanitize_state_data(state)
|
||||
next_state = self._sanitize_state_data(next_state)
|
||||
|
||||
# Sanitize action - ensure it's an integer
|
||||
if isinstance(action, dict):
|
||||
# If action is a dict, try to extract action value
|
||||
action = action.get('action', action.get('value', 0))
|
||||
action = int(action) if not isinstance(action, (int, np.integer)) else action
|
||||
|
||||
# Sanitize reward - ensure it's a float
|
||||
if isinstance(reward, dict):
|
||||
# If reward is a dict, try to extract reward value
|
||||
reward = reward.get('reward', reward.get('value', 0.0))
|
||||
reward = float(reward) if not isinstance(reward, (float, np.floating)) else reward
|
||||
|
||||
# Sanitize done - ensure it's a boolean/float
|
||||
if isinstance(done, dict):
|
||||
done = done.get('done', done.get('value', False))
|
||||
done = bool(done) if not isinstance(done, (bool, np.bool_)) else done
|
||||
|
||||
# Convert state to proper numpy array
|
||||
state = np.asarray(state, dtype=np.float32)
|
||||
next_state = np.asarray(next_state, dtype=np.float32)
|
||||
|
||||
# Add to sanitized lists
|
||||
sanitized_states.append(state)
|
||||
sanitized_next_states.append(next_state)
|
||||
sanitized_experiences.append((state, action, reward, next_state, done))
|
||||
|
||||
except Exception as ex:
|
||||
print(f"[DQNAgent] Bad experience at index {i}: {ex}")
|
||||
continue
|
||||
|
||||
if not sanitized_states or not sanitized_next_states:
|
||||
print("[DQNAgent] No valid states in replay batch.")
|
||||
return 0.0 # Return float instead of None for consistency
|
||||
|
||||
# Validate all states have the same dimensions before stacking
|
||||
expected_dim = getattr(self, 'state_size', getattr(self, 'state_dim', 403))
|
||||
if isinstance(expected_dim, tuple):
|
||||
expected_dim = np.prod(expected_dim)
|
||||
|
||||
# Debug: Check what dimensions we're actually seeing
|
||||
if sanitized_states:
|
||||
actual_dims = [len(state) for state in sanitized_states[:5]] # Check first 5
|
||||
logger.debug(f"DQN State dimensions - Expected: {expected_dim}, Actual samples: {actual_dims}")
|
||||
|
||||
# If all states have a consistent dimension different from expected, use that
|
||||
unique_dims = list(set(len(state) for state in sanitized_states))
|
||||
if len(unique_dims) == 1 and unique_dims[0] != expected_dim:
|
||||
logger.warning(f"All states have dimension {unique_dims[0]} but expected {expected_dim}. Using actual dimension.")
|
||||
expected_dim = unique_dims[0]
|
||||
|
||||
# Filter out states with wrong dimensions and fix them
|
||||
valid_states = []
|
||||
valid_next_states = []
|
||||
# Sanitize and validate experiences
|
||||
valid_experiences = []
|
||||
for i, exp in enumerate(experiences):
|
||||
try:
|
||||
if len(exp) != 5:
|
||||
logger.debug(f"Invalid experience format at index {i}: expected 5 elements, got {len(exp)}")
|
||||
continue
|
||||
|
||||
state, action, reward, next_state, done = exp
|
||||
|
||||
# Validate state
|
||||
state = self._validate_and_fix_state(state)
|
||||
next_state = self._validate_and_fix_state(next_state)
|
||||
|
||||
if state is None or next_state is None:
|
||||
continue
|
||||
|
||||
# Validate action
|
||||
if isinstance(action, dict):
|
||||
action = action.get('action', action.get('value', 0))
|
||||
action = int(action) if action is not None else 0
|
||||
action = max(0, min(action, self.n_actions - 1)) # Clamp to valid range
|
||||
|
||||
# Validate reward
|
||||
if isinstance(reward, dict):
|
||||
reward = reward.get('reward', reward.get('value', 0.0))
|
||||
reward = float(reward) if reward is not None else 0.0
|
||||
|
||||
# Validate done flag
|
||||
done = bool(done) if done is not None else False
|
||||
|
||||
valid_experiences.append((state, action, reward, next_state, done))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing experience {i}: {e}")
|
||||
continue
|
||||
|
||||
for i, (state, next_state, exp) in enumerate(zip(sanitized_states, sanitized_next_states, sanitized_experiences)):
|
||||
# Ensure states have correct dimensions
|
||||
if len(state) != expected_dim:
|
||||
logger.debug(f"Fixing state dimension: {len(state)} -> {expected_dim}")
|
||||
if len(state) < expected_dim:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(expected_dim, dtype=np.float32)
|
||||
padded_state[:len(state)] = state
|
||||
state = padded_state
|
||||
else:
|
||||
# Truncate
|
||||
state = state[:expected_dim]
|
||||
|
||||
if len(next_state) != expected_dim:
|
||||
logger.debug(f"Fixing next_state dimension: {len(next_state)} -> {expected_dim}")
|
||||
if len(next_state) < expected_dim:
|
||||
# Pad with zeros
|
||||
padded_next_state = np.zeros(expected_dim, dtype=np.float32)
|
||||
padded_next_state[:len(next_state)] = next_state
|
||||
next_state = padded_next_state
|
||||
else:
|
||||
# Truncate
|
||||
next_state = next_state[:expected_dim]
|
||||
|
||||
valid_states.append(state)
|
||||
valid_next_states.append(next_state)
|
||||
valid_experiences.append(exp)
|
||||
|
||||
if not valid_states:
|
||||
print("[DQNAgent] No valid states after dimension fixing.")
|
||||
if len(valid_experiences) == 0:
|
||||
logger.warning("No valid experiences after sanitization")
|
||||
return 0.0
|
||||
|
||||
# Use validated experiences for training
|
||||
experiences = valid_experiences
|
||||
|
||||
states = torch.FloatTensor(np.stack(valid_states)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.stack(valid_next_states)).to(self.device)
|
||||
# Extract components
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
|
||||
# Choose appropriate replay method
|
||||
if self.use_mixed_precision:
|
||||
# Convert experiences to tensors for mixed precision
|
||||
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
||||
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
||||
# Convert to tensors with proper validation
|
||||
try:
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
|
||||
# Use mixed precision replay
|
||||
# Final validation of tensor shapes
|
||||
if states.shape[0] == 0 or actions.shape[0] == 0:
|
||||
logger.warning("Empty tensors after conversion")
|
||||
return 0.0
|
||||
|
||||
# Ensure all tensors have the same batch size
|
||||
batch_size = states.shape[0]
|
||||
if not all(tensor.shape[0] == batch_size for tensor in [actions, rewards, next_states, dones]):
|
||||
logger.warning("Inconsistent batch sizes across tensors")
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting experiences to tensors: {e}")
|
||||
return 0.0
|
||||
|
||||
# Choose training method based on precision mode
|
||||
if self.use_mixed_precision:
|
||||
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
|
||||
else:
|
||||
# Pass experiences directly to standard replay method
|
||||
loss = self._replay_standard(experiences)
|
||||
|
||||
# Store loss for monitoring
|
||||
loss = self._replay_standard(states, actions, rewards, next_states, dones)
|
||||
|
||||
# Update epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
# Update statistics
|
||||
self.losses.append(loss)
|
||||
|
||||
# Track and decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
# Randomly decide if we should train on extrema points from special memory
|
||||
if random.random() < 0.3 and len(self.extrema_memory) >= self.batch_size:
|
||||
# Train specifically on extrema memory examples
|
||||
extrema_indices = np.random.choice(len(self.extrema_memory), size=min(self.batch_size, len(self.extrema_memory)), replace=False)
|
||||
extrema_batch = [self.extrema_memory[i] for i in extrema_indices]
|
||||
|
||||
# Sanitize extrema batch
|
||||
sanitized_extrema = []
|
||||
for e in extrema_batch:
|
||||
try:
|
||||
state, action, reward, next_state, done = e
|
||||
state = self._sanitize_state_data(state)
|
||||
next_state = self._sanitize_state_data(next_state)
|
||||
state = np.asarray(state, dtype=np.float32)
|
||||
next_state = np.asarray(next_state, dtype=np.float32)
|
||||
sanitized_extrema.append((state, action, reward, next_state, done))
|
||||
except:
|
||||
continue
|
||||
|
||||
if sanitized_extrema:
|
||||
# Extract tensors from extrema batch
|
||||
extrema_states = torch.FloatTensor(np.array([e[0] for e in sanitized_extrema])).to(self.device)
|
||||
extrema_actions = torch.LongTensor(np.array([e[1] for e in sanitized_extrema])).to(self.device)
|
||||
extrema_rewards = torch.FloatTensor(np.array([e[2] for e in sanitized_extrema])).to(self.device)
|
||||
extrema_next_states = torch.FloatTensor(np.array([e[3] for e in sanitized_extrema])).to(self.device)
|
||||
extrema_dones = torch.FloatTensor(np.array([e[4] for e in sanitized_extrema])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for extrema training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.8
|
||||
|
||||
# Train on extrema memory
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(extrema_states, extrema_actions, extrema_rewards, extrema_next_states, extrema_dones)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(sanitized_extrema)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log extrema loss
|
||||
logger.info(f"Extra training on extrema points, loss: {extrema_loss:.4f}")
|
||||
|
||||
# Randomly train on price movement examples (similar to extrema)
|
||||
if random.random() < 0.3 and len(self.price_movement_memory) >= self.batch_size:
|
||||
# Train specifically on price movement memory examples
|
||||
price_indices = np.random.choice(len(self.price_movement_memory), size=min(self.batch_size, len(self.price_movement_memory)), replace=False)
|
||||
price_batch = [self.price_movement_memory[i] for i in price_indices]
|
||||
|
||||
# Sanitize price movement batch
|
||||
sanitized_price = []
|
||||
for e in price_batch:
|
||||
try:
|
||||
state, action, reward, next_state, done = e
|
||||
state = self._sanitize_state_data(state)
|
||||
next_state = self._sanitize_state_data(next_state)
|
||||
state = np.asarray(state, dtype=np.float32)
|
||||
next_state = np.asarray(next_state, dtype=np.float32)
|
||||
sanitized_price.append((state, action, reward, next_state, done))
|
||||
except:
|
||||
continue
|
||||
|
||||
if sanitized_price:
|
||||
# Extract tensors from price movement batch
|
||||
price_states = torch.FloatTensor(np.array([e[0] for e in sanitized_price])).to(self.device)
|
||||
price_actions = torch.LongTensor(np.array([e[1] for e in sanitized_price])).to(self.device)
|
||||
price_rewards = torch.FloatTensor(np.array([e[2] for e in sanitized_price])).to(self.device)
|
||||
price_next_states = torch.FloatTensor(np.array([e[3] for e in sanitized_price])).to(self.device)
|
||||
price_dones = torch.FloatTensor(np.array([e[4] for e in sanitized_price])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for price movement training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.75
|
||||
|
||||
# Train on price movement memory
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(price_states, price_actions, price_rewards, price_next_states, price_dones)
|
||||
else:
|
||||
price_loss = self._replay_standard(sanitized_price)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log price movement loss
|
||||
logger.info(f"Extra training on price movement examples, loss: {price_loss:.4f}")
|
||||
if len(self.losses) > 1000:
|
||||
self.losses = self.losses[-500:] # Keep only recent losses
|
||||
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, *args):
|
||||
def _validate_and_fix_state(self, state):
|
||||
"""Validate and fix state to ensure it has correct dimensions and no empty data"""
|
||||
try:
|
||||
# Convert to numpy if needed
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.detach().cpu().numpy()
|
||||
elif not isinstance(state, np.ndarray):
|
||||
state = np.array(state, dtype=np.float32)
|
||||
|
||||
# Flatten if multi-dimensional
|
||||
if state.ndim > 1:
|
||||
state = state.flatten()
|
||||
|
||||
# Check for empty or invalid state
|
||||
if state.size == 0:
|
||||
logger.warning("Empty state detected, using default")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
# Check for NaN or infinite values
|
||||
if np.any(np.isnan(state)) or np.any(np.isinf(state)):
|
||||
logger.warning("NaN or infinite values in state, replacing with zeros")
|
||||
state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Ensure correct dimensions
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
expected_size = int(expected_size)
|
||||
|
||||
if len(state) != expected_size:
|
||||
if len(state) < expected_size:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(expected_size, dtype=np.float32)
|
||||
padded_state[:len(state)] = state
|
||||
state = padded_state
|
||||
else:
|
||||
# Truncate
|
||||
state = state[:expected_size]
|
||||
|
||||
return state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating state: {e}")
|
||||
# Return default state as fallback
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
# Support both (experiences,) and (states, actions, rewards, next_states, dones)
|
||||
if len(args) == 1:
|
||||
experiences = args[0]
|
||||
# Use experiences if provided, otherwise sample from memory
|
||||
if experiences is None:
|
||||
# If memory is too small, skip training
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
batch = [self.memory[i] for i in indices]
|
||||
experiences = batch
|
||||
# Unpack experiences
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
elif len(args) == 5:
|
||||
states, actions, rewards, next_states, dones = args
|
||||
else:
|
||||
raise ValueError("Invalid arguments to _replay_standard")
|
||||
# Validate input tensors
|
||||
if states.shape[0] == 0:
|
||||
logger.warning("Empty batch in _replay_standard")
|
||||
return 0.0
|
||||
|
||||
# Get current Q values using safe wrapper
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
@ -1047,14 +959,14 @@ class DQNAgent:
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN: Use target network for both selection and evaluation
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index error
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
# Ensure tensor shapes are consistent
|
||||
batch_size = states.shape[0]
|
||||
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
|
||||
logger.warning(f"Shape mismatch in replay: batch_size={batch_size}, rewards={rewards.shape}, next_q_values={next_q_values.shape}")
|
||||
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
@ -1063,70 +975,82 @@ class DQNAgent:
|
||||
# Calculate target Q values
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss for Q value
|
||||
q_loss = self.criterion(current_q_values, target_q_values)
|
||||
# Compute loss for Q value - ensure tensors require gradients
|
||||
if not current_q_values.requires_grad:
|
||||
logger.warning("Current Q values do not require gradients")
|
||||
return 0.0
|
||||
|
||||
q_loss = self.criterion(current_q_values, target_q_values.detach())
|
||||
|
||||
# Try to compute extrema loss if possible
|
||||
# Initialize total loss with Q loss
|
||||
total_loss = q_loss
|
||||
|
||||
# Add auxiliary losses if available and valid
|
||||
try:
|
||||
# Get the target classes from extrema predictions
|
||||
extrema_targets = torch.argmax(current_extrema_pred, dim=1).long()
|
||||
|
||||
# Compute extrema loss using cross-entropy - this is an auxiliary task
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
|
||||
# Combined loss with emphasis on Q-learning
|
||||
total_loss = q_loss + 0.1 * extrema_loss
|
||||
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
||||
# Create simple extrema targets based on Q-values
|
||||
with torch.no_grad():
|
||||
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2 # Default to "neither"
|
||||
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
total_loss = total_loss + 0.1 * extrema_loss
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
total_loss = q_loss
|
||||
|
||||
logger.debug(f"Could not calculate auxiliary loss: {e}")
|
||||
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
# Ensure total loss requires gradients
|
||||
if not total_loss.requires_grad:
|
||||
logger.warning("Total loss tensor does not require gradients, skipping backward pass")
|
||||
logger.warning("Total loss does not require gradients - policy network may not be in training mode")
|
||||
self.policy_net.train() # Ensure training mode
|
||||
return 0.0
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Enhanced gradient clipping with configurable norm
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.gradient_clip_norm)
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
||||
|
||||
# Check if gradients are valid
|
||||
has_valid_gradients = False
|
||||
for param in self.policy_net.parameters():
|
||||
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
|
||||
has_valid_gradients = True
|
||||
break
|
||||
|
||||
if not has_valid_gradients:
|
||||
logger.warning("No valid gradients found, skipping optimizer step")
|
||||
return 0.0
|
||||
|
||||
# Update weights
|
||||
self.optimizer.step()
|
||||
|
||||
# Enhanced target network update tracking
|
||||
# Update target network periodically
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.debug(f"Target network updated at step {self.training_steps}")
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
# Calculate and store TD error for analysis
|
||||
with torch.no_grad():
|
||||
td_error = torch.abs(current_q_values - target_q_values).mean().item()
|
||||
self.td_errors.append(td_error)
|
||||
|
||||
# Return loss
|
||||
return total_loss.item()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in replay standard: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Error in standard replay: {e}")
|
||||
return 0.0
|
||||
|
||||
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
|
||||
"""Mixed precision training step for better GPU performance"""
|
||||
# Check if mixed precision should be explicitly disabled
|
||||
if 'DISABLE_MIXED_PRECISION' in os.environ:
|
||||
logger.info("Mixed precision explicitly disabled by environment variable")
|
||||
"""Mixed precision training step"""
|
||||
if not self.use_mixed_precision:
|
||||
logger.warning("Mixed precision not available, falling back to standard replay")
|
||||
return self._replay_standard(states, actions, rewards, next_states, dones)
|
||||
|
||||
try:
|
||||
# Validate input tensors
|
||||
if states.shape[0] == 0:
|
||||
logger.warning("Empty batch in _replay_mixed_precision")
|
||||
return 0.0
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@ -1135,21 +1059,28 @@ class DQNAgent:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
# Get current Q values and predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
if self.use_double_dqn:
|
||||
# Double DQN
|
||||
policy_q_values, _, _, _, _ = self._safe_cnn_forward(self.policy_net, next_states)
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN
|
||||
next_q_values, _, _, _, _ = self._safe_cnn_forward(self.target_net, next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
# Log the shape mismatch for debugging
|
||||
logger.warning(f"Shape mismatch detected: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index errors
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
# Ensure consistent shapes
|
||||
batch_size = states.shape[0]
|
||||
if rewards.shape[0] != batch_size or next_q_values.shape[0] != batch_size:
|
||||
logger.warning(f"Shape mismatch in mixed precision replay")
|
||||
min_size = min(batch_size, rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
@ -1158,147 +1089,63 @@ class DQNAgent:
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values.detach())
|
||||
|
||||
# Initialize loss with q_loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
# Add auxiliary losses if available
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(states.shape) == 3: # [batch, seq, features]
|
||||
current_prices = states[:, -1, -1] # Last timestep, last feature
|
||||
next_prices = next_states[:, -1, -1]
|
||||
else: # [batch, features]
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Calculate price change for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Get the actual batch size for this calculation
|
||||
actual_batch_size = states.shape[0]
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((actual_batch_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= actual_batch_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:actual_batch_size]
|
||||
midterm_pred = current_price_pred['midterm'][:actual_batch_size]
|
||||
longterm_pred = current_price_pred['longterm'][:actual_batch_size]
|
||||
price_values_pred = current_price_pred['values'][:actual_batch_size]
|
||||
if current_extrema_pred is not None and current_extrema_pred.shape[0] > 0:
|
||||
# Simple extrema targets
|
||||
with torch.no_grad():
|
||||
extrema_targets = torch.ones(current_extrema_pred.shape[0], dtype=torch.long, device=current_extrema_pred.device) * 2
|
||||
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
loss = loss + 0.1 * extrema_loss
|
||||
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= actual_batch_size:
|
||||
current_extrema_pred = current_extrema_pred[:actual_batch_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Mixed precision losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
if not loss.requires_grad:
|
||||
logger.warning("Loss tensor does not require gradients, skipping backward pass")
|
||||
return 0.0
|
||||
|
||||
# Backward pass with scaled gradients
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# Gradient clipping on scaled gradients
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
|
||||
# Update with scaler
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Track and decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
return loss.item()
|
||||
|
||||
logger.debug(f"Could not add auxiliary loss in mixed precision: {e}")
|
||||
|
||||
# Check if loss requires gradients
|
||||
if not loss.requires_grad:
|
||||
logger.warning("Loss does not require gradients in mixed precision training")
|
||||
return 0.0
|
||||
|
||||
# Scale and backward pass
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# Unscale gradients and clip
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
|
||||
|
||||
# Check for valid gradients
|
||||
has_valid_gradients = False
|
||||
for param in self.policy_net.parameters():
|
||||
if param.grad is not None and torch.any(torch.isfinite(param.grad)):
|
||||
has_valid_gradients = True
|
||||
break
|
||||
|
||||
if not has_valid_gradients:
|
||||
logger.warning("No valid gradients in mixed precision training")
|
||||
self.scaler.update() # Still update scaler
|
||||
return 0.0
|
||||
|
||||
# Optimizer step with scaler
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# Update target network
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.debug(f"Target network updated at step {self.training_steps}")
|
||||
|
||||
return loss.item()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in mixed precision training: {str(e)}")
|
||||
logger.warning("Falling back to standard precision training")
|
||||
# Fall back to standard training
|
||||
return self._replay_standard(states, actions, rewards, next_states, dones)
|
||||
logger.error(f"Error in mixed precision replay: {e}")
|
||||
return 0.0
|
||||
|
||||
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user