dedulicae model storage
This commit is contained in:
@@ -457,6 +457,72 @@ class ModelManager:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
|
||||
try:
|
||||
stats = {
|
||||
'total_models': 0,
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Count files in different directories as "checkpoints"
|
||||
checkpoint_dirs = [
|
||||
self.checkpoints_dir / "cnn",
|
||||
self.checkpoints_dir / "dqn",
|
||||
self.checkpoints_dir / "rl",
|
||||
self.checkpoints_dir / "transformer",
|
||||
self.checkpoints_dir / "hybrid"
|
||||
]
|
||||
|
||||
total_size = 0
|
||||
total_files = 0
|
||||
|
||||
for checkpoint_dir in checkpoint_dirs:
|
||||
if checkpoint_dir.exists():
|
||||
model_files = list(checkpoint_dir.rglob('*.pt'))
|
||||
if model_files:
|
||||
model_name = checkpoint_dir.name
|
||||
stats['total_models'] += 1
|
||||
|
||||
model_size = sum(f.stat().st_size for f in model_files)
|
||||
stats['total_checkpoints'] += len(model_files)
|
||||
stats['total_size_mb'] += model_size / (1024 * 1024)
|
||||
total_size += model_size
|
||||
total_files += len(model_files)
|
||||
|
||||
# Get the most recent file as "latest"
|
||||
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
||||
|
||||
stats['models'][model_name] = {
|
||||
'checkpoint_count': len(model_files),
|
||||
'total_size_mb': model_size / (1024 * 1024),
|
||||
'best_performance': 0.0, # Not tracked in unified system
|
||||
'best_checkpoint_id': latest_file.name,
|
||||
'latest_checkpoint': latest_file.name
|
||||
}
|
||||
|
||||
# Also check saved models directory
|
||||
if self.saved_dir.exists():
|
||||
saved_files = list(self.saved_dir.rglob('*.pt'))
|
||||
if saved_files:
|
||||
stats['total_checkpoints'] += len(saved_files)
|
||||
saved_size = sum(f.stat().st_size for f in saved_files)
|
||||
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_models': 0,
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {},
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||
"""Get model performance leaderboard"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user