#!/usr/bin/env python3 """ Best Model Selection for Startup This module provides intelligent model selection logic for choosing the best available models at system startup based on various criteria. """ import os import logging import json from pathlib import Path from typing import Dict, Any, Optional, List, Tuple from datetime import datetime, timedelta import torch from utils.model_registry import get_model_registry, load_model, load_best_checkpoint logger = logging.getLogger(__name__) class ModelSelector: """ Intelligent model selector for startup and runtime model selection. """ def __init__(self): """Initialize the model selector""" self.registry = get_model_registry() self.selection_criteria = { 'max_age_days': 30, # Don't use models older than 30 days 'min_performance_score': 0.5, # Minimum acceptable performance 'prefer_recent': True, # Prefer recently trained models 'fallback_to_any': True # Use any model if no good ones found } logger.info("Model Selector initialized") def select_best_models_for_startup(self) -> Dict[str, Dict[str, Any]]: """ Select the best available models for each type at startup. Returns: Dictionary mapping model types to selected model info """ logger.info("Selecting best models for startup...") available_models = self.registry.list_models() selected_models = {} # Group models by type models_by_type = {} for model_name, model_info in available_models.items(): model_type = model_info.get('type', 'unknown') if model_type not in models_by_type: models_by_type[model_type] = [] models_by_type[model_type].append((model_name, model_info)) # Select best model for each type for model_type, models in models_by_type.items(): if not models: continue logger.info(f"Selecting best {model_type} model from {len(models)} candidates") best_model = self._select_best_model_for_type(models, model_type) if best_model: selected_models[model_type] = best_model logger.info(f"Selected {best_model['name']} for {model_type}") else: logger.warning(f"No suitable {model_type} model found") return selected_models def _select_best_model_for_type(self, models: List[Tuple[str, Dict]], model_type: str) -> Optional[Dict[str, Any]]: """ Select the best model for a specific type. Args: models: List of (name, info) tuples model_type: Type of model to select Returns: Selected model information or None """ if not models: return None candidates = [] for model_name, model_info in models: # Check if model meets basic criteria if not self._meets_basic_criteria(model_info): continue # Calculate selection score score = self._calculate_selection_score(model_name, model_info, model_type) candidates.append({ 'name': model_name, 'info': model_info, 'score': score, 'has_checkpoints': model_info.get('checkpoint_count', 0) > 0 }) if not candidates: if self.selection_criteria['fallback_to_any']: # Fallback to most recent model logger.info(f"No good {model_type} candidates, using fallback") return self._select_fallback_model(models) return None # Sort by score (highest first) candidates.sort(key=lambda x: x['score'], reverse=True) best_candidate = candidates[0] # Try to load the model to verify it's working if self._verify_model_loadable(best_candidate['name'], model_type): return { 'name': best_candidate['name'], 'type': model_type, 'info': best_candidate['info'], 'score': best_candidate['score'], 'selection_reason': self._get_selection_reason(best_candidate), 'verified': True } else: logger.warning(f"Selected model {best_candidate['name']} failed verification") # Try next candidate if len(candidates) > 1: next_candidate = candidates[1] if self._verify_model_loadable(next_candidate['name'], model_type): return { 'name': next_candidate['name'], 'type': model_type, 'info': next_candidate['info'], 'score': next_candidate['score'], 'selection_reason': 'fallback_after_verification_failure', 'verified': True } return None def _meets_basic_criteria(self, model_info: Dict[str, Any]) -> bool: """Check if model meets basic selection criteria""" # Check age last_saved = model_info.get('last_saved') if last_saved: try: # Parse timestamp (format: YYYYMMDD_HHMMSS) model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S') age_days = (datetime.now() - model_date).days if age_days > self.selection_criteria['max_age_days']: return False except ValueError: logger.warning(f"Could not parse timestamp: {last_saved}") return True def _calculate_selection_score(self, model_name: str, model_info: Dict[str, Any], model_type: str) -> float: """Calculate selection score for a model""" score = 0.0 # Base score from recency (newer is better) last_saved = model_info.get('last_saved') if last_saved: try: model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S') days_old = (datetime.now() - model_date).days recency_score = max(0, 30 - days_old) / 30.0 # 0-1 score for last 30 days score += recency_score * 0.4 except ValueError: pass # Score from checkpoints (having checkpoints is good) checkpoint_count = model_info.get('checkpoint_count', 0) if checkpoint_count > 0: checkpoint_score = min(checkpoint_count / 10.0, 1.0) # Max score for 10+ checkpoints score += checkpoint_score * 0.3 # Score from save count (more saves might indicate stability) save_count = model_info.get('save_count', 0) if save_count > 1: stability_score = min(save_count / 5.0, 1.0) # Max score for 5+ saves score += stability_score * 0.3 return score def _select_fallback_model(self, models: List[Tuple[str, Dict]]) -> Optional[Dict[str, Any]]: """Select a fallback model when no good candidates found""" if not models: return None # Sort by recency sorted_models = sorted(models, key=lambda x: x[1].get('last_saved', ''), reverse=True) model_name, model_info = sorted_models[0] return { 'name': model_name, 'type': model_info.get('type', 'unknown'), 'info': model_info, 'score': 0.0, 'selection_reason': 'fallback_most_recent', 'verified': False } def _verify_model_loadable(self, model_name: str, model_type: str) -> bool: """Verify that a model can be loaded successfully""" try: model = load_model(model_name, model_type) return model is not None except Exception as e: logger.warning(f"Model verification failed for {model_name}: {e}") return False def _get_selection_reason(self, candidate: Dict[str, Any]) -> str: """Get human-readable selection reason""" reasons = [] if candidate.get('has_checkpoints'): reasons.append("has_checkpoints") score = candidate.get('score', 0) if score > 0.8: reasons.append("high_score") elif score > 0.6: reasons.append("good_score") else: reasons.append("acceptable_score") return ", ".join(reasons) if reasons else "default_selection" def load_selected_models(self, selected_models: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: """ Load the selected models into memory. Args: selected_models: Dictionary from select_best_models_for_startup Returns: Dictionary of loaded models """ loaded_models = {} for model_type, selection_info in selected_models.items(): model_name = selection_info['name'] logger.info(f"Loading {model_type} model: {model_name}") try: # Try to load best checkpoint first if available if selection_info['info'].get('checkpoint_count', 0) > 0: checkpoint_result = load_best_checkpoint(model_name, model_type) if checkpoint_result: checkpoint_path, checkpoint_data = checkpoint_result loaded_models[model_type] = { 'model': None, # Would need proper model class instantiation 'checkpoint_data': checkpoint_data, 'source': 'checkpoint', 'path': checkpoint_path, 'performance_score': checkpoint_data.get('performance_score', 0) } logger.info(f"Loaded {model_type} from checkpoint: {checkpoint_path}") continue # Fall back to regular model loading model = load_model(model_name, model_type) if model: loaded_models[model_type] = { 'model': model, 'source': 'latest', 'path': selection_info['info'].get('latest_path'), 'performance_score': None } logger.info(f"Loaded {model_type} from latest: {model_name}") else: logger.error(f"Failed to load {model_type} model: {model_name}") except Exception as e: logger.error(f"Error loading {model_type} model {model_name}: {e}") return loaded_models def get_startup_report(self, selected_models: Dict[str, Dict[str, Any]], loaded_models: Dict[str, Any]) -> str: """Generate a startup report""" report_lines = [ "=" * 60, "MODEL STARTUP SELECTION REPORT", "=" * 60, f"Selection Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "" ] if selected_models: report_lines.append("SELECTED MODELS:") for model_type, selection_info in selected_models.items(): report_lines.append(f" {model_type.upper()}: {selection_info['name']}") report_lines.append(f" - Score: {selection_info.get('score', 0):.3f}") report_lines.append(f" - Reason: {selection_info.get('selection_reason', 'unknown')}") report_lines.append(f" - Verified: {selection_info.get('verified', False)}") report_lines.append(f" - Last Saved: {selection_info['info'].get('last_saved', 'unknown')}") report_lines.append("") else: report_lines.append("NO MODELS SELECTED") report_lines.append("") if loaded_models: report_lines.append("LOADED MODELS:") for model_type, model_info in loaded_models.items(): source = model_info.get('source', 'unknown') report_lines.append(f" {model_type.upper()}: Loaded from {source}") if 'performance_score' in model_info and model_info['performance_score'] is not None: report_lines.append(f" - Performance Score: {model_info['performance_score']:.3f}") report_lines.append("") else: report_lines.append("NO MODELS LOADED") report_lines.append("") # Add summary statistics total_models = len(self.registry.list_models()) selected_count = len(selected_models) loaded_count = len(loaded_models) report_lines.extend([ "SUMMARY STATISTICS:", f" Total Available Models: {total_models}", f" Models Selected: {selected_count}", f" Models Loaded: {loaded_count}", "=" * 60 ]) return "\n".join(report_lines) # Global instance _model_selector = None def get_model_selector() -> ModelSelector: """Get the global model selector instance""" global _model_selector if _model_selector is None: _model_selector = ModelSelector() return _model_selector def select_and_load_best_models() -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]: """ Convenience function to select and load best models for startup. Returns: Tuple of (selected_models_info, loaded_models) """ selector = get_model_selector() # Select best models selected_models = selector.select_best_models_for_startup() # Load selected models loaded_models = selector.load_selected_models(selected_models) # Generate and log report report = selector.get_startup_report(selected_models, loaded_models) logger.info("Model Startup Report:\n" + report) return selected_models, loaded_models