365 lines
14 KiB
Python
365 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Best Model Selection for Startup
|
|
|
|
This module provides intelligent model selection logic for choosing the best
|
|
available models at system startup based on various criteria.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
from datetime import datetime, timedelta
|
|
import torch
|
|
|
|
from utils.model_registry import get_model_registry, load_model, load_best_checkpoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelSelector:
|
|
"""
|
|
Intelligent model selector for startup and runtime model selection.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the model selector"""
|
|
self.registry = get_model_registry()
|
|
self.selection_criteria = {
|
|
'max_age_days': 30, # Don't use models older than 30 days
|
|
'min_performance_score': 0.5, # Minimum acceptable performance
|
|
'prefer_recent': True, # Prefer recently trained models
|
|
'fallback_to_any': True # Use any model if no good ones found
|
|
}
|
|
|
|
logger.info("Model Selector initialized")
|
|
|
|
def select_best_models_for_startup(self) -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Select the best available models for each type at startup.
|
|
|
|
Returns:
|
|
Dictionary mapping model types to selected model info
|
|
"""
|
|
logger.info("Selecting best models for startup...")
|
|
|
|
available_models = self.registry.list_models()
|
|
selected_models = {}
|
|
|
|
# Group models by type
|
|
models_by_type = {}
|
|
for model_name, model_info in available_models.items():
|
|
model_type = model_info.get('type', 'unknown')
|
|
if model_type not in models_by_type:
|
|
models_by_type[model_type] = []
|
|
models_by_type[model_type].append((model_name, model_info))
|
|
|
|
# Select best model for each type
|
|
for model_type, models in models_by_type.items():
|
|
if not models:
|
|
continue
|
|
|
|
logger.info(f"Selecting best {model_type} model from {len(models)} candidates")
|
|
|
|
best_model = self._select_best_model_for_type(models, model_type)
|
|
if best_model:
|
|
selected_models[model_type] = best_model
|
|
logger.info(f"Selected {best_model['name']} for {model_type}")
|
|
else:
|
|
logger.warning(f"No suitable {model_type} model found")
|
|
|
|
return selected_models
|
|
|
|
def _select_best_model_for_type(self, models: List[Tuple[str, Dict]], model_type: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Select the best model for a specific type.
|
|
|
|
Args:
|
|
models: List of (name, info) tuples
|
|
model_type: Type of model to select
|
|
|
|
Returns:
|
|
Selected model information or None
|
|
"""
|
|
if not models:
|
|
return None
|
|
|
|
candidates = []
|
|
|
|
for model_name, model_info in models:
|
|
# Check if model meets basic criteria
|
|
if not self._meets_basic_criteria(model_info):
|
|
continue
|
|
|
|
# Calculate selection score
|
|
score = self._calculate_selection_score(model_name, model_info, model_type)
|
|
|
|
candidates.append({
|
|
'name': model_name,
|
|
'info': model_info,
|
|
'score': score,
|
|
'has_checkpoints': model_info.get('checkpoint_count', 0) > 0
|
|
})
|
|
|
|
if not candidates:
|
|
if self.selection_criteria['fallback_to_any']:
|
|
# Fallback to most recent model
|
|
logger.info(f"No good {model_type} candidates, using fallback")
|
|
return self._select_fallback_model(models)
|
|
return None
|
|
|
|
# Sort by score (highest first)
|
|
candidates.sort(key=lambda x: x['score'], reverse=True)
|
|
best_candidate = candidates[0]
|
|
|
|
# Try to load the model to verify it's working
|
|
if self._verify_model_loadable(best_candidate['name'], model_type):
|
|
return {
|
|
'name': best_candidate['name'],
|
|
'type': model_type,
|
|
'info': best_candidate['info'],
|
|
'score': best_candidate['score'],
|
|
'selection_reason': self._get_selection_reason(best_candidate),
|
|
'verified': True
|
|
}
|
|
else:
|
|
logger.warning(f"Selected model {best_candidate['name']} failed verification")
|
|
# Try next candidate
|
|
if len(candidates) > 1:
|
|
next_candidate = candidates[1]
|
|
if self._verify_model_loadable(next_candidate['name'], model_type):
|
|
return {
|
|
'name': next_candidate['name'],
|
|
'type': model_type,
|
|
'info': next_candidate['info'],
|
|
'score': next_candidate['score'],
|
|
'selection_reason': 'fallback_after_verification_failure',
|
|
'verified': True
|
|
}
|
|
|
|
return None
|
|
|
|
def _meets_basic_criteria(self, model_info: Dict[str, Any]) -> bool:
|
|
"""Check if model meets basic selection criteria"""
|
|
# Check age
|
|
last_saved = model_info.get('last_saved')
|
|
if last_saved:
|
|
try:
|
|
# Parse timestamp (format: YYYYMMDD_HHMMSS)
|
|
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
|
|
age_days = (datetime.now() - model_date).days
|
|
|
|
if age_days > self.selection_criteria['max_age_days']:
|
|
return False
|
|
except ValueError:
|
|
logger.warning(f"Could not parse timestamp: {last_saved}")
|
|
|
|
return True
|
|
|
|
def _calculate_selection_score(self, model_name: str, model_info: Dict[str, Any], model_type: str) -> float:
|
|
"""Calculate selection score for a model"""
|
|
score = 0.0
|
|
|
|
# Base score from recency (newer is better)
|
|
last_saved = model_info.get('last_saved')
|
|
if last_saved:
|
|
try:
|
|
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
|
|
days_old = (datetime.now() - model_date).days
|
|
recency_score = max(0, 30 - days_old) / 30.0 # 0-1 score for last 30 days
|
|
score += recency_score * 0.4
|
|
except ValueError:
|
|
pass
|
|
|
|
# Score from checkpoints (having checkpoints is good)
|
|
checkpoint_count = model_info.get('checkpoint_count', 0)
|
|
if checkpoint_count > 0:
|
|
checkpoint_score = min(checkpoint_count / 10.0, 1.0) # Max score for 10+ checkpoints
|
|
score += checkpoint_score * 0.3
|
|
|
|
# Score from save count (more saves might indicate stability)
|
|
save_count = model_info.get('save_count', 0)
|
|
if save_count > 1:
|
|
stability_score = min(save_count / 5.0, 1.0) # Max score for 5+ saves
|
|
score += stability_score * 0.3
|
|
|
|
return score
|
|
|
|
def _select_fallback_model(self, models: List[Tuple[str, Dict]]) -> Optional[Dict[str, Any]]:
|
|
"""Select a fallback model when no good candidates found"""
|
|
if not models:
|
|
return None
|
|
|
|
# Sort by recency
|
|
sorted_models = sorted(models, key=lambda x: x[1].get('last_saved', ''), reverse=True)
|
|
model_name, model_info = sorted_models[0]
|
|
|
|
return {
|
|
'name': model_name,
|
|
'type': model_info.get('type', 'unknown'),
|
|
'info': model_info,
|
|
'score': 0.0,
|
|
'selection_reason': 'fallback_most_recent',
|
|
'verified': False
|
|
}
|
|
|
|
def _verify_model_loadable(self, model_name: str, model_type: str) -> bool:
|
|
"""Verify that a model can be loaded successfully"""
|
|
try:
|
|
model = load_model(model_name, model_type)
|
|
return model is not None
|
|
except Exception as e:
|
|
logger.warning(f"Model verification failed for {model_name}: {e}")
|
|
return False
|
|
|
|
def _get_selection_reason(self, candidate: Dict[str, Any]) -> str:
|
|
"""Get human-readable selection reason"""
|
|
reasons = []
|
|
|
|
if candidate.get('has_checkpoints'):
|
|
reasons.append("has_checkpoints")
|
|
|
|
score = candidate.get('score', 0)
|
|
if score > 0.8:
|
|
reasons.append("high_score")
|
|
elif score > 0.6:
|
|
reasons.append("good_score")
|
|
else:
|
|
reasons.append("acceptable_score")
|
|
|
|
return ", ".join(reasons) if reasons else "default_selection"
|
|
|
|
def load_selected_models(self, selected_models: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""
|
|
Load the selected models into memory.
|
|
|
|
Args:
|
|
selected_models: Dictionary from select_best_models_for_startup
|
|
|
|
Returns:
|
|
Dictionary of loaded models
|
|
"""
|
|
loaded_models = {}
|
|
|
|
for model_type, selection_info in selected_models.items():
|
|
model_name = selection_info['name']
|
|
|
|
logger.info(f"Loading {model_type} model: {model_name}")
|
|
|
|
try:
|
|
# Try to load best checkpoint first if available
|
|
if selection_info['info'].get('checkpoint_count', 0) > 0:
|
|
checkpoint_result = load_best_checkpoint(model_name, model_type)
|
|
if checkpoint_result:
|
|
checkpoint_path, checkpoint_data = checkpoint_result
|
|
loaded_models[model_type] = {
|
|
'model': None, # Would need proper model class instantiation
|
|
'checkpoint_data': checkpoint_data,
|
|
'source': 'checkpoint',
|
|
'path': checkpoint_path,
|
|
'performance_score': checkpoint_data.get('performance_score', 0)
|
|
}
|
|
logger.info(f"Loaded {model_type} from checkpoint: {checkpoint_path}")
|
|
continue
|
|
|
|
# Fall back to regular model loading
|
|
model = load_model(model_name, model_type)
|
|
if model:
|
|
loaded_models[model_type] = {
|
|
'model': model,
|
|
'source': 'latest',
|
|
'path': selection_info['info'].get('latest_path'),
|
|
'performance_score': None
|
|
}
|
|
logger.info(f"Loaded {model_type} from latest: {model_name}")
|
|
else:
|
|
logger.error(f"Failed to load {model_type} model: {model_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading {model_type} model {model_name}: {e}")
|
|
|
|
return loaded_models
|
|
|
|
def get_startup_report(self, selected_models: Dict[str, Dict[str, Any]],
|
|
loaded_models: Dict[str, Any]) -> str:
|
|
"""Generate a startup report"""
|
|
report_lines = [
|
|
"=" * 60,
|
|
"MODEL STARTUP SELECTION REPORT",
|
|
"=" * 60,
|
|
f"Selection Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
|
""
|
|
]
|
|
|
|
if selected_models:
|
|
report_lines.append("SELECTED MODELS:")
|
|
for model_type, selection_info in selected_models.items():
|
|
report_lines.append(f" {model_type.upper()}: {selection_info['name']}")
|
|
report_lines.append(f" - Score: {selection_info.get('score', 0):.3f}")
|
|
report_lines.append(f" - Reason: {selection_info.get('selection_reason', 'unknown')}")
|
|
report_lines.append(f" - Verified: {selection_info.get('verified', False)}")
|
|
report_lines.append(f" - Last Saved: {selection_info['info'].get('last_saved', 'unknown')}")
|
|
report_lines.append("")
|
|
else:
|
|
report_lines.append("NO MODELS SELECTED")
|
|
report_lines.append("")
|
|
|
|
if loaded_models:
|
|
report_lines.append("LOADED MODELS:")
|
|
for model_type, model_info in loaded_models.items():
|
|
source = model_info.get('source', 'unknown')
|
|
report_lines.append(f" {model_type.upper()}: Loaded from {source}")
|
|
if 'performance_score' in model_info and model_info['performance_score'] is not None:
|
|
report_lines.append(f" - Performance Score: {model_info['performance_score']:.3f}")
|
|
report_lines.append("")
|
|
else:
|
|
report_lines.append("NO MODELS LOADED")
|
|
report_lines.append("")
|
|
|
|
# Add summary statistics
|
|
total_models = len(self.registry.list_models())
|
|
selected_count = len(selected_models)
|
|
loaded_count = len(loaded_models)
|
|
|
|
report_lines.extend([
|
|
"SUMMARY STATISTICS:",
|
|
f" Total Available Models: {total_models}",
|
|
f" Models Selected: {selected_count}",
|
|
f" Models Loaded: {loaded_count}",
|
|
"=" * 60
|
|
])
|
|
|
|
return "\n".join(report_lines)
|
|
|
|
# Global instance
|
|
_model_selector = None
|
|
|
|
def get_model_selector() -> ModelSelector:
|
|
"""Get the global model selector instance"""
|
|
global _model_selector
|
|
if _model_selector is None:
|
|
_model_selector = ModelSelector()
|
|
return _model_selector
|
|
|
|
def select_and_load_best_models() -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
|
|
"""
|
|
Convenience function to select and load best models for startup.
|
|
|
|
Returns:
|
|
Tuple of (selected_models_info, loaded_models)
|
|
"""
|
|
selector = get_model_selector()
|
|
|
|
# Select best models
|
|
selected_models = selector.select_best_models_for_startup()
|
|
|
|
# Load selected models
|
|
loaded_models = selector.load_selected_models(selected_models)
|
|
|
|
# Generate and log report
|
|
report = selector.get_startup_report(selected_models, loaded_models)
|
|
logger.info("Model Startup Report:\n" + report)
|
|
|
|
return selected_models, loaded_models
|