408 lines
15 KiB
Python
408 lines
15 KiB
Python
"""
|
|
Checkpoint Manager
|
|
|
|
This module provides functionality for managing model checkpoints, including:
|
|
- Saving checkpoints with metadata
|
|
- Loading the best checkpoint based on performance metrics
|
|
- Cleaning up old or underperforming checkpoints
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import glob
|
|
import logging
|
|
import shutil
|
|
import torch
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global checkpoint manager instance
|
|
_checkpoint_manager_instance = None
|
|
|
|
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
|
|
"""
|
|
Get the global checkpoint manager instance
|
|
|
|
Args:
|
|
checkpoint_dir: Directory to store checkpoints
|
|
max_checkpoints: Maximum number of checkpoints to keep
|
|
metric_name: Metric to use for ranking checkpoints
|
|
|
|
Returns:
|
|
CheckpointManager: Global checkpoint manager instance
|
|
"""
|
|
global _checkpoint_manager_instance
|
|
|
|
if _checkpoint_manager_instance is None:
|
|
_checkpoint_manager_instance = CheckpointManager(
|
|
checkpoint_dir=checkpoint_dir,
|
|
max_checkpoints=max_checkpoints,
|
|
metric_name=metric_name
|
|
)
|
|
|
|
return _checkpoint_manager_instance
|
|
|
|
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
|
"""
|
|
Save a checkpoint with metadata
|
|
|
|
Args:
|
|
model: The model to save
|
|
model_name: Name of the model
|
|
model_type: Type of the model ('cnn', 'rl', etc.)
|
|
performance_metrics: Performance metrics
|
|
training_metadata: Additional training metadata
|
|
checkpoint_dir: Directory to store checkpoints
|
|
|
|
Returns:
|
|
Any: Checkpoint metadata
|
|
"""
|
|
try:
|
|
# Create checkpoint directory
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# Create timestamp
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Create checkpoint path
|
|
model_dir = os.path.join(checkpoint_dir, model_name)
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
|
|
|
|
# Save model
|
|
if hasattr(model, 'save'):
|
|
# Use model's save method if available
|
|
model.save(checkpoint_path)
|
|
else:
|
|
# Otherwise, save state_dict
|
|
torch_path = f"{checkpoint_path}.pt"
|
|
torch.save({
|
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
|
'model_name': model_name,
|
|
'model_type': model_type,
|
|
'timestamp': timestamp
|
|
}, torch_path)
|
|
|
|
# Create metadata
|
|
checkpoint_metadata = {
|
|
'model_name': model_name,
|
|
'model_type': model_type,
|
|
'timestamp': timestamp,
|
|
'performance_metrics': performance_metrics,
|
|
'training_metadata': training_metadata or {},
|
|
'checkpoint_id': f"{model_name}_{timestamp}"
|
|
}
|
|
|
|
# Add performance score for sorting
|
|
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
|
|
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
|
|
checkpoint_metadata['created_at'] = timestamp
|
|
|
|
# Save metadata
|
|
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
|
json.dump(checkpoint_metadata, f, indent=2)
|
|
|
|
# Get checkpoint manager and clean up old checkpoints
|
|
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
|
checkpoint_manager._cleanup_checkpoints(model_name)
|
|
|
|
# Return metadata as an object
|
|
class CheckpointMetadata:
|
|
def __init__(self, metadata):
|
|
for key, value in metadata.items():
|
|
setattr(self, key, value)
|
|
|
|
return CheckpointMetadata(checkpoint_metadata)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint: {e}")
|
|
return None
|
|
|
|
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
|
"""
|
|
Load the best checkpoint based on performance metrics
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
checkpoint_dir: Directory to store checkpoints
|
|
|
|
Returns:
|
|
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
|
"""
|
|
try:
|
|
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
|
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
|
|
|
if not checkpoint_path:
|
|
return None
|
|
|
|
# Convert metadata to object
|
|
class CheckpointMetadata:
|
|
def __init__(self, metadata):
|
|
for key, value in metadata.items():
|
|
setattr(self, key, value)
|
|
|
|
# Add performance score if not present
|
|
if not hasattr(self, 'performance_score'):
|
|
metrics = getattr(self, 'metrics', {})
|
|
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
|
self.performance_score = metrics.get(primary_metric, 0.0)
|
|
|
|
# Add created_at if not present
|
|
if not hasattr(self, 'created_at'):
|
|
self.created_at = getattr(self, 'timestamp', 'unknown')
|
|
|
|
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading best checkpoint: {e}")
|
|
return None
|
|
|
|
class CheckpointManager:
|
|
"""
|
|
Manages model checkpoints with performance-based optimization
|
|
|
|
This class:
|
|
1. Saves checkpoints with metadata
|
|
2. Loads the best checkpoint based on performance metrics
|
|
3. Cleans up old or underperforming checkpoints
|
|
"""
|
|
|
|
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
|
|
"""
|
|
Initialize the checkpoint manager
|
|
|
|
Args:
|
|
checkpoint_dir: Directory to store checkpoints
|
|
max_checkpoints: Maximum number of checkpoints to keep
|
|
metric_name: Metric to use for ranking checkpoints
|
|
"""
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.max_checkpoints = max_checkpoints
|
|
self.metric_name = metric_name
|
|
|
|
# Create checkpoint directory if it doesn't exist
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
|
|
|
|
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
|
|
"""
|
|
Save a checkpoint with metadata
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
model_path: Path to the model file
|
|
metrics: Performance metrics
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
str: Path to the saved checkpoint
|
|
"""
|
|
try:
|
|
# Create timestamp
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Create checkpoint directory
|
|
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# Create checkpoint path
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
|
|
|
|
# Copy model file to checkpoint path
|
|
shutil.copy2(model_path, f"{checkpoint_path}.pt")
|
|
|
|
# Create metadata
|
|
checkpoint_metadata = {
|
|
'model_name': model_name,
|
|
'timestamp': timestamp,
|
|
'metrics': metrics,
|
|
'metadata': metadata or {}
|
|
}
|
|
|
|
# Save metadata
|
|
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
|
json.dump(checkpoint_metadata, f, indent=2)
|
|
|
|
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
|
|
|
# Clean up old checkpoints
|
|
self._cleanup_checkpoints(model_name)
|
|
|
|
return checkpoint_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint: {e}")
|
|
return ""
|
|
|
|
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
|
|
"""
|
|
Load the best checkpoint based on performance metrics
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
|
|
"""
|
|
try:
|
|
# Find all checkpoint metadata files
|
|
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
|
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
|
|
|
if not metadata_files:
|
|
logger.info(f"No checkpoints found for {model_name}")
|
|
return "", {}
|
|
|
|
# Load metadata for each checkpoint
|
|
checkpoints = []
|
|
for metadata_file in metadata_files:
|
|
try:
|
|
with open(metadata_file, 'r') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Get checkpoint path (remove _metadata.json)
|
|
checkpoint_path = metadata_file[:-14]
|
|
|
|
# Check if model file exists
|
|
if not os.path.exists(f"{checkpoint_path}.pt"):
|
|
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
|
continue
|
|
|
|
checkpoints.append((checkpoint_path, metadata))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
|
|
|
if not checkpoints:
|
|
logger.info(f"No valid checkpoints found for {model_name}")
|
|
return "", {}
|
|
|
|
# Sort by metric (highest first)
|
|
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
|
|
|
# Return best checkpoint
|
|
best_checkpoint_path = checkpoints[0][0]
|
|
best_checkpoint_metadata = checkpoints[0][1]
|
|
|
|
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
|
|
|
|
return best_checkpoint_path, best_checkpoint_metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading best checkpoint: {e}")
|
|
return "", {}
|
|
|
|
def _cleanup_checkpoints(self, model_name: str) -> int:
|
|
"""
|
|
Clean up old or underperforming checkpoints
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
int: Number of checkpoints deleted
|
|
"""
|
|
try:
|
|
# Find all checkpoint metadata files
|
|
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
|
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
|
|
|
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
|
|
return 0
|
|
|
|
# Load metadata for each checkpoint
|
|
checkpoints = []
|
|
for metadata_file in metadata_files:
|
|
try:
|
|
with open(metadata_file, 'r') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Get checkpoint path (remove _metadata.json)
|
|
checkpoint_path = metadata_file[:-14]
|
|
|
|
checkpoints.append((checkpoint_path, metadata))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
|
|
|
# Sort by metric (highest first)
|
|
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
|
|
|
# Keep only the best checkpoints
|
|
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
|
|
|
|
# Delete checkpoints
|
|
deleted_count = 0
|
|
for checkpoint_path, _ in checkpoints_to_delete:
|
|
try:
|
|
# Delete model file
|
|
if os.path.exists(f"{checkpoint_path}.pt"):
|
|
os.remove(f"{checkpoint_path}.pt")
|
|
|
|
# Delete metadata file
|
|
if os.path.exists(f"{checkpoint_path}_metadata.json"):
|
|
os.remove(f"{checkpoint_path}_metadata.json")
|
|
|
|
deleted_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
|
|
|
|
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up checkpoints: {e}")
|
|
return 0
|
|
|
|
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
|
|
"""
|
|
Get all checkpoints for a model
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
|
|
"""
|
|
try:
|
|
# Find all checkpoint metadata files
|
|
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
|
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
|
|
|
if not metadata_files:
|
|
return []
|
|
|
|
# Load metadata for each checkpoint
|
|
checkpoints = []
|
|
for metadata_file in metadata_files:
|
|
try:
|
|
with open(metadata_file, 'r') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Get checkpoint path (remove _metadata.json)
|
|
checkpoint_path = metadata_file[:-14]
|
|
|
|
# Check if model file exists
|
|
if not os.path.exists(f"{checkpoint_path}.pt"):
|
|
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
|
continue
|
|
|
|
checkpoints.append((checkpoint_path, metadata))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
|
|
|
# Sort by timestamp (newest first)
|
|
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
|
|
|
|
return checkpoints
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting all checkpoints: {e}")
|
|
return [] |