checkpoint info in dropdowns
This commit is contained in:
@@ -189,6 +189,65 @@ class AnnotationDashboard:
|
||||
|
||||
logger.info("Annotation Dashboard initialized")
|
||||
|
||||
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get best checkpoint info for a model without loading it
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
Dict with checkpoint info or None if no checkpoint found
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
import glob
|
||||
|
||||
# Map model names to checkpoint directories
|
||||
checkpoint_dirs = {
|
||||
'Transformer': 'models/checkpoints/transformer',
|
||||
'CNN': 'models/checkpoints/enhanced_cnn',
|
||||
'DQN': 'models/checkpoints/dqn_agent'
|
||||
}
|
||||
|
||||
checkpoint_dir = checkpoint_dirs.get(model_name)
|
||||
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
|
||||
return None
|
||||
|
||||
# Find all checkpoint files
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
||||
if not checkpoint_files:
|
||||
return None
|
||||
|
||||
# Load metadata from each checkpoint and find best
|
||||
best_checkpoint = None
|
||||
best_accuracy = -1
|
||||
|
||||
for cp_file in checkpoint_files:
|
||||
try:
|
||||
# Load only metadata (map_location='cpu' to avoid GPU)
|
||||
checkpoint = torch.load(cp_file, map_location='cpu')
|
||||
|
||||
accuracy = checkpoint.get('accuracy', 0.0)
|
||||
if accuracy > best_accuracy:
|
||||
best_accuracy = accuracy
|
||||
best_checkpoint = {
|
||||
'filename': os.path.basename(cp_file),
|
||||
'epoch': checkpoint.get('epoch', 0),
|
||||
'loss': checkpoint.get('loss', 0.0),
|
||||
'accuracy': accuracy,
|
||||
'learning_rate': checkpoint.get('learning_rate', 0.0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {cp_file}: {e}")
|
||||
continue
|
||||
|
||||
return best_checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting checkpoint info for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _load_model_lazy(self, model_name: str) -> dict:
|
||||
"""
|
||||
Lazy load a specific model on demand
|
||||
@@ -1220,15 +1279,40 @@ class AnnotationDashboard:
|
||||
# Use self.available_models which is a simple list of strings
|
||||
# Don't call training_adapter.get_available_models() as it may return objects
|
||||
|
||||
# Build model state dict
|
||||
# Build model state dict with checkpoint info
|
||||
model_states = []
|
||||
for model_name in self.available_models:
|
||||
is_loaded = model_name in self.loaded_models
|
||||
|
||||
# Get checkpoint info (even for unloaded models)
|
||||
checkpoint_info = None
|
||||
|
||||
# If loaded, get from orchestrator
|
||||
if is_loaded and self.orchestrator:
|
||||
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
|
||||
cp_info = self.orchestrator.transformer_checkpoint_info
|
||||
if cp_info and cp_info.get('status') == 'loaded':
|
||||
checkpoint_info = {
|
||||
'filename': cp_info.get('filename', 'unknown'),
|
||||
'epoch': cp_info.get('epoch', 0),
|
||||
'loss': cp_info.get('loss', 0.0),
|
||||
'accuracy': cp_info.get('accuracy', 0.0),
|
||||
'loaded_at': cp_info.get('loaded_at', ''),
|
||||
'source': 'loaded'
|
||||
}
|
||||
|
||||
# If not loaded, try to read best checkpoint from disk
|
||||
if not checkpoint_info:
|
||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||
if checkpoint_info:
|
||||
checkpoint_info['source'] = 'disk'
|
||||
|
||||
model_states.append({
|
||||
'name': model_name,
|
||||
'loaded': is_loaded,
|
||||
'can_train': is_loaded,
|
||||
'can_infer': is_loaded
|
||||
'can_infer': is_loaded,
|
||||
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
|
||||
Reference in New Issue
Block a user