checkpoint info in dropdowns

This commit is contained in:
Dobromir Popov
2025-11-11 11:20:23 +02:00
parent 27039c70a3
commit 6c1ca8baf4
3 changed files with 121 additions and 9 deletions

View File

@@ -189,6 +189,65 @@ class AnnotationDashboard:
logger.info("Annotation Dashboard initialized")
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
"""
Get best checkpoint info for a model without loading it
Args:
model_name: Name of the model
Returns:
Dict with checkpoint info or None if no checkpoint found
"""
try:
import torch
import glob
# Map model names to checkpoint directories
checkpoint_dirs = {
'Transformer': 'models/checkpoints/transformer',
'CNN': 'models/checkpoints/enhanced_cnn',
'DQN': 'models/checkpoints/dqn_agent'
}
checkpoint_dir = checkpoint_dirs.get(model_name)
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
return None
# Find all checkpoint files
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
if not checkpoint_files:
return None
# Load metadata from each checkpoint and find best
best_checkpoint = None
best_accuracy = -1
for cp_file in checkpoint_files:
try:
# Load only metadata (map_location='cpu' to avoid GPU)
checkpoint = torch.load(cp_file, map_location='cpu')
accuracy = checkpoint.get('accuracy', 0.0)
if accuracy > best_accuracy:
best_accuracy = accuracy
best_checkpoint = {
'filename': os.path.basename(cp_file),
'epoch': checkpoint.get('epoch', 0),
'loss': checkpoint.get('loss', 0.0),
'accuracy': accuracy,
'learning_rate': checkpoint.get('learning_rate', 0.0)
}
except Exception as e:
logger.debug(f"Could not load checkpoint {cp_file}: {e}")
continue
return best_checkpoint
except Exception as e:
logger.debug(f"Error getting checkpoint info for {model_name}: {e}")
return None
def _load_model_lazy(self, model_name: str) -> dict:
"""
Lazy load a specific model on demand
@@ -1220,15 +1279,40 @@ class AnnotationDashboard:
# Use self.available_models which is a simple list of strings
# Don't call training_adapter.get_available_models() as it may return objects
# Build model state dict
# Build model state dict with checkpoint info
model_states = []
for model_name in self.available_models:
is_loaded = model_name in self.loaded_models
# Get checkpoint info (even for unloaded models)
checkpoint_info = None
# If loaded, get from orchestrator
if is_loaded and self.orchestrator:
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
cp_info = self.orchestrator.transformer_checkpoint_info
if cp_info and cp_info.get('status') == 'loaded':
checkpoint_info = {
'filename': cp_info.get('filename', 'unknown'),
'epoch': cp_info.get('epoch', 0),
'loss': cp_info.get('loss', 0.0),
'accuracy': cp_info.get('accuracy', 0.0),
'loaded_at': cp_info.get('loaded_at', ''),
'source': 'loaded'
}
# If not loaded, try to read best checkpoint from disk
if not checkpoint_info:
checkpoint_info = self._get_best_checkpoint_info(model_name)
if checkpoint_info:
checkpoint_info['source'] = 'disk'
model_states.append({
'name': model_name,
'loaded': is_loaded,
'can_train': is_loaded,
'can_infer': is_loaded
'can_infer': is_loaded,
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
})
return jsonify({

View File

@@ -127,20 +127,36 @@
placeholder.textContent = 'Select a model...';
modelSelect.appendChild(placeholder);
// Add model options with load status
// Add model options with load status and checkpoint info
data.models.forEach((model, index) => {
console.log(` Model ${index}:`, model, 'Type:', typeof model);
// Ensure model is an object with name property
const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model);
const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false;
const checkpoint = (model && typeof model === 'object' && model.checkpoint) ? model.checkpoint : null;
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`);
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`, checkpoint ? `Checkpoint: epoch ${checkpoint.epoch}, loss ${checkpoint.loss.toFixed(4)}` : '');
const option = document.createElement('option');
option.value = modelName;
option.textContent = modelName + (isLoaded ? ' ✓' : ' (not loaded)');
// Build option text with checkpoint info
let optionText = modelName;
if (isLoaded) {
optionText += ' ✓';
if (checkpoint) {
optionText += ` (E${checkpoint.epoch}, L:${checkpoint.loss.toFixed(3)}, A:${(checkpoint.accuracy * 100).toFixed(1)}%)`;
}
} else {
optionText += ' (not loaded)';
}
option.textContent = optionText;
option.dataset.loaded = isLoaded;
if (checkpoint) {
option.dataset.checkpoint = JSON.stringify(checkpoint);
}
modelSelect.appendChild(option);
});

View File

@@ -785,15 +785,27 @@ class TradingOrchestrator:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
# Update checkpoint info
# Extract checkpoint metrics for display
epoch = checkpoint.get('epoch', 0)
loss = checkpoint.get('loss', 0.0)
accuracy = checkpoint.get('accuracy', 0.0)
learning_rate = checkpoint.get('learning_rate', 0.0)
# Update checkpoint info with detailed metrics
self.transformer_checkpoint_info = {
'path': checkpoint_path,
'filename': os.path.basename(checkpoint_path),
'metadata': checkpoint_metadata,
'loaded_at': datetime.now().isoformat()
'loaded_at': datetime.now().isoformat(),
'epoch': epoch,
'loss': loss,
'accuracy': accuracy,
'learning_rate': learning_rate,
'status': 'loaded'
}
logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}")
logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}")
logger.info(f"✅ Loaded transformer checkpoint: {os.path.basename(checkpoint_path)}")
logger.info(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}, LR: {learning_rate:.6f}")
checkpoint_loaded = True
else:
logger.info("No transformer checkpoint found - using fresh model")