#!/usr/bin/env python import torch import torch.nn as nn import os import logging import sys import platform # Fix for Windows asyncio issues with aiodns if platform.system() == 'Windows': try: import asyncio asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) print("Using Windows SelectorEventLoopPolicy to fix aiodns issue") except Exception as e: print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}") # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("test_save.log"), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) # Define a simple model for testing class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc1 = nn.Linear(10, 50) self.fc2 = nn.Linear(50, 20) self.fc3 = nn.Linear(20, 5) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x) # Create a simple Agent class for testing class TestAgent: def __init__(self): self.policy_net = SimpleModel() self.target_net = SimpleModel() self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001) self.epsilon = 0.1 def save(self, path): """Standard save method that might fail""" checkpoint = { 'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epsilon': self.epsilon } torch.save(checkpoint, path) logger.info(f"Model saved to {path}") # Robust save function with multiple fallback approaches def robust_save(model, path): """ Robust model saving with multiple fallback approaches Args: model: The Agent model to save path: Path to save the model Returns: bool: True if successful, False otherwise """ # Create directory if it doesn't exist os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # Backup path in case the main save fails backup_path = f"{path}.backup" # Attempt 1: Try with default settings in a separate file first try: logger.info(f"Saving model to {backup_path} (attempt 1)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, backup_path) logger.info(f"Successfully saved to {backup_path}") # If backup worked, copy to the actual path if os.path.exists(backup_path): import shutil shutil.copy(backup_path, path) logger.info(f"Copied backup to {path}") return True except Exception as e: logger.warning(f"First save attempt failed: {e}") # Attempt 2: Try with pickle protocol 2 (more compatible) try: logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path, pickle_protocol=2) logger.info(f"Successfully saved to {path} with pickle_protocol=2") return True except Exception as e: logger.warning(f"Second save attempt failed: {e}") # Attempt 3: Try without optimizer state (which can be large and cause issues) try: logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path) logger.info(f"Successfully saved to {path} without optimizer state") return True except Exception as e: logger.warning(f"Third save attempt failed: {e}") # Attempt 4: Try with torch.jit.save instead try: logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") # Save policy network using jit scripted_policy = torch.jit.script(model.policy_net) torch.jit.save(scripted_policy, f"{path}.policy.jit") # Save target network using jit scripted_target = torch.jit.script(model.target_net) torch.jit.save(scripted_target, f"{path}.target.jit") # Save epsilon value separately with open(f"{path}.epsilon.txt", "w") as f: f.write(str(model.epsilon)) logger.info(f"Successfully saved model components with jit.save") return True except Exception as e: logger.error(f"All save attempts failed: {e}") return False def main(): # Create a test directory save_dir = "test_models" os.makedirs(save_dir, exist_ok=True) # Create a test agent agent = TestAgent() # Test the regular save method (might fail) try: logger.info("Testing regular save method...") save_path = os.path.join(save_dir, "regular_save.pt") agent.save(save_path) logger.info("Regular save succeeded") except Exception as e: logger.error(f"Regular save failed: {e}") # Test our robust save method logger.info("Testing robust save method...") save_path = os.path.join(save_dir, "robust_save.pt") success = robust_save(agent, save_path) if success: logger.info("Robust save succeeded!") else: logger.error("Robust save failed!") # Check which files were created logger.info("Files created:") for file in os.listdir(save_dir): file_path = os.path.join(save_dir, file) file_size = os.path.getsize(file_path) logger.info(f" - {file} ({file_size} bytes)") if __name__ == "__main__": main()