model checkpoint manager

This commit is contained in:
Dobromir Popov
2025-09-08 13:31:11 +03:00
parent 060fdd28b4
commit c9fba56622
6 changed files with 838 additions and 142 deletions

View File

@@ -16,6 +16,9 @@ import random
WANDB_AVAILABLE = False
# Import model registry
from utils.model_registry import get_model_registry
logger = logging.getLogger(__name__)
@dataclass
@@ -68,39 +71,48 @@ class CheckpointManager:
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str,
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with improved error handling and validation"""
"""Save a model checkpoint with improved error handling and validation using unified registry"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
from utils.model_registry import save_checkpoint as registry_save_checkpoint
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
# Use unified registry for checkpointing
success = registry_save_checkpoint(
model=model,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
metadata={
'performance_metrics': performance_metrics,
'training_metadata': training_metadata,
'checkpoint_manager': True
}
)
if not success:
return None
# Get checkpoint info from registry
registry = get_model_registry()
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
# Create CheckpointMetadata object
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_info['id'],
model_name=model_name,
model_type=model_type,
file_path=checkpoint_info['path'],
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
file_size_mb=0.0, # Will be calculated by registry
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
@@ -112,9 +124,8 @@ class CheckpointManager:
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
# W&B disabled
# Update local checkpoint tracking
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
@@ -128,14 +139,42 @@ class CheckpointManager:
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
# First, try the standard checkpoint system
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
# First, try the unified registry
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
if registry_result is None:
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
if registry_result:
checkpoint_path, checkpoint_data = registry_result
# Create CheckpointMetadata from registry data
metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=checkpoint_data.get('model_type', 'unknown'),
file_path=checkpoint_path,
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated by registry
performance_score=checkpoint_data.get('performance_score', 0.0),
accuracy=checkpoint_data.get('accuracy'),
loss=checkpoint_data.get('loss'),
reward=checkpoint_data.get('reward'),
pnl=checkpoint_data.get('pnl')
)
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
return checkpoint_path, metadata
# Fallback: Try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
@@ -146,22 +185,22 @@ class CheckpointManager:
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
# Only warn once per model to avoid spam
if model_name not in self._warned_models:
logger.info(f"No checkpoints found for {model_name}, starting fresh")
self._warned_models.add(model_name)
return None
except Exception as e:

446
utils/model_registry.py Normal file
View File

@@ -0,0 +1,446 @@
#!/usr/bin/env python3
"""
Unified Model Registry for Centralized Model Management
This module provides a unified interface for saving, loading, and managing
all machine learning models in the trading system. It consolidates model
storage from multiple locations into a single, organized structure.
"""
import os
import json
import torch
import logging
import pickle
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from datetime import datetime
import hashlib
logger = logging.getLogger(__name__)
class ModelRegistry:
"""
Unified model registry for centralized model management.
Handles saving, loading, and organization of all ML models.
"""
def __init__(self, base_dir: str = "models"):
"""
Initialize the model registry.
Args:
base_dir: Base directory for model storage
"""
self.base_dir = Path(base_dir)
self.saved_dir = self.base_dir / "saved"
self.checkpoint_dir = self.base_dir / "checkpoints"
self.archive_dir = self.base_dir / "archive"
# Model type directories
self.model_dirs = {
'cnn': self.base_dir / "cnn",
'dqn': self.base_dir / "dqn",
'transformer': self.base_dir / "transformer",
'hybrid': self.base_dir / "hybrid"
}
# Ensure all directories exist
self._ensure_directories()
# Metadata tracking
self.metadata_file = self.base_dir / "registry_metadata.json"
self.metadata = self._load_metadata()
logger.info(f"Model Registry initialized at {self.base_dir}")
def _ensure_directories(self):
"""Ensure all required directories exist."""
directories = [
self.saved_dir,
self.checkpoint_dir,
self.archive_dir
]
# Add model type directories
for model_dir in self.model_dirs.values():
directories.extend([
model_dir / "saved",
model_dir / "checkpoints",
model_dir / "archive"
])
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load registry metadata."""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _save_metadata(self):
"""Save registry metadata."""
self.metadata['last_updated'] = datetime.now().isoformat()
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model to the unified storage.
Args:
model: The model to save
model_name: Name of the model
model_type: Type of model (cnn, dqn, transformer, hybrid)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
# Generate filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{model_name}_{timestamp}.pt"
filepath = save_dir / filename
# Also save as latest
latest_filepath = save_dir / f"{model_name}_latest.pt"
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(save_dict, filepath)
torch.save(save_dict, latest_filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
self.metadata['models'][model_name].update({
'type': model_type,
'latest_path': str(latest_filepath),
'last_saved': timestamp,
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
})
self._save_metadata()
logger.info(f"Model {model_name} saved to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def load_model(self, model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Load a model from the unified storage.
Args:
model_name: Name of the model to load
model_type: Type of model (cnn, dqn, transformer, hybrid)
model_class: Model class to instantiate (if needed)
Returns:
The loaded model or None if failed
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found at {latest_filepath}")
return None
# Load checkpoint
checkpoint = torch.load(latest_filepath, map_location='cpu')
# Instantiate model if class provided
if model_class is not None:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Try to reconstruct model from state_dict
model = type('LoadedModel', (), {})()
model.state_dict = lambda: checkpoint['model_state_dict']
model.load_state_dict = lambda state_dict: None
logger.info(f"Model {model_name} loaded from {latest_filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
return None
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model checkpoint.
Args:
model: The model to checkpoint
model_name: Name of the model
model_type: Type of model
performance_score: Performance score for this checkpoint
metadata: Additional metadata
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
checkpoint_dir = model_dir / "checkpoints"
# Generate checkpoint ID
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
# Save checkpoint
checkpoint_data = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'model_name': model_name,
'performance_score': performance_score,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(checkpoint_data, filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
if 'checkpoints' not in self.metadata['models'][model_name]:
self.metadata['models'][model_name]['checkpoints'] = []
checkpoint_info = {
'id': checkpoint_id,
'path': str(filepath),
'performance_score': performance_score,
'timestamp': timestamp
}
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
# Keep only top 5 checkpoints
checkpoints = self.metadata['models'][model_name]['checkpoints']
if len(checkpoints) > 5:
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
checkpoints_to_remove = checkpoints[5:]
for checkpoint in checkpoints_to_remove:
try:
os.remove(checkpoint['path'])
except:
pass
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
self._save_metadata()
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
return True
except Exception as e:
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
return False
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint for a model.
Args:
model_name: Name of the model
model_type: Type of model
Returns:
Tuple of (checkpoint_path, checkpoint_data) or None
"""
try:
if model_name not in self.metadata['models']:
logger.warning(f"No metadata found for model {model_name}")
return None
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if not checkpoints:
logger.warning(f"No checkpoints found for model {model_name}")
return None
# Find best checkpoint by performance score
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
checkpoint_path = best_checkpoint['path']
if not os.path.exists(checkpoint_path):
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
return checkpoint_path, checkpoint_data
except Exception as e:
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
return None
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
"""
Archive a model by moving it to archive directory.
Args:
model_name: Name of the model to archive
model_type: Type of model
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
archive_dir = model_dir / "archive"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found to archive")
return False
# Move to archive with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
os.rename(latest_filepath, archive_filepath)
logger.info(f"Model {model_name} archived to {archive_filepath}")
return True
except Exception as e:
logger.error(f"Failed to archive model {model_name}: {e}")
return False
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
"""
List all models in the registry.
Args:
model_type: Filter by model type (optional)
Returns:
Dictionary of model information
"""
models_info = {}
for model_name, model_data in self.metadata['models'].items():
if model_type and model_data.get('type') != model_type:
continue
models_info[model_name] = {
'type': model_data.get('type'),
'last_saved': model_data.get('last_saved'),
'save_count': model_data.get('save_count', 0),
'checkpoint_count': len(model_data.get('checkpoints', [])),
'latest_path': model_data.get('latest_path')
}
return models_info
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
"""
Clean up old checkpoints, keeping only the best ones.
Args:
model_name: Name of the model
keep_count: Number of checkpoints to keep
Returns:
Number of checkpoints removed
"""
if model_name not in self.metadata['models']:
return 0
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if len(checkpoints) <= keep_count:
return 0
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
# Remove old checkpoints
removed_count = 0
for checkpoint in checkpoints[keep_count:]:
try:
os.remove(checkpoint['path'])
removed_count += 1
except:
pass
# Update metadata
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
self._save_metadata()
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
return removed_count
# Global registry instance
_registry_instance = None
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry_instance
if _registry_instance is None:
_registry_instance = ModelRegistry()
return _registry_instance
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a model using the global registry.
"""
return get_model_registry().save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Convenience function to load a model using the global registry.
"""
return get_model_registry().load_model(model_name, model_type, model_class)
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a checkpoint using the global registry.
"""
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Convenience function to load the best checkpoint using the global registry.
"""
return get_model_registry().load_best_checkpoint(model_name, model_type)