model checkpoint manager
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -774,42 +775,107 @@ class CNNModelTrainer:
|
||||
# Return realistic loss values based on random baseline performance
|
||||
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
|
||||
|
||||
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata"""
|
||||
save_dict = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'training_history': self.training_history,
|
||||
'model_config': {
|
||||
'input_size': self.model.input_size,
|
||||
'feature_dim': self.model.feature_dim,
|
||||
'output_size': self.model.output_size,
|
||||
'base_channels': self.model.base_channels
|
||||
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
|
||||
# Prepare model data
|
||||
model_data = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'training_history': self.training_history,
|
||||
'model_config': {
|
||||
'input_size': self.model.input_size,
|
||||
'feature_dim': self.model.feature_dim,
|
||||
'output_size': self.model.output_size,
|
||||
'base_channels': self.model.base_channels
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metadata:
|
||||
save_dict['metadata'] = metadata
|
||||
|
||||
torch.save(save_dict, filepath)
|
||||
logger.info(f"Enhanced CNN model saved to {filepath}")
|
||||
|
||||
if metadata:
|
||||
model_data['metadata'] = metadata
|
||||
|
||||
# Use unified registry if no filepath specified
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
# Extract model name from filepath or use default
|
||||
model_name = "enhanced_cnn"
|
||||
if filepath:
|
||||
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||
|
||||
success = save_model(
|
||||
model=self.model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
metadata={'full_checkpoint': model_data}
|
||||
)
|
||||
if success:
|
||||
logger.info(f"Enhanced CNN model saved to unified registry: {model_name}")
|
||||
return success
|
||||
else:
|
||||
# Legacy direct file save
|
||||
torch.save(model_data, filepath)
|
||||
logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save CNN model: {e}")
|
||||
return False
|
||||
|
||||
def load_model(self, filepath: str) -> Dict:
|
||||
"""Load model from file"""
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from {filepath}")
|
||||
return checkpoint.get('metadata', {})
|
||||
def load_model(self, filepath: str = None) -> Dict:
|
||||
"""Load model from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
|
||||
# Use unified registry if no filepath or if it's a models/ path
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
model_name = "enhanced_cnn"
|
||||
if filepath:
|
||||
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||
|
||||
model = load_model(model_name, 'cnn')
|
||||
if model is None:
|
||||
logger.warning(f"Could not load model {model_name} from unified registry")
|
||||
return {}
|
||||
|
||||
# Load full checkpoint data from metadata
|
||||
registry = get_model_registry()
|
||||
if model_name in registry.metadata['models']:
|
||||
model_data = registry.metadata['models'][model_name]
|
||||
if 'full_checkpoint' in model_data:
|
||||
checkpoint = model_data['full_checkpoint']
|
||||
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}")
|
||||
return checkpoint.get('metadata', {})
|
||||
|
||||
return {}
|
||||
|
||||
else:
|
||||
# Legacy direct file load
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)")
|
||||
return checkpoint.get('metadata', {})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load CNN model: {e}")
|
||||
return {}
|
||||
|
||||
def create_enhanced_cnn_model(input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
|
Reference in New Issue
Block a user