model checkpoint manager
This commit is contained in:
@@ -16,6 +16,9 @@ import random
|
||||
|
||||
WANDB_AVAILABLE = False
|
||||
|
||||
# Import model registry
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
@@ -68,39 +71,48 @@ class CheckpointManager:
|
||||
|
||||
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
"""Save a model checkpoint with improved error handling and validation"""
|
||||
"""Save a model checkpoint with improved error handling and validation using unified registry"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}"
|
||||
|
||||
model_dir = self.base_dir / model_name
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
|
||||
|
||||
from utils.model_registry import save_checkpoint as registry_save_checkpoint
|
||||
|
||||
performance_score = self._calculate_performance_score(performance_metrics)
|
||||
|
||||
|
||||
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||
return None
|
||||
|
||||
success = self._save_model_file(model, checkpoint_path, model_type)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
|
||||
# Use unified registry for checkpointing
|
||||
success = registry_save_checkpoint(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=str(checkpoint_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=performance_score,
|
||||
metadata={
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata,
|
||||
'checkpoint_manager': True
|
||||
}
|
||||
)
|
||||
|
||||
if not success:
|
||||
return None
|
||||
|
||||
# Get checkpoint info from registry
|
||||
registry = get_model_registry()
|
||||
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
|
||||
|
||||
# Create CheckpointMetadata object
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_info['id'],
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=checkpoint_info['path'],
|
||||
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
|
||||
file_size_mb=0.0, # Will be calculated by registry
|
||||
performance_score=performance_score,
|
||||
accuracy=performance_metrics.get('accuracy'),
|
||||
loss=performance_metrics.get('loss'),
|
||||
@@ -112,9 +124,8 @@ class CheckpointManager:
|
||||
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
# W&B disabled
|
||||
|
||||
|
||||
# Update local checkpoint tracking
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
self._save_metadata()
|
||||
@@ -128,14 +139,42 @@ class CheckpointManager:
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
try:
|
||||
# First, try the standard checkpoint system
|
||||
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
|
||||
|
||||
# First, try the unified registry
|
||||
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
|
||||
if registry_result is None:
|
||||
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
|
||||
|
||||
if registry_result:
|
||||
checkpoint_path, checkpoint_data = registry_result
|
||||
|
||||
# Create CheckpointMetadata from registry data
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"{model_name}_registry",
|
||||
model_name=model_name,
|
||||
model_type=checkpoint_data.get('model_type', 'unknown'),
|
||||
file_path=checkpoint_path,
|
||||
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
|
||||
file_size_mb=0.0, # Will be calculated by registry
|
||||
performance_score=checkpoint_data.get('performance_score', 0.0),
|
||||
accuracy=checkpoint_data.get('accuracy'),
|
||||
loss=checkpoint_data.get('loss'),
|
||||
reward=checkpoint_data.get('reward'),
|
||||
pnl=checkpoint_data.get('pnl')
|
||||
)
|
||||
|
||||
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
|
||||
return checkpoint_path, metadata
|
||||
|
||||
# Fallback: Try the standard checkpoint system
|
||||
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
||||
# Filter out checkpoints with non-existent files
|
||||
valid_checkpoints = [
|
||||
cp for cp in self.checkpoints[model_name]
|
||||
cp for cp in self.checkpoints[model_name]
|
||||
if Path(cp.file_path).exists()
|
||||
]
|
||||
|
||||
|
||||
if valid_checkpoints:
|
||||
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
|
||||
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
@@ -146,22 +185,22 @@ class CheckpointManager:
|
||||
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
|
||||
self.checkpoints[model_name] = []
|
||||
self._save_metadata()
|
||||
|
||||
|
||||
# Fallback: Look for existing saved models in the legacy format
|
||||
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
|
||||
legacy_model_path = self._find_legacy_model(model_name)
|
||||
|
||||
|
||||
if legacy_model_path:
|
||||
# Create checkpoint metadata for the legacy model using actual file data
|
||||
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
|
||||
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
|
||||
return str(legacy_model_path), legacy_metadata
|
||||
|
||||
|
||||
# Only warn once per model to avoid spam
|
||||
if model_name not in self._warned_models:
|
||||
logger.info(f"No checkpoints found for {model_name}, starting fresh")
|
||||
self._warned_models.add(model_name)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user