stability
This commit is contained in:
@ -491,4 +491,57 @@ class CheckpointManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all checkpoints: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about all checkpoints
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Statistics about checkpoints
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Iterate through all model directories
|
||||
for model_dir in os.listdir(self.checkpoint_dir):
|
||||
model_path = os.path.join(self.checkpoint_dir, model_dir)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# Count checkpoints for this model
|
||||
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
|
||||
model_checkpoints = len(checkpoint_files)
|
||||
|
||||
# Calculate total size for this model
|
||||
model_size_mb = 0.0
|
||||
for checkpoint_file in checkpoint_files:
|
||||
try:
|
||||
size_bytes = os.path.getsize(checkpoint_file)
|
||||
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
stats['models'][model_dir] = {
|
||||
'checkpoints': model_checkpoints,
|
||||
'size_mb': round(model_size_mb, 2)
|
||||
}
|
||||
|
||||
stats['total_checkpoints'] += model_checkpoints
|
||||
stats['total_size_mb'] += model_size_mb
|
||||
|
||||
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
Reference in New Issue
Block a user