""" Checkpoint Manager This module provides functionality for managing model checkpoints, including: - Saving checkpoints with metadata - Loading the best checkpoint based on performance metrics - Cleaning up old or underperforming checkpoints - Database-backed metadata storage for efficient access """ import os import json import glob import logging import shutil import torch import hashlib from datetime import datetime from typing import Dict, List, Optional, Any, Tuple from .database_manager import get_database_manager, CheckpointMetadata from .text_logger import get_text_logger logger = logging.getLogger(__name__) # Global checkpoint manager instance _checkpoint_manager_instance = None def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager': """ Get the global checkpoint manager instance Args: checkpoint_dir: Directory to store checkpoints max_checkpoints: Maximum number of checkpoints to keep metric_name: Metric to use for ranking checkpoints Returns: CheckpointManager: Global checkpoint manager instance """ global _checkpoint_manager_instance if _checkpoint_manager_instance is None: _checkpoint_manager_instance = CheckpointManager( checkpoint_dir=checkpoint_dir, max_checkpoints=max_checkpoints, metric_name=metric_name ) return _checkpoint_manager_instance def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any: """ Save a checkpoint with metadata to both filesystem and database Args: model: The model to save model_name: Name of the model model_type: Type of the model ('cnn', 'rl', etc.) performance_metrics: Performance metrics training_metadata: Additional training metadata checkpoint_dir: Directory to store checkpoints Returns: Any: Checkpoint metadata """ try: # Create checkpoint directory os.makedirs(checkpoint_dir, exist_ok=True) # Create timestamp timestamp = datetime.now() timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S") # Create checkpoint path model_dir = os.path.join(checkpoint_dir, model_name) os.makedirs(model_dir, exist_ok=True) checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}") checkpoint_id = f"{model_name}_{timestamp_str}" # Save model torch_path = f"{checkpoint_path}.pt" if hasattr(model, 'save'): # Use model's save method if available model.save(checkpoint_path) else: # Otherwise, save state_dict torch.save({ 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None, 'model_name': model_name, 'model_type': model_type, 'timestamp': timestamp_str, 'checkpoint_id': checkpoint_id }, torch_path) # Calculate file size file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0 # Save metadata to database db_manager = get_database_manager() checkpoint_metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=model_name, model_type=model_type, timestamp=timestamp, performance_metrics=performance_metrics, training_metadata=training_metadata or {}, file_path=torch_path, file_size_mb=file_size_mb, is_active=True # New checkpoint is active by default ) # Save to database if db_manager.save_checkpoint_metadata(checkpoint_metadata): # Log checkpoint save event to text file text_logger = get_text_logger() text_logger.log_checkpoint_event( model_name=model_name, event_type="SAVED", checkpoint_id=checkpoint_id, details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB" ) logger.info(f"Checkpoint saved: {checkpoint_id}") else: logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}") # Also save legacy JSON metadata for backward compatibility legacy_metadata = { 'model_name': model_name, 'model_type': model_type, 'timestamp': timestamp_str, 'performance_metrics': performance_metrics, 'training_metadata': training_metadata or {}, 'checkpoint_id': checkpoint_id, 'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)), 'created_at': timestamp_str } with open(f"{checkpoint_path}_metadata.json", 'w') as f: json.dump(legacy_metadata, f, indent=2) # Get checkpoint manager and clean up old checkpoints checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) checkpoint_manager._cleanup_checkpoints(model_name) # Return metadata as an object for backward compatibility class CheckpointMetadataObj: def __init__(self, metadata): for key, value in metadata.items(): setattr(self, key, value) # Add database fields self.checkpoint_id = checkpoint_id self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0)) return CheckpointMetadataObj(legacy_metadata) except Exception as e: logger.error(f"Error saving checkpoint: {e}") return None def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]: """ Load the best checkpoint based on performance metrics using database metadata Args: model_name: Name of the model checkpoint_dir: Directory to store checkpoints Returns: Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found """ try: # First try to get from database (fast metadata access) db_manager = get_database_manager() checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy") if not checkpoint_metadata: # Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs) pass # Silent fallback checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name) if not checkpoint_path: return None # Convert legacy metadata to object class CheckpointMetadataObj: def __init__(self, metadata): for key, value in metadata.items(): setattr(self, key, value) # Add performance score if not present if not hasattr(self, 'performance_score'): metrics = getattr(self, 'metrics', {}) primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward' self.performance_score = metrics.get(primary_metric, 0.0) # Add created_at if not present if not hasattr(self, 'created_at'): self.created_at = getattr(self, 'timestamp', 'unknown') # Add loss for compatibility self.loss = metrics.get('loss', self.performance_score) self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown") return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata) # Check if checkpoint file exists if not os.path.exists(checkpoint_metadata.file_path): logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}") return None # Log checkpoint load event to text file text_logger = get_text_logger() text_logger.log_checkpoint_event( model_name=model_name, event_type="LOADED", checkpoint_id=checkpoint_metadata.checkpoint_id, details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}" ) # Convert database metadata to object for backward compatibility class CheckpointMetadataObj: def __init__(self, db_metadata: CheckpointMetadata): self.checkpoint_id = db_metadata.checkpoint_id self.model_name = db_metadata.model_name self.model_type = db_metadata.model_type self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S") self.performance_metrics = db_metadata.performance_metrics self.training_metadata = db_metadata.training_metadata self.file_path = db_metadata.file_path self.file_size_mb = db_metadata.file_size_mb self.is_active = db_metadata.is_active # Backward compatibility fields self.metrics = db_metadata.performance_metrics self.metadata = db_metadata.training_metadata self.created_at = self.timestamp self.performance_score = db_metadata.performance_metrics.get('accuracy', db_metadata.performance_metrics.get('reward', 0.0)) self.loss = db_metadata.performance_metrics.get('loss', self.performance_score) return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata) except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return None class CheckpointManager: """ Manages model checkpoints with performance-based optimization This class: 1. Saves checkpoints with metadata 2. Loads the best checkpoint based on performance metrics 3. Cleans up old or underperforming checkpoints """ def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"): """ Initialize the checkpoint manager Args: checkpoint_dir: Directory to store checkpoints max_checkpoints: Maximum number of checkpoints to keep metric_name: Metric to use for ranking checkpoints """ self.checkpoint_dir = checkpoint_dir self.max_checkpoints = max_checkpoints self.metric_name = metric_name # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}") def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str: """ Save a checkpoint with metadata Args: model_name: Name of the model model_path: Path to the model file metrics: Performance metrics metadata: Additional metadata Returns: str: Path to the saved checkpoint """ try: # Create timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create checkpoint directory checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) os.makedirs(checkpoint_dir, exist_ok=True) # Create checkpoint path checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}") # Copy model file to checkpoint path shutil.copy2(model_path, f"{checkpoint_path}.pt") # Create metadata checkpoint_metadata = { 'model_name': model_name, 'timestamp': timestamp, 'metrics': metrics, 'metadata': metadata or {} } # Save metadata with open(f"{checkpoint_path}_metadata.json", 'w') as f: json.dump(checkpoint_metadata, f, indent=2) logger.info(f"Saved checkpoint to {checkpoint_path}") # Clean up old checkpoints self._cleanup_checkpoints(model_name) return checkpoint_path except Exception as e: logger.error(f"Error saving checkpoint: {e}") return "" def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]: """ Load the best checkpoint based on performance metrics Args: model_name: Name of the model Returns: Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files: # No more scattered "No checkpoints found" logs - handled by database system return "", {} # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] # Check if model file exists if not os.path.exists(f"{checkpoint_path}.pt"): logger.warning(f"Model file not found for checkpoint {checkpoint_path}") continue checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") if not checkpoints: # No more scattered logs - handled by database system return "", {} # Sort by metric (highest first) checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) # Return best checkpoint best_checkpoint_path = checkpoints[0][0] best_checkpoint_metadata = checkpoints[0][1] logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}") return best_checkpoint_path, best_checkpoint_metadata except Exception as e: logger.error(f"Error loading best checkpoint: {e}") return "", {} def _cleanup_checkpoints(self, model_name: str) -> int: """ Clean up old or underperforming checkpoints Args: model_name: Name of the model Returns: int: Number of checkpoints deleted """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files or len(metadata_files) <= self.max_checkpoints: return 0 # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") # Sort by metric (highest first) checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) # Keep only the best checkpoints checkpoints_to_delete = checkpoints[self.max_checkpoints:] # Delete checkpoints deleted_count = 0 for checkpoint_path, _ in checkpoints_to_delete: try: # Delete model file if os.path.exists(f"{checkpoint_path}.pt"): os.remove(f"{checkpoint_path}.pt") # Delete metadata file if os.path.exists(f"{checkpoint_path}_metadata.json"): os.remove(f"{checkpoint_path}_metadata.json") deleted_count += 1 except Exception as e: logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}") logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}") return deleted_count except Exception as e: logger.error(f"Error cleaning up checkpoints: {e}") return 0 def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]: """ Get all checkpoints for a model Args: model_name: Name of the model Returns: List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata """ try: # Find all checkpoint metadata files checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) if not metadata_files: return [] # Load metadata for each checkpoint checkpoints = [] for metadata_file in metadata_files: try: with open(metadata_file, 'r') as f: metadata = json.load(f) # Get checkpoint path (remove _metadata.json) checkpoint_path = metadata_file[:-14] # Check if model file exists if not os.path.exists(f"{checkpoint_path}.pt"): logger.warning(f"Model file not found for checkpoint {checkpoint_path}") continue checkpoints.append((checkpoint_path, metadata)) except Exception as e: logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") # Sort by timestamp (newest first) checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True) return checkpoints except Exception as e: logger.error(f"Error getting all checkpoints: {e}") return [] def get_checkpoint_stats(self) -> Dict[str, Any]: """ Get statistics about all checkpoints Returns: Dict[str, Any]: Statistics about checkpoints """ try: stats = { 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {} } # Iterate through all model directories for model_dir in os.listdir(self.checkpoint_dir): model_path = os.path.join(self.checkpoint_dir, model_dir) if not os.path.isdir(model_path): continue # Count checkpoints for this model checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt")) model_checkpoints = len(checkpoint_files) # Calculate total size for this model model_size_mb = 0.0 for checkpoint_file in checkpoint_files: try: size_bytes = os.path.getsize(checkpoint_file) model_size_mb += size_bytes / (1024 * 1024) # Convert to MB except OSError: pass stats['models'][model_dir] = { 'checkpoints': model_checkpoints, 'size_mb': round(model_size_mb, 2) } stats['total_checkpoints'] += model_checkpoints stats['total_size_mb'] += model_size_mb stats['total_size_mb'] = round(stats['total_size_mb'], 2) return stats except Exception as e: logger.error(f"Error getting checkpoint stats: {e}") return { 'total_checkpoints': 0, 'total_size_mb': 0.0, 'models': {} }