Files
gogo2/utils/model_registry.py
2025-09-08 13:31:11 +03:00

447 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Unified Model Registry for Centralized Model Management
This module provides a unified interface for saving, loading, and managing
all machine learning models in the trading system. It consolidates model
storage from multiple locations into a single, organized structure.
"""
import os
import json
import torch
import logging
import pickle
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from datetime import datetime
import hashlib
logger = logging.getLogger(__name__)
class ModelRegistry:
"""
Unified model registry for centralized model management.
Handles saving, loading, and organization of all ML models.
"""
def __init__(self, base_dir: str = "models"):
"""
Initialize the model registry.
Args:
base_dir: Base directory for model storage
"""
self.base_dir = Path(base_dir)
self.saved_dir = self.base_dir / "saved"
self.checkpoint_dir = self.base_dir / "checkpoints"
self.archive_dir = self.base_dir / "archive"
# Model type directories
self.model_dirs = {
'cnn': self.base_dir / "cnn",
'dqn': self.base_dir / "dqn",
'transformer': self.base_dir / "transformer",
'hybrid': self.base_dir / "hybrid"
}
# Ensure all directories exist
self._ensure_directories()
# Metadata tracking
self.metadata_file = self.base_dir / "registry_metadata.json"
self.metadata = self._load_metadata()
logger.info(f"Model Registry initialized at {self.base_dir}")
def _ensure_directories(self):
"""Ensure all required directories exist."""
directories = [
self.saved_dir,
self.checkpoint_dir,
self.archive_dir
]
# Add model type directories
for model_dir in self.model_dirs.values():
directories.extend([
model_dir / "saved",
model_dir / "checkpoints",
model_dir / "archive"
])
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load registry metadata."""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _save_metadata(self):
"""Save registry metadata."""
self.metadata['last_updated'] = datetime.now().isoformat()
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model to the unified storage.
Args:
model: The model to save
model_name: Name of the model
model_type: Type of model (cnn, dqn, transformer, hybrid)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
# Generate filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{model_name}_{timestamp}.pt"
filepath = save_dir / filename
# Also save as latest
latest_filepath = save_dir / f"{model_name}_latest.pt"
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(save_dict, filepath)
torch.save(save_dict, latest_filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
self.metadata['models'][model_name].update({
'type': model_type,
'latest_path': str(latest_filepath),
'last_saved': timestamp,
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
})
self._save_metadata()
logger.info(f"Model {model_name} saved to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def load_model(self, model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Load a model from the unified storage.
Args:
model_name: Name of the model to load
model_type: Type of model (cnn, dqn, transformer, hybrid)
model_class: Model class to instantiate (if needed)
Returns:
The loaded model or None if failed
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found at {latest_filepath}")
return None
# Load checkpoint
checkpoint = torch.load(latest_filepath, map_location='cpu')
# Instantiate model if class provided
if model_class is not None:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Try to reconstruct model from state_dict
model = type('LoadedModel', (), {})()
model.state_dict = lambda: checkpoint['model_state_dict']
model.load_state_dict = lambda state_dict: None
logger.info(f"Model {model_name} loaded from {latest_filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
return None
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model checkpoint.
Args:
model: The model to checkpoint
model_name: Name of the model
model_type: Type of model
performance_score: Performance score for this checkpoint
metadata: Additional metadata
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
checkpoint_dir = model_dir / "checkpoints"
# Generate checkpoint ID
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
# Save checkpoint
checkpoint_data = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'model_name': model_name,
'performance_score': performance_score,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(checkpoint_data, filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
if 'checkpoints' not in self.metadata['models'][model_name]:
self.metadata['models'][model_name]['checkpoints'] = []
checkpoint_info = {
'id': checkpoint_id,
'path': str(filepath),
'performance_score': performance_score,
'timestamp': timestamp
}
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
# Keep only top 5 checkpoints
checkpoints = self.metadata['models'][model_name]['checkpoints']
if len(checkpoints) > 5:
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
checkpoints_to_remove = checkpoints[5:]
for checkpoint in checkpoints_to_remove:
try:
os.remove(checkpoint['path'])
except:
pass
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
self._save_metadata()
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
return True
except Exception as e:
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
return False
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint for a model.
Args:
model_name: Name of the model
model_type: Type of model
Returns:
Tuple of (checkpoint_path, checkpoint_data) or None
"""
try:
if model_name not in self.metadata['models']:
logger.warning(f"No metadata found for model {model_name}")
return None
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if not checkpoints:
logger.warning(f"No checkpoints found for model {model_name}")
return None
# Find best checkpoint by performance score
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
checkpoint_path = best_checkpoint['path']
if not os.path.exists(checkpoint_path):
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
return checkpoint_path, checkpoint_data
except Exception as e:
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
return None
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
"""
Archive a model by moving it to archive directory.
Args:
model_name: Name of the model to archive
model_type: Type of model
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
archive_dir = model_dir / "archive"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found to archive")
return False
# Move to archive with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
os.rename(latest_filepath, archive_filepath)
logger.info(f"Model {model_name} archived to {archive_filepath}")
return True
except Exception as e:
logger.error(f"Failed to archive model {model_name}: {e}")
return False
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
"""
List all models in the registry.
Args:
model_type: Filter by model type (optional)
Returns:
Dictionary of model information
"""
models_info = {}
for model_name, model_data in self.metadata['models'].items():
if model_type and model_data.get('type') != model_type:
continue
models_info[model_name] = {
'type': model_data.get('type'),
'last_saved': model_data.get('last_saved'),
'save_count': model_data.get('save_count', 0),
'checkpoint_count': len(model_data.get('checkpoints', [])),
'latest_path': model_data.get('latest_path')
}
return models_info
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
"""
Clean up old checkpoints, keeping only the best ones.
Args:
model_name: Name of the model
keep_count: Number of checkpoints to keep
Returns:
Number of checkpoints removed
"""
if model_name not in self.metadata['models']:
return 0
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if len(checkpoints) <= keep_count:
return 0
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
# Remove old checkpoints
removed_count = 0
for checkpoint in checkpoints[keep_count:]:
try:
os.remove(checkpoint['path'])
removed_count += 1
except:
pass
# Update metadata
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
self._save_metadata()
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
return removed_count
# Global registry instance
_registry_instance = None
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry_instance
if _registry_instance is None:
_registry_instance = ModelRegistry()
return _registry_instance
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a model using the global registry.
"""
return get_model_registry().save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Convenience function to load a model using the global registry.
"""
return get_model_registry().load_model(model_name, model_type, model_class)
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a checkpoint using the global registry.
"""
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Convenience function to load the best checkpoint using the global registry.
"""
return get_model_registry().load_best_checkpoint(model_name, model_type)