182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
#!/usr/bin/env python
|
|
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
import logging
|
|
import sys
|
|
import platform
|
|
|
|
# Fix for Windows asyncio issues with aiodns
|
|
if platform.system() == 'Windows':
|
|
try:
|
|
import asyncio
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
|
|
except Exception as e:
|
|
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("test_save.log"),
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Define a simple model for testing
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self):
|
|
super(SimpleModel, self).__init__()
|
|
self.fc1 = nn.Linear(10, 50)
|
|
self.fc2 = nn.Linear(50, 20)
|
|
self.fc3 = nn.Linear(20, 5)
|
|
|
|
def forward(self, x):
|
|
x = torch.relu(self.fc1(x))
|
|
x = torch.relu(self.fc2(x))
|
|
return self.fc3(x)
|
|
|
|
# Create a simple Agent class for testing
|
|
class TestAgent:
|
|
def __init__(self):
|
|
self.policy_net = SimpleModel()
|
|
self.target_net = SimpleModel()
|
|
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
|
|
self.epsilon = 0.1
|
|
|
|
def save(self, path):
|
|
"""Standard save method that might fail"""
|
|
checkpoint = {
|
|
'policy_net': self.policy_net.state_dict(),
|
|
'target_net': self.target_net.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'epsilon': self.epsilon
|
|
}
|
|
torch.save(checkpoint, path)
|
|
logger.info(f"Model saved to {path}")
|
|
|
|
# Robust save function with multiple fallback approaches
|
|
def robust_save(model, path):
|
|
"""
|
|
Robust model saving with multiple fallback approaches
|
|
|
|
Args:
|
|
model: The Agent model to save
|
|
path: Path to save the model
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
|
|
# Backup path in case the main save fails
|
|
backup_path = f"{path}.backup"
|
|
|
|
# Attempt 1: Try with default settings in a separate file first
|
|
try:
|
|
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'optimizer': model.optimizer.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, backup_path)
|
|
logger.info(f"Successfully saved to {backup_path}")
|
|
|
|
# If backup worked, copy to the actual path
|
|
if os.path.exists(backup_path):
|
|
import shutil
|
|
shutil.copy(backup_path, path)
|
|
logger.info(f"Copied backup to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"First save attempt failed: {e}")
|
|
|
|
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'optimizer': model.optimizer.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, path, pickle_protocol=2)
|
|
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Second save attempt failed: {e}")
|
|
|
|
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'epsilon': model.epsilon
|
|
}
|
|
torch.save(checkpoint, path)
|
|
logger.info(f"Successfully saved to {path} without optimizer state")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Third save attempt failed: {e}")
|
|
|
|
# Attempt 4: Try with torch.jit.save instead
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
|
# Save policy network using jit
|
|
scripted_policy = torch.jit.script(model.policy_net)
|
|
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
|
# Save target network using jit
|
|
scripted_target = torch.jit.script(model.target_net)
|
|
torch.jit.save(scripted_target, f"{path}.target.jit")
|
|
# Save epsilon value separately
|
|
with open(f"{path}.epsilon.txt", "w") as f:
|
|
f.write(str(model.epsilon))
|
|
logger.info(f"Successfully saved model components with jit.save")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"All save attempts failed: {e}")
|
|
return False
|
|
|
|
def main():
|
|
# Create a test directory
|
|
save_dir = "test_models"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Create a test agent
|
|
agent = TestAgent()
|
|
|
|
# Test the regular save method (might fail)
|
|
try:
|
|
logger.info("Testing regular save method...")
|
|
save_path = os.path.join(save_dir, "regular_save.pt")
|
|
agent.save(save_path)
|
|
logger.info("Regular save succeeded")
|
|
except Exception as e:
|
|
logger.error(f"Regular save failed: {e}")
|
|
|
|
# Test our robust save method
|
|
logger.info("Testing robust save method...")
|
|
save_path = os.path.join(save_dir, "robust_save.pt")
|
|
success = robust_save(agent, save_path)
|
|
|
|
if success:
|
|
logger.info("Robust save succeeded!")
|
|
else:
|
|
logger.error("Robust save failed!")
|
|
|
|
# Check which files were created
|
|
logger.info("Files created:")
|
|
for file in os.listdir(save_dir):
|
|
file_path = os.path.join(save_dir, file)
|
|
file_size = os.path.getsize(file_path)
|
|
logger.info(f" - {file} ({file_size} bytes)")
|
|
|
|
if __name__ == "__main__":
|
|
main() |