Files
gogo2/NN/training/model_manager.py
Dobromir Popov 317c703ea0 unify model names
2025-09-09 01:10:35 +03:00

783 lines
33 KiB
Python

"""
Unified Model Management System for Trading Dashboard
CONSOLIDATED SYSTEM - All model management functionality in one place
This system provides:
- Automatic cleanup of old model checkpoints
- Best model tracking with performance metrics
- Configurable retention policies
- Startup model loading
- Performance-based model selection
- Robust model saving with multiple fallback strategies
- Checkpoint management with W&B integration
- Centralized storage using @checkpoints/ structure
"""
import os
import json
import shutil
import logging
import torch
import glob
import pickle
import hashlib
import random
import numpy as np
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List, Tuple, Union
from collections import defaultdict
# W&B import (optional)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
"""Enhanced performance metrics for model evaluation"""
accuracy: float = 0.0
profit_factor: float = 0.0
win_rate: float = 0.0
sharpe_ratio: float = 0.0
max_drawdown: float = 0.0
total_trades: int = 0
avg_trade_duration: float = 0.0
confidence_score: float = 0.0
# Additional metrics from checkpoint_manager
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
def get_composite_score(self) -> float:
"""Calculate composite performance score"""
# Weighted composite score
weights = {
'profit_factor': 0.25,
'sharpe_ratio': 0.2,
'win_rate': 0.15,
'accuracy': 0.15,
'confidence_score': 0.1,
'loss_penalty': 0.1, # New: penalize high loss
'val_penalty': 0.05 # New: penalize validation loss
}
# Normalize values to 0-1 range
normalized_pf = min(max(self.profit_factor / 3.0, 0), 1) # PF of 3+ = 1.0
normalized_sharpe = min(max((self.sharpe_ratio + 2) / 4, 0), 1) # Sharpe -2 to 2 -> 0 to 1
normalized_win_rate = self.win_rate
normalized_accuracy = self.accuracy
normalized_confidence = self.confidence_score
# Loss penalty (lower loss = higher score)
loss_penalty = 1.0
if self.loss is not None and self.loss > 0:
loss_penalty = max(0.1, 1 / (1 + self.loss)) # Better loss = higher penalty
# Validation penalty
val_penalty = 1.0
if self.val_loss is not None and self.val_loss > 0:
val_penalty = max(0.1, 1 / (1 + self.val_loss))
# Apply penalties for poor performance
drawdown_penalty = max(0, 1 - self.max_drawdown / 0.2) # Penalty for >20% drawdown
score = (
weights['profit_factor'] * normalized_pf +
weights['sharpe_ratio'] * normalized_sharpe +
weights['win_rate'] * normalized_win_rate +
weights['accuracy'] * normalized_accuracy +
weights['confidence_score'] * normalized_confidence +
weights['loss_penalty'] * loss_penalty +
weights['val_penalty'] * val_penalty
) * drawdown_penalty
return min(max(score, 0), 1)
@dataclass
class ModelInfo:
"""Model information tracking"""
model_type: str # 'cnn', 'rl', 'transformer'
model_name: str
file_path: str
creation_time: datetime
last_updated: datetime
file_size_mb: float
metrics: ModelMetrics
training_episodes: int = 0
model_version: str = "1.0"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
data = asdict(self)
data['creation_time'] = self.creation_time.isoformat()
data['last_updated'] = self.last_updated.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelInfo':
"""Create from dictionary"""
data['creation_time'] = datetime.fromisoformat(data['creation_time'])
data['last_updated'] = datetime.fromisoformat(data['last_updated'])
data['metrics'] = ModelMetrics(**data['metrics'])
return cls(**data)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class ModelManager:
"""Unified model management system with @checkpoints/ structure"""
def __init__(self, base_dir: str = ".", config: Optional[Dict[str, Any]] = None):
self.base_dir = Path(base_dir)
self.config = config or self._get_default_config()
# Updated directory structure using @checkpoints/
self.checkpoints_dir = self.base_dir / "@checkpoints"
self.models_dir = self.checkpoints_dir / "models"
self.saved_dir = self.checkpoints_dir / "saved"
self.best_models_dir = self.checkpoints_dir / "best_models"
self.archive_dir = self.checkpoints_dir / "archive"
# Model type directories within @checkpoints/
self.model_dirs = {
'cnn': self.checkpoints_dir / "cnn",
'dqn': self.checkpoints_dir / "dqn",
'rl': self.checkpoints_dir / "rl",
'transformer': self.checkpoints_dir / "transformer",
'hybrid': self.checkpoints_dir / "hybrid"
}
# Legacy directories for backward compatibility
self.nn_models_dir = self.base_dir / "NN" / "models"
self.legacy_models_dir = self.base_dir / "models"
# Legacy checkpoint directories (where existing checkpoints are stored)
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
# Metadata and checkpoint management
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
# Initialize storage
self._initialize_directories()
self.metadata = self._load_metadata()
self.checkpoint_metadata = self._load_checkpoint_metadata()
logger.info(f"ModelManager initialized with @checkpoints/ structure at {self.checkpoints_dir}")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'max_checkpoints_per_model': 5,
'cleanup_old_models': True,
'auto_archive': True,
'wandb_enabled': WANDB_AVAILABLE,
'checkpoint_retention_days': 30
}
def _initialize_directories(self):
"""Initialize directory structure"""
directories = [
self.checkpoints_dir,
self.models_dir,
self.saved_dir,
self.best_models_dir,
self.archive_dir
] + list(self.model_dirs.values())
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load model metadata with legacy support"""
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
# First try to load from new unified metadata
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
metadata = json.load(f)
logger.info(f"Loaded unified metadata from {self.metadata_file}")
except Exception as e:
logger.error(f"Error loading unified metadata: {e}")
# Also load legacy metadata for backward compatibility
if self.legacy_registry_file.exists():
try:
with open(self.legacy_registry_file, 'r') as f:
legacy_data = json.load(f)
# Merge legacy data into unified metadata
if 'models' in legacy_data:
for model_name, model_info in legacy_data['models'].items():
if model_name not in metadata['models']:
# Convert legacy path format to absolute path
if 'latest_path' in model_info:
legacy_path = model_info['latest_path']
# Handle different legacy path formats
if not legacy_path.startswith('/'):
# Try multiple path resolution strategies
possible_paths = [
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
self.base_dir / legacy_path, # /project/models/cnn/...
]
resolved_path = None
for path in possible_paths:
if path.exists():
resolved_path = path
break
if resolved_path:
legacy_path = str(resolved_path)
else:
# If no resolved path found, try to find the file by pattern
filename = Path(legacy_path).name
for search_path in [self.legacy_checkpoints_dir]:
for file_path in search_path.rglob(filename):
legacy_path = str(file_path)
break
metadata['models'][model_name] = {
'type': model_info.get('type', 'unknown'),
'latest_path': legacy_path,
'last_saved': model_info.get('last_saved', 'legacy'),
'save_count': model_info.get('save_count', 1),
'checkpoints': model_info.get('checkpoints', [])
}
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
except Exception as e:
logger.error(f"Error loading legacy metadata: {e}")
return metadata
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
"""Load checkpoint metadata"""
if self.checkpoint_metadata_file.exists():
try:
with open(self.checkpoint_metadata_file, 'r') as f:
data = json.load(f)
# Convert dict values back to CheckpointMetadata objects
result = {}
for key, checkpoints in data.items():
result[key] = [CheckpointMetadata.from_dict(cp) for cp in checkpoints]
return result
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
return defaultdict(list)
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with enhanced error handling and validation"""
try:
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
# Create checkpoint directory
checkpoint_dir = self.model_dirs.get(model_type, self.saved_dir) / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Generate checkpoint filename
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
filename = f"{checkpoint_id}.pt"
filepath = checkpoint_dir / filename
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'checkpoint_id': checkpoint_id,
'model_name': model_name,
'model_type': model_type,
'performance_score': performance_score,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'created_at': datetime.now().isoformat(),
'version': '2.0'
}
torch.save(save_dict, filepath)
# Create checkpoint metadata
file_size_mb = filepath.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(filepath),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=performance_metrics.get('epoch'),
training_time_hours=performance_metrics.get('training_time_hours'),
total_parameters=performance_metrics.get('total_parameters')
)
# Store metadata
self.checkpoint_metadata[model_name].append(metadata)
self._save_checkpoint_metadata()
# Rotate checkpoints if needed
self._rotate_checkpoints(model_name)
# Upload to W&B if enabled
if self.config.get('wandb_enabled'):
self._upload_to_wandb(metadata)
logger.info(f"Checkpoint saved: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score from metrics"""
# Simple weighted score - can be enhanced
weights = {'accuracy': 0.4, 'profit_factor': 0.3, 'win_rate': 0.2, 'sharpe_ratio': 0.1}
score = 0.0
for metric, weight in weights.items():
if metric in metrics:
score += metrics[metric] * weight
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Determine if checkpoint should be saved"""
existing_checkpoints = self.checkpoint_metadata.get(model_name, [])
if not existing_checkpoints:
return True
# Keep if better than worst checkpoint or if we have fewer than max
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(existing_checkpoints) < max_checkpoints:
return True
worst_score = min(cp.performance_score for cp in existing_checkpoints)
return performance_score > worst_score
def _rotate_checkpoints(self, model_name: str):
"""Rotate checkpoints to maintain max count"""
checkpoints = self.checkpoint_metadata.get(model_name, [])
max_checkpoints = self.config.get('max_checkpoints_per_model', 5)
if len(checkpoints) <= max_checkpoints:
return
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x.performance_score, reverse=True)
# Remove excess checkpoints
to_remove = checkpoints[max_checkpoints:]
for checkpoint in to_remove:
try:
Path(checkpoint.file_path).unlink(missing_ok=True)
logger.debug(f"Removed old checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing checkpoint {checkpoint.checkpoint_id}: {e}")
# Update metadata
self.checkpoint_metadata[model_name] = checkpoints[:max_checkpoints]
self._save_checkpoint_metadata()
def _save_checkpoint_metadata(self):
"""Save checkpoint metadata to file"""
try:
data = {}
for model_name, checkpoints in self.checkpoint_metadata.items():
data[model_name] = [cp.to_dict() for cp in checkpoints]
with open(self.checkpoint_metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def _upload_to_wandb(self, metadata: CheckpointMetadata) -> Optional[str]:
"""Upload checkpoint to W&B"""
if not WANDB_AVAILABLE:
return None
try:
# This would be implemented based on your W&B workflow
logger.debug(f"W&B upload not implemented yet for {metadata.checkpoint_id}")
return None
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Load the best checkpoint for a model with legacy support"""
try:
# First, try the unified registry
model_info = self.metadata['models'].get(model_name)
if model_info and Path(model_info['latest_path']).exists():
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
# Create metadata from model info for compatibility
registry_metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=model_info.get('type', model_name),
file_path=model_info['latest_path'],
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated if needed
performance_score=0.0, # Unknown from registry
accuracy=None,
loss=None, # Orchestrator will handle this
val_accuracy=None,
val_loss=None
)
return model_info['latest_path'], registry_metadata
# Fallback to checkpoint metadata
checkpoints = self.checkpoint_metadata.get(model_name, [])
if checkpoints:
# Get best checkpoint
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
if Path(best_checkpoint.file_path).exists():
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
return best_checkpoint.file_path, best_checkpoint
# Legacy fallback: Look for checkpoints in legacy directories
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
legacy_path = self._find_legacy_checkpoint(model_name)
if legacy_path:
logger.info(f"Found legacy checkpoint: {legacy_path}")
# Create a basic CheckpointMetadata for the legacy checkpoint
legacy_metadata = CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name, # Will be inferred from model type
file_path=str(legacy_path),
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
performance_score=0.0, # Unknown for legacy
accuracy=None,
loss=None
)
return str(legacy_path), legacy_metadata
logger.warning(f"No checkpoints found for {model_name} in any location")
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
"""Find checkpoint in legacy directories"""
if not self.legacy_checkpoints_dir.exists():
return None
# Use unified model naming throughout the project
# All model references use consistent short names: dqn, cnn, cob_rl, transformer, decision
# This eliminates complex mapping and ensures consistency across the entire codebase
patterns = [model_name]
# Add minimal backward compatibility patterns
if model_name == 'dqn':
patterns.extend(['dqn_agent', 'agent'])
elif model_name == 'cnn':
patterns.extend(['cnn_model', 'enhanced_cnn'])
elif model_name == 'cob_rl':
patterns.extend(['rl', 'rl_agent', 'trading_agent'])
# Search in legacy saved directory first
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
if legacy_saved_dir.exists():
for file_path in legacy_saved_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in model-specific directories
for model_type in ['cnn', 'dqn', 'rl', 'transformer', 'decision']:
model_dir = self.legacy_checkpoints_dir / model_type
if model_dir.exists():
saved_dir = model_dir / "saved"
if saved_dir.exists():
for file_path in saved_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in archive directory
archive_dir = self.legacy_checkpoints_dir / "archive"
if archive_dir.exists():
for file_path in archive_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Search in backtest directory (might contain RL or other models)
backtest_dir = self.legacy_checkpoints_dir / "backtest"
if backtest_dir.exists():
for file_path in backtest_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
# Last resort: search entire legacy directory
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
filename = file_path.name.lower()
if any(pattern in filename for pattern in patterns):
return file_path
return None
def get_storage_stats(self) -> Dict[str, Any]:
"""Get storage statistics"""
try:
total_size = 0
file_count = 0
for directory in [self.checkpoints_dir, self.models_dir, self.saved_dir]:
if directory.exists():
for file_path in directory.rglob('*'):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
return {
'total_size_mb': total_size / (1024 * 1024),
'file_count': file_count,
'directories': len(list(self.checkpoints_dir.iterdir())) if self.checkpoints_dir.exists() else 0
}
except Exception as e:
logger.error(f"Error getting storage stats: {e}")
return {'error': str(e)}
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
try:
stats = {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Count files in new unified directories
checkpoint_dirs = [
self.checkpoints_dir / "cnn",
self.checkpoints_dir / "dqn",
self.checkpoints_dir / "rl",
self.checkpoints_dir / "transformer",
self.checkpoints_dir / "hybrid"
]
total_size = 0
total_files = 0
for checkpoint_dir in checkpoint_dirs:
if checkpoint_dir.exists():
model_files = list(checkpoint_dir.rglob('*.pt'))
if model_files:
model_name = checkpoint_dir.name
stats['total_models'] += 1
model_size = sum(f.stat().st_size for f in model_files)
stats['total_checkpoints'] += len(model_files)
stats['total_size_mb'] += model_size / (1024 * 1024)
total_size += model_size
total_files += len(model_files)
# Get the most recent file as "latest"
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
stats['models'][model_name] = {
'checkpoint_count': len(model_files),
'total_size_mb': model_size / (1024 * 1024),
'best_performance': 0.0, # Not tracked in unified system
'best_checkpoint_id': latest_file.name,
'latest_checkpoint': latest_file.name
}
# Also check saved models directory
if self.saved_dir.exists():
saved_files = list(self.saved_dir.rglob('*.pt'))
if saved_files:
stats['total_checkpoints'] += len(saved_files)
saved_size = sum(f.stat().st_size for f in saved_files)
stats['total_size_mb'] += saved_size / (1024 * 1024)
# Add legacy checkpoint statistics
if self.legacy_checkpoints_dir.exists():
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
if legacy_files:
legacy_size = sum(f.stat().st_size for f in legacy_files)
stats['total_checkpoints'] += len(legacy_files)
stats['total_size_mb'] += legacy_size / (1024 * 1024)
# Add legacy models to stats
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
for model_dir_name in legacy_model_dirs:
model_dir = self.legacy_checkpoints_dir / model_dir_name
if model_dir.exists():
model_files = list(model_dir.rglob('*.pt'))
if model_files and model_dir_name not in stats['models']:
stats['total_models'] += 1
model_size = sum(f.stat().st_size for f in model_files)
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
stats['models'][model_dir_name] = {
'checkpoint_count': len(model_files),
'total_size_mb': model_size / (1024 * 1024),
'best_performance': 0.0,
'best_checkpoint_id': latest_file.name,
'latest_checkpoint': latest_file.name,
'location': 'legacy'
}
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {},
'error': str(e)
}
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
"""Get model performance leaderboard"""
try:
leaderboard = []
for model_name, model_info in self.metadata['models'].items():
if 'metrics' in model_info:
metrics = ModelMetrics(**model_info['metrics'])
leaderboard.append({
'model_name': model_name,
'model_type': model_info.get('model_type', 'unknown'),
'composite_score': metrics.get_composite_score(),
'accuracy': metrics.accuracy,
'profit_factor': metrics.profit_factor,
'win_rate': metrics.win_rate,
'last_updated': model_info.get('last_saved', 'unknown')
})
# Sort by composite score
leaderboard.sort(key=lambda x: x['composite_score'], reverse=True)
return leaderboard
except Exception as e:
logger.error(f"Error getting leaderboard: {e}")
return []
# ===== LEGACY COMPATIBILITY FUNCTIONS =====
def create_model_manager() -> ModelManager:
"""Create and return a ModelManager instance"""
return ModelManager()
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Legacy compatibility function to save a model"""
manager = create_model_manager()
return manager.save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""Legacy compatibility function to load a model"""
manager = create_model_manager()
return manager.load_model(model_name, model_type, model_class)
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Legacy compatibility function to save a checkpoint"""
manager = create_model_manager()
return manager.save_checkpoint(model, model_name, model_type,
performance_metrics, training_metadata, force_save)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
"""Legacy compatibility function to load the best checkpoint"""
manager = create_model_manager()
return manager.load_best_checkpoint(model_name)
# ===== EXAMPLE USAGE =====
if __name__ == "__main__":
# Example usage of the unified model manager
manager = create_model_manager()
print(f"ModelManager initialized at: {manager.checkpoints_dir}")
# Get storage stats
stats = manager.get_storage_stats()
print(f"Storage stats: {stats}")
# Get leaderboard
leaderboard = manager.get_model_leaderboard()
print(f"Models in leaderboard: {len(leaderboard)}")