gogo2/crypto/gogo2/test_model_save_load.py
2025-03-18 02:04:58 +02:00

227 lines
9.2 KiB
Python

#!/usr/bin/env python
import os
import logging
import torch
import argparse
import gc
import traceback
import shutil
from main import Agent, robust_save
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("test_model_save_load.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def create_test_directory():
"""Create a test directory for saving models"""
test_dir = "test_models"
os.makedirs(test_dir, exist_ok=True)
return test_dir
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
"""Test a full cycle of saving and loading models"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Define paths for testing
save_path = os.path.join(test_dir, "test_agent.pt")
# Test saving
logger.info(f"Testing save to {save_path}")
save_success = agent.save(save_path)
if save_success:
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
else:
logger.error("Save failed!")
return False
# Memory cleanup
del agent
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Test loading
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info("Load successful")
# Verify model architecture
logger.info(f"Verifying model architecture")
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
logger.info("Model architecture verified correctly")
return True
except Exception as e:
logger.error(f"Error during load or verification: {e}")
logger.error(traceback.format_exc())
return False
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
"""Test all the robust save methods"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent for robust save testing")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Test each robust save method
methods = [
("regular", os.path.join(test_dir, "regular_save.pt")),
("backup", os.path.join(test_dir, "backup_save.pt")),
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
("jit", os.path.join(test_dir, "jit_save.pt"))
]
results = {}
for method_name, save_path in methods:
logger.info(f"Testing {method_name} save method to {save_path}")
try:
if method_name == "regular":
# Use regular save
success = agent.save(save_path)
elif method_name == "backup":
# Use backup method directly
backup_path = f"{save_path}.backup"
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, backup_path)
shutil.copy(backup_path, save_path)
success = os.path.exists(save_path)
elif method_name == "pickle2":
# Use pickle protocol 2
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path, pickle_protocol=2)
success = os.path.exists(save_path)
elif method_name == "no_optimizer":
# Save without optimizer
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path)
success = os.path.exists(save_path)
elif method_name == "jit":
# Use JIT save
try:
scripted_policy = torch.jit.script(agent.policy_net)
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
scripted_target = torch.jit.script(agent.target_net)
torch.jit.save(scripted_target, f"{save_path}.target.jit")
# Save parameters
with open(f"{save_path}.params.json", "w") as f:
import json
params = {
'epsilon': float(agent.epsilon),
'state_size': int(agent.state_size),
'action_size': int(agent.action_size),
'hidden_size': int(agent.hidden_size)
}
json.dump(params, f)
success = (os.path.exists(f"{save_path}.policy.jit") and
os.path.exists(f"{save_path}.target.jit") and
os.path.exists(f"{save_path}.params.json"))
except Exception as e:
logger.error(f"JIT save failed: {e}")
success = False
if success:
if method_name != "jit":
file_size = os.path.getsize(save_path)
logger.info(f"{method_name} save successful, size: {file_size} bytes")
else:
logger.info(f"{method_name} save successful")
results[method_name] = True
else:
logger.error(f"{method_name} save failed")
results[method_name] = False
except Exception as e:
logger.error(f"Error during {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] = False
# Test loading each saved model
for method_name, save_path in methods:
if not results[method_name]:
logger.info(f"Skipping load test for {method_name} (save failed)")
continue
if method_name == "jit":
logger.info(f"Skipping load test for {method_name} (requires special loading)")
continue
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info(f"Load successful for {method_name} save")
except Exception as e:
logger.error(f"Error loading from {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] += " (load failed)"
# Return summary of results
return results
def main():
parser = argparse.ArgumentParser(description='Test model saving and loading')
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
args = parser.parse_args()
logger.info("Starting model save/load test")
if args.test_robust:
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Robust save method results: {results}")
else:
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
logger.info("Test completed")
if __name__ == "__main__":
main()