sqlite for checkpoints, cleanup
This commit is contained in:
@ -5,6 +5,7 @@ This module provides functionality for managing model checkpoints, including:
|
||||
- Saving checkpoints with metadata
|
||||
- Loading the best checkpoint based on performance metrics
|
||||
- Cleaning up old or underperforming checkpoints
|
||||
- Database-backed metadata storage for efficient access
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -13,9 +14,13 @@ import glob
|
||||
import logging
|
||||
import shutil
|
||||
import torch
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .database_manager import get_database_manager, CheckpointMetadata
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpoint manager instance
|
||||
@ -46,7 +51,7 @@ def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_check
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
Save a checkpoint with metadata to both filesystem and database
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
@ -64,57 +69,90 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
timestamp = datetime.now()
|
||||
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint path
|
||||
model_dir = os.path.join(checkpoint_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
|
||||
checkpoint_id = f"{model_name}_{timestamp_str}"
|
||||
|
||||
# Save model
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
if hasattr(model, 'save'):
|
||||
# Use model's save method if available
|
||||
model.save(checkpoint_path)
|
||||
else:
|
||||
# Otherwise, save state_dict
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp
|
||||
'timestamp': timestamp_str,
|
||||
'checkpoint_id': checkpoint_id
|
||||
}, torch_path)
|
||||
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
# Calculate file size
|
||||
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
|
||||
|
||||
# Save metadata to database
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
timestamp=timestamp,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata or {},
|
||||
file_path=torch_path,
|
||||
file_size_mb=file_size_mb,
|
||||
is_active=True # New checkpoint is active by default
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
|
||||
# Log checkpoint save event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="SAVED",
|
||||
checkpoint_id=checkpoint_id,
|
||||
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
|
||||
)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
|
||||
|
||||
# Also save legacy JSON metadata for backward compatibility
|
||||
legacy_metadata = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp,
|
||||
'timestamp': timestamp_str,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'checkpoint_id': f"{model_name}_{timestamp}"
|
||||
'checkpoint_id': checkpoint_id,
|
||||
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
|
||||
'created_at': timestamp_str
|
||||
}
|
||||
|
||||
# Add performance score for sorting
|
||||
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
|
||||
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
|
||||
checkpoint_metadata['created_at'] = timestamp
|
||||
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
json.dump(legacy_metadata, f, indent=2)
|
||||
|
||||
# Get checkpoint manager and clean up old checkpoints
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_manager._cleanup_checkpoints(model_name)
|
||||
|
||||
# Return metadata as an object
|
||||
class CheckpointMetadata:
|
||||
# Return metadata as an object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
# Add database fields
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
|
||||
|
||||
return CheckpointMetadata(checkpoint_metadata)
|
||||
return CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
@ -122,7 +160,7 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
|
||||
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
Load the best checkpoint based on performance metrics using database metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
@ -132,29 +170,77 @@ def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoi
|
||||
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
||||
"""
|
||||
try:
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
# First try to get from database (fast metadata access)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
|
||||
|
||||
if not checkpoint_path:
|
||||
if not checkpoint_metadata:
|
||||
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
|
||||
pass # Silent fallback
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Convert legacy metadata to object
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
|
||||
# Add loss for compatibility
|
||||
self.loss = metrics.get('loss', self.performance_score)
|
||||
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
# Check if checkpoint file exists
|
||||
if not os.path.exists(checkpoint_metadata.file_path):
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
|
||||
return None
|
||||
|
||||
# Convert metadata to object
|
||||
class CheckpointMetadata:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
# Log checkpoint load event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="LOADED",
|
||||
checkpoint_id=checkpoint_metadata.checkpoint_id,
|
||||
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
|
||||
)
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
|
||||
# Convert database metadata to object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, db_metadata: CheckpointMetadata):
|
||||
self.checkpoint_id = db_metadata.checkpoint_id
|
||||
self.model_name = db_metadata.model_name
|
||||
self.model_type = db_metadata.model_type
|
||||
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
self.performance_metrics = db_metadata.performance_metrics
|
||||
self.training_metadata = db_metadata.training_metadata
|
||||
self.file_path = db_metadata.file_path
|
||||
self.file_size_mb = db_metadata.file_size_mb
|
||||
self.is_active = db_metadata.is_active
|
||||
|
||||
# Backward compatibility fields
|
||||
self.metrics = db_metadata.performance_metrics
|
||||
self.metadata = db_metadata.training_metadata
|
||||
self.created_at = self.timestamp
|
||||
self.performance_score = db_metadata.performance_metrics.get('accuracy',
|
||||
db_metadata.performance_metrics.get('reward', 0.0))
|
||||
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
|
||||
|
||||
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
@ -254,7 +340,7 @@ class CheckpointManager:
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
logger.info(f"No checkpoints found for {model_name}")
|
||||
# No more scattered "No checkpoints found" logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
@ -278,7 +364,7 @@ class CheckpointManager:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
logger.info(f"No valid checkpoints found for {model_name}")
|
||||
# No more scattered logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Sort by metric (highest first)
|
||||
|
Reference in New Issue
Block a user