checkpoint info in dropdowns
This commit is contained in:
@@ -189,6 +189,65 @@ class AnnotationDashboard:
|
|||||||
|
|
||||||
logger.info("Annotation Dashboard initialized")
|
logger.info("Annotation Dashboard initialized")
|
||||||
|
|
||||||
|
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Get best checkpoint info for a model without loading it
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with checkpoint info or None if no checkpoint found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import glob
|
||||||
|
|
||||||
|
# Map model names to checkpoint directories
|
||||||
|
checkpoint_dirs = {
|
||||||
|
'Transformer': 'models/checkpoints/transformer',
|
||||||
|
'CNN': 'models/checkpoints/enhanced_cnn',
|
||||||
|
'DQN': 'models/checkpoints/dqn_agent'
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_dir = checkpoint_dirs.get(model_name)
|
||||||
|
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find all checkpoint files
|
||||||
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
||||||
|
if not checkpoint_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load metadata from each checkpoint and find best
|
||||||
|
best_checkpoint = None
|
||||||
|
best_accuracy = -1
|
||||||
|
|
||||||
|
for cp_file in checkpoint_files:
|
||||||
|
try:
|
||||||
|
# Load only metadata (map_location='cpu' to avoid GPU)
|
||||||
|
checkpoint = torch.load(cp_file, map_location='cpu')
|
||||||
|
|
||||||
|
accuracy = checkpoint.get('accuracy', 0.0)
|
||||||
|
if accuracy > best_accuracy:
|
||||||
|
best_accuracy = accuracy
|
||||||
|
best_checkpoint = {
|
||||||
|
'filename': os.path.basename(cp_file),
|
||||||
|
'epoch': checkpoint.get('epoch', 0),
|
||||||
|
'loss': checkpoint.get('loss', 0.0),
|
||||||
|
'accuracy': accuracy,
|
||||||
|
'learning_rate': checkpoint.get('learning_rate', 0.0)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not load checkpoint {cp_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return best_checkpoint
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting checkpoint info for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _load_model_lazy(self, model_name: str) -> dict:
|
def _load_model_lazy(self, model_name: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Lazy load a specific model on demand
|
Lazy load a specific model on demand
|
||||||
@@ -1220,15 +1279,40 @@ class AnnotationDashboard:
|
|||||||
# Use self.available_models which is a simple list of strings
|
# Use self.available_models which is a simple list of strings
|
||||||
# Don't call training_adapter.get_available_models() as it may return objects
|
# Don't call training_adapter.get_available_models() as it may return objects
|
||||||
|
|
||||||
# Build model state dict
|
# Build model state dict with checkpoint info
|
||||||
model_states = []
|
model_states = []
|
||||||
for model_name in self.available_models:
|
for model_name in self.available_models:
|
||||||
is_loaded = model_name in self.loaded_models
|
is_loaded = model_name in self.loaded_models
|
||||||
|
|
||||||
|
# Get checkpoint info (even for unloaded models)
|
||||||
|
checkpoint_info = None
|
||||||
|
|
||||||
|
# If loaded, get from orchestrator
|
||||||
|
if is_loaded and self.orchestrator:
|
||||||
|
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
|
||||||
|
cp_info = self.orchestrator.transformer_checkpoint_info
|
||||||
|
if cp_info and cp_info.get('status') == 'loaded':
|
||||||
|
checkpoint_info = {
|
||||||
|
'filename': cp_info.get('filename', 'unknown'),
|
||||||
|
'epoch': cp_info.get('epoch', 0),
|
||||||
|
'loss': cp_info.get('loss', 0.0),
|
||||||
|
'accuracy': cp_info.get('accuracy', 0.0),
|
||||||
|
'loaded_at': cp_info.get('loaded_at', ''),
|
||||||
|
'source': 'loaded'
|
||||||
|
}
|
||||||
|
|
||||||
|
# If not loaded, try to read best checkpoint from disk
|
||||||
|
if not checkpoint_info:
|
||||||
|
checkpoint_info = self._get_best_checkpoint_info(model_name)
|
||||||
|
if checkpoint_info:
|
||||||
|
checkpoint_info['source'] = 'disk'
|
||||||
|
|
||||||
model_states.append({
|
model_states.append({
|
||||||
'name': model_name,
|
'name': model_name,
|
||||||
'loaded': is_loaded,
|
'loaded': is_loaded,
|
||||||
'can_train': is_loaded,
|
'can_train': is_loaded,
|
||||||
'can_infer': is_loaded
|
'can_infer': is_loaded,
|
||||||
|
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
|
||||||
})
|
})
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|||||||
@@ -127,20 +127,36 @@
|
|||||||
placeholder.textContent = 'Select a model...';
|
placeholder.textContent = 'Select a model...';
|
||||||
modelSelect.appendChild(placeholder);
|
modelSelect.appendChild(placeholder);
|
||||||
|
|
||||||
// Add model options with load status
|
// Add model options with load status and checkpoint info
|
||||||
data.models.forEach((model, index) => {
|
data.models.forEach((model, index) => {
|
||||||
console.log(` Model ${index}:`, model, 'Type:', typeof model);
|
console.log(` Model ${index}:`, model, 'Type:', typeof model);
|
||||||
|
|
||||||
// Ensure model is an object with name property
|
// Ensure model is an object with name property
|
||||||
const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model);
|
const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model);
|
||||||
const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false;
|
const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false;
|
||||||
|
const checkpoint = (model && typeof model === 'object' && model.checkpoint) ? model.checkpoint : null;
|
||||||
|
|
||||||
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`);
|
console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`, checkpoint ? `Checkpoint: epoch ${checkpoint.epoch}, loss ${checkpoint.loss.toFixed(4)}` : '');
|
||||||
|
|
||||||
const option = document.createElement('option');
|
const option = document.createElement('option');
|
||||||
option.value = modelName;
|
option.value = modelName;
|
||||||
option.textContent = modelName + (isLoaded ? ' ✓' : ' (not loaded)');
|
|
||||||
|
// Build option text with checkpoint info
|
||||||
|
let optionText = modelName;
|
||||||
|
if (isLoaded) {
|
||||||
|
optionText += ' ✓';
|
||||||
|
if (checkpoint) {
|
||||||
|
optionText += ` (E${checkpoint.epoch}, L:${checkpoint.loss.toFixed(3)}, A:${(checkpoint.accuracy * 100).toFixed(1)}%)`;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
optionText += ' (not loaded)';
|
||||||
|
}
|
||||||
|
|
||||||
|
option.textContent = optionText;
|
||||||
option.dataset.loaded = isLoaded;
|
option.dataset.loaded = isLoaded;
|
||||||
|
if (checkpoint) {
|
||||||
|
option.dataset.checkpoint = JSON.stringify(checkpoint);
|
||||||
|
}
|
||||||
modelSelect.appendChild(option);
|
modelSelect.appendChild(option);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -785,15 +785,27 @@ class TradingOrchestrator:
|
|||||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
|
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
|
||||||
|
|
||||||
# Update checkpoint info
|
# Extract checkpoint metrics for display
|
||||||
|
epoch = checkpoint.get('epoch', 0)
|
||||||
|
loss = checkpoint.get('loss', 0.0)
|
||||||
|
accuracy = checkpoint.get('accuracy', 0.0)
|
||||||
|
learning_rate = checkpoint.get('learning_rate', 0.0)
|
||||||
|
|
||||||
|
# Update checkpoint info with detailed metrics
|
||||||
self.transformer_checkpoint_info = {
|
self.transformer_checkpoint_info = {
|
||||||
'path': checkpoint_path,
|
'path': checkpoint_path,
|
||||||
|
'filename': os.path.basename(checkpoint_path),
|
||||||
'metadata': checkpoint_metadata,
|
'metadata': checkpoint_metadata,
|
||||||
'loaded_at': datetime.now().isoformat()
|
'loaded_at': datetime.now().isoformat(),
|
||||||
|
'epoch': epoch,
|
||||||
|
'loss': loss,
|
||||||
|
'accuracy': accuracy,
|
||||||
|
'learning_rate': learning_rate,
|
||||||
|
'status': 'loaded'
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}")
|
logger.info(f"✅ Loaded transformer checkpoint: {os.path.basename(checkpoint_path)}")
|
||||||
logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}")
|
logger.info(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}, LR: {learning_rate:.6f}")
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
else:
|
else:
|
||||||
logger.info("No transformer checkpoint found - using fresh model")
|
logger.info("No transformer checkpoint found - using fresh model")
|
||||||
|
|||||||
Reference in New Issue
Block a user