Files
gogo2/utils/model_selector.py
Dobromir Popov 32d54f0604 model selector
2025-09-08 14:53:46 +03:00

365 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Best Model Selection for Startup
This module provides intelligent model selection logic for choosing the best
available models at system startup based on various criteria.
"""
import os
import logging
import json
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
import torch
from utils.model_registry import get_model_registry, load_model, load_best_checkpoint
logger = logging.getLogger(__name__)
class ModelSelector:
"""
Intelligent model selector for startup and runtime model selection.
"""
def __init__(self):
"""Initialize the model selector"""
self.registry = get_model_registry()
self.selection_criteria = {
'max_age_days': 30, # Don't use models older than 30 days
'min_performance_score': 0.5, # Minimum acceptable performance
'prefer_recent': True, # Prefer recently trained models
'fallback_to_any': True # Use any model if no good ones found
}
logger.info("Model Selector initialized")
def select_best_models_for_startup(self) -> Dict[str, Dict[str, Any]]:
"""
Select the best available models for each type at startup.
Returns:
Dictionary mapping model types to selected model info
"""
logger.info("Selecting best models for startup...")
available_models = self.registry.list_models()
selected_models = {}
# Group models by type
models_by_type = {}
for model_name, model_info in available_models.items():
model_type = model_info.get('type', 'unknown')
if model_type not in models_by_type:
models_by_type[model_type] = []
models_by_type[model_type].append((model_name, model_info))
# Select best model for each type
for model_type, models in models_by_type.items():
if not models:
continue
logger.info(f"Selecting best {model_type} model from {len(models)} candidates")
best_model = self._select_best_model_for_type(models, model_type)
if best_model:
selected_models[model_type] = best_model
logger.info(f"Selected {best_model['name']} for {model_type}")
else:
logger.warning(f"No suitable {model_type} model found")
return selected_models
def _select_best_model_for_type(self, models: List[Tuple[str, Dict]], model_type: str) -> Optional[Dict[str, Any]]:
"""
Select the best model for a specific type.
Args:
models: List of (name, info) tuples
model_type: Type of model to select
Returns:
Selected model information or None
"""
if not models:
return None
candidates = []
for model_name, model_info in models:
# Check if model meets basic criteria
if not self._meets_basic_criteria(model_info):
continue
# Calculate selection score
score = self._calculate_selection_score(model_name, model_info, model_type)
candidates.append({
'name': model_name,
'info': model_info,
'score': score,
'has_checkpoints': model_info.get('checkpoint_count', 0) > 0
})
if not candidates:
if self.selection_criteria['fallback_to_any']:
# Fallback to most recent model
logger.info(f"No good {model_type} candidates, using fallback")
return self._select_fallback_model(models)
return None
# Sort by score (highest first)
candidates.sort(key=lambda x: x['score'], reverse=True)
best_candidate = candidates[0]
# Try to load the model to verify it's working
if self._verify_model_loadable(best_candidate['name'], model_type):
return {
'name': best_candidate['name'],
'type': model_type,
'info': best_candidate['info'],
'score': best_candidate['score'],
'selection_reason': self._get_selection_reason(best_candidate),
'verified': True
}
else:
logger.warning(f"Selected model {best_candidate['name']} failed verification")
# Try next candidate
if len(candidates) > 1:
next_candidate = candidates[1]
if self._verify_model_loadable(next_candidate['name'], model_type):
return {
'name': next_candidate['name'],
'type': model_type,
'info': next_candidate['info'],
'score': next_candidate['score'],
'selection_reason': 'fallback_after_verification_failure',
'verified': True
}
return None
def _meets_basic_criteria(self, model_info: Dict[str, Any]) -> bool:
"""Check if model meets basic selection criteria"""
# Check age
last_saved = model_info.get('last_saved')
if last_saved:
try:
# Parse timestamp (format: YYYYMMDD_HHMMSS)
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
age_days = (datetime.now() - model_date).days
if age_days > self.selection_criteria['max_age_days']:
return False
except ValueError:
logger.warning(f"Could not parse timestamp: {last_saved}")
return True
def _calculate_selection_score(self, model_name: str, model_info: Dict[str, Any], model_type: str) -> float:
"""Calculate selection score for a model"""
score = 0.0
# Base score from recency (newer is better)
last_saved = model_info.get('last_saved')
if last_saved:
try:
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
days_old = (datetime.now() - model_date).days
recency_score = max(0, 30 - days_old) / 30.0 # 0-1 score for last 30 days
score += recency_score * 0.4
except ValueError:
pass
# Score from checkpoints (having checkpoints is good)
checkpoint_count = model_info.get('checkpoint_count', 0)
if checkpoint_count > 0:
checkpoint_score = min(checkpoint_count / 10.0, 1.0) # Max score for 10+ checkpoints
score += checkpoint_score * 0.3
# Score from save count (more saves might indicate stability)
save_count = model_info.get('save_count', 0)
if save_count > 1:
stability_score = min(save_count / 5.0, 1.0) # Max score for 5+ saves
score += stability_score * 0.3
return score
def _select_fallback_model(self, models: List[Tuple[str, Dict]]) -> Optional[Dict[str, Any]]:
"""Select a fallback model when no good candidates found"""
if not models:
return None
# Sort by recency
sorted_models = sorted(models, key=lambda x: x[1].get('last_saved', ''), reverse=True)
model_name, model_info = sorted_models[0]
return {
'name': model_name,
'type': model_info.get('type', 'unknown'),
'info': model_info,
'score': 0.0,
'selection_reason': 'fallback_most_recent',
'verified': False
}
def _verify_model_loadable(self, model_name: str, model_type: str) -> bool:
"""Verify that a model can be loaded successfully"""
try:
model = load_model(model_name, model_type)
return model is not None
except Exception as e:
logger.warning(f"Model verification failed for {model_name}: {e}")
return False
def _get_selection_reason(self, candidate: Dict[str, Any]) -> str:
"""Get human-readable selection reason"""
reasons = []
if candidate.get('has_checkpoints'):
reasons.append("has_checkpoints")
score = candidate.get('score', 0)
if score > 0.8:
reasons.append("high_score")
elif score > 0.6:
reasons.append("good_score")
else:
reasons.append("acceptable_score")
return ", ".join(reasons) if reasons else "default_selection"
def load_selected_models(self, selected_models: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Load the selected models into memory.
Args:
selected_models: Dictionary from select_best_models_for_startup
Returns:
Dictionary of loaded models
"""
loaded_models = {}
for model_type, selection_info in selected_models.items():
model_name = selection_info['name']
logger.info(f"Loading {model_type} model: {model_name}")
try:
# Try to load best checkpoint first if available
if selection_info['info'].get('checkpoint_count', 0) > 0:
checkpoint_result = load_best_checkpoint(model_name, model_type)
if checkpoint_result:
checkpoint_path, checkpoint_data = checkpoint_result
loaded_models[model_type] = {
'model': None, # Would need proper model class instantiation
'checkpoint_data': checkpoint_data,
'source': 'checkpoint',
'path': checkpoint_path,
'performance_score': checkpoint_data.get('performance_score', 0)
}
logger.info(f"Loaded {model_type} from checkpoint: {checkpoint_path}")
continue
# Fall back to regular model loading
model = load_model(model_name, model_type)
if model:
loaded_models[model_type] = {
'model': model,
'source': 'latest',
'path': selection_info['info'].get('latest_path'),
'performance_score': None
}
logger.info(f"Loaded {model_type} from latest: {model_name}")
else:
logger.error(f"Failed to load {model_type} model: {model_name}")
except Exception as e:
logger.error(f"Error loading {model_type} model {model_name}: {e}")
return loaded_models
def get_startup_report(self, selected_models: Dict[str, Dict[str, Any]],
loaded_models: Dict[str, Any]) -> str:
"""Generate a startup report"""
report_lines = [
"=" * 60,
"MODEL STARTUP SELECTION REPORT",
"=" * 60,
f"Selection Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
""
]
if selected_models:
report_lines.append("SELECTED MODELS:")
for model_type, selection_info in selected_models.items():
report_lines.append(f" {model_type.upper()}: {selection_info['name']}")
report_lines.append(f" - Score: {selection_info.get('score', 0):.3f}")
report_lines.append(f" - Reason: {selection_info.get('selection_reason', 'unknown')}")
report_lines.append(f" - Verified: {selection_info.get('verified', False)}")
report_lines.append(f" - Last Saved: {selection_info['info'].get('last_saved', 'unknown')}")
report_lines.append("")
else:
report_lines.append("NO MODELS SELECTED")
report_lines.append("")
if loaded_models:
report_lines.append("LOADED MODELS:")
for model_type, model_info in loaded_models.items():
source = model_info.get('source', 'unknown')
report_lines.append(f" {model_type.upper()}: Loaded from {source}")
if 'performance_score' in model_info and model_info['performance_score'] is not None:
report_lines.append(f" - Performance Score: {model_info['performance_score']:.3f}")
report_lines.append("")
else:
report_lines.append("NO MODELS LOADED")
report_lines.append("")
# Add summary statistics
total_models = len(self.registry.list_models())
selected_count = len(selected_models)
loaded_count = len(loaded_models)
report_lines.extend([
"SUMMARY STATISTICS:",
f" Total Available Models: {total_models}",
f" Models Selected: {selected_count}",
f" Models Loaded: {loaded_count}",
"=" * 60
])
return "\n".join(report_lines)
# Global instance
_model_selector = None
def get_model_selector() -> ModelSelector:
"""Get the global model selector instance"""
global _model_selector
if _model_selector is None:
_model_selector = ModelSelector()
return _model_selector
def select_and_load_best_models() -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
"""
Convenience function to select and load best models for startup.
Returns:
Tuple of (selected_models_info, loaded_models)
"""
selector = get_model_selector()
# Select best models
selected_models = selector.select_best_models_for_startup()
# Load selected models
loaded_models = selector.load_selected_models(selected_models)
# Generate and log report
report = selector.get_startup_report(selected_models, loaded_models)
logger.info("Model Startup Report:\n" + report)
return selected_models, loaded_models