refactoring

This commit is contained in:
Dobromir Popov
2025-09-08 23:57:21 +03:00
parent 98ebbe5089
commit c3a94600c8
50 changed files with 856 additions and 1302 deletions

View File

@@ -1,560 +0,0 @@
#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
import os
import json
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import torch
import random
WANDB_AVAILABLE = False
# Import model registry
from utils.model_registry import get_model_registry
logger = logging.getLogger(__name__)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = False):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = False
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._warned_models = set() # Track models we've warned about to reduce spam
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with improved error handling and validation using unified registry"""
try:
from utils.model_registry import save_checkpoint as registry_save_checkpoint
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
# Use unified registry for checkpointing
success = registry_save_checkpoint(
model=model,
model_name=model_name,
model_type=model_type,
performance_score=performance_score,
metadata={
'performance_metrics': performance_metrics,
'training_metadata': training_metadata,
'checkpoint_manager': True
}
)
if not success:
return None
# Get checkpoint info from registry
registry = get_model_registry()
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
# Create CheckpointMetadata object
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_info['id'],
model_name=model_name,
model_type=model_type,
file_path=checkpoint_info['path'],
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
file_size_mb=0.0, # Will be calculated by registry
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
# Update local checkpoint tracking
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
# First, try the unified registry
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
if registry_result is None:
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
if registry_result:
checkpoint_path, checkpoint_data = registry_result
# Create CheckpointMetadata from registry data
metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=checkpoint_data.get('model_type', 'unknown'),
file_path=checkpoint_path,
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated by registry
performance_score=checkpoint_data.get('performance_score', 0.0),
accuracy=checkpoint_data.get('accuracy'),
loss=checkpoint_data.get('loss'),
reward=checkpoint_data.get('reward'),
pnl=checkpoint_data.get('pnl')
)
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
return checkpoint_path, metadata
# Fallback: Try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
else:
# Clean up invalid metadata entries
invalid_count = len(self.checkpoints[model_name])
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
# Only warn once per model to avoid spam
if model_name not in self._warned_models:
logger.info(f"No checkpoints found for {model_name}, starting fresh")
self._warned_models.add(model_name)
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score with improved sensitivity for training models"""
score = 0.0
# Prioritize loss reduction for active training models
if 'loss' in metrics:
# Invert loss so lower loss = higher score, with better scaling
loss_value = metrics['loss']
if loss_value > 0:
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
else:
score += 100 # Perfect loss
# Add other metrics with appropriate weights
if 'accuracy' in metrics:
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 50
if 'val_loss' in metrics:
val_loss = metrics['val_loss']
if val_loss > 0:
score += max(0, 50 / (1 + val_loss))
if 'reward' in metrics:
score += metrics['reward'] * 10
if 'pnl' in metrics:
score += metrics['pnl'] * 5
if 'training_samples' in metrics:
# Bonus for processing more training samples
score += min(10, metrics['training_samples'] / 10)
# Return actual calculated score - NO SYNTHETIC MINIMUM
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Improved checkpoint saving logic with more frequent saves during training"""
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True # Always save first checkpoint
# Allow more checkpoints during active training
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
# Get current best and worst scores
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
best_score = max(scores)
worst_score = min(scores)
# Save if better than worst (more frequent saves)
if performance_score > worst_score:
return True
# For high-performing models (score > 100), be more sensitive to small improvements
if best_score > 100:
# Save if within 0.1% of best score (very sensitive for converged models)
if performance_score >= best_score * 0.999:
return True
else:
# Also save if we're within 10% of best score (capture near-optimal models)
if performance_score >= best_score * 0.9:
return True
# Save more frequently during active training (every 5th attempt instead of 10th)
if random.random() < 0.2: # 20% chance to save anyway
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
return True
return False
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
}
stats['total_size_mb'] += model_size
return stats
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir)
# Additional search locations
search_dirs = [
base_dir,
Path("models/saved"),
Path("NN/models/saved"),
Path("models"),
Path("models/archive"),
Path("models/backtest")
]
# Define model name mappings and patterns for legacy files
legacy_patterns = {
'dqn_agent': [
'dqn_agent_session_policy.pt',
'dqn_agent_session_agent_state.pt',
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt',
'trading_agent_best_pnl.pt'
],
'enhanced_cnn': [
'cnn_model_session.pt',
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
'optimized_short_term_model_ticks_best.pt'
],
'extrema_trainer': [
'supervised_model_best.pt'
],
'cob_rl': [
'best_rl_model.pth_policy.pt',
'rl_agent_best_policy.pt'
],
'decision': [
# Decision models might be in subdirectories, but let's check main dir too
'decision_best.pt',
'decision_model_best.pt',
# Check for transformer models which might be used as decision models
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt'
]
}
# Get patterns for this model name
patterns = legacy_patterns.get(model_name, [])
# Also try generic patterns based on model name
patterns.extend([
f'{model_name}_best.pt',
f'{model_name}_best_policy.pt',
f'{model_name}_final.pt',
f'{model_name}_final_policy.pt'
])
# Search for the model files in all search directories
for search_dir in search_dirs:
if not search_dir.exists():
continue
for pattern in patterns:
candidate_path = search_dir / pattern
if candidate_path.exists():
logger.info(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():
if subdir.is_dir() and subdir.name == model_name:
for pattern in patterns:
candidate_path = subdir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
# Extended search: scan common project model directories for best checkpoints
try:
# Attempt to infer project root from base_dir (NN/models/saved -> root)
project_root = base_dir.resolve().parent.parent.parent
except Exception:
project_root = Path(".").resolve()
additional_dirs = [
project_root / "models",
project_root / "models" / "archive",
project_root / "models" / "backtest",
]
def _match_legacy_name(candidate: Path, model: str) -> bool:
name = candidate.name.lower()
model_keys = {
'dqn_agent': ['dqn', 'agent', 'policy'],
'enhanced_cnn': ['cnn', 'optimized_short_term'],
'extrema_trainer': ['supervised', 'extrema'],
'cob_rl': ['cob', 'rl', 'policy'],
'decision': ['decision', 'transformer']
}.get(model, [model])
return any(k in name for k in model_keys)
candidates: List[Path] = []
for adir in additional_dirs:
if not adir.exists():
continue
try:
for pt in adir.rglob('*.pt'):
# Prefer files that indicate "best" and match model hints
lname = pt.name.lower()
if 'best' in lname and _match_legacy_name(pt, model_name):
candidates.append(pt)
# Do not add generic fallbacks to avoid mismatched model types
except Exception:
# Ignore directory traversal issues
pass
if candidates:
# Pick the most recently modified candidate
try:
best = max(candidates, key=lambda p: p.stat().st_mtime)
logger.debug(f"Found legacy model file in project models dir: {best}")
return best
except Exception:
# If stat fails, just return the first one deterministically
candidates.sort()
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
return candidates[0]
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
"""Create metadata for legacy model files using only actual file information"""
try:
file_size_mb = file_path.stat().st_size / (1024 * 1024)
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
# NO SYNTHETIC DATA - use only actual file information
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=created_time,
file_size_mb=file_size_mb,
performance_score=0.0, # Unknown performance - use 0, not synthetic values
accuracy=None,
loss=None,
val_accuracy=None,
val_loss=None,
reward=None,
pnl=None,
epoch=None,
training_time_hours=None,
total_parameters=None,
wandb_run_id=None,
wandb_artifact_name=None
)
except Exception as e:
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=datetime.now(),
file_size_mb=0.0,
performance_score=0.0 # Unknown - use 0, not synthetic
)
_checkpoint_manager = None
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)

View File

@@ -1,446 +0,0 @@
#!/usr/bin/env python3
"""
Unified Model Registry for Centralized Model Management
This module provides a unified interface for saving, loading, and managing
all machine learning models in the trading system. It consolidates model
storage from multiple locations into a single, organized structure.
"""
import os
import json
import torch
import logging
import pickle
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from datetime import datetime
import hashlib
logger = logging.getLogger(__name__)
class ModelRegistry:
"""
Unified model registry for centralized model management.
Handles saving, loading, and organization of all ML models.
"""
def __init__(self, base_dir: str = "models"):
"""
Initialize the model registry.
Args:
base_dir: Base directory for model storage
"""
self.base_dir = Path(base_dir)
self.saved_dir = self.base_dir / "saved"
self.checkpoint_dir = self.base_dir / "checkpoints"
self.archive_dir = self.base_dir / "archive"
# Model type directories
self.model_dirs = {
'cnn': self.base_dir / "cnn",
'dqn': self.base_dir / "dqn",
'transformer': self.base_dir / "transformer",
'hybrid': self.base_dir / "hybrid"
}
# Ensure all directories exist
self._ensure_directories()
# Metadata tracking
self.metadata_file = self.base_dir / "registry_metadata.json"
self.metadata = self._load_metadata()
logger.info(f"Model Registry initialized at {self.base_dir}")
def _ensure_directories(self):
"""Ensure all required directories exist."""
directories = [
self.saved_dir,
self.checkpoint_dir,
self.archive_dir
]
# Add model type directories
for model_dir in self.model_dirs.values():
directories.extend([
model_dir / "saved",
model_dir / "checkpoints",
model_dir / "archive"
])
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load registry metadata."""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _save_metadata(self):
"""Save registry metadata."""
self.metadata['last_updated'] = datetime.now().isoformat()
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model to the unified storage.
Args:
model: The model to save
model_name: Name of the model
model_type: Type of model (cnn, dqn, transformer, hybrid)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
# Generate filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{model_name}_{timestamp}.pt"
filepath = save_dir / filename
# Also save as latest
latest_filepath = save_dir / f"{model_name}_latest.pt"
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(save_dict, filepath)
torch.save(save_dict, latest_filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
self.metadata['models'][model_name].update({
'type': model_type,
'latest_path': str(latest_filepath),
'last_saved': timestamp,
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
})
self._save_metadata()
logger.info(f"Model {model_name} saved to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def load_model(self, model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Load a model from the unified storage.
Args:
model_name: Name of the model to load
model_type: Type of model (cnn, dqn, transformer, hybrid)
model_class: Model class to instantiate (if needed)
Returns:
The loaded model or None if failed
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found at {latest_filepath}")
return None
# Load checkpoint
checkpoint = torch.load(latest_filepath, map_location='cpu')
# Instantiate model if class provided
if model_class is not None:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Try to reconstruct model from state_dict
model = type('LoadedModel', (), {})()
model.state_dict = lambda: checkpoint['model_state_dict']
model.load_state_dict = lambda state_dict: None
logger.info(f"Model {model_name} loaded from {latest_filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
return None
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model checkpoint.
Args:
model: The model to checkpoint
model_name: Name of the model
model_type: Type of model
performance_score: Performance score for this checkpoint
metadata: Additional metadata
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
checkpoint_dir = model_dir / "checkpoints"
# Generate checkpoint ID
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
# Save checkpoint
checkpoint_data = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'model_name': model_name,
'performance_score': performance_score,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(checkpoint_data, filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
if 'checkpoints' not in self.metadata['models'][model_name]:
self.metadata['models'][model_name]['checkpoints'] = []
checkpoint_info = {
'id': checkpoint_id,
'path': str(filepath),
'performance_score': performance_score,
'timestamp': timestamp
}
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
# Keep only top 5 checkpoints
checkpoints = self.metadata['models'][model_name]['checkpoints']
if len(checkpoints) > 5:
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
checkpoints_to_remove = checkpoints[5:]
for checkpoint in checkpoints_to_remove:
try:
os.remove(checkpoint['path'])
except:
pass
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
self._save_metadata()
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
return True
except Exception as e:
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
return False
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint for a model.
Args:
model_name: Name of the model
model_type: Type of model
Returns:
Tuple of (checkpoint_path, checkpoint_data) or None
"""
try:
if model_name not in self.metadata['models']:
logger.warning(f"No metadata found for model {model_name}")
return None
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if not checkpoints:
logger.warning(f"No checkpoints found for model {model_name}")
return None
# Find best checkpoint by performance score
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
checkpoint_path = best_checkpoint['path']
if not os.path.exists(checkpoint_path):
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
return checkpoint_path, checkpoint_data
except Exception as e:
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
return None
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
"""
Archive a model by moving it to archive directory.
Args:
model_name: Name of the model to archive
model_type: Type of model
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
archive_dir = model_dir / "archive"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found to archive")
return False
# Move to archive with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
os.rename(latest_filepath, archive_filepath)
logger.info(f"Model {model_name} archived to {archive_filepath}")
return True
except Exception as e:
logger.error(f"Failed to archive model {model_name}: {e}")
return False
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
"""
List all models in the registry.
Args:
model_type: Filter by model type (optional)
Returns:
Dictionary of model information
"""
models_info = {}
for model_name, model_data in self.metadata['models'].items():
if model_type and model_data.get('type') != model_type:
continue
models_info[model_name] = {
'type': model_data.get('type'),
'last_saved': model_data.get('last_saved'),
'save_count': model_data.get('save_count', 0),
'checkpoint_count': len(model_data.get('checkpoints', [])),
'latest_path': model_data.get('latest_path')
}
return models_info
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
"""
Clean up old checkpoints, keeping only the best ones.
Args:
model_name: Name of the model
keep_count: Number of checkpoints to keep
Returns:
Number of checkpoints removed
"""
if model_name not in self.metadata['models']:
return 0
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if len(checkpoints) <= keep_count:
return 0
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
# Remove old checkpoints
removed_count = 0
for checkpoint in checkpoints[keep_count:]:
try:
os.remove(checkpoint['path'])
removed_count += 1
except:
pass
# Update metadata
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
self._save_metadata()
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
return removed_count
# Global registry instance
_registry_instance = None
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry_instance
if _registry_instance is None:
_registry_instance = ModelRegistry()
return _registry_instance
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a model using the global registry.
"""
return get_model_registry().save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Convenience function to load a model using the global registry.
"""
return get_model_registry().load_model(model_name, model_type, model_class)
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a checkpoint using the global registry.
"""
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Convenience function to load the best checkpoint using the global registry.
"""
return get_model_registry().load_best_checkpoint(model_name, model_type)