447 lines
16 KiB
Python
447 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unified Model Registry for Centralized Model Management
|
|
|
|
This module provides a unified interface for saving, loading, and managing
|
|
all machine learning models in the trading system. It consolidates model
|
|
storage from multiple locations into a single, organized structure.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import torch
|
|
import logging
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, Tuple, List
|
|
from datetime import datetime
|
|
import hashlib
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
Unified model registry for centralized model management.
|
|
Handles saving, loading, and organization of all ML models.
|
|
"""
|
|
|
|
def __init__(self, base_dir: str = "models"):
|
|
"""
|
|
Initialize the model registry.
|
|
|
|
Args:
|
|
base_dir: Base directory for model storage
|
|
"""
|
|
self.base_dir = Path(base_dir)
|
|
self.saved_dir = self.base_dir / "saved"
|
|
self.checkpoint_dir = self.base_dir / "checkpoints"
|
|
self.archive_dir = self.base_dir / "archive"
|
|
|
|
# Model type directories
|
|
self.model_dirs = {
|
|
'cnn': self.base_dir / "cnn",
|
|
'dqn': self.base_dir / "dqn",
|
|
'transformer': self.base_dir / "transformer",
|
|
'hybrid': self.base_dir / "hybrid"
|
|
}
|
|
|
|
# Ensure all directories exist
|
|
self._ensure_directories()
|
|
|
|
# Metadata tracking
|
|
self.metadata_file = self.base_dir / "registry_metadata.json"
|
|
self.metadata = self._load_metadata()
|
|
|
|
logger.info(f"Model Registry initialized at {self.base_dir}")
|
|
|
|
def _ensure_directories(self):
|
|
"""Ensure all required directories exist."""
|
|
directories = [
|
|
self.saved_dir,
|
|
self.checkpoint_dir,
|
|
self.archive_dir
|
|
]
|
|
|
|
# Add model type directories
|
|
for model_dir in self.model_dirs.values():
|
|
directories.extend([
|
|
model_dir / "saved",
|
|
model_dir / "checkpoints",
|
|
model_dir / "archive"
|
|
])
|
|
|
|
for directory in directories:
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _load_metadata(self) -> Dict[str, Any]:
|
|
"""Load registry metadata."""
|
|
if self.metadata_file.exists():
|
|
try:
|
|
with open(self.metadata_file, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load metadata: {e}")
|
|
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
|
|
|
def _save_metadata(self):
|
|
"""Save registry metadata."""
|
|
self.metadata['last_updated'] = datetime.now().isoformat()
|
|
try:
|
|
with open(self.metadata_file, 'w') as f:
|
|
json.dump(self.metadata, f, indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save metadata: {e}")
|
|
|
|
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Save a model to the unified storage.
|
|
|
|
Args:
|
|
model: The model to save
|
|
model_name: Name of the model
|
|
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
|
metadata: Additional metadata to save
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
|
save_dir = model_dir / "saved"
|
|
|
|
# Generate filename with timestamp
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
filename = f"{model_name}_{timestamp}.pt"
|
|
filepath = save_dir / filename
|
|
|
|
# Also save as latest
|
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
|
|
|
# Save model
|
|
save_dict = {
|
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
|
'model_class': model.__class__.__name__,
|
|
'model_type': model_type,
|
|
'timestamp': timestamp,
|
|
'metadata': metadata or {}
|
|
}
|
|
|
|
torch.save(save_dict, filepath)
|
|
torch.save(save_dict, latest_filepath)
|
|
|
|
# Update metadata
|
|
if model_name not in self.metadata['models']:
|
|
self.metadata['models'][model_name] = {}
|
|
|
|
self.metadata['models'][model_name].update({
|
|
'type': model_type,
|
|
'latest_path': str(latest_filepath),
|
|
'last_saved': timestamp,
|
|
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
|
|
})
|
|
|
|
self._save_metadata()
|
|
|
|
logger.info(f"Model {model_name} saved to {filepath}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save model {model_name}: {e}")
|
|
return False
|
|
|
|
def load_model(self, model_name: str, model_type: str = 'cnn',
|
|
model_class: Optional[Any] = None) -> Optional[Any]:
|
|
"""
|
|
Load a model from the unified storage.
|
|
|
|
Args:
|
|
model_name: Name of the model to load
|
|
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
|
model_class: Model class to instantiate (if needed)
|
|
|
|
Returns:
|
|
The loaded model or None if failed
|
|
"""
|
|
try:
|
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
|
save_dir = model_dir / "saved"
|
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
|
|
|
if not latest_filepath.exists():
|
|
logger.warning(f"Model {model_name} not found at {latest_filepath}")
|
|
return None
|
|
|
|
# Load checkpoint
|
|
checkpoint = torch.load(latest_filepath, map_location='cpu')
|
|
|
|
# Instantiate model if class provided
|
|
if model_class is not None:
|
|
model = model_class()
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
else:
|
|
# Try to reconstruct model from state_dict
|
|
model = type('LoadedModel', (), {})()
|
|
model.state_dict = lambda: checkpoint['model_state_dict']
|
|
model.load_state_dict = lambda state_dict: None
|
|
|
|
logger.info(f"Model {model_name} loaded from {latest_filepath}")
|
|
return model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model {model_name}: {e}")
|
|
return None
|
|
|
|
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
|
|
performance_score: float = 0.0,
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Save a model checkpoint.
|
|
|
|
Args:
|
|
model: The model to checkpoint
|
|
model_name: Name of the model
|
|
model_type: Type of model
|
|
performance_score: Performance score for this checkpoint
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
|
|
checkpoint_dir = model_dir / "checkpoints"
|
|
|
|
# Generate checkpoint ID
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
|
|
|
|
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
|
|
|
|
# Save checkpoint
|
|
checkpoint_data = {
|
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
|
'model_class': model.__class__.__name__,
|
|
'model_type': model_type,
|
|
'model_name': model_name,
|
|
'performance_score': performance_score,
|
|
'timestamp': timestamp,
|
|
'metadata': metadata or {}
|
|
}
|
|
|
|
torch.save(checkpoint_data, filepath)
|
|
|
|
# Update metadata
|
|
if model_name not in self.metadata['models']:
|
|
self.metadata['models'][model_name] = {}
|
|
|
|
if 'checkpoints' not in self.metadata['models'][model_name]:
|
|
self.metadata['models'][model_name]['checkpoints'] = []
|
|
|
|
checkpoint_info = {
|
|
'id': checkpoint_id,
|
|
'path': str(filepath),
|
|
'performance_score': performance_score,
|
|
'timestamp': timestamp
|
|
}
|
|
|
|
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
|
|
|
|
# Keep only top 5 checkpoints
|
|
checkpoints = self.metadata['models'][model_name]['checkpoints']
|
|
if len(checkpoints) > 5:
|
|
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
|
checkpoints_to_remove = checkpoints[5:]
|
|
|
|
for checkpoint in checkpoints_to_remove:
|
|
try:
|
|
os.remove(checkpoint['path'])
|
|
except:
|
|
pass
|
|
|
|
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
|
|
|
|
self._save_metadata()
|
|
|
|
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
|
|
return False
|
|
|
|
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
|
"""
|
|
Load the best checkpoint for a model.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
model_type: Type of model
|
|
|
|
Returns:
|
|
Tuple of (checkpoint_path, checkpoint_data) or None
|
|
"""
|
|
try:
|
|
if model_name not in self.metadata['models']:
|
|
logger.warning(f"No metadata found for model {model_name}")
|
|
return None
|
|
|
|
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
|
if not checkpoints:
|
|
logger.warning(f"No checkpoints found for model {model_name}")
|
|
return None
|
|
|
|
# Find best checkpoint by performance score
|
|
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
|
|
checkpoint_path = best_checkpoint['path']
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
|
|
return None
|
|
|
|
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
|
|
return checkpoint_path, checkpoint_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
|
|
return None
|
|
|
|
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
|
|
"""
|
|
Archive a model by moving it to archive directory.
|
|
|
|
Args:
|
|
model_name: Name of the model to archive
|
|
model_type: Type of model
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
|
save_dir = model_dir / "saved"
|
|
archive_dir = model_dir / "archive"
|
|
|
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
|
|
|
if not latest_filepath.exists():
|
|
logger.warning(f"Model {model_name} not found to archive")
|
|
return False
|
|
|
|
# Move to archive with timestamp
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
|
|
|
|
os.rename(latest_filepath, archive_filepath)
|
|
|
|
logger.info(f"Model {model_name} archived to {archive_filepath}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to archive model {model_name}: {e}")
|
|
return False
|
|
|
|
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
List all models in the registry.
|
|
|
|
Args:
|
|
model_type: Filter by model type (optional)
|
|
|
|
Returns:
|
|
Dictionary of model information
|
|
"""
|
|
models_info = {}
|
|
|
|
for model_name, model_data in self.metadata['models'].items():
|
|
if model_type and model_data.get('type') != model_type:
|
|
continue
|
|
|
|
models_info[model_name] = {
|
|
'type': model_data.get('type'),
|
|
'last_saved': model_data.get('last_saved'),
|
|
'save_count': model_data.get('save_count', 0),
|
|
'checkpoint_count': len(model_data.get('checkpoints', [])),
|
|
'latest_path': model_data.get('latest_path')
|
|
}
|
|
|
|
return models_info
|
|
|
|
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
|
|
"""
|
|
Clean up old checkpoints, keeping only the best ones.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
keep_count: Number of checkpoints to keep
|
|
|
|
Returns:
|
|
Number of checkpoints removed
|
|
"""
|
|
if model_name not in self.metadata['models']:
|
|
return 0
|
|
|
|
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
|
if len(checkpoints) <= keep_count:
|
|
return 0
|
|
|
|
# Sort by performance score (descending)
|
|
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
|
|
|
# Remove old checkpoints
|
|
removed_count = 0
|
|
for checkpoint in checkpoints[keep_count:]:
|
|
try:
|
|
os.remove(checkpoint['path'])
|
|
removed_count += 1
|
|
except:
|
|
pass
|
|
|
|
# Update metadata
|
|
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
|
|
self._save_metadata()
|
|
|
|
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
|
|
return removed_count
|
|
|
|
|
|
# Global registry instance
|
|
_registry_instance = None
|
|
|
|
def get_model_registry() -> ModelRegistry:
|
|
"""Get the global model registry instance."""
|
|
global _registry_instance
|
|
if _registry_instance is None:
|
|
_registry_instance = ModelRegistry()
|
|
return _registry_instance
|
|
|
|
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Convenience function to save a model using the global registry.
|
|
"""
|
|
return get_model_registry().save_model(model, model_name, model_type, metadata)
|
|
|
|
def load_model(model_name: str, model_type: str = 'cnn',
|
|
model_class: Optional[Any] = None) -> Optional[Any]:
|
|
"""
|
|
Convenience function to load a model using the global registry.
|
|
"""
|
|
return get_model_registry().load_model(model_name, model_type, model_class)
|
|
|
|
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
|
|
performance_score: float = 0.0,
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Convenience function to save a checkpoint using the global registry.
|
|
"""
|
|
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
|
|
|
|
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
|
"""
|
|
Convenience function to load the best checkpoint using the global registry.
|
|
"""
|
|
return get_model_registry().load_best_checkpoint(model_name, model_type)
|