#!/usr/bin/env python import os import logging import torch import argparse import gc import traceback import shutil from main import Agent, robust_save # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler("test_model_save_load.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def create_test_directory(): """Create a test directory for saving models""" test_dir = "test_models" os.makedirs(test_dir, exist_ok=True) return test_dir def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384): """Test a full cycle of saving and loading models""" test_dir = create_test_directory() # Create a test agent logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}") agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) # Define paths for testing save_path = os.path.join(test_dir, "test_agent.pt") # Test saving logger.info(f"Testing save to {save_path}") save_success = agent.save(save_path) if save_success: logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes") else: logger.error("Save failed!") return False # Memory cleanup del agent if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Test loading logger.info(f"Testing load from {save_path}") try: new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) new_agent.load(save_path) logger.info("Load successful") # Verify model architecture logger.info(f"Verifying model architecture") assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}" assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}" assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}" logger.info("Model architecture verified correctly") return True except Exception as e: logger.error(f"Error during load or verification: {e}") logger.error(traceback.format_exc()) return False def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384): """Test all the robust save methods""" test_dir = create_test_directory() # Create a test agent logger.info(f"Creating test agent for robust save testing") agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) # Test each robust save method methods = [ ("regular", os.path.join(test_dir, "regular_save.pt")), ("backup", os.path.join(test_dir, "backup_save.pt")), ("pickle2", os.path.join(test_dir, "pickle2_save.pt")), ("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")), ("jit", os.path.join(test_dir, "jit_save.pt")) ] results = {} for method_name, save_path in methods: logger.info(f"Testing {method_name} save method to {save_path}") try: if method_name == "regular": # Use regular save success = agent.save(save_path) elif method_name == "backup": # Use backup method directly backup_path = f"{save_path}.backup" checkpoint = { 'policy_net': agent.policy_net.state_dict(), 'target_net': agent.target_net.state_dict(), 'optimizer': agent.optimizer.state_dict(), 'epsilon': agent.epsilon, 'state_size': agent.state_size, 'action_size': agent.action_size, 'hidden_size': agent.hidden_size } torch.save(checkpoint, backup_path) shutil.copy(backup_path, save_path) success = os.path.exists(save_path) elif method_name == "pickle2": # Use pickle protocol 2 checkpoint = { 'policy_net': agent.policy_net.state_dict(), 'target_net': agent.target_net.state_dict(), 'optimizer': agent.optimizer.state_dict(), 'epsilon': agent.epsilon, 'state_size': agent.state_size, 'action_size': agent.action_size, 'hidden_size': agent.hidden_size } torch.save(checkpoint, save_path, pickle_protocol=2) success = os.path.exists(save_path) elif method_name == "no_optimizer": # Save without optimizer checkpoint = { 'policy_net': agent.policy_net.state_dict(), 'target_net': agent.target_net.state_dict(), 'epsilon': agent.epsilon, 'state_size': agent.state_size, 'action_size': agent.action_size, 'hidden_size': agent.hidden_size } torch.save(checkpoint, save_path) success = os.path.exists(save_path) elif method_name == "jit": # Use JIT save try: scripted_policy = torch.jit.script(agent.policy_net) torch.jit.save(scripted_policy, f"{save_path}.policy.jit") scripted_target = torch.jit.script(agent.target_net) torch.jit.save(scripted_target, f"{save_path}.target.jit") # Save parameters with open(f"{save_path}.params.json", "w") as f: import json params = { 'epsilon': float(agent.epsilon), 'state_size': int(agent.state_size), 'action_size': int(agent.action_size), 'hidden_size': int(agent.hidden_size) } json.dump(params, f) success = (os.path.exists(f"{save_path}.policy.jit") and os.path.exists(f"{save_path}.target.jit") and os.path.exists(f"{save_path}.params.json")) except Exception as e: logger.error(f"JIT save failed: {e}") success = False if success: if method_name != "jit": file_size = os.path.getsize(save_path) logger.info(f"{method_name} save successful, size: {file_size} bytes") else: logger.info(f"{method_name} save successful") results[method_name] = True else: logger.error(f"{method_name} save failed") results[method_name] = False except Exception as e: logger.error(f"Error during {method_name} save: {e}") logger.error(traceback.format_exc()) results[method_name] = False # Test loading each saved model for method_name, save_path in methods: if not results[method_name]: logger.info(f"Skipping load test for {method_name} (save failed)") continue if method_name == "jit": logger.info(f"Skipping load test for {method_name} (requires special loading)") continue logger.info(f"Testing load from {save_path}") try: new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) new_agent.load(save_path) logger.info(f"Load successful for {method_name} save") except Exception as e: logger.error(f"Error loading from {method_name} save: {e}") logger.error(traceback.format_exc()) results[method_name] += " (load failed)" # Return summary of results return results def main(): parser = argparse.ArgumentParser(description='Test model saving and loading') parser.add_argument('--state_size', type=int, default=64, help='State size for test model') parser.add_argument('--action_size', type=int, default=4, help='Action size for test model') parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model') parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods') args = parser.parse_args() logger.info("Starting model save/load test") if args.test_robust: results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size) logger.info(f"Robust save method results: {results}") else: success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size) logger.info(f"Save/load cycle {'successful' if success else 'failed'}") logger.info("Test completed") if __name__ == "__main__": main()