#!/usr/bin/env python3 """ Example: Using the Checkpoint Management System """ import logging import torch import torch.nn as nn import numpy as np from datetime import datetime from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager from utils.training_integration import get_training_integration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ExampleCNN(nn.Module): def __init__(self, input_channels=5, num_classes=3): super().__init__() self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64, num_classes) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = self.pool(x) x = x.view(x.size(0), -1) return self.fc(x) def example_cnn_training(): logger.info("=== CNN Training Example ===") model = ExampleCNN() training_integration = get_training_integration() for epoch in range(5): # Simulate 5 epochs # Simulate training metrics train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1) train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02) val_loss = train_loss + np.random.normal(0, 0.05) val_acc = train_acc - 0.05 + np.random.normal(0, 0.02) # Clamp values to realistic ranges train_acc = max(0.0, min(1.0, train_acc)) val_acc = max(0.0, min(1.0, val_acc)) train_loss = max(0.1, train_loss) val_loss = max(0.1, val_loss) logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}") # Save checkpoint saved = training_integration.save_cnn_checkpoint( cnn_model=model, model_name="example_cnn", epoch=epoch + 1, train_accuracy=train_acc, val_accuracy=val_acc, train_loss=train_loss, val_loss=val_loss, training_time_hours=0.1 * (epoch + 1) ) if saved: logger.info(f" Checkpoint saved for epoch {epoch+1}") else: logger.info(f" Checkpoint not saved (performance not improved)") # Load the best checkpoint logger.info("\\nLoading best checkpoint...") best_result = load_best_checkpoint("example_cnn") if best_result: file_path, metadata = best_result logger.info(f"Best checkpoint: {metadata.checkpoint_id}") logger.info(f"Performance score: {metadata.performance_score:.4f}") def example_manual_checkpoint(): logger.info("\\n=== Manual Checkpoint Example ===") model = nn.Linear(10, 3) performance_metrics = { 'accuracy': 0.85, 'val_accuracy': 0.82, 'loss': 0.45, 'val_loss': 0.48 } training_metadata = { 'epoch': 25, 'training_time_hours': 2.5, 'total_parameters': sum(p.numel() for p in model.parameters()) } logger.info("Saving checkpoint manually...") metadata = save_checkpoint( model=model, model_name="example_manual", model_type="cnn", performance_metrics=performance_metrics, training_metadata=training_metadata, force_save=True ) if metadata: logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}") logger.info(f" Performance score: {metadata.performance_score:.4f}") def show_checkpoint_stats(): logger.info("\\n=== Checkpoint Statistics ===") checkpoint_manager = get_checkpoint_manager() stats = checkpoint_manager.get_checkpoint_stats() logger.info(f"Total models: {stats['total_models']}") logger.info(f"Total checkpoints: {stats['total_checkpoints']}") logger.info(f"Total size: {stats['total_size_mb']:.2f} MB") for model_name, model_stats in stats['models'].items(): logger.info(f"\\n{model_name}:") logger.info(f" Checkpoints: {model_stats['checkpoint_count']}") logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB") logger.info(f" Best performance: {model_stats['best_performance']:.4f}") def main(): logger.info(" Checkpoint Management System Examples") logger.info("=" * 50) try: example_cnn_training() example_manual_checkpoint() show_checkpoint_stats() logger.info("\\n All examples completed successfully!") logger.info("\\nTo use in your training:") logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint") logger.info("2. Or use: from utils.training_integration import get_training_integration") logger.info("3. Save checkpoints during training with performance metrics") logger.info("4. Load best checkpoints for inference or continued training") except Exception as e: logger.error(f"Error in examples: {e}") raise if __name__ == "__main__": main()