fix HTML
This commit is contained in:
@@ -193,6 +193,7 @@ class AnnotationDashboard:
|
||||
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get best checkpoint info for a model without loading it
|
||||
Uses filename parsing instead of torch.load to avoid crashes
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
@@ -201,8 +202,8 @@ class AnnotationDashboard:
|
||||
Dict with checkpoint info or None if no checkpoint found
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
import glob
|
||||
import re
|
||||
|
||||
# Map model names to checkpoint directories
|
||||
checkpoint_dirs = {
|
||||
@@ -212,37 +213,51 @@ class AnnotationDashboard:
|
||||
}
|
||||
|
||||
checkpoint_dir = checkpoint_dirs.get(model_name)
|
||||
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
|
||||
if not checkpoint_dir:
|
||||
return None
|
||||
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
logger.debug(f"Checkpoint directory not found: {checkpoint_dir}")
|
||||
return None
|
||||
|
||||
# Find all checkpoint files
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
||||
if not checkpoint_files:
|
||||
logger.debug(f"No checkpoint files found in {checkpoint_dir}")
|
||||
return None
|
||||
|
||||
# Load metadata from each checkpoint and find best
|
||||
logger.debug(f"Found {len(checkpoint_files)} checkpoints for {model_name}")
|
||||
|
||||
# Parse filenames to extract epoch info
|
||||
# Format: transformer_epoch5_20251110_123620.pt
|
||||
best_checkpoint = None
|
||||
best_accuracy = -1
|
||||
best_epoch = -1
|
||||
|
||||
for cp_file in checkpoint_files:
|
||||
try:
|
||||
# Load only metadata (map_location='cpu' to avoid GPU)
|
||||
checkpoint = torch.load(cp_file, map_location='cpu')
|
||||
filename = os.path.basename(cp_file)
|
||||
|
||||
accuracy = checkpoint.get('accuracy', 0.0)
|
||||
if accuracy > best_accuracy:
|
||||
best_accuracy = accuracy
|
||||
best_checkpoint = {
|
||||
'filename': os.path.basename(cp_file),
|
||||
'epoch': checkpoint.get('epoch', 0),
|
||||
'loss': checkpoint.get('loss', 0.0),
|
||||
'accuracy': accuracy,
|
||||
'learning_rate': checkpoint.get('learning_rate', 0.0)
|
||||
}
|
||||
# Extract epoch number from filename
|
||||
match = re.search(r'epoch(\d+)', filename, re.IGNORECASE)
|
||||
if match:
|
||||
epoch = int(match.group(1))
|
||||
if epoch > best_epoch:
|
||||
best_epoch = epoch
|
||||
best_checkpoint = {
|
||||
'filename': filename,
|
||||
'epoch': epoch,
|
||||
'loss': None, # Can't get without loading
|
||||
'accuracy': None, # Can't get without loading
|
||||
'source': 'filename'
|
||||
}
|
||||
logger.debug(f"Found checkpoint: {filename}, epoch {epoch}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {cp_file}: {e}")
|
||||
logger.debug(f"Could not parse checkpoint {cp_file}: {e}")
|
||||
continue
|
||||
|
||||
if best_checkpoint:
|
||||
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint['filename']} (E{best_checkpoint['epoch']})")
|
||||
|
||||
return best_checkpoint
|
||||
|
||||
except Exception as e:
|
||||
@@ -1305,15 +1320,16 @@ class AnnotationDashboard:
|
||||
'source': 'loaded'
|
||||
}
|
||||
|
||||
# If not loaded, try to read best checkpoint from disk
|
||||
# If not loaded, try to read best checkpoint from disk (filename parsing only)
|
||||
if not checkpoint_info:
|
||||
try:
|
||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||
if checkpoint_info:
|
||||
cp_info = self._get_best_checkpoint_info(model_name)
|
||||
if cp_info:
|
||||
checkpoint_info = cp_info
|
||||
checkpoint_info['source'] = 'disk'
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading checkpoint for {model_name}: {e}")
|
||||
# Continue without checkpoint info
|
||||
logger.warning(f"Could not read checkpoint for {model_name}: {e}")
|
||||
# Continue without checkpoint info - not critical
|
||||
|
||||
model_states.append({
|
||||
'name': model_name,
|
||||
|
||||
Reference in New Issue
Block a user