wip....
This commit is contained in:
@ -130,7 +130,7 @@ class DQNAgent:
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
|
||||
|
||||
# Load model states
|
||||
if 'policy_net_state_dict' in checkpoint:
|
||||
@ -1212,7 +1212,7 @@ class DQNAgent:
|
||||
|
||||
# Load agent state
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
|
Reference in New Issue
Block a user