stability

This commit is contained in:
Dobromir Popov
2025-07-28 12:10:52 +03:00
parent 9219b78241
commit fb72c93743
8 changed files with 207 additions and 53 deletions

View File

@ -1164,35 +1164,23 @@ class DQNAgent:
# Check if state is a dict or complex object
if isinstance(state, dict):
logger.error(f"State is a dict: {state}")
# Handle empty dictionary case
if not state:
logger.error("No numerical values found in state dict, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Extract numerical values from dict if possible
if 'features' in state:
state = state['features']
elif 'state' in state:
state = state['state']
else:
# Try to extract all numerical values
numerical_values = []
for key, value in state.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
# Handle nested structures safely
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
try:
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception:
continue
# Try to extract all numerical values using the helper method
numerical_values = self._extract_numeric_from_dict(state)
if numerical_values:
state = np.array(numerical_values, dtype=np.float32)
else:
@ -1254,6 +1242,31 @@ class DQNAgent:
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
def _extract_numeric_from_dict(self, data_dict):
"""Recursively extract numerical values from nested dictionaries"""
numerical_values = []
try:
for key, value in data_dict.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception as e:
logger.debug(f"Error extracting numeric values from dict: {e}")
return numerical_values
def _replay_standard(self, states, actions, rewards, next_states, dones):
"""Standard training step without mixed precision"""
try: