#!/usr/bin/env python3 """ Unified Model Registry for Centralized Model Management This module provides a unified interface for saving, loading, and managing all machine learning models in the trading system. It consolidates model storage from multiple locations into a single, organized structure. """ import os import json import torch import logging import pickle from pathlib import Path from typing import Dict, Any, Optional, Tuple, List from datetime import datetime import hashlib logger = logging.getLogger(__name__) class ModelRegistry: """ Unified model registry for centralized model management. Handles saving, loading, and organization of all ML models. """ def __init__(self, base_dir: str = "models"): """ Initialize the model registry. Args: base_dir: Base directory for model storage """ self.base_dir = Path(base_dir) self.saved_dir = self.base_dir / "saved" self.checkpoint_dir = self.base_dir / "checkpoints" self.archive_dir = self.base_dir / "archive" # Model type directories self.model_dirs = { 'cnn': self.base_dir / "cnn", 'dqn': self.base_dir / "dqn", 'transformer': self.base_dir / "transformer", 'hybrid': self.base_dir / "hybrid" } # Ensure all directories exist self._ensure_directories() # Metadata tracking self.metadata_file = self.base_dir / "registry_metadata.json" self.metadata = self._load_metadata() logger.info(f"Model Registry initialized at {self.base_dir}") def _ensure_directories(self): """Ensure all required directories exist.""" directories = [ self.saved_dir, self.checkpoint_dir, self.archive_dir ] # Add model type directories for model_dir in self.model_dirs.values(): directories.extend([ model_dir / "saved", model_dir / "checkpoints", model_dir / "archive" ]) for directory in directories: directory.mkdir(parents=True, exist_ok=True) def _load_metadata(self) -> Dict[str, Any]: """Load registry metadata.""" if self.metadata_file.exists(): try: with open(self.metadata_file, 'r') as f: return json.load(f) except Exception as e: logger.warning(f"Failed to load metadata: {e}") return {'models': {}, 'last_updated': datetime.now().isoformat()} def _save_metadata(self): """Save registry metadata.""" self.metadata['last_updated'] = datetime.now().isoformat() try: with open(self.metadata_file, 'w') as f: json.dump(self.metadata, f, indent=2) except Exception as e: logger.error(f"Failed to save metadata: {e}") def save_model(self, model: Any, model_name: str, model_type: str = 'cnn', metadata: Optional[Dict[str, Any]] = None) -> bool: """ Save a model to the unified storage. Args: model: The model to save model_name: Name of the model model_type: Type of model (cnn, dqn, transformer, hybrid) metadata: Additional metadata to save Returns: bool: True if successful, False otherwise """ try: model_dir = self.model_dirs.get(model_type, self.saved_dir) save_dir = model_dir / "saved" # Generate filename with timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"{model_name}_{timestamp}.pt" filepath = save_dir / filename # Also save as latest latest_filepath = save_dir / f"{model_name}_latest.pt" # Save model save_dict = { 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, 'model_class': model.__class__.__name__, 'model_type': model_type, 'timestamp': timestamp, 'metadata': metadata or {} } torch.save(save_dict, filepath) torch.save(save_dict, latest_filepath) # Update metadata if model_name not in self.metadata['models']: self.metadata['models'][model_name] = {} self.metadata['models'][model_name].update({ 'type': model_type, 'latest_path': str(latest_filepath), 'last_saved': timestamp, 'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1 }) self._save_metadata() logger.info(f"Model {model_name} saved to {filepath}") return True except Exception as e: logger.error(f"Failed to save model {model_name}: {e}") return False def load_model(self, model_name: str, model_type: str = 'cnn', model_class: Optional[Any] = None) -> Optional[Any]: """ Load a model from the unified storage. Args: model_name: Name of the model to load model_type: Type of model (cnn, dqn, transformer, hybrid) model_class: Model class to instantiate (if needed) Returns: The loaded model or None if failed """ try: model_dir = self.model_dirs.get(model_type, self.saved_dir) save_dir = model_dir / "saved" latest_filepath = save_dir / f"{model_name}_latest.pt" if not latest_filepath.exists(): logger.warning(f"Model {model_name} not found at {latest_filepath}") return None # Load checkpoint checkpoint = torch.load(latest_filepath, map_location='cpu') # Instantiate model if class provided if model_class is not None: model = model_class() model.load_state_dict(checkpoint['model_state_dict']) else: # Try to reconstruct model from state_dict model = type('LoadedModel', (), {})() model.state_dict = lambda: checkpoint['model_state_dict'] model.load_state_dict = lambda state_dict: None logger.info(f"Model {model_name} loaded from {latest_filepath}") return model except Exception as e: logger.error(f"Failed to load model {model_name}: {e}") return None def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn', performance_score: float = 0.0, metadata: Optional[Dict[str, Any]] = None) -> bool: """ Save a model checkpoint. Args: model: The model to checkpoint model_name: Name of the model model_type: Type of model performance_score: Performance score for this checkpoint metadata: Additional metadata Returns: bool: True if successful, False otherwise """ try: model_dir = self.model_dirs.get(model_type, self.checkpoint_dir) checkpoint_dir = model_dir / "checkpoints" # Generate checkpoint ID timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}" filepath = checkpoint_dir / f"{checkpoint_id}.pt" # Save checkpoint checkpoint_data = { 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {}, 'model_class': model.__class__.__name__, 'model_type': model_type, 'model_name': model_name, 'performance_score': performance_score, 'timestamp': timestamp, 'metadata': metadata or {} } torch.save(checkpoint_data, filepath) # Update metadata if model_name not in self.metadata['models']: self.metadata['models'][model_name] = {} if 'checkpoints' not in self.metadata['models'][model_name]: self.metadata['models'][model_name]['checkpoints'] = [] checkpoint_info = { 'id': checkpoint_id, 'path': str(filepath), 'performance_score': performance_score, 'timestamp': timestamp } self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info) # Keep only top 5 checkpoints checkpoints = self.metadata['models'][model_name]['checkpoints'] if len(checkpoints) > 5: checkpoints.sort(key=lambda x: x['performance_score'], reverse=True) checkpoints_to_remove = checkpoints[5:] for checkpoint in checkpoints_to_remove: try: os.remove(checkpoint['path']) except: pass self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5] self._save_metadata() logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}") return True except Exception as e: logger.error(f"Failed to save checkpoint for {model_name}: {e}") return False def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]: """ Load the best checkpoint for a model. Args: model_name: Name of the model model_type: Type of model Returns: Tuple of (checkpoint_path, checkpoint_data) or None """ try: if model_name not in self.metadata['models']: logger.warning(f"No metadata found for model {model_name}") return None checkpoints = self.metadata['models'][model_name].get('checkpoints', []) if not checkpoints: logger.warning(f"No checkpoints found for model {model_name}") return None # Find best checkpoint by performance score best_checkpoint = max(checkpoints, key=lambda x: x['performance_score']) checkpoint_path = best_checkpoint['path'] if not os.path.exists(checkpoint_path): logger.warning(f"Checkpoint file not found: {checkpoint_path}") return None checkpoint_data = torch.load(checkpoint_path, map_location='cpu') logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}") return checkpoint_path, checkpoint_data except Exception as e: logger.error(f"Failed to load best checkpoint for {model_name}: {e}") return None def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool: """ Archive a model by moving it to archive directory. Args: model_name: Name of the model to archive model_type: Type of model Returns: bool: True if successful, False otherwise """ try: model_dir = self.model_dirs.get(model_type, self.saved_dir) save_dir = model_dir / "saved" archive_dir = model_dir / "archive" latest_filepath = save_dir / f"{model_name}_latest.pt" if not latest_filepath.exists(): logger.warning(f"Model {model_name} not found to archive") return False # Move to archive with timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt" os.rename(latest_filepath, archive_filepath) logger.info(f"Model {model_name} archived to {archive_filepath}") return True except Exception as e: logger.error(f"Failed to archive model {model_name}: {e}") return False def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]: """ List all models in the registry. Args: model_type: Filter by model type (optional) Returns: Dictionary of model information """ models_info = {} for model_name, model_data in self.metadata['models'].items(): if model_type and model_data.get('type') != model_type: continue models_info[model_name] = { 'type': model_data.get('type'), 'last_saved': model_data.get('last_saved'), 'save_count': model_data.get('save_count', 0), 'checkpoint_count': len(model_data.get('checkpoints', [])), 'latest_path': model_data.get('latest_path') } return models_info def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int: """ Clean up old checkpoints, keeping only the best ones. Args: model_name: Name of the model keep_count: Number of checkpoints to keep Returns: Number of checkpoints removed """ if model_name not in self.metadata['models']: return 0 checkpoints = self.metadata['models'][model_name].get('checkpoints', []) if len(checkpoints) <= keep_count: return 0 # Sort by performance score (descending) checkpoints.sort(key=lambda x: x['performance_score'], reverse=True) # Remove old checkpoints removed_count = 0 for checkpoint in checkpoints[keep_count:]: try: os.remove(checkpoint['path']) removed_count += 1 except: pass # Update metadata self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count] self._save_metadata() logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}") return removed_count # Global registry instance _registry_instance = None def get_model_registry() -> ModelRegistry: """Get the global model registry instance.""" global _registry_instance if _registry_instance is None: _registry_instance = ModelRegistry() return _registry_instance def save_model(model: Any, model_name: str, model_type: str = 'cnn', metadata: Optional[Dict[str, Any]] = None) -> bool: """ Convenience function to save a model using the global registry. """ return get_model_registry().save_model(model, model_name, model_type, metadata) def load_model(model_name: str, model_type: str = 'cnn', model_class: Optional[Any] = None) -> Optional[Any]: """ Convenience function to load a model using the global registry. """ return get_model_registry().load_model(model_name, model_type, model_class) def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn', performance_score: float = 0.0, metadata: Optional[Dict[str, Any]] = None) -> bool: """ Convenience function to save a checkpoint using the global registry. """ return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata) def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]: """ Convenience function to load the best checkpoint using the global registry. """ return get_model_registry().load_best_checkpoint(model_name, model_type)