checkpoint info in dropdowns

This commit is contained in:
Dobromir Popov
2025-11-11 11:20:23 +02:00
parent 27039c70a3
commit 6c1ca8baf4
3 changed files with 121 additions and 9 deletions

View File

@@ -189,6 +189,65 @@ class AnnotationDashboard:
logger.info("Annotation Dashboard initialized")
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
"""
Get best checkpoint info for a model without loading it
Args:
model_name: Name of the model
Returns:
Dict with checkpoint info or None if no checkpoint found
"""
try:
import torch
import glob
# Map model names to checkpoint directories
checkpoint_dirs = {
'Transformer': 'models/checkpoints/transformer',
'CNN': 'models/checkpoints/enhanced_cnn',
'DQN': 'models/checkpoints/dqn_agent'
}
checkpoint_dir = checkpoint_dirs.get(model_name)
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
return None
# Find all checkpoint files
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
if not checkpoint_files:
return None
# Load metadata from each checkpoint and find best
best_checkpoint = None
best_accuracy = -1
for cp_file in checkpoint_files:
try:
# Load only metadata (map_location='cpu' to avoid GPU)
checkpoint = torch.load(cp_file, map_location='cpu')
accuracy = checkpoint.get('accuracy', 0.0)
if accuracy > best_accuracy:
best_accuracy = accuracy
best_checkpoint = {
'filename': os.path.basename(cp_file),
'epoch': checkpoint.get('epoch', 0),
'loss': checkpoint.get('loss', 0.0),
'accuracy': accuracy,
'learning_rate': checkpoint.get('learning_rate', 0.0)
}
except Exception as e:
logger.debug(f"Could not load checkpoint {cp_file}: {e}")
continue
return best_checkpoint
except Exception as e:
logger.debug(f"Error getting checkpoint info for {model_name}: {e}")
return None
def _load_model_lazy(self, model_name: str) -> dict:
"""
Lazy load a specific model on demand
@@ -1220,15 +1279,40 @@ class AnnotationDashboard:
# Use self.available_models which is a simple list of strings
# Don't call training_adapter.get_available_models() as it may return objects
# Build model state dict
# Build model state dict with checkpoint info
model_states = []
for model_name in self.available_models:
is_loaded = model_name in self.loaded_models
# Get checkpoint info (even for unloaded models)
checkpoint_info = None
# If loaded, get from orchestrator
if is_loaded and self.orchestrator:
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
cp_info = self.orchestrator.transformer_checkpoint_info
if cp_info and cp_info.get('status') == 'loaded':
checkpoint_info = {
'filename': cp_info.get('filename', 'unknown'),
'epoch': cp_info.get('epoch', 0),
'loss': cp_info.get('loss', 0.0),
'accuracy': cp_info.get('accuracy', 0.0),
'loaded_at': cp_info.get('loaded_at', ''),
'source': 'loaded'
}
# If not loaded, try to read best checkpoint from disk
if not checkpoint_info:
checkpoint_info = self._get_best_checkpoint_info(model_name)
if checkpoint_info:
checkpoint_info['source'] = 'disk'
model_states.append({
'name': model_name,
'loaded': is_loaded,
'can_train': is_loaded,
'can_infer': is_loaded
'can_infer': is_loaded,
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
})
return jsonify({