checkpoint manager

This commit is contained in:
Dobromir Popov
2025-07-23 21:40:04 +03:00
parent bab39fa68f
commit 45a62443a0
9 changed files with 1587 additions and 709 deletions

View File

@ -0,0 +1,3 @@
"""
Utils package for the multi-modal trading system
"""

View File

@ -1,466 +1,408 @@
#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
Checkpoint Manager
This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
"""
import os
import json
import glob
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import shutil
import torch
import random
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
# Global checkpoint manager instance
_checkpoint_manager_instance = None
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
"""
Get the global checkpoint manager instance
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
Returns:
CheckpointManager: Global checkpoint manager instance
"""
global _checkpoint_manager_instance
if _checkpoint_manager_instance is None:
_checkpoint_manager_instance = CheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=max_checkpoints,
metric_name=metric_name
)
return _checkpoint_manager_instance
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata
Args:
model: The model to save
model_name: Name of the model
model_type: Type of the model ('cnn', 'rl', etc.)
performance_metrics: Performance metrics
training_metadata: Additional training metadata
checkpoint_dir: Directory to store checkpoints
Returns:
Any: Checkpoint metadata
"""
try:
# Create checkpoint directory
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
# Save model
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch_path = f"{checkpoint_path}.pt"
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp
}, torch_path)
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': f"{model_name}_{timestamp}"
}
# Add performance score for sorting
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
checkpoint_metadata['created_at'] = timestamp
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
return CheckpointMetadata(checkpoint_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return None
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
checkpoint_dir: Directory to store checkpoints
Returns:
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert metadata to object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return None
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
"""
Manages model checkpoints with performance-based optimization
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
This class:
1. Saves checkpoints with metadata
2. Loads the best checkpoint based on performance metrics
3. Cleans up old or underperforming checkpoints
"""
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
"""
Initialize the checkpoint manager
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
"""
self.checkpoint_dir = checkpoint_dir
self.max_checkpoints = max_checkpoints
self.metric_name = metric_name
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
"""
Save a checkpoint with metadata
Args:
model_name: Name of the model
model_path: Path to the model file
metrics: Performance metrics
metadata: Additional metadata
Returns:
str: Path to the saved checkpoint
"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
# Create checkpoint directory
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
# Create checkpoint path
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
performance_score = self._calculate_performance_score(performance_metrics)
# Copy model file to checkpoint path
shutil.copy2(model_path, f"{checkpoint_path}.pt")
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
# First, try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
else:
# Clean up invalid metadata entries
invalid_count = len(self.checkpoints[model_name])
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score with improved sensitivity for training models"""
score = 0.0
# Prioritize loss reduction for active training models
if 'loss' in metrics:
# Invert loss so lower loss = higher score, with better scaling
loss_value = metrics['loss']
if loss_value > 0:
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
else:
score += 100 # Perfect loss
# Add other metrics with appropriate weights
if 'accuracy' in metrics:
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 50
if 'val_loss' in metrics:
val_loss = metrics['val_loss']
if val_loss > 0:
score += max(0, 50 / (1 + val_loss))
if 'reward' in metrics:
score += metrics['reward'] * 10
if 'pnl' in metrics:
score += metrics['pnl'] * 5
if 'training_samples' in metrics:
# Bonus for processing more training samples
score += min(10, metrics['training_samples'] / 10)
# Return actual calculated score - NO SYNTHETIC MINIMUM
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Improved checkpoint saving logic with more frequent saves during training"""
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True # Always save first checkpoint
# Allow more checkpoints during active training
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
# Get current best and worst scores
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
best_score = max(scores)
worst_score = min(scores)
# Save if better than worst (more frequent saves)
if performance_score > worst_score:
return True
# For high-performing models (score > 100), be more sensitive to small improvements
if best_score > 100:
# Save if within 0.1% of best score (very sensitive for converged models)
if performance_score >= best_score * 0.999:
return True
else:
# Also save if we're within 10% of best score (capture near-optimal models)
if performance_score >= best_score * 0.9:
return True
# Save more frequently during active training (every 5th attempt instead of 10th)
if random.random() < 0.2: # 20% chance to save anyway
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
return True
return False
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'timestamp': timestamp,
'metrics': metrics,
'metadata': metadata or {}
}
stats['total_size_mb'] += model_size
return stats
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir)
# Define model name mappings and patterns for legacy files
legacy_patterns = {
'dqn_agent': [
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt'
],
'enhanced_cnn': [
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
'optimized_short_term_model_ticks_best.pt'
],
'extrema_trainer': [
'supervised_model_best.pt'
],
'cob_rl': [
'best_rl_model.pth_policy.pt',
'rl_agent_best_policy.pt'
],
'decision': [
# Decision models might be in subdirectories, but let's check main dir too
'decision_best.pt',
'decision_model_best.pt',
# Check for transformer models which might be used as decision models
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt'
]
}
# Get patterns for this model name
patterns = legacy_patterns.get(model_name, [])
# Also try generic patterns based on model name
patterns.extend([
f'{model_name}_best.pt',
f'{model_name}_best_policy.pt',
f'{model_name}_final.pt',
f'{model_name}_final_policy.pt'
])
# Search for the model files
for pattern in patterns:
candidate_path = base_dir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():
if subdir.is_dir() and subdir.name == model_name:
for pattern in patterns:
candidate_path = subdir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
"""Create metadata for legacy model files using only actual file information"""
try:
file_size_mb = file_path.stat().st_size / (1024 * 1024)
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Clean up old checkpoints
self._cleanup_checkpoints(model_name)
return checkpoint_path
# NO SYNTHETIC DATA - use only actual file information
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=created_time,
file_size_mb=file_size_mb,
performance_score=0.0, # Unknown performance - use 0, not synthetic values
accuracy=None,
loss=None,
val_accuracy=None,
val_loss=None,
reward=None,
pnl=None,
epoch=None,
training_time_hours=None,
total_parameters=None,
wandb_run_id=None,
wandb_artifact_name=None
)
except Exception as e:
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=datetime.now(),
file_size_mb=0.0,
performance_score=0.0 # Unknown - use 0, not synthetic
)
_checkpoint_manager = None
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)
logger.error(f"Error saving checkpoint: {e}")
return ""
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
Returns:
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
logger.info(f"No checkpoints found for {model_name}")
return "", {}
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
logger.info(f"No valid checkpoints found for {model_name}")
return "", {}
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Return best checkpoint
best_checkpoint_path = checkpoints[0][0]
best_checkpoint_metadata = checkpoints[0][1]
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
return best_checkpoint_path, best_checkpoint_metadata
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return "", {}
def _cleanup_checkpoints(self, model_name: str) -> int:
"""
Clean up old or underperforming checkpoints
Args:
model_name: Name of the model
Returns:
int: Number of checkpoints deleted
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
return 0
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Keep only the best checkpoints
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
# Delete checkpoints
deleted_count = 0
for checkpoint_path, _ in checkpoints_to_delete:
try:
# Delete model file
if os.path.exists(f"{checkpoint_path}.pt"):
os.remove(f"{checkpoint_path}.pt")
# Delete metadata file
if os.path.exists(f"{checkpoint_path}_metadata.json"):
os.remove(f"{checkpoint_path}_metadata.json")
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
return 0
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get all checkpoints for a model
Args:
model_name: Name of the model
Returns:
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
return []
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
return checkpoints
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []

View File

@ -9,7 +9,7 @@ from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
logger = logging.getLogger(__name__)
@ -78,7 +78,7 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model,
model_name=model_name,
model_type='cnn',
@ -137,7 +137,7 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent,
model_name=model_name,
model_type='rl',
@ -158,7 +158,7 @@ class TrainingIntegration:
def load_best_model(self, model_name: str, model_class=None):
try:
result = load_best_checkpoint(model_name)
result = self.checkpoint_manager.load_best_checkpoint(model_name)
if not result:
logger.warning(f"No checkpoint found for model: {model_name}")
return None