stability
This commit is contained in:
@ -1164,35 +1164,23 @@ class DQNAgent:
|
||||
# Check if state is a dict or complex object
|
||||
if isinstance(state, dict):
|
||||
logger.error(f"State is a dict: {state}")
|
||||
|
||||
# Handle empty dictionary case
|
||||
if not state:
|
||||
logger.error("No numerical values found in state dict, using default state")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
# Extract numerical values from dict if possible
|
||||
if 'features' in state:
|
||||
state = state['features']
|
||||
elif 'state' in state:
|
||||
state = state['state']
|
||||
else:
|
||||
# Try to extract all numerical values
|
||||
numerical_values = []
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
# Handle nested structures safely
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
try:
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception:
|
||||
continue
|
||||
# Try to extract all numerical values using the helper method
|
||||
numerical_values = self._extract_numeric_from_dict(state)
|
||||
if numerical_values:
|
||||
state = np.array(numerical_values, dtype=np.float32)
|
||||
else:
|
||||
@ -1254,6 +1242,31 @@ class DQNAgent:
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
def _extract_numeric_from_dict(self, data_dict):
|
||||
"""Recursively extract numerical values from nested dictionaries"""
|
||||
numerical_values = []
|
||||
try:
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting numeric values from dict: {e}")
|
||||
return numerical_values
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user