sqlite for checkpoints, cleanup

This commit is contained in:
Dobromir Popov
2025-07-25 22:34:13 +03:00
parent 130a52fb9b
commit dd9f4b63ba
42 changed files with 2017 additions and 1485 deletions

View File

@ -5,6 +5,7 @@ This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
- Database-backed metadata storage for efficient access
"""
import os
@ -13,9 +14,13 @@ import glob
import logging
import shutil
import torch
import hashlib
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .database_manager import get_database_manager, CheckpointMetadata
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
# Global checkpoint manager instance
@ -46,7 +51,7 @@ def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_check
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata
Save a checkpoint with metadata to both filesystem and database
Args:
model: The model to save
@ -64,57 +69,90 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now()
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
checkpoint_id = f"{model_name}_{timestamp_str}"
# Save model
torch_path = f"{checkpoint_path}.pt"
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch_path = f"{checkpoint_path}.pt"
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp
'timestamp': timestamp_str,
'checkpoint_id': checkpoint_id
}, torch_path)
# Create metadata
checkpoint_metadata = {
# Calculate file size
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
# Save metadata to database
db_manager = get_database_manager()
checkpoint_metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
timestamp=timestamp,
performance_metrics=performance_metrics,
training_metadata=training_metadata or {},
file_path=torch_path,
file_size_mb=file_size_mb,
is_active=True # New checkpoint is active by default
)
# Save to database
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
# Log checkpoint save event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="SAVED",
checkpoint_id=checkpoint_id,
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
)
logger.info(f"Checkpoint saved: {checkpoint_id}")
else:
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
# Also save legacy JSON metadata for backward compatibility
legacy_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp,
'timestamp': timestamp_str,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': f"{model_name}_{timestamp}"
'checkpoint_id': checkpoint_id,
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
'created_at': timestamp_str
}
# Add performance score for sorting
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
checkpoint_metadata['created_at'] = timestamp
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
json.dump(legacy_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object
class CheckpointMetadata:
# Return metadata as an object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add database fields
self.checkpoint_id = checkpoint_id
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
return CheckpointMetadata(checkpoint_metadata)
return CheckpointMetadataObj(legacy_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
@ -122,7 +160,7 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Load the best checkpoint based on performance metrics using database metadata
Args:
model_name: Name of the model
@ -132,29 +170,77 @@ def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoi
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
# First try to get from database (fast metadata access)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
if not checkpoint_path:
if not checkpoint_metadata:
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
pass # Silent fallback
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert legacy metadata to object
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Add loss for compatibility
self.loss = metrics.get('loss', self.performance_score)
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
# Check if checkpoint file exists
if not os.path.exists(checkpoint_metadata.file_path):
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
return None
# Convert metadata to object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Log checkpoint load event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="LOADED",
checkpoint_id=checkpoint_metadata.checkpoint_id,
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
)
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
# Convert database metadata to object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, db_metadata: CheckpointMetadata):
self.checkpoint_id = db_metadata.checkpoint_id
self.model_name = db_metadata.model_name
self.model_type = db_metadata.model_type
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
self.performance_metrics = db_metadata.performance_metrics
self.training_metadata = db_metadata.training_metadata
self.file_path = db_metadata.file_path
self.file_size_mb = db_metadata.file_size_mb
self.is_active = db_metadata.is_active
# Backward compatibility fields
self.metrics = db_metadata.performance_metrics
self.metadata = db_metadata.training_metadata
self.created_at = self.timestamp
self.performance_score = db_metadata.performance_metrics.get('accuracy',
db_metadata.performance_metrics.get('reward', 0.0))
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
@ -254,7 +340,7 @@ class CheckpointManager:
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
logger.info(f"No checkpoints found for {model_name}")
# No more scattered "No checkpoints found" logs - handled by database system
return "", {}
# Load metadata for each checkpoint
@ -278,7 +364,7 @@ class CheckpointManager:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
logger.info(f"No valid checkpoints found for {model_name}")
# No more scattered logs - handled by database system
return "", {}
# Sort by metric (highest first)