This commit is contained in:
Dobromir Popov
2025-06-27 03:48:48 +03:00
parent 601e44de25
commit fab25ffe6f
4 changed files with 242 additions and 76 deletions

View File

@ -130,7 +130,7 @@ class DQNAgent:
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location=self.device)
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
# Load model states
if 'policy_net_state_dict' in checkpoint:
@ -1212,7 +1212,7 @@ class DQNAgent:
# Load agent state
try:
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
self.epsilon = agent_state['epsilon']
self.update_count = agent_state['update_count']
self.losses = agent_state['losses']