fix model CP
This commit is contained in:
@@ -245,7 +245,9 @@ class AnnotationDashboard:
|
|||||||
return best_checkpoint
|
return best_checkpoint
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error getting checkpoint info for {model_name}: {e}")
|
logger.error(f"Error getting checkpoint info for {model_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_model_lazy(self, model_name: str) -> dict:
|
def _load_model_lazy(self, model_name: str) -> dict:
|
||||||
@@ -1280,6 +1282,7 @@ class AnnotationDashboard:
|
|||||||
# Don't call training_adapter.get_available_models() as it may return objects
|
# Don't call training_adapter.get_available_models() as it may return objects
|
||||||
|
|
||||||
# Build model state dict with checkpoint info
|
# Build model state dict with checkpoint info
|
||||||
|
logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}")
|
||||||
model_states = []
|
model_states = []
|
||||||
for model_name in self.available_models:
|
for model_name in self.available_models:
|
||||||
is_loaded = model_name in self.loaded_models
|
is_loaded = model_name in self.loaded_models
|
||||||
@@ -1303,9 +1306,13 @@ class AnnotationDashboard:
|
|||||||
|
|
||||||
# If not loaded, try to read best checkpoint from disk
|
# If not loaded, try to read best checkpoint from disk
|
||||||
if not checkpoint_info:
|
if not checkpoint_info:
|
||||||
|
try:
|
||||||
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||||
if checkpoint_info:
|
if checkpoint_info:
|
||||||
checkpoint_info['source'] = 'disk'
|
checkpoint_info['source'] = 'disk'
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading checkpoint for {model_name}: {e}")
|
||||||
|
# Continue without checkpoint info
|
||||||
|
|
||||||
model_states.append({
|
model_states.append({
|
||||||
'name': model_name,
|
'name': model_name,
|
||||||
@@ -1315,6 +1322,7 @@ class AnnotationDashboard:
|
|||||||
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
|
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
logger.info(f"Returning {len(model_states)} model states")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'success': True,
|
'success': True,
|
||||||
'models': model_states,
|
'models': model_states,
|
||||||
|
|||||||
67
test_checkpoint_reading.py
Normal file
67
test_checkpoint_reading.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test checkpoint reading without loading models"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def test_checkpoint_reading():
|
||||||
|
"""Test reading checkpoint info from disk"""
|
||||||
|
|
||||||
|
checkpoint_dirs = {
|
||||||
|
'Transformer': 'models/checkpoints/transformer',
|
||||||
|
'CNN': 'models/checkpoints/enhanced_cnn',
|
||||||
|
'DQN': 'models/checkpoints/dqn_agent'
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_name, checkpoint_dir in checkpoint_dirs.items():
|
||||||
|
print(f"\n{model_name}:")
|
||||||
|
print(f" Directory: {checkpoint_dir}")
|
||||||
|
print(f" Exists: {os.path.exists(checkpoint_dir)}")
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
print(f" ❌ Directory not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find checkpoint files
|
||||||
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
||||||
|
print(f" Checkpoints found: {len(checkpoint_files)}")
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
print(f" ❌ No checkpoint files")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to load best checkpoint
|
||||||
|
best_checkpoint = None
|
||||||
|
best_accuracy = -1
|
||||||
|
|
||||||
|
for cp_file in checkpoint_files:
|
||||||
|
try:
|
||||||
|
print(f" Loading: {os.path.basename(cp_file)}")
|
||||||
|
checkpoint = torch.load(cp_file, map_location='cpu')
|
||||||
|
|
||||||
|
epoch = checkpoint.get('epoch', 0)
|
||||||
|
loss = checkpoint.get('loss', 0.0)
|
||||||
|
accuracy = checkpoint.get('accuracy', 0.0)
|
||||||
|
|
||||||
|
print(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}")
|
||||||
|
|
||||||
|
if accuracy > best_accuracy:
|
||||||
|
best_accuracy = accuracy
|
||||||
|
best_checkpoint = {
|
||||||
|
'filename': os.path.basename(cp_file),
|
||||||
|
'epoch': epoch,
|
||||||
|
'loss': loss,
|
||||||
|
'accuracy': accuracy
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error: {e}")
|
||||||
|
|
||||||
|
if best_checkpoint:
|
||||||
|
print(f" ✅ Best: {best_checkpoint['filename']}")
|
||||||
|
print(f" Epoch: {best_checkpoint['epoch']}, Loss: {best_checkpoint['loss']:.6f}, Accuracy: {best_checkpoint['accuracy']:.2%}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ No valid checkpoint found")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_checkpoint_reading()
|
||||||
Reference in New Issue
Block a user