checkpoint manager
This commit is contained in:
186
cleanup_checkpoints.py
Normal file
186
cleanup_checkpoints.py
Normal file
@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checkpoint Cleanup and Migration Script
|
||||
|
||||
This script helps clean up existing checkpoints and migrate to the new
|
||||
checkpoint management system with W&B integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointCleanup:
|
||||
def __init__(self):
|
||||
self.saved_models_dir = Path("NN/models/saved")
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
|
||||
logger.info("Analyzing existing checkpoint files...")
|
||||
|
||||
analysis = {
|
||||
'total_files': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'model_types': {},
|
||||
'file_patterns': {},
|
||||
'potential_duplicates': []
|
||||
}
|
||||
|
||||
if not self.saved_models_dir.exists():
|
||||
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
|
||||
return analysis
|
||||
|
||||
for pt_file in self.saved_models_dir.rglob("*.pt"):
|
||||
try:
|
||||
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
|
||||
analysis['total_files'] += 1
|
||||
analysis['total_size_mb'] += file_size_mb
|
||||
|
||||
filename = pt_file.name
|
||||
|
||||
if 'cnn' in filename.lower():
|
||||
model_type = 'cnn'
|
||||
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
|
||||
model_type = 'rl'
|
||||
elif 'agent' in filename.lower():
|
||||
model_type = 'rl'
|
||||
else:
|
||||
model_type = 'unknown'
|
||||
|
||||
if model_type not in analysis['model_types']:
|
||||
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
|
||||
|
||||
analysis['model_types'][model_type]['count'] += 1
|
||||
analysis['model_types'][model_type]['size_mb'] += file_size_mb
|
||||
|
||||
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
|
||||
if base_name not in analysis['file_patterns']:
|
||||
analysis['file_patterns'][base_name] = []
|
||||
|
||||
analysis['file_patterns'][base_name].append({
|
||||
'path': str(pt_file),
|
||||
'size_mb': file_size_mb,
|
||||
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing {pt_file}: {e}")
|
||||
|
||||
for base_name, files in analysis['file_patterns'].items():
|
||||
if len(files) > 5: # More than 5 files with same base name
|
||||
analysis['potential_duplicates'].append({
|
||||
'base_name': base_name,
|
||||
'count': len(files),
|
||||
'total_size_mb': sum(f['size_mb'] for f in files),
|
||||
'files': files
|
||||
})
|
||||
|
||||
logger.info(f"Analysis complete:")
|
||||
logger.info(f" Total files: {analysis['total_files']}")
|
||||
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
|
||||
logger.info(f" Model types: {analysis['model_types']}")
|
||||
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
|
||||
|
||||
return analysis
|
||||
|
||||
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
|
||||
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
|
||||
|
||||
cleanup_results = {
|
||||
'removed': 0,
|
||||
'kept': 0,
|
||||
'space_saved_mb': 0.0,
|
||||
'details': []
|
||||
}
|
||||
|
||||
analysis = self.analyze_existing_checkpoints()
|
||||
|
||||
for duplicate_group in analysis['potential_duplicates']:
|
||||
base_name = duplicate_group['base_name']
|
||||
files = duplicate_group['files']
|
||||
|
||||
# Sort by modification time (newest first)
|
||||
files.sort(key=lambda x: x['modified'], reverse=True)
|
||||
|
||||
logger.info(f"Processing {base_name}: {len(files)} files")
|
||||
|
||||
# Keep only the 5 newest files
|
||||
for i, file_info in enumerate(files):
|
||||
if i < 5: # Keep first 5 (newest)
|
||||
cleanup_results['kept'] += 1
|
||||
cleanup_results['details'].append({
|
||||
'action': 'kept',
|
||||
'file': file_info['path']
|
||||
})
|
||||
else: # Remove the rest
|
||||
if not dry_run:
|
||||
try:
|
||||
Path(file_info['path']).unlink()
|
||||
logger.info(f"Removed: {file_info['path']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing {file_info['path']}: {e}")
|
||||
continue
|
||||
|
||||
cleanup_results['removed'] += 1
|
||||
cleanup_results['space_saved_mb'] += file_info['size_mb']
|
||||
cleanup_results['details'].append({
|
||||
'action': 'removed',
|
||||
'file': file_info['path'],
|
||||
'size_mb': file_info['size_mb']
|
||||
})
|
||||
|
||||
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
|
||||
logger.info(f" Kept: {cleanup_results['kept']}")
|
||||
logger.info(f" Removed: {cleanup_results['removed']}")
|
||||
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
|
||||
|
||||
return cleanup_results
|
||||
|
||||
def main():
|
||||
logger.info("=== Checkpoint Cleanup Tool ===")
|
||||
|
||||
cleanup = CheckpointCleanup()
|
||||
|
||||
# Analyze existing checkpoints
|
||||
logger.info("\\n1. Analyzing existing checkpoints...")
|
||||
analysis = cleanup.analyze_existing_checkpoints()
|
||||
|
||||
if analysis['total_files'] == 0:
|
||||
logger.info("No checkpoint files found.")
|
||||
return
|
||||
|
||||
# Show potential space savings
|
||||
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
|
||||
if total_duplicates > 0:
|
||||
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
|
||||
|
||||
# Dry run first
|
||||
logger.info("\\n2. Simulating cleanup...")
|
||||
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
|
||||
|
||||
if dry_run_results['removed'] > 0:
|
||||
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
|
||||
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
|
||||
|
||||
if proceed:
|
||||
logger.info("\\n3. Performing actual cleanup...")
|
||||
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
|
||||
logger.info("\\n=== Cleanup Complete ===")
|
||||
else:
|
||||
logger.info("Cleanup cancelled.")
|
||||
else:
|
||||
logger.info("No files to remove.")
|
||||
else:
|
||||
logger.info("No duplicate files found that need cleanup.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user