device fix , TZ fix

This commit is contained in:
Dobromir Popov
2025-07-27 22:13:28 +03:00
parent 9e1684f9f8
commit 368c49df50
6 changed files with 194 additions and 163 deletions

View File

@ -1056,7 +1056,55 @@ class DQNAgent:
if isinstance(state, torch.Tensor):
state = state.detach().cpu().numpy()
elif not isinstance(state, np.ndarray):
state = np.array(state, dtype=np.float32)
# Check if state is a dict or complex object
if isinstance(state, dict):
logger.error(f"State is a dict: {state}")
# Extract numerical values from dict if possible
if 'features' in state:
state = state['features']
elif 'state' in state:
state = state['state']
else:
# Try to extract all numerical values
numerical_values = []
for key, value in state.items():
if isinstance(value, (int, float)):
numerical_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
# Handle nested structures safely
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numerical_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numerical_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
try:
nested_values = self._extract_numeric_from_dict(value)
numerical_values.extend(nested_values)
except Exception:
continue
if numerical_values:
state = np.array(numerical_values, dtype=np.float32)
else:
logger.error("No numerical values found in state dict, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
else:
try:
state = np.array(state, dtype=np.float32)
except (ValueError, TypeError) as e:
logger.error(f"Cannot convert state to numpy array: {type(state)}, {e}")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)
return np.zeros(int(expected_size), dtype=np.float32)
# Flatten if multi-dimensional
if state.ndim > 1:
@ -1761,4 +1809,34 @@ class DQNAgent:
return 0.0
except:
return 0.0
return 0.0
def _extract_numeric_from_dict(self, data_dict):
"""Recursively extract all numeric values from a dictionary"""
numeric_values = []
try:
for key, value in data_dict.items():
if isinstance(value, (int, float)):
numeric_values.append(float(value))
elif isinstance(value, (list, np.ndarray)):
try:
flattened = np.array(value).flatten()
for x in flattened:
if isinstance(x, (int, float)):
numeric_values.append(float(x))
elif hasattr(x, 'item'): # numpy scalar
numeric_values.append(float(x.item()))
except (ValueError, TypeError):
continue
elif isinstance(value, dict):
# Recursively extract from nested dicts
nested_values = self._extract_numeric_from_dict(value)
numeric_values.extend(nested_values)
elif isinstance(value, torch.Tensor):
try:
numeric_values.append(float(value.item()))
except Exception:
continue
except Exception as e:
logger.debug(f"Error extracting numeric values from dict: {e}")
return numeric_values