stability

This commit is contained in:
Dobromir Popov
2025-07-28 12:10:52 +03:00
parent 9219b78241
commit fb72c93743
8 changed files with 207 additions and 53 deletions

View File

@ -491,4 +491,57 @@ class CheckpointManager:
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []
return []
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""
Get statistics about all checkpoints
Returns:
Dict[str, Any]: Statistics about checkpoints
"""
try:
stats = {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Iterate through all model directories
for model_dir in os.listdir(self.checkpoint_dir):
model_path = os.path.join(self.checkpoint_dir, model_dir)
if not os.path.isdir(model_path):
continue
# Count checkpoints for this model
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
model_checkpoints = len(checkpoint_files)
# Calculate total size for this model
model_size_mb = 0.0
for checkpoint_file in checkpoint_files:
try:
size_bytes = os.path.getsize(checkpoint_file)
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
except OSError:
pass
stats['models'][model_dir] = {
'checkpoints': model_checkpoints,
'size_mb': round(model_size_mb, 2)
}
stats['total_checkpoints'] += model_checkpoints
stats['total_size_mb'] += model_size_mb
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}

View File

@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = True):
self.enable_wandb = enable_wandb
self.checkpoint_manager = get_checkpoint_manager()
@ -55,9 +56,13 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(cnn_model.state_dict(), model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model,
model_name=model_name,
model_path=model_path,
model_type='cnn',
performance_metrics=performance_metrics,
training_metadata=training_metadata
@ -114,9 +119,13 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent,
model_name=model_name,
model_path=model_path,
model_type='rl',
performance_metrics=performance_metrics,
training_metadata=training_metadata