model checkpoint manager

This commit is contained in:
Dobromir Popov
2025-09-08 13:31:11 +03:00
parent 060fdd28b4
commit c9fba56622
6 changed files with 838 additions and 142 deletions

View File

@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.model_registry import get_model_registry
# Configure logging
logger = logging.getLogger(__name__)
@@ -774,42 +775,107 @@ class CNNModelTrainer:
# Return realistic loss values based on random baseline performance
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
"""Save model with metadata"""
save_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'training_history': self.training_history,
'model_config': {
'input_size': self.model.input_size,
'feature_dim': self.model.feature_dim,
'output_size': self.model.output_size,
'base_channels': self.model.base_channels
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
"""Save model with metadata using unified registry"""
try:
from utils.model_registry import save_model
# Prepare model data
model_data = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'training_history': self.training_history,
'model_config': {
'input_size': self.model.input_size,
'feature_dim': self.model.feature_dim,
'output_size': self.model.output_size,
'base_channels': self.model.base_channels
}
}
}
if metadata:
save_dict['metadata'] = metadata
torch.save(save_dict, filepath)
logger.info(f"Enhanced CNN model saved to {filepath}")
if metadata:
model_data['metadata'] = metadata
# Use unified registry if no filepath specified
if filepath is None or filepath.startswith('models/'):
# Extract model name from filepath or use default
model_name = "enhanced_cnn"
if filepath:
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
success = save_model(
model=self.model,
model_name=model_name,
model_type='cnn',
metadata={'full_checkpoint': model_data}
)
if success:
logger.info(f"Enhanced CNN model saved to unified registry: {model_name}")
return success
else:
# Legacy direct file save
torch.save(model_data, filepath)
logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)")
return True
except Exception as e:
logger.error(f"Failed to save CNN model: {e}")
return False
def load_model(self, filepath: str) -> Dict:
"""Load model from file"""
checkpoint = torch.load(filepath, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from {filepath}")
return checkpoint.get('metadata', {})
def load_model(self, filepath: str = None) -> Dict:
"""Load model from unified registry or file"""
try:
from utils.model_registry import load_model
# Use unified registry if no filepath or if it's a models/ path
if filepath is None or filepath.startswith('models/'):
model_name = "enhanced_cnn"
if filepath:
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
model = load_model(model_name, 'cnn')
if model is None:
logger.warning(f"Could not load model {model_name} from unified registry")
return {}
# Load full checkpoint data from metadata
registry = get_model_registry()
if model_name in registry.metadata['models']:
model_data = registry.metadata['models'][model_name]
if 'full_checkpoint' in model_data:
checkpoint = model_data['full_checkpoint']
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}")
return checkpoint.get('metadata', {})
return {}
else:
# Legacy direct file load
checkpoint = torch.load(filepath, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)")
return checkpoint.get('metadata', {})
except Exception as e:
logger.error(f"Failed to load CNN model: {e}")
return {}
def create_enhanced_cnn_model(input_size: int = 60,
feature_dim: int = 50,