fix model CP
This commit is contained in:
67
test_checkpoint_reading.py
Normal file
67
test_checkpoint_reading.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test checkpoint reading without loading models"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
|
||||
def test_checkpoint_reading():
|
||||
"""Test reading checkpoint info from disk"""
|
||||
|
||||
checkpoint_dirs = {
|
||||
'Transformer': 'models/checkpoints/transformer',
|
||||
'CNN': 'models/checkpoints/enhanced_cnn',
|
||||
'DQN': 'models/checkpoints/dqn_agent'
|
||||
}
|
||||
|
||||
for model_name, checkpoint_dir in checkpoint_dirs.items():
|
||||
print(f"\n{model_name}:")
|
||||
print(f" Directory: {checkpoint_dir}")
|
||||
print(f" Exists: {os.path.exists(checkpoint_dir)}")
|
||||
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
print(f" ❌ Directory not found")
|
||||
continue
|
||||
|
||||
# Find checkpoint files
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
||||
print(f" Checkpoints found: {len(checkpoint_files)}")
|
||||
|
||||
if not checkpoint_files:
|
||||
print(f" ❌ No checkpoint files")
|
||||
continue
|
||||
|
||||
# Try to load best checkpoint
|
||||
best_checkpoint = None
|
||||
best_accuracy = -1
|
||||
|
||||
for cp_file in checkpoint_files:
|
||||
try:
|
||||
print(f" Loading: {os.path.basename(cp_file)}")
|
||||
checkpoint = torch.load(cp_file, map_location='cpu')
|
||||
|
||||
epoch = checkpoint.get('epoch', 0)
|
||||
loss = checkpoint.get('loss', 0.0)
|
||||
accuracy = checkpoint.get('accuracy', 0.0)
|
||||
|
||||
print(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}")
|
||||
|
||||
if accuracy > best_accuracy:
|
||||
best_accuracy = accuracy
|
||||
best_checkpoint = {
|
||||
'filename': os.path.basename(cp_file),
|
||||
'epoch': epoch,
|
||||
'loss': loss,
|
||||
'accuracy': accuracy
|
||||
}
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
|
||||
if best_checkpoint:
|
||||
print(f" ✅ Best: {best_checkpoint['filename']}")
|
||||
print(f" Epoch: {best_checkpoint['epoch']}, Loss: {best_checkpoint['loss']:.6f}, Accuracy: {best_checkpoint['accuracy']:.2%}")
|
||||
else:
|
||||
print(f" ❌ No valid checkpoint found")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint_reading()
|
||||
Reference in New Issue
Block a user