Files
gogo2/utils/checkpoint_manager.py
Dobromir Popov fb72c93743 stability
2025-07-28 12:10:52 +03:00

547 lines
22 KiB
Python

"""
Checkpoint Manager
This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
- Database-backed metadata storage for efficient access
"""
import os
import json
import glob
import logging
import shutil
import torch
import hashlib
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .database_manager import get_database_manager, CheckpointMetadata
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
# Global checkpoint manager instance
_checkpoint_manager_instance = None
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
"""
Get the global checkpoint manager instance
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
Returns:
CheckpointManager: Global checkpoint manager instance
"""
global _checkpoint_manager_instance
if _checkpoint_manager_instance is None:
_checkpoint_manager_instance = CheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=max_checkpoints,
metric_name=metric_name
)
return _checkpoint_manager_instance
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata to both filesystem and database
Args:
model: The model to save
model_name: Name of the model
model_type: Type of the model ('cnn', 'rl', etc.)
performance_metrics: Performance metrics
training_metadata: Additional training metadata
checkpoint_dir: Directory to store checkpoints
Returns:
Any: Checkpoint metadata
"""
try:
# Create checkpoint directory
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now()
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
checkpoint_id = f"{model_name}_{timestamp_str}"
# Save model
torch_path = f"{checkpoint_path}.pt"
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp_str,
'checkpoint_id': checkpoint_id
}, torch_path)
# Calculate file size
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
# Save metadata to database
db_manager = get_database_manager()
checkpoint_metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
timestamp=timestamp,
performance_metrics=performance_metrics,
training_metadata=training_metadata or {},
file_path=torch_path,
file_size_mb=file_size_mb,
is_active=True # New checkpoint is active by default
)
# Save to database
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
# Log checkpoint save event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="SAVED",
checkpoint_id=checkpoint_id,
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
)
logger.info(f"Checkpoint saved: {checkpoint_id}")
else:
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
# Also save legacy JSON metadata for backward compatibility
legacy_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp_str,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': checkpoint_id,
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
'created_at': timestamp_str
}
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(legacy_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add database fields
self.checkpoint_id = checkpoint_id
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
return CheckpointMetadataObj(legacy_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return None
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics using database metadata
Args:
model_name: Name of the model
checkpoint_dir: Directory to store checkpoints
Returns:
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
# First try to get from database (fast metadata access)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
if not checkpoint_metadata:
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
pass # Silent fallback
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert legacy metadata to object
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Add loss for compatibility
self.loss = metrics.get('loss', self.performance_score)
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
# Check if checkpoint file exists
if not os.path.exists(checkpoint_metadata.file_path):
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
return None
# Log checkpoint load event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="LOADED",
checkpoint_id=checkpoint_metadata.checkpoint_id,
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
)
# Convert database metadata to object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, db_metadata: CheckpointMetadata):
self.checkpoint_id = db_metadata.checkpoint_id
self.model_name = db_metadata.model_name
self.model_type = db_metadata.model_type
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
self.performance_metrics = db_metadata.performance_metrics
self.training_metadata = db_metadata.training_metadata
self.file_path = db_metadata.file_path
self.file_size_mb = db_metadata.file_size_mb
self.is_active = db_metadata.is_active
# Backward compatibility fields
self.metrics = db_metadata.performance_metrics
self.metadata = db_metadata.training_metadata
self.created_at = self.timestamp
self.performance_score = db_metadata.performance_metrics.get('accuracy',
db_metadata.performance_metrics.get('reward', 0.0))
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return None
class CheckpointManager:
"""
Manages model checkpoints with performance-based optimization
This class:
1. Saves checkpoints with metadata
2. Loads the best checkpoint based on performance metrics
3. Cleans up old or underperforming checkpoints
"""
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
"""
Initialize the checkpoint manager
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
"""
self.checkpoint_dir = checkpoint_dir
self.max_checkpoints = max_checkpoints
self.metric_name = metric_name
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
"""
Save a checkpoint with metadata
Args:
model_name: Name of the model
model_path: Path to the model file
metrics: Performance metrics
metadata: Additional metadata
Returns:
str: Path to the saved checkpoint
"""
try:
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create checkpoint directory
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
os.makedirs(checkpoint_dir, exist_ok=True)
# Create checkpoint path
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
# Copy model file to checkpoint path
shutil.copy2(model_path, f"{checkpoint_path}.pt")
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'timestamp': timestamp,
'metrics': metrics,
'metadata': metadata or {}
}
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Clean up old checkpoints
self._cleanup_checkpoints(model_name)
return checkpoint_path
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return ""
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
Returns:
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
# No more scattered "No checkpoints found" logs - handled by database system
return "", {}
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
# No more scattered logs - handled by database system
return "", {}
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Return best checkpoint
best_checkpoint_path = checkpoints[0][0]
best_checkpoint_metadata = checkpoints[0][1]
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
return best_checkpoint_path, best_checkpoint_metadata
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return "", {}
def _cleanup_checkpoints(self, model_name: str) -> int:
"""
Clean up old or underperforming checkpoints
Args:
model_name: Name of the model
Returns:
int: Number of checkpoints deleted
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
return 0
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Keep only the best checkpoints
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
# Delete checkpoints
deleted_count = 0
for checkpoint_path, _ in checkpoints_to_delete:
try:
# Delete model file
if os.path.exists(f"{checkpoint_path}.pt"):
os.remove(f"{checkpoint_path}.pt")
# Delete metadata file
if os.path.exists(f"{checkpoint_path}_metadata.json"):
os.remove(f"{checkpoint_path}_metadata.json")
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
return 0
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get all checkpoints for a model
Args:
model_name: Name of the model
Returns:
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
return []
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
return checkpoints
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""
Get statistics about all checkpoints
Returns:
Dict[str, Any]: Statistics about checkpoints
"""
try:
stats = {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Iterate through all model directories
for model_dir in os.listdir(self.checkpoint_dir):
model_path = os.path.join(self.checkpoint_dir, model_dir)
if not os.path.isdir(model_path):
continue
# Count checkpoints for this model
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
model_checkpoints = len(checkpoint_files)
# Calculate total size for this model
model_size_mb = 0.0
for checkpoint_file in checkpoint_files:
try:
size_bytes = os.path.getsize(checkpoint_file)
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
except OSError:
pass
stats['models'][model_dir] = {
'checkpoints': model_checkpoints,
'size_mb': round(model_size_mb, 2)
}
stats['total_checkpoints'] += model_checkpoints
stats['total_size_mb'] += model_size_mb
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}