""" Checkpoint Manager This module provides functionality for managing model checkpoints, including: - Saving checkpoints with metadata - Loading the best checkpoint based on performance metrics - Cleaning up old or underperforming checkpoints """ import os import json import glob import logging import shutil import torch from datetime import datetime from typing import Dict, List, Optional, Any, Tuple logger = logging.getLogger(__name__) # Global checkpoint manager instance _checkpoint_manager_instance = None def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager': """ Get the global checkpoint manager instance Args: checkpoint_dir: Directory to store checkpoints max_checkpoints: Maximum number of checkpoints to keep metric_name: Metric to use for ranking checkpoints Returns: CheckpointManager: Global checkpoint manager instance """ global _checkpoint_manager_instance if _checkpoint_manager_instance is None: _checkpoint_manager_instance = CheckpointManager( checkpoint_dir=checkpoint_dir, max_checkpoints=max_checkpoints, metric_name=metric_name ) return _checkpoint_manager_instance def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any: """ Save a checkpoint with metadata Args: model: The model to save model_name: Name of the model model_type: Type of the model ('cnn', 'rl', etc.) performance_metrics: Performance metrics training_metadata: Additional training metadata checkpoint_dir: Directory to store checkpoints Returns: Any: Checkpoint metadata """ try: # Create checkpoint directory os.makedirs(checkpoint_dir, exist_ok=True) # Create timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create checkpoint path model_dir = os.path.join(checkpoint_dir, model_name) os.makedirs(model_dir, exist_ok=True) checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}") # Save model if hasattr(model, 'save'): # Use model's save method if available model.save(checkpoint_path) else: # Otherwise, save state_dict torch_path = f"{checkpoint_path}.pt" torch.save({ 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None, 'model_name': model_name, 'model_type': model_type, 'timestamp': timestamp }, torch_path) # Create metadata checkpoint_metadata = { 'model_name': model_name, 'model_type': model_type, 'timestamp': timestamp, 'performance_metrics': performance_metrics, 'training_metadata': training_metadata or {}, 'checkpoint_id': f"{model_name}_{timestamp}" } # Add performance score for sorting primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward' checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0) checkpoint_metadata['created_at'] = timestamp # Save metadata with open(f"{checkpoint_path}_metadata.json", 'w') as f: json.dump(checkpoint_metadata, f, indent=2) # Get checkpoint manager and clean up old checkpoints checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) checkpoint_manager._cleanup_checkpoints(model_name) # Return metadata as an object class CheckpointMetadata: def __init__(self, metadata): for key, value in metadata.items(): setattr(self, key, value) return CheckpointMetadata(checkpoint_metadata) except Exception as e: logger.error(f"Error saving checkpoint: {e}") return None def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]: """ Load the best checkpoint based on performance metrics Args: model_name: Name of the model checkpoint_dir: Directory to store checkpoints Returns: Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found """ try: checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name) if not checkpoint_path: return None # Convert metadata to object class CheckpointMetadata: def __init__(self, metadata): for key, value in metadata.items(): setattr(self, key, value) # Add performance score if not present if not hasattr(self, 'performance_score'): metrics = getattr(self, 'metrics', {}) primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward' self.performance_score = metrics.get(primary_metric, 0.0) # Add created_at if not present if not hasattr(self, 'created_at'): self.created_at = getattr(self, 'timestamp', 'unknown') return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata) except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return None class CheckpointManager: """ Manages model checkpoints with performance-based optimization This class: 1. Saves checkpoints with metadata 2. Loads the best checkpoint based on performance metrics 3. Cleans up old or underperforming checkpoints """ def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"): """ Initialize the checkpoint manager Args: checkpoint_dir: Directory to store checkpoints max_checkpoints: Maximum number of checkpoints to keep metric_name: Metric to use for ranking checkpoints """ self.checkpoint_dir = checkpoint_dir self.max_checkpoints = max_checkpoints self.metric_name = metric_name # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}") def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str: """ Save a checkpoint with metadata Args: model_name: Name of the model model_path: Path to the model file metrics: Performance metrics metadata: Additional metadata Returns: str: Path to the saved checkpoint """ try: # Create timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create checkpoint directory checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) os.makedirs(checkpoint_dir, exist_ok=True) # Create checkpoint path checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}") # Copy model file to checkpoint path shutil.copy2(model_path, f"{checkpoint_path}.pt") # Create metadata checkpoint_metadata = { 'model_name': model_name, 'timestamp': timestamp, 'metrics': metrics, 'metadata': metadata or {} } # Save metadata with open(f"{checkpoint_path}_metadata.json", 'w') as f: json.dump(checkpoint_metadata, f, indent=2) logger.info(f"Saved checkpoint to {checkpoint_path}") # Clean up old checkpoints self._cleanup_checkpoints(model_name) return checkpoint_path except Exception as e: logger.error(f"Error saving checkpoint: {e}") return "" def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]: """ Load the best checkpoint based on performance metrics Args: model_name: Name of the model Returns: Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files: logger.info(f"No checkpoints found for {model_name}") return "", {} # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] # Check if model file exists if not os.path.exists(f"{checkpoint_path}.pt"): logger.warning(f"Model file not found for checkpoint {checkpoint_path}") continue checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") if not checkpoints: logger.info(f"No valid checkpoints found for {model_name}") return "", {} # Sort by metric (highest first) checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) # Return best checkpoint best_checkpoint_path = checkpoints[0][0] best_checkpoint_metadata = checkpoints[0][1] logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}") return best_checkpoint_path, best_checkpoint_metadata except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return "", {} def _cleanup_checkpoints(self, model_name: str) -> int: """ Clean up old or underperforming checkpoints Args: model_name: Name of the model Returns: int: Number of checkpoints deleted """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files or len(metadata_files) <= self.max_checkpoints: return 0 # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") # Sort by metric (highest first) checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) # Keep only the best checkpoints checkpoints_to_delete = checkpoints[self.max_checkpoints:] # Delete checkpoints deleted_count = 0 for checkpoint_path, _ in checkpoints_to_delete: try: # Delete model file if os.path.exists(f"{checkpoint_path}.pt"): os.remove(f"{checkpoint_path}.pt") # Delete metadata file if os.path.exists(f"{checkpoint_path}_metadata.json"): os.remove(f"{checkpoint_path}_metadata.json") deleted_count += 1 except Exception as e: logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}") logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}") return deleted_count except Exception as e: logger.error(f"Error cleaning up checkpoints: {e}") return 0 def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]: """ Get all checkpoints for a model Args: model_name: Name of the model Returns: List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files: return [] # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] # Check if model file exists if not os.path.exists(f"{checkpoint_path}.pt"): logger.warning(f"Model file not found for checkpoint {checkpoint_path}") continue checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") # Sort by timestamp (newest first) checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True) return checkpoints except Exception as e: logger.error(f"Error getting all checkpoints: {e}") return []