149 lines
5.0 KiB
Python
149 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Example: Using the Checkpoint Management System
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from datetime import datetime
|
|
|
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
|
|
from utils.training_integration import get_training_integration
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ExampleCNN(nn.Module):
|
|
def __init__(self, input_channels=5, num_classes=3):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
|
|
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
|
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(64, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = torch.relu(self.conv1(x))
|
|
x = torch.relu(self.conv2(x))
|
|
x = self.pool(x)
|
|
x = x.view(x.size(0), -1)
|
|
return self.fc(x)
|
|
|
|
def example_cnn_training():
|
|
logger.info("=== CNN Training Example ===")
|
|
|
|
model = ExampleCNN()
|
|
training_integration = get_training_integration()
|
|
|
|
for epoch in range(5): # Simulate 5 epochs
|
|
# Simulate training metrics
|
|
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
|
|
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
|
|
val_loss = train_loss + np.random.normal(0, 0.05)
|
|
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
|
|
|
|
# Clamp values to realistic ranges
|
|
train_acc = max(0.0, min(1.0, train_acc))
|
|
val_acc = max(0.0, min(1.0, val_acc))
|
|
train_loss = max(0.1, train_loss)
|
|
val_loss = max(0.1, val_loss)
|
|
|
|
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
|
|
|
|
# Save checkpoint
|
|
saved = training_integration.save_cnn_checkpoint(
|
|
cnn_model=model,
|
|
model_name="example_cnn",
|
|
epoch=epoch + 1,
|
|
train_accuracy=train_acc,
|
|
val_accuracy=val_acc,
|
|
train_loss=train_loss,
|
|
val_loss=val_loss,
|
|
training_time_hours=0.1 * (epoch + 1)
|
|
)
|
|
|
|
if saved:
|
|
logger.info(f" Checkpoint saved for epoch {epoch+1}")
|
|
else:
|
|
logger.info(f" Checkpoint not saved (performance not improved)")
|
|
|
|
# Load the best checkpoint
|
|
logger.info("\\nLoading best checkpoint...")
|
|
best_result = load_best_checkpoint("example_cnn")
|
|
if best_result:
|
|
file_path, metadata = best_result
|
|
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
|
|
logger.info(f"Performance score: {metadata.performance_score:.4f}")
|
|
|
|
def example_manual_checkpoint():
|
|
logger.info("\\n=== Manual Checkpoint Example ===")
|
|
|
|
model = nn.Linear(10, 3)
|
|
|
|
performance_metrics = {
|
|
'accuracy': 0.85,
|
|
'val_accuracy': 0.82,
|
|
'loss': 0.45,
|
|
'val_loss': 0.48
|
|
}
|
|
|
|
training_metadata = {
|
|
'epoch': 25,
|
|
'training_time_hours': 2.5,
|
|
'total_parameters': sum(p.numel() for p in model.parameters())
|
|
}
|
|
|
|
logger.info("Saving checkpoint manually...")
|
|
metadata = save_checkpoint(
|
|
model=model,
|
|
model_name="example_manual",
|
|
model_type="cnn",
|
|
performance_metrics=performance_metrics,
|
|
training_metadata=training_metadata,
|
|
force_save=True
|
|
)
|
|
|
|
if metadata:
|
|
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
|
|
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
|
|
|
def show_checkpoint_stats():
|
|
logger.info("\\n=== Checkpoint Statistics ===")
|
|
|
|
checkpoint_manager = get_checkpoint_manager()
|
|
stats = checkpoint_manager.get_checkpoint_stats()
|
|
|
|
logger.info(f"Total models: {stats['total_models']}")
|
|
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
|
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
|
|
|
for model_name, model_stats in stats['models'].items():
|
|
logger.info(f"\\n{model_name}:")
|
|
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
|
|
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
|
|
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
|
|
|
|
def main():
|
|
logger.info(" Checkpoint Management System Examples")
|
|
logger.info("=" * 50)
|
|
|
|
try:
|
|
example_cnn_training()
|
|
example_manual_checkpoint()
|
|
show_checkpoint_stats()
|
|
|
|
logger.info("\\n All examples completed successfully!")
|
|
logger.info("\\nTo use in your training:")
|
|
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
|
|
logger.info("2. Or use: from utils.training_integration import get_training_integration")
|
|
logger.info("3. Save checkpoints during training with performance metrics")
|
|
logger.info("4. Load best checkpoints for inference or continued training")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in examples: {e}")
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
main()
|