training fixes and enhancements wip

This commit is contained in:
Dobromir Popov
2025-07-14 10:00:42 +03:00
parent e76b1b16dc
commit e74f1393c4
6 changed files with 378 additions and 99 deletions

View File

@ -57,7 +57,10 @@ class DQNAgent:
else:
# 1D state
if isinstance(state_shape, tuple):
self.state_dim = state_shape[0]
if len(state_shape) == 0:
self.state_dim = 1 # Safe default for empty tuple
else:
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
@ -615,8 +618,8 @@ class DQNAgent:
self.recent_actions.append(action)
return action
else:
# Return None to indicate HOLD (don't change position)
return None
# Return 1 (HOLD) as a safe default if action is None
return 1
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
@ -647,7 +650,10 @@ class DQNAgent:
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
return action, adapted_confidence
# Always return int, float
if action is None:
return 1, 0.1
return int(action), float(adapted_confidence)
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
"""
@ -748,13 +754,29 @@ class DQNAgent:
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
experiences = [self.memory[i] for i in indices]
# Sanitize and stack states and next_states
sanitized_states = []
sanitized_next_states = []
for i, e in enumerate(experiences):
try:
state = np.asarray(e[0], dtype=np.float32)
next_state = np.asarray(e[3], dtype=np.float32)
sanitized_states.append(state)
sanitized_next_states.append(next_state)
except Exception as ex:
print(f"[DQNAgent] Bad experience at index {i}: {ex}")
continue
if not sanitized_states or not sanitized_next_states:
print("[DQNAgent] No valid states in replay batch.")
return 0.0 # Return float instead of None for consistency
states = torch.FloatTensor(np.stack(sanitized_states)).to(self.device)
next_states = torch.FloatTensor(np.stack(sanitized_next_states)).to(self.device)
# Choose appropriate replay method
if self.use_mixed_precision:
# Convert experiences to tensors for mixed precision
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
# Use mixed precision replay
@ -829,29 +851,32 @@ class DQNAgent:
return loss
def _replay_standard(self, experiences=None):
def _replay_standard(self, *args):
"""Standard training step without mixed precision"""
try:
# Use experiences if provided, otherwise sample from memory
if experiences is None:
# If memory is too small, skip training
if len(self.memory) < self.batch_size:
return 0.0
# Sample random mini-batch from memory
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
batch = [self.memory[i] for i in indices]
experiences = batch
# Unpack experiences
states, actions, rewards, next_states, dones = zip(*experiences)
# Convert to PyTorch tensors
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(np.array(actions)).to(self.device)
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(np.array(dones)).to(self.device)
# Support both (experiences,) and (states, actions, rewards, next_states, dones)
if len(args) == 1:
experiences = args[0]
# Use experiences if provided, otherwise sample from memory
if experiences is None:
# If memory is too small, skip training
if len(self.memory) < self.batch_size:
return 0.0
# Sample random mini-batch from memory
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
batch = [self.memory[i] for i in indices]
experiences = batch
# Unpack experiences
states, actions, rewards, next_states, dones = zip(*experiences)
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(np.array(actions)).to(self.device)
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(np.array(dones)).to(self.device)
elif len(args) == 5:
states, actions, rewards, next_states, dones = args
else:
raise ValueError("Invalid arguments to _replay_standard")
# Get current Q values
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)