diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index 2e44925..e8b0cc3 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -189,6 +189,65 @@ class AnnotationDashboard: logger.info("Annotation Dashboard initialized") + def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]: + """ + Get best checkpoint info for a model without loading it + + Args: + model_name: Name of the model + + Returns: + Dict with checkpoint info or None if no checkpoint found + """ + try: + import torch + import glob + + # Map model names to checkpoint directories + checkpoint_dirs = { + 'Transformer': 'models/checkpoints/transformer', + 'CNN': 'models/checkpoints/enhanced_cnn', + 'DQN': 'models/checkpoints/dqn_agent' + } + + checkpoint_dir = checkpoint_dirs.get(model_name) + if not checkpoint_dir or not os.path.exists(checkpoint_dir): + return None + + # Find all checkpoint files + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt')) + if not checkpoint_files: + return None + + # Load metadata from each checkpoint and find best + best_checkpoint = None + best_accuracy = -1 + + for cp_file in checkpoint_files: + try: + # Load only metadata (map_location='cpu' to avoid GPU) + checkpoint = torch.load(cp_file, map_location='cpu') + + accuracy = checkpoint.get('accuracy', 0.0) + if accuracy > best_accuracy: + best_accuracy = accuracy + best_checkpoint = { + 'filename': os.path.basename(cp_file), + 'epoch': checkpoint.get('epoch', 0), + 'loss': checkpoint.get('loss', 0.0), + 'accuracy': accuracy, + 'learning_rate': checkpoint.get('learning_rate', 0.0) + } + except Exception as e: + logger.debug(f"Could not load checkpoint {cp_file}: {e}") + continue + + return best_checkpoint + + except Exception as e: + logger.debug(f"Error getting checkpoint info for {model_name}: {e}") + return None + def _load_model_lazy(self, model_name: str) -> dict: """ Lazy load a specific model on demand @@ -1220,15 +1279,40 @@ class AnnotationDashboard: # Use self.available_models which is a simple list of strings # Don't call training_adapter.get_available_models() as it may return objects - # Build model state dict + # Build model state dict with checkpoint info model_states = [] for model_name in self.available_models: is_loaded = model_name in self.loaded_models + + # Get checkpoint info (even for unloaded models) + checkpoint_info = None + + # If loaded, get from orchestrator + if is_loaded and self.orchestrator: + if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'): + cp_info = self.orchestrator.transformer_checkpoint_info + if cp_info and cp_info.get('status') == 'loaded': + checkpoint_info = { + 'filename': cp_info.get('filename', 'unknown'), + 'epoch': cp_info.get('epoch', 0), + 'loss': cp_info.get('loss', 0.0), + 'accuracy': cp_info.get('accuracy', 0.0), + 'loaded_at': cp_info.get('loaded_at', ''), + 'source': 'loaded' + } + + # If not loaded, try to read best checkpoint from disk + if not checkpoint_info: + checkpoint_info = self._get_best_checkpoint_info(model_name) + if checkpoint_info: + checkpoint_info['source'] = 'disk' + model_states.append({ 'name': model_name, 'loaded': is_loaded, 'can_train': is_loaded, - 'can_infer': is_loaded + 'can_infer': is_loaded, + 'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk) }) return jsonify({ diff --git a/ANNOTATE/web/templates/components/training_panel.html b/ANNOTATE/web/templates/components/training_panel.html index c47074e..ffa2e57 100644 --- a/ANNOTATE/web/templates/components/training_panel.html +++ b/ANNOTATE/web/templates/components/training_panel.html @@ -127,20 +127,36 @@ placeholder.textContent = 'Select a model...'; modelSelect.appendChild(placeholder); - // Add model options with load status + // Add model options with load status and checkpoint info data.models.forEach((model, index) => { console.log(` Model ${index}:`, model, 'Type:', typeof model); // Ensure model is an object with name property const modelName = (model && typeof model === 'object' && model.name) ? model.name : String(model); const isLoaded = (model && typeof model === 'object' && 'loaded' in model) ? model.loaded : false; + const checkpoint = (model && typeof model === 'object' && model.checkpoint) ? model.checkpoint : null; - console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`); + console.log(` → Name: "${modelName}", Loaded: ${isLoaded}`, checkpoint ? `Checkpoint: epoch ${checkpoint.epoch}, loss ${checkpoint.loss.toFixed(4)}` : ''); const option = document.createElement('option'); option.value = modelName; - option.textContent = modelName + (isLoaded ? ' ✓' : ' (not loaded)'); + + // Build option text with checkpoint info + let optionText = modelName; + if (isLoaded) { + optionText += ' ✓'; + if (checkpoint) { + optionText += ` (E${checkpoint.epoch}, L:${checkpoint.loss.toFixed(3)}, A:${(checkpoint.accuracy * 100).toFixed(1)}%)`; + } + } else { + optionText += ' (not loaded)'; + } + + option.textContent = optionText; option.dataset.loaded = isLoaded; + if (checkpoint) { + option.dataset.checkpoint = JSON.stringify(checkpoint); + } modelSelect.appendChild(option); }); diff --git a/core/orchestrator.py b/core/orchestrator.py index c11e438..d177307 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -785,15 +785,27 @@ class TradingOrchestrator: checkpoint = torch.load(checkpoint_path, map_location=self.device) self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint)) - # Update checkpoint info + # Extract checkpoint metrics for display + epoch = checkpoint.get('epoch', 0) + loss = checkpoint.get('loss', 0.0) + accuracy = checkpoint.get('accuracy', 0.0) + learning_rate = checkpoint.get('learning_rate', 0.0) + + # Update checkpoint info with detailed metrics self.transformer_checkpoint_info = { 'path': checkpoint_path, + 'filename': os.path.basename(checkpoint_path), 'metadata': checkpoint_metadata, - 'loaded_at': datetime.now().isoformat() + 'loaded_at': datetime.now().isoformat(), + 'epoch': epoch, + 'loss': loss, + 'accuracy': accuracy, + 'learning_rate': learning_rate, + 'status': 'loaded' } - logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}") - logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}") + logger.info(f"✅ Loaded transformer checkpoint: {os.path.basename(checkpoint_path)}") + logger.info(f" Epoch: {epoch}, Loss: {loss:.6f}, Accuracy: {accuracy:.2%}, LR: {learning_rate:.6f}") checkpoint_loaded = True else: logger.info("No transformer checkpoint found - using fresh model")