#!/usr/bin/env python3 """ Checkpoint Cleanup and Migration Script This script helps clean up existing checkpoints and migrate to the new checkpoint management system with W&B integration. """ import os import logging import shutil from pathlib import Path from datetime import datetime from typing import List, Dict, Any import torch from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class CheckpointCleanup: def __init__(self): self.saved_models_dir = Path("NN/models/saved") self.checkpoint_manager = get_checkpoint_manager() def analyze_existing_checkpoints(self) -> Dict[str, Any]: logger.info("Analyzing existing checkpoint files...") analysis = { 'total_files': 0, 'total_size_mb': 0.0, 'model_types': {}, 'file_patterns': {}, 'potential_duplicates': [] } if not self.saved_models_dir.exists(): logger.warning(f"Saved models directory not found: {self.saved_models_dir}") return analysis for pt_file in self.saved_models_dir.rglob("*.pt"): try: file_size_mb = pt_file.stat().st_size / (1024 * 1024) analysis['total_files'] += 1 analysis['total_size_mb'] += file_size_mb filename = pt_file.name if 'cnn' in filename.lower(): model_type = 'cnn' elif 'dqn' in filename.lower() or 'rl' in filename.lower(): model_type = 'rl' elif 'agent' in filename.lower(): model_type = 'rl' else: model_type = 'unknown' if model_type not in analysis['model_types']: analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0} analysis['model_types'][model_type]['count'] += 1 analysis['model_types'][model_type]['size_mb'] += file_size_mb base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '') if base_name not in analysis['file_patterns']: analysis['file_patterns'][base_name] = [] analysis['file_patterns'][base_name].append({ 'path': str(pt_file), 'size_mb': file_size_mb, 'modified': datetime.fromtimestamp(pt_file.stat().st_mtime) }) except Exception as e: logger.error(f"Error analyzing {pt_file}: {e}") for base_name, files in analysis['file_patterns'].items(): if len(files) > 5: # More than 5 files with same base name analysis['potential_duplicates'].append({ 'base_name': base_name, 'count': len(files), 'total_size_mb': sum(f['size_mb'] for f in files), 'files': files }) logger.info(f"Analysis complete:") logger.info(f" Total files: {analysis['total_files']}") logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB") logger.info(f" Model types: {analysis['model_types']}") logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}") return analysis def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]: logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...") cleanup_results = { 'removed': 0, 'kept': 0, 'space_saved_mb': 0.0, 'details': [] } analysis = self.analyze_existing_checkpoints() for duplicate_group in analysis['potential_duplicates']: base_name = duplicate_group['base_name'] files = duplicate_group['files'] # Sort by modification time (newest first) files.sort(key=lambda x: x['modified'], reverse=True) logger.info(f"Processing {base_name}: {len(files)} files") # Keep only the 5 newest files for i, file_info in enumerate(files): if i < 5: # Keep first 5 (newest) cleanup_results['kept'] += 1 cleanup_results['details'].append({ 'action': 'kept', 'file': file_info['path'] }) else: # Remove the rest if not dry_run: try: Path(file_info['path']).unlink() logger.info(f"Removed: {file_info['path']}") except Exception as e: logger.error(f"Error removing {file_info['path']}: {e}") continue cleanup_results['removed'] += 1 cleanup_results['space_saved_mb'] += file_info['size_mb'] cleanup_results['details'].append({ 'action': 'removed', 'file': file_info['path'], 'size_mb': file_info['size_mb'] }) logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:") logger.info(f" Kept: {cleanup_results['kept']}") logger.info(f" Removed: {cleanup_results['removed']}") logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB") return cleanup_results def main(): logger.info("=== Checkpoint Cleanup Tool ===") cleanup = CheckpointCleanup() # Analyze existing checkpoints logger.info("\\n1. Analyzing existing checkpoints...") analysis = cleanup.analyze_existing_checkpoints() if analysis['total_files'] == 0: logger.info("No checkpoint files found.") return # Show potential space savings total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5) if total_duplicates > 0: logger.info(f"\\nFound {total_duplicates} files that could be cleaned up") # Dry run first logger.info("\\n2. Simulating cleanup...") dry_run_results = cleanup.cleanup_duplicates(dry_run=True) if dry_run_results['removed'] > 0: proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files " f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y' if proceed: logger.info("\\n3. Performing actual cleanup...") cleanup_results = cleanup.cleanup_duplicates(dry_run=False) logger.info("\\n=== Cleanup Complete ===") else: logger.info("Cleanup cancelled.") else: logger.info("No files to remove.") else: logger.info("No duplicate files found that need cleanup.") if __name__ == "__main__": main()