gogo2/test_save.py
Dobromir Popov 3871afd4b8 init
2025-03-18 09:23:09 +02:00

182 lines
6.1 KiB
Python

#!/usr/bin/env python
import torch
import torch.nn as nn
import os
import logging
import sys
import platform
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("test_save.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Define a simple model for testing
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# Create a simple Agent class for testing
class TestAgent:
def __init__(self):
self.policy_net = SimpleModel()
self.target_net = SimpleModel()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
self.epsilon = 0.1
def save(self, path):
"""Standard save method that might fail"""
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Model saved to {path}")
# Robust save function with multiple fallback approaches
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def main():
# Create a test directory
save_dir = "test_models"
os.makedirs(save_dir, exist_ok=True)
# Create a test agent
agent = TestAgent()
# Test the regular save method (might fail)
try:
logger.info("Testing regular save method...")
save_path = os.path.join(save_dir, "regular_save.pt")
agent.save(save_path)
logger.info("Regular save succeeded")
except Exception as e:
logger.error(f"Regular save failed: {e}")
# Test our robust save method
logger.info("Testing robust save method...")
save_path = os.path.join(save_dir, "robust_save.pt")
success = robust_save(agent, save_path)
if success:
logger.info("Robust save succeeded!")
else:
logger.error("Robust save failed!")
# Check which files were created
logger.info("Files created:")
for file in os.listdir(save_dir):
file_path = os.path.join(save_dir, file)
file_size = os.path.getsize(file_path)
logger.info(f" - {file} ({file_size} bytes)")
if __name__ == "__main__":
main()