Files
gogo2/cleanup_checkpoints.py
2025-06-24 21:41:50 +03:00

187 lines
7.3 KiB
Python

#!/usr/bin/env python3
"""
Checkpoint Cleanup and Migration Script
This script helps clean up existing checkpoints and migrate to the new
checkpoint management system with W&B integration.
"""
import os
import logging
import shutil
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any
import torch
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class CheckpointCleanup:
def __init__(self):
self.saved_models_dir = Path("NN/models/saved")
self.checkpoint_manager = get_checkpoint_manager()
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
logger.info("Analyzing existing checkpoint files...")
analysis = {
'total_files': 0,
'total_size_mb': 0.0,
'model_types': {},
'file_patterns': {},
'potential_duplicates': []
}
if not self.saved_models_dir.exists():
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
return analysis
for pt_file in self.saved_models_dir.rglob("*.pt"):
try:
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
analysis['total_files'] += 1
analysis['total_size_mb'] += file_size_mb
filename = pt_file.name
if 'cnn' in filename.lower():
model_type = 'cnn'
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
model_type = 'rl'
elif 'agent' in filename.lower():
model_type = 'rl'
else:
model_type = 'unknown'
if model_type not in analysis['model_types']:
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
analysis['model_types'][model_type]['count'] += 1
analysis['model_types'][model_type]['size_mb'] += file_size_mb
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
if base_name not in analysis['file_patterns']:
analysis['file_patterns'][base_name] = []
analysis['file_patterns'][base_name].append({
'path': str(pt_file),
'size_mb': file_size_mb,
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
})
except Exception as e:
logger.error(f"Error analyzing {pt_file}: {e}")
for base_name, files in analysis['file_patterns'].items():
if len(files) > 5: # More than 5 files with same base name
analysis['potential_duplicates'].append({
'base_name': base_name,
'count': len(files),
'total_size_mb': sum(f['size_mb'] for f in files),
'files': files
})
logger.info(f"Analysis complete:")
logger.info(f" Total files: {analysis['total_files']}")
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
logger.info(f" Model types: {analysis['model_types']}")
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
return analysis
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
cleanup_results = {
'removed': 0,
'kept': 0,
'space_saved_mb': 0.0,
'details': []
}
analysis = self.analyze_existing_checkpoints()
for duplicate_group in analysis['potential_duplicates']:
base_name = duplicate_group['base_name']
files = duplicate_group['files']
# Sort by modification time (newest first)
files.sort(key=lambda x: x['modified'], reverse=True)
logger.info(f"Processing {base_name}: {len(files)} files")
# Keep only the 5 newest files
for i, file_info in enumerate(files):
if i < 5: # Keep first 5 (newest)
cleanup_results['kept'] += 1
cleanup_results['details'].append({
'action': 'kept',
'file': file_info['path']
})
else: # Remove the rest
if not dry_run:
try:
Path(file_info['path']).unlink()
logger.info(f"Removed: {file_info['path']}")
except Exception as e:
logger.error(f"Error removing {file_info['path']}: {e}")
continue
cleanup_results['removed'] += 1
cleanup_results['space_saved_mb'] += file_info['size_mb']
cleanup_results['details'].append({
'action': 'removed',
'file': file_info['path'],
'size_mb': file_info['size_mb']
})
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
logger.info(f" Kept: {cleanup_results['kept']}")
logger.info(f" Removed: {cleanup_results['removed']}")
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
return cleanup_results
def main():
logger.info("=== Checkpoint Cleanup Tool ===")
cleanup = CheckpointCleanup()
# Analyze existing checkpoints
logger.info("\\n1. Analyzing existing checkpoints...")
analysis = cleanup.analyze_existing_checkpoints()
if analysis['total_files'] == 0:
logger.info("No checkpoint files found.")
return
# Show potential space savings
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
if total_duplicates > 0:
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
# Dry run first
logger.info("\\n2. Simulating cleanup...")
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
if dry_run_results['removed'] > 0:
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
if proceed:
logger.info("\\n3. Performing actual cleanup...")
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
logger.info("\\n=== Cleanup Complete ===")
else:
logger.info("Cleanup cancelled.")
else:
logger.info("No files to remove.")
else:
logger.info("No duplicate files found that need cleanup.")
if __name__ == "__main__":
main()