227 lines
9.2 KiB
Python
227 lines
9.2 KiB
Python
#!/usr/bin/env python
|
|
import os
|
|
import logging
|
|
import torch
|
|
import argparse
|
|
import gc
|
|
import traceback
|
|
import shutil
|
|
from main import Agent, robust_save
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
handlers=[
|
|
logging.FileHandler("test_model_save_load.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def create_test_directory():
|
|
"""Create a test directory for saving models"""
|
|
test_dir = "test_models"
|
|
os.makedirs(test_dir, exist_ok=True)
|
|
return test_dir
|
|
|
|
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
|
|
"""Test a full cycle of saving and loading models"""
|
|
test_dir = create_test_directory()
|
|
|
|
# Create a test agent
|
|
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
|
|
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
|
|
|
# Define paths for testing
|
|
save_path = os.path.join(test_dir, "test_agent.pt")
|
|
|
|
# Test saving
|
|
logger.info(f"Testing save to {save_path}")
|
|
save_success = agent.save(save_path)
|
|
|
|
if save_success:
|
|
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
|
|
else:
|
|
logger.error("Save failed!")
|
|
return False
|
|
|
|
# Memory cleanup
|
|
del agent
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Test loading
|
|
logger.info(f"Testing load from {save_path}")
|
|
try:
|
|
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
|
new_agent.load(save_path)
|
|
logger.info("Load successful")
|
|
|
|
# Verify model architecture
|
|
logger.info(f"Verifying model architecture")
|
|
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
|
|
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
|
|
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
|
|
|
|
logger.info("Model architecture verified correctly")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error during load or verification: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
|
|
"""Test all the robust save methods"""
|
|
test_dir = create_test_directory()
|
|
|
|
# Create a test agent
|
|
logger.info(f"Creating test agent for robust save testing")
|
|
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
|
|
|
# Test each robust save method
|
|
methods = [
|
|
("regular", os.path.join(test_dir, "regular_save.pt")),
|
|
("backup", os.path.join(test_dir, "backup_save.pt")),
|
|
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
|
|
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
|
|
("jit", os.path.join(test_dir, "jit_save.pt"))
|
|
]
|
|
|
|
results = {}
|
|
|
|
for method_name, save_path in methods:
|
|
logger.info(f"Testing {method_name} save method to {save_path}")
|
|
|
|
try:
|
|
if method_name == "regular":
|
|
# Use regular save
|
|
success = agent.save(save_path)
|
|
elif method_name == "backup":
|
|
# Use backup method directly
|
|
backup_path = f"{save_path}.backup"
|
|
checkpoint = {
|
|
'policy_net': agent.policy_net.state_dict(),
|
|
'target_net': agent.target_net.state_dict(),
|
|
'optimizer': agent.optimizer.state_dict(),
|
|
'epsilon': agent.epsilon,
|
|
'state_size': agent.state_size,
|
|
'action_size': agent.action_size,
|
|
'hidden_size': agent.hidden_size
|
|
}
|
|
torch.save(checkpoint, backup_path)
|
|
shutil.copy(backup_path, save_path)
|
|
success = os.path.exists(save_path)
|
|
elif method_name == "pickle2":
|
|
# Use pickle protocol 2
|
|
checkpoint = {
|
|
'policy_net': agent.policy_net.state_dict(),
|
|
'target_net': agent.target_net.state_dict(),
|
|
'optimizer': agent.optimizer.state_dict(),
|
|
'epsilon': agent.epsilon,
|
|
'state_size': agent.state_size,
|
|
'action_size': agent.action_size,
|
|
'hidden_size': agent.hidden_size
|
|
}
|
|
torch.save(checkpoint, save_path, pickle_protocol=2)
|
|
success = os.path.exists(save_path)
|
|
elif method_name == "no_optimizer":
|
|
# Save without optimizer
|
|
checkpoint = {
|
|
'policy_net': agent.policy_net.state_dict(),
|
|
'target_net': agent.target_net.state_dict(),
|
|
'epsilon': agent.epsilon,
|
|
'state_size': agent.state_size,
|
|
'action_size': agent.action_size,
|
|
'hidden_size': agent.hidden_size
|
|
}
|
|
torch.save(checkpoint, save_path)
|
|
success = os.path.exists(save_path)
|
|
elif method_name == "jit":
|
|
# Use JIT save
|
|
try:
|
|
scripted_policy = torch.jit.script(agent.policy_net)
|
|
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
|
|
|
|
scripted_target = torch.jit.script(agent.target_net)
|
|
torch.jit.save(scripted_target, f"{save_path}.target.jit")
|
|
|
|
# Save parameters
|
|
with open(f"{save_path}.params.json", "w") as f:
|
|
import json
|
|
params = {
|
|
'epsilon': float(agent.epsilon),
|
|
'state_size': int(agent.state_size),
|
|
'action_size': int(agent.action_size),
|
|
'hidden_size': int(agent.hidden_size)
|
|
}
|
|
json.dump(params, f)
|
|
|
|
success = (os.path.exists(f"{save_path}.policy.jit") and
|
|
os.path.exists(f"{save_path}.target.jit") and
|
|
os.path.exists(f"{save_path}.params.json"))
|
|
except Exception as e:
|
|
logger.error(f"JIT save failed: {e}")
|
|
success = False
|
|
|
|
if success:
|
|
if method_name != "jit":
|
|
file_size = os.path.getsize(save_path)
|
|
logger.info(f"{method_name} save successful, size: {file_size} bytes")
|
|
else:
|
|
logger.info(f"{method_name} save successful")
|
|
results[method_name] = True
|
|
else:
|
|
logger.error(f"{method_name} save failed")
|
|
results[method_name] = False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during {method_name} save: {e}")
|
|
logger.error(traceback.format_exc())
|
|
results[method_name] = False
|
|
|
|
# Test loading each saved model
|
|
for method_name, save_path in methods:
|
|
if not results[method_name]:
|
|
logger.info(f"Skipping load test for {method_name} (save failed)")
|
|
continue
|
|
|
|
if method_name == "jit":
|
|
logger.info(f"Skipping load test for {method_name} (requires special loading)")
|
|
continue
|
|
|
|
logger.info(f"Testing load from {save_path}")
|
|
try:
|
|
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
|
new_agent.load(save_path)
|
|
logger.info(f"Load successful for {method_name} save")
|
|
except Exception as e:
|
|
logger.error(f"Error loading from {method_name} save: {e}")
|
|
logger.error(traceback.format_exc())
|
|
results[method_name] += " (load failed)"
|
|
|
|
# Return summary of results
|
|
return results
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Test model saving and loading')
|
|
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
|
|
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
|
|
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
|
|
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
|
|
args = parser.parse_args()
|
|
|
|
logger.info("Starting model save/load test")
|
|
|
|
if args.test_robust:
|
|
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
|
|
logger.info(f"Robust save method results: {results}")
|
|
else:
|
|
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
|
|
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
|
|
|
|
logger.info("Test completed")
|
|
|
|
if __name__ == "__main__":
|
|
main() |