device fix , TZ fix
This commit is contained in:
@ -1056,7 +1056,55 @@ class DQNAgent:
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.detach().cpu().numpy()
|
||||
elif not isinstance(state, np.ndarray):
|
||||
state = np.array(state, dtype=np.float32)
|
||||
# Check if state is a dict or complex object
|
||||
if isinstance(state, dict):
|
||||
logger.error(f"State is a dict: {state}")
|
||||
# Extract numerical values from dict if possible
|
||||
if 'features' in state:
|
||||
state = state['features']
|
||||
elif 'state' in state:
|
||||
state = state['state']
|
||||
else:
|
||||
# Try to extract all numerical values
|
||||
numerical_values = []
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numerical_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
# Handle nested structures safely
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numerical_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numerical_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
try:
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numerical_values.extend(nested_values)
|
||||
except Exception:
|
||||
continue
|
||||
if numerical_values:
|
||||
state = np.array(numerical_values, dtype=np.float32)
|
||||
else:
|
||||
logger.error("No numerical values found in state dict, using default state")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
else:
|
||||
try:
|
||||
state = np.array(state, dtype=np.float32)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Cannot convert state to numpy array: {type(state)}, {e}")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
return np.zeros(int(expected_size), dtype=np.float32)
|
||||
|
||||
# Flatten if multi-dimensional
|
||||
if state.ndim > 1:
|
||||
@ -1761,4 +1809,34 @@ class DQNAgent:
|
||||
return 0.0
|
||||
|
||||
except:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def _extract_numeric_from_dict(self, data_dict):
|
||||
"""Recursively extract all numeric values from a dictionary"""
|
||||
numeric_values = []
|
||||
try:
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
numeric_values.append(float(value))
|
||||
elif isinstance(value, (list, np.ndarray)):
|
||||
try:
|
||||
flattened = np.array(value).flatten()
|
||||
for x in flattened:
|
||||
if isinstance(x, (int, float)):
|
||||
numeric_values.append(float(x))
|
||||
elif hasattr(x, 'item'): # numpy scalar
|
||||
numeric_values.append(float(x.item()))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
# Recursively extract from nested dicts
|
||||
nested_values = self._extract_numeric_from_dict(value)
|
||||
numeric_values.extend(nested_values)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
try:
|
||||
numeric_values.append(float(value.item()))
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting numeric values from dict: {e}")
|
||||
return numeric_values
|
Reference in New Issue
Block a user