#!/usr/bin/env python """ Model utilities for robust saving and loading of PyTorch models """ import os import logging import torch import shutil import gc import json from typing import Any, Dict, Optional, Union logger = logging.getLogger(__name__) def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool: """ Robust model saving with multiple fallback approaches Args: model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes) path: Path to save the model include_optimizer: Whether to include optimizer state in the save Returns: bool: True if successful, False otherwise """ # Create directory if it doesn't exist os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # Backup path in case the main save fails backup_path = f"{path}.backup" # Clean up GPU memory before saving if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Prepare checkpoint data checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'epsilon': getattr(model, 'epsilon', 0.0), 'state_size': getattr(model, 'state_size', None), 'action_size': getattr(model, 'action_size', None), 'hidden_size': getattr(model, 'hidden_size', None), } # Add optimizer state if requested and available if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None: checkpoint['optimizer'] = model.optimizer.state_dict() # Attempt 1: Try with default settings in a separate file first try: logger.info(f"Saving model to {backup_path} (attempt 1)") torch.save(checkpoint, backup_path) logger.info(f"Successfully saved to {backup_path}") # If backup worked, copy to the actual path if os.path.exists(backup_path): shutil.copy(backup_path, path) logger.info(f"Copied backup to {path}") return True except Exception as e: logger.warning(f"First save attempt failed: {e}") # Attempt 2: Try with pickle protocol 2 (more compatible) try: logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") torch.save(checkpoint, path, pickle_protocol=2) logger.info(f"Successfully saved to {path} with pickle_protocol=2") return True except Exception as e: logger.warning(f"Second save attempt failed: {e}") # Attempt 3: Try without optimizer state (which can be large and cause issues) try: logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'} torch.save(checkpoint_no_opt, path) logger.info(f"Successfully saved to {path} without optimizer state") return True except Exception as e: logger.warning(f"Third save attempt failed: {e}") # Attempt 4: Try with torch.jit.save instead try: logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") # Save policy network using jit scripted_policy = torch.jit.script(model.policy_net) torch.jit.save(scripted_policy, f"{path}.policy.jit") # Save target network using jit scripted_target = torch.jit.script(model.target_net) torch.jit.save(scripted_target, f"{path}.target.jit") # Save parameters separately as JSON params = { 'epsilon': float(getattr(model, 'epsilon', 0.0)), 'state_size': int(getattr(model, 'state_size', 0)), 'action_size': int(getattr(model, 'action_size', 0)), 'hidden_size': int(getattr(model, 'hidden_size', 0)) } with open(f"{path}.params.json", "w") as f: json.dump(params, f) logger.info(f"Successfully saved model components with jit.save") return True except Exception as e: logger.error(f"All save attempts failed: {e}") return False def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool: """ Robust model loading with fallback approaches Args: model: The model object to load into path: Path to load the model from device: Device to load the model on Returns: bool: True if successful, False otherwise """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Try regular PyTorch load first try: logger.info(f"Loading model from {path}") if os.path.exists(path): checkpoint = torch.load(path, map_location=device) # Load network states if 'policy_net' in checkpoint: model.policy_net.load_state_dict(checkpoint['policy_net']) if 'target_net' in checkpoint: model.target_net.load_state_dict(checkpoint['target_net']) # Load other attributes if 'epsilon' in checkpoint: model.epsilon = checkpoint['epsilon'] if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None: try: model.optimizer.load_state_dict(checkpoint['optimizer']) except Exception as e: logger.warning(f"Failed to load optimizer state: {e}") logger.info("Successfully loaded model") return True except Exception as e: logger.warning(f"Regular load failed: {e}") # Try loading JIT saved components try: policy_path = f"{path}.policy.jit" target_path = f"{path}.target.jit" params_path = f"{path}.params.json" if all(os.path.exists(p) for p in [policy_path, target_path, params_path]): logger.info(f"Loading JIT model components") # Load JIT models (this is more complex and may need model reconstruction) # For now, just log that we found JIT files logger.info("Found JIT model files, but loading them requires special handling") with open(params_path, 'r') as f: params = json.load(f) logger.info(f"Model parameters: {params}") # Note: Actually loading JIT models would require recreating the model architecture # This is a placeholder for future implementation return False except Exception as e: logger.error(f"JIT load failed: {e}") logger.error(f"All load attempts failed for {path}") return False def get_model_info(path: str) -> Dict[str, Any]: """ Get information about a saved model Args: path: Path to the model file Returns: dict: Model information """ info = { 'exists': False, 'size_bytes': 0, 'has_optimizer': False, 'parameters': {} } try: if os.path.exists(path): info['exists'] = True info['size_bytes'] = os.path.getsize(path) # Try to load and inspect checkpoint = torch.load(path, map_location='cpu') info['has_optimizer'] = 'optimizer' in checkpoint # Extract parameter info for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']: if key in checkpoint: info['parameters'][key] = checkpoint[key] except Exception as e: logger.warning(f"Failed to get model info for {path}: {e}") return info def verify_save_load_cycle(model: Any, test_path: str) -> bool: """ Test that a model can be saved and loaded correctly Args: model: Model to test test_path: Path for test file Returns: bool: True if save/load cycle successful """ try: # Save the model if not robust_save(model, test_path): return False # Create a new model instance (this would need model creation logic) # For now, just verify the file exists and has content if os.path.exists(test_path) and os.path.getsize(test_path) > 0: logger.info("Save/load cycle verification successful") # Clean up test file os.remove(test_path) return True else: return False except Exception as e: logger.error(f"Save/load cycle verification failed: {e}") return False