#!/usr/bin/env python3 """ Improved Model Saver A comprehensive model saving utility that handles various model types and ensures reliable checkpointing with validation. """ import logging import torch import os import json from pathlib import Path from datetime import datetime from typing import Dict, Any, Optional, Union import shutil logger = logging.getLogger(__name__) class ImprovedModelSaver: """Enhanced model saving with validation and backup strategies""" def __init__(self, base_dir: str = "models/saved"): self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) def save_model_safely(self, model: Any, model_name: str, model_type: str = "unknown", metadata: Optional[Dict[str, Any]] = None) -> bool: """ Save a model with multiple fallback strategies Args: model: The model to save model_name: Name identifier for the model model_type: Type of model (dqn, cnn, rl, etc.) metadata: Additional metadata to save Returns: bool: True if successful, False otherwise """ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') model_dir = self.base_dir / model_name model_dir.mkdir(parents=True, exist_ok=True) # Create backup file names main_path = model_dir / f"{model_name}_latest.pt" backup_path = model_dir / f"{model_name}_{timestamp}.pt" try: # Strategy 1: Try to save using robust_save if available if hasattr(model, '__dict__') and hasattr(torch, 'save'): success = self._save_pytorch_model(model, main_path, backup_path) if success: self._save_metadata(model_dir, model_name, model_type, metadata) logger.info(f"Successfully saved {model_name} using PyTorch save") return True # Strategy 2: Try state_dict saving for PyTorch models if hasattr(model, 'state_dict'): success = self._save_state_dict(model, main_path, backup_path) if success: self._save_metadata(model_dir, model_name, model_type, metadata) logger.info(f"Successfully saved {model_name} using state_dict") return True # Strategy 3: Try component-based saving for complex models if hasattr(model, 'policy_net') or hasattr(model, 'target_net'): success = self._save_rl_agent_components(model, model_dir, model_name) if success: self._save_metadata(model_dir, model_name, model_type, metadata) logger.info(f"Successfully saved {model_name} using component-based saving") return True # Strategy 4: Fallback - try pickle success = self._save_with_pickle(model, main_path, backup_path) if success: self._save_metadata(model_dir, model_name, model_type, metadata) logger.info(f"Successfully saved {model_name} using pickle fallback") return True logger.error(f"All save strategies failed for {model_name}") return False except Exception as e: logger.error(f"Critical error saving {model_name}: {e}") return False def _save_pytorch_model(self, model, main_path: Path, backup_path: Path) -> bool: """Save using standard PyTorch torch.save""" try: # Create checkpoint data if hasattr(model, 'state_dict'): checkpoint = { 'model_state_dict': model.state_dict(), 'model_class': model.__class__.__name__, 'timestamp': datetime.now().isoformat() } # Add additional attributes for attr in ['epsilon', 'total_steps', 'current_reward', 'optimizer']: if hasattr(model, attr): try: value = getattr(model, attr) if attr == 'optimizer' and value is not None: checkpoint['optimizer_state_dict'] = value.state_dict() else: checkpoint[attr] = value except Exception: pass # Skip problematic attributes else: checkpoint = { 'model': model, 'timestamp': datetime.now().isoformat() } # Save to backup location first torch.save(checkpoint, backup_path) # Verify backup was saved correctly torch.load(backup_path, map_location='cpu') # Copy to main location shutil.copy2(backup_path, main_path) return True except Exception as e: logger.warning(f"PyTorch save failed: {e}") return False def _save_state_dict(self, model, main_path: Path, backup_path: Path) -> bool: """Save using state_dict only""" try: state_dict = model.state_dict() checkpoint = { 'state_dict': state_dict, 'model_class': model.__class__.__name__, 'timestamp': datetime.now().isoformat() } torch.save(checkpoint, backup_path) torch.load(backup_path, map_location='cpu') # Verify shutil.copy2(backup_path, main_path) return True except Exception as e: logger.warning(f"State dict save failed: {e}") return False def _save_rl_agent_components(self, model, model_dir: Path, model_name: str) -> bool: """Save RL agent components separately""" try: components_saved = 0 # Save policy network if hasattr(model, 'policy_net') and model.policy_net is not None: policy_path = model_dir / f"{model_name}_policy.pt" torch.save(model.policy_net.state_dict(), policy_path) components_saved += 1 # Save target network if hasattr(model, 'target_net') and model.target_net is not None: target_path = model_dir / f"{model_name}_target.pt" torch.save(model.target_net.state_dict(), target_path) components_saved += 1 # Save agent state agent_state = {} for attr in ['epsilon', 'total_steps', 'current_reward', 'memory']: if hasattr(model, attr): try: value = getattr(model, attr) if attr == 'memory' and hasattr(value, '__len__'): # Don't save large replay buffers agent_state[attr + '_size'] = len(value) else: agent_state[attr] = value except Exception: pass if agent_state: state_path = model_dir / f"{model_name}_agent_state.pt" torch.save(agent_state, state_path) components_saved += 1 return components_saved > 0 except Exception as e: logger.warning(f"Component-based save failed: {e}") return False def _save_with_pickle(self, model, main_path: Path, backup_path: Path) -> bool: """Fallback: save using pickle""" try: import pickle with open(backup_path.with_suffix('.pkl'), 'wb') as f: pickle.dump(model, f) # Verify with open(backup_path.with_suffix('.pkl'), 'rb') as f: pickle.load(f) shutil.copy2(backup_path.with_suffix('.pkl'), main_path.with_suffix('.pkl')) return True except Exception as e: logger.warning(f"Pickle save failed: {e}") return False def _save_metadata(self, model_dir: Path, model_name: str, model_type: str, metadata: Optional[Dict[str, Any]]): """Save model metadata""" try: meta_data = { 'model_name': model_name, 'model_type': model_type, 'saved_at': datetime.now().isoformat(), 'save_method': 'improved_model_saver' } if metadata: meta_data.update(metadata) meta_path = model_dir / f"{model_name}_metadata.json" with open(meta_path, 'w') as f: json.dump(meta_data, f, indent=2, default=str) except Exception as e: logger.warning(f"Failed to save metadata: {e}") def load_model_safely(self, model_name: str, model_class=None): """ Load a model with multiple strategies Args: model_name: Name of the model to load model_class: Class to instantiate if needed Returns: Loaded model or None """ model_dir = self.base_dir / model_name if not model_dir.exists(): logger.warning(f"Model directory not found: {model_dir}") return None # Try different loading strategies loaders = [ self._load_pytorch_checkpoint, self._load_state_dict_only, self._load_rl_components, self._load_pickle_fallback ] for loader in loaders: try: result = loader(model_dir, model_name, model_class) if result is not None: logger.info(f"Successfully loaded {model_name} using {loader.__name__}") return result except Exception as e: logger.debug(f"{loader.__name__} failed: {e}") continue logger.error(f"All load strategies failed for {model_name}") return None def _load_pytorch_checkpoint(self, model_dir: Path, model_name: str, model_class): """Load PyTorch checkpoint""" main_path = model_dir / f"{model_name}_latest.pt" if main_path.exists(): checkpoint = torch.load(main_path, map_location='cpu') if model_class and 'model_state_dict' in checkpoint: model = model_class() model.load_state_dict(checkpoint['model_state_dict']) # Restore other attributes for key, value in checkpoint.items(): if key not in ['model_state_dict', 'optimizer_state_dict', 'timestamp', 'model_class']: if hasattr(model, key): setattr(model, key, value) return model return checkpoint.get('model', checkpoint) return None def _load_state_dict_only(self, model_dir: Path, model_name: str, model_class): """Load state dict only""" main_path = model_dir / f"{model_name}_latest.pt" if main_path.exists() and model_class: checkpoint = torch.load(main_path, map_location='cpu') if 'state_dict' in checkpoint: model = model_class() model.load_state_dict(checkpoint['state_dict']) return model return None def _load_rl_components(self, model_dir: Path, model_name: str, model_class): """Load RL agent from components""" policy_path = model_dir / f"{model_name}_policy.pt" target_path = model_dir / f"{model_name}_target.pt" state_path = model_dir / f"{model_name}_agent_state.pt" if policy_path.exists() and model_class: model = model_class() # Load policy network if hasattr(model, 'policy_net'): model.policy_net.load_state_dict(torch.load(policy_path, map_location='cpu')) # Load target network if target_path.exists() and hasattr(model, 'target_net'): model.target_net.load_state_dict(torch.load(target_path, map_location='cpu')) # Load agent state if state_path.exists(): agent_state = torch.load(state_path, map_location='cpu') for key, value in agent_state.items(): if hasattr(model, key): setattr(model, key, value) return model return None def _load_pickle_fallback(self, model_dir: Path, model_name: str, model_class): """Load from pickle""" pickle_path = model_dir / f"{model_name}_latest.pkl" if pickle_path.exists(): import pickle with open(pickle_path, 'rb') as f: return pickle.load(f) return None # Global instance for easy access _improved_model_saver = None def get_improved_model_saver() -> ImprovedModelSaver: """Get or create the global improved model saver instance""" global _improved_model_saver if _improved_model_saver is None: _improved_model_saver = ImprovedModelSaver() return _improved_model_saver