model selector
This commit is contained in:
@@ -1820,14 +1820,19 @@ class TradingOrchestrator:
|
||||
def start_enhanced_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
try:
|
||||
if not self.training_enabled or not getattr(self, 'training_manager', None):
|
||||
if not self.training_enabled or not self.enhanced_training_system:
|
||||
logger.warning("Enhanced training system not available")
|
||||
return False
|
||||
|
||||
self.training_manager.start()
|
||||
logger.info("Enhanced real-time training started")
|
||||
return True
|
||||
|
||||
|
||||
# Check if the enhanced training system has a start_training method
|
||||
if hasattr(self.enhanced_training_system, 'start_training'):
|
||||
self.enhanced_training_system.start_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
return True
|
||||
else:
|
||||
logger.warning("Enhanced training system does not have start_training method")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced training: {e}")
|
||||
return False
|
||||
@@ -1835,12 +1840,12 @@ class TradingOrchestrator:
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if getattr(self, 'training_manager', None):
|
||||
self.training_manager.stop()
|
||||
if self.enhanced_training_system and hasattr(self.enhanced_training_system, 'stop_training'):
|
||||
self.enhanced_training_system.stop_training()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced training: {e}")
|
||||
return False
|
||||
|
@@ -87,10 +87,20 @@ def main():
|
||||
os.environ['ENABLE_NN_MODELS'] = '1'
|
||||
|
||||
try:
|
||||
# Model Selection at Startup
|
||||
logger.info("Performing intelligent model selection...")
|
||||
try:
|
||||
from utils.model_selector import select_and_load_best_models
|
||||
selected_models, loaded_models = select_and_load_best_models()
|
||||
logger.info(f"Selected {len(selected_models)} model types, loaded {len(loaded_models)} models")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model selection failed, using defaults: {e}")
|
||||
selected_models, loaded_models = {}, {}
|
||||
|
||||
# Create data provider
|
||||
logger.info("Initializing data provider...")
|
||||
data_provider = DataProvider(symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
|
||||
# Create orchestrator (with safe CNN handling)
|
||||
logger.info("Initializing trading orchestrator...")
|
||||
orchestrator = create_safe_orchestrator()
|
||||
|
344
test_model_audit.py
Normal file
344
test_model_audit.py
Normal file
@@ -0,0 +1,344 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Loading/Saving Audit Test
|
||||
|
||||
This script tests the model registry and saving/loading mechanisms
|
||||
to identify any issues and provide recommendations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SimpleTestModel(nn.Module):
|
||||
"""Simple neural network for testing"""
|
||||
def __init__(self, input_size=10, hidden_size=32, output_size=2):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, output_size)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def test_model_registry():
|
||||
"""Test the model registry functionality"""
|
||||
logger.info("=== MODEL REGISTRY AUDIT ===")
|
||||
|
||||
registry = get_model_registry()
|
||||
logger.info(f"Registry base directory: {registry.base_dir}")
|
||||
logger.info(f"Registry metadata file: {registry.metadata_file}")
|
||||
|
||||
# Check existing models
|
||||
existing_models = registry.list_models()
|
||||
logger.info(f"Existing models: {existing_models}")
|
||||
|
||||
# Test model creation and saving
|
||||
logger.info("Creating test model...")
|
||||
test_model = SimpleTestModel()
|
||||
|
||||
# Generate some fake training data
|
||||
test_input = torch.randn(32, 10)
|
||||
test_output = test_model(test_input)
|
||||
|
||||
logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}")
|
||||
|
||||
# Test saving with different methods
|
||||
logger.info("Testing model saving...")
|
||||
|
||||
# Test 1: Save with unified registry
|
||||
success = save_model(
|
||||
model=test_model,
|
||||
model_name="audit_test_model",
|
||||
model_type="cnn",
|
||||
metadata={
|
||||
"test_type": "registry_audit",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"input_shape": list(test_input.shape),
|
||||
"output_shape": list(test_output.shape)
|
||||
}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("✅ Model saved successfully with unified registry")
|
||||
else:
|
||||
logger.error("❌ Failed to save model with unified registry")
|
||||
|
||||
# Test 2: Load model back
|
||||
logger.info("Testing model loading...")
|
||||
loaded_model = load_model("audit_test_model", "cnn")
|
||||
|
||||
if loaded_model is not None:
|
||||
logger.info("✅ Model loaded successfully")
|
||||
|
||||
# Test if loaded model has proper structure
|
||||
if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict):
|
||||
state_dict = loaded_model.state_dict()
|
||||
logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}")
|
||||
|
||||
# Check if we can create a new instance and load the state
|
||||
fresh_model = SimpleTestModel()
|
||||
try:
|
||||
fresh_model.load_state_dict(state_dict)
|
||||
test_output_loaded = fresh_model(test_input)
|
||||
logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}")
|
||||
|
||||
# Compare outputs (should be identical)
|
||||
if torch.allclose(test_output, test_output_loaded, atol=1e-6):
|
||||
logger.info("✅ Loaded model produces identical outputs")
|
||||
else:
|
||||
logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not test loaded model: {e}")
|
||||
else:
|
||||
logger.warning("Loaded model does not have proper structure")
|
||||
|
||||
else:
|
||||
logger.error("❌ Failed to load model")
|
||||
|
||||
# Test 3: Save checkpoint
|
||||
logger.info("Testing checkpoint saving...")
|
||||
checkpoint_success = save_checkpoint(
|
||||
model=test_model,
|
||||
model_name="audit_test_model",
|
||||
model_type="cnn",
|
||||
performance_score=0.85,
|
||||
metadata={
|
||||
"checkpoint_test": True,
|
||||
"performance_metric": "accuracy",
|
||||
"epoch": 1
|
||||
}
|
||||
)
|
||||
|
||||
if checkpoint_success:
|
||||
logger.info("✅ Checkpoint saved successfully")
|
||||
else:
|
||||
logger.error("❌ Failed to save checkpoint")
|
||||
|
||||
# Check registry metadata after operations
|
||||
logger.info("Checking registry metadata after operations...")
|
||||
updated_models = registry.list_models()
|
||||
logger.info(f"Updated models: {updated_models}")
|
||||
|
||||
# Check file system
|
||||
logger.info("Checking file system...")
|
||||
models_dir = Path("models")
|
||||
if models_dir.exists():
|
||||
logger.info(f"Models directory contents:")
|
||||
for item in models_dir.rglob("*"):
|
||||
if item.is_file():
|
||||
logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)")
|
||||
|
||||
return {
|
||||
"registry_save_success": success,
|
||||
"registry_load_success": loaded_model is not None,
|
||||
"checkpoint_success": checkpoint_success,
|
||||
"existing_models": existing_models,
|
||||
"updated_models": updated_models
|
||||
}
|
||||
|
||||
def audit_model_metadata():
|
||||
"""Audit the model metadata structure"""
|
||||
logger.info("=== MODEL METADATA AUDIT ===")
|
||||
|
||||
registry = get_model_registry()
|
||||
|
||||
# Check metadata structure
|
||||
metadata = registry.metadata
|
||||
logger.info(f"Metadata keys: {list(metadata.keys())}")
|
||||
|
||||
if 'models' in metadata:
|
||||
models = metadata['models']
|
||||
logger.info(f"Number of registered models: {len(models)}")
|
||||
|
||||
for model_name, model_data in models.items():
|
||||
logger.info(f"Model '{model_name}':")
|
||||
logger.info(f" - Type: {model_data.get('type', 'unknown')}")
|
||||
logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}")
|
||||
logger.info(f" - Save count: {model_data.get('save_count', 0)}")
|
||||
logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}")
|
||||
logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}")
|
||||
|
||||
if 'last_updated' in metadata:
|
||||
logger.info(f"Last metadata update: {metadata['last_updated']}")
|
||||
|
||||
return metadata
|
||||
|
||||
def analyze_model_files():
|
||||
"""Analyze the model files on disk"""
|
||||
logger.info("=== MODEL FILES ANALYSIS ===")
|
||||
|
||||
models_dir = Path("models")
|
||||
|
||||
if not models_dir.exists():
|
||||
logger.error("Models directory does not exist")
|
||||
return {}
|
||||
|
||||
analysis = {
|
||||
'total_files': 0,
|
||||
'total_size': 0,
|
||||
'by_type': {},
|
||||
'by_model': {},
|
||||
'orphaned_files': [],
|
||||
'missing_files': []
|
||||
}
|
||||
|
||||
# Analyze all .pt files
|
||||
for pt_file in models_dir.rglob("*.pt"):
|
||||
analysis['total_files'] += 1
|
||||
analysis['total_size'] += pt_file.stat().st_size
|
||||
|
||||
# Categorize by type
|
||||
parts = pt_file.parts
|
||||
model_type = "unknown"
|
||||
if "cnn" in parts:
|
||||
model_type = "cnn"
|
||||
elif "dqn" in parts:
|
||||
model_type = "dqn"
|
||||
elif "transformer" in parts:
|
||||
model_type = "transformer"
|
||||
elif "hybrid" in parts:
|
||||
model_type = "hybrid"
|
||||
|
||||
if model_type not in analysis['by_type']:
|
||||
analysis['by_type'][model_type] = []
|
||||
analysis['by_type'][model_type].append(str(pt_file))
|
||||
|
||||
# Try to extract model name
|
||||
filename = pt_file.name
|
||||
if "_latest" in filename:
|
||||
model_name = filename.replace("_latest.pt", "")
|
||||
elif "_" in filename:
|
||||
# Extract timestamp-based names
|
||||
parts = filename.split("_")
|
||||
if len(parts) >= 2:
|
||||
model_name = "_".join(parts[:-1]) # Everything except timestamp
|
||||
else:
|
||||
model_name = filename.replace(".pt", "")
|
||||
else:
|
||||
model_name = filename.replace(".pt", "")
|
||||
|
||||
if model_name not in analysis['by_model']:
|
||||
analysis['by_model'][model_name] = []
|
||||
analysis['by_model'][model_name].append(str(pt_file))
|
||||
|
||||
logger.info(f"Total model files: {analysis['total_files']}")
|
||||
logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB")
|
||||
|
||||
logger.info("Files by type:")
|
||||
for model_type, files in analysis['by_type'].items():
|
||||
logger.info(f" {model_type}: {len(files)} files")
|
||||
|
||||
logger.info("Files by model:")
|
||||
for model_name, files in analysis['by_model'].items():
|
||||
logger.info(f" {model_name}: {len(files)} files")
|
||||
|
||||
return analysis
|
||||
|
||||
def recommend_best_model_selection():
|
||||
"""Provide recommendations for best model selection at startup"""
|
||||
logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===")
|
||||
|
||||
registry = get_model_registry()
|
||||
models = registry.list_models()
|
||||
|
||||
recommendations = {
|
||||
'startup_strategy': 'hybrid',
|
||||
'fallback_models': [],
|
||||
'performance_criteria': [],
|
||||
'metadata_requirements': []
|
||||
}
|
||||
|
||||
if models:
|
||||
logger.info("Available models for selection:")
|
||||
|
||||
# Analyze each model type
|
||||
for model_name, model_info in models.items():
|
||||
model_type = model_info.get('type', 'unknown')
|
||||
logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}")
|
||||
|
||||
# Check if checkpoints exist
|
||||
if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0:
|
||||
logger.info(f" - Has {model_info['checkpoint_count']} checkpoints")
|
||||
recommendations['fallback_models'].append(model_name)
|
||||
|
||||
# Recommendations
|
||||
logger.info("RECOMMENDATIONS:")
|
||||
logger.info("1. Startup Strategy:")
|
||||
logger.info(" - Try to load latest model for each type")
|
||||
logger.info(" - Fall back to checkpoints if latest model fails")
|
||||
logger.info(" - Use fallback to basic/default model if all else fails")
|
||||
|
||||
logger.info("2. Performance-based Selection:")
|
||||
logger.info(" - For models with checkpoints, select highest performance_score")
|
||||
logger.info(" - Track model age and prefer recently trained models")
|
||||
logger.info(" - Implement model validation on startup")
|
||||
|
||||
logger.info("3. Metadata Requirements:")
|
||||
logger.info(" - Store performance metrics in metadata")
|
||||
logger.info(" - Track training data quality and size")
|
||||
logger.info(" - Include model validation results")
|
||||
|
||||
else:
|
||||
logger.info("No models registered - system will need initial training")
|
||||
logger.info("RECOMMENDATION: Implement default model initialization")
|
||||
|
||||
return recommendations
|
||||
|
||||
def main():
|
||||
"""Main audit function"""
|
||||
logger.info("Starting Model Loading/Saving Audit")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Test model registry
|
||||
registry_results = test_model_registry()
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Audit metadata
|
||||
metadata = audit_model_metadata()
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Analyze files
|
||||
file_analysis = analyze_model_files()
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Recommendations
|
||||
recommendations = recommend_best_model_selection()
|
||||
logger.info("-" * 40)
|
||||
|
||||
# Summary
|
||||
logger.info("=== AUDIT SUMMARY ===")
|
||||
logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}")
|
||||
logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}")
|
||||
logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}")
|
||||
logger.info(f"Total model files: {file_analysis.get('total_files', 0)}")
|
||||
logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}")
|
||||
|
||||
logger.info("Audit completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit failed with error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
364
utils/model_selector.py
Normal file
364
utils/model_selector.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Best Model Selection for Startup
|
||||
|
||||
This module provides intelligent model selection logic for choosing the best
|
||||
available models at system startup based on various criteria.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import torch
|
||||
|
||||
from utils.model_registry import get_model_registry, load_model, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelSelector:
|
||||
"""
|
||||
Intelligent model selector for startup and runtime model selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model selector"""
|
||||
self.registry = get_model_registry()
|
||||
self.selection_criteria = {
|
||||
'max_age_days': 30, # Don't use models older than 30 days
|
||||
'min_performance_score': 0.5, # Minimum acceptable performance
|
||||
'prefer_recent': True, # Prefer recently trained models
|
||||
'fallback_to_any': True # Use any model if no good ones found
|
||||
}
|
||||
|
||||
logger.info("Model Selector initialized")
|
||||
|
||||
def select_best_models_for_startup(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Select the best available models for each type at startup.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model types to selected model info
|
||||
"""
|
||||
logger.info("Selecting best models for startup...")
|
||||
|
||||
available_models = self.registry.list_models()
|
||||
selected_models = {}
|
||||
|
||||
# Group models by type
|
||||
models_by_type = {}
|
||||
for model_name, model_info in available_models.items():
|
||||
model_type = model_info.get('type', 'unknown')
|
||||
if model_type not in models_by_type:
|
||||
models_by_type[model_type] = []
|
||||
models_by_type[model_type].append((model_name, model_info))
|
||||
|
||||
# Select best model for each type
|
||||
for model_type, models in models_by_type.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
logger.info(f"Selecting best {model_type} model from {len(models)} candidates")
|
||||
|
||||
best_model = self._select_best_model_for_type(models, model_type)
|
||||
if best_model:
|
||||
selected_models[model_type] = best_model
|
||||
logger.info(f"Selected {best_model['name']} for {model_type}")
|
||||
else:
|
||||
logger.warning(f"No suitable {model_type} model found")
|
||||
|
||||
return selected_models
|
||||
|
||||
def _select_best_model_for_type(self, models: List[Tuple[str, Dict]], model_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Select the best model for a specific type.
|
||||
|
||||
Args:
|
||||
models: List of (name, info) tuples
|
||||
model_type: Type of model to select
|
||||
|
||||
Returns:
|
||||
Selected model information or None
|
||||
"""
|
||||
if not models:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
|
||||
for model_name, model_info in models:
|
||||
# Check if model meets basic criteria
|
||||
if not self._meets_basic_criteria(model_info):
|
||||
continue
|
||||
|
||||
# Calculate selection score
|
||||
score = self._calculate_selection_score(model_name, model_info, model_type)
|
||||
|
||||
candidates.append({
|
||||
'name': model_name,
|
||||
'info': model_info,
|
||||
'score': score,
|
||||
'has_checkpoints': model_info.get('checkpoint_count', 0) > 0
|
||||
})
|
||||
|
||||
if not candidates:
|
||||
if self.selection_criteria['fallback_to_any']:
|
||||
# Fallback to most recent model
|
||||
logger.info(f"No good {model_type} candidates, using fallback")
|
||||
return self._select_fallback_model(models)
|
||||
return None
|
||||
|
||||
# Sort by score (highest first)
|
||||
candidates.sort(key=lambda x: x['score'], reverse=True)
|
||||
best_candidate = candidates[0]
|
||||
|
||||
# Try to load the model to verify it's working
|
||||
if self._verify_model_loadable(best_candidate['name'], model_type):
|
||||
return {
|
||||
'name': best_candidate['name'],
|
||||
'type': model_type,
|
||||
'info': best_candidate['info'],
|
||||
'score': best_candidate['score'],
|
||||
'selection_reason': self._get_selection_reason(best_candidate),
|
||||
'verified': True
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Selected model {best_candidate['name']} failed verification")
|
||||
# Try next candidate
|
||||
if len(candidates) > 1:
|
||||
next_candidate = candidates[1]
|
||||
if self._verify_model_loadable(next_candidate['name'], model_type):
|
||||
return {
|
||||
'name': next_candidate['name'],
|
||||
'type': model_type,
|
||||
'info': next_candidate['info'],
|
||||
'score': next_candidate['score'],
|
||||
'selection_reason': 'fallback_after_verification_failure',
|
||||
'verified': True
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _meets_basic_criteria(self, model_info: Dict[str, Any]) -> bool:
|
||||
"""Check if model meets basic selection criteria"""
|
||||
# Check age
|
||||
last_saved = model_info.get('last_saved')
|
||||
if last_saved:
|
||||
try:
|
||||
# Parse timestamp (format: YYYYMMDD_HHMMSS)
|
||||
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
|
||||
age_days = (datetime.now() - model_date).days
|
||||
|
||||
if age_days > self.selection_criteria['max_age_days']:
|
||||
return False
|
||||
except ValueError:
|
||||
logger.warning(f"Could not parse timestamp: {last_saved}")
|
||||
|
||||
return True
|
||||
|
||||
def _calculate_selection_score(self, model_name: str, model_info: Dict[str, Any], model_type: str) -> float:
|
||||
"""Calculate selection score for a model"""
|
||||
score = 0.0
|
||||
|
||||
# Base score from recency (newer is better)
|
||||
last_saved = model_info.get('last_saved')
|
||||
if last_saved:
|
||||
try:
|
||||
model_date = datetime.strptime(last_saved, '%Y%m%d_%H%M%S')
|
||||
days_old = (datetime.now() - model_date).days
|
||||
recency_score = max(0, 30 - days_old) / 30.0 # 0-1 score for last 30 days
|
||||
score += recency_score * 0.4
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Score from checkpoints (having checkpoints is good)
|
||||
checkpoint_count = model_info.get('checkpoint_count', 0)
|
||||
if checkpoint_count > 0:
|
||||
checkpoint_score = min(checkpoint_count / 10.0, 1.0) # Max score for 10+ checkpoints
|
||||
score += checkpoint_score * 0.3
|
||||
|
||||
# Score from save count (more saves might indicate stability)
|
||||
save_count = model_info.get('save_count', 0)
|
||||
if save_count > 1:
|
||||
stability_score = min(save_count / 5.0, 1.0) # Max score for 5+ saves
|
||||
score += stability_score * 0.3
|
||||
|
||||
return score
|
||||
|
||||
def _select_fallback_model(self, models: List[Tuple[str, Dict]]) -> Optional[Dict[str, Any]]:
|
||||
"""Select a fallback model when no good candidates found"""
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Sort by recency
|
||||
sorted_models = sorted(models, key=lambda x: x[1].get('last_saved', ''), reverse=True)
|
||||
model_name, model_info = sorted_models[0]
|
||||
|
||||
return {
|
||||
'name': model_name,
|
||||
'type': model_info.get('type', 'unknown'),
|
||||
'info': model_info,
|
||||
'score': 0.0,
|
||||
'selection_reason': 'fallback_most_recent',
|
||||
'verified': False
|
||||
}
|
||||
|
||||
def _verify_model_loadable(self, model_name: str, model_type: str) -> bool:
|
||||
"""Verify that a model can be loaded successfully"""
|
||||
try:
|
||||
model = load_model(model_name, model_type)
|
||||
return model is not None
|
||||
except Exception as e:
|
||||
logger.warning(f"Model verification failed for {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _get_selection_reason(self, candidate: Dict[str, Any]) -> str:
|
||||
"""Get human-readable selection reason"""
|
||||
reasons = []
|
||||
|
||||
if candidate.get('has_checkpoints'):
|
||||
reasons.append("has_checkpoints")
|
||||
|
||||
score = candidate.get('score', 0)
|
||||
if score > 0.8:
|
||||
reasons.append("high_score")
|
||||
elif score > 0.6:
|
||||
reasons.append("good_score")
|
||||
else:
|
||||
reasons.append("acceptable_score")
|
||||
|
||||
return ", ".join(reasons) if reasons else "default_selection"
|
||||
|
||||
def load_selected_models(self, selected_models: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the selected models into memory.
|
||||
|
||||
Args:
|
||||
selected_models: Dictionary from select_best_models_for_startup
|
||||
|
||||
Returns:
|
||||
Dictionary of loaded models
|
||||
"""
|
||||
loaded_models = {}
|
||||
|
||||
for model_type, selection_info in selected_models.items():
|
||||
model_name = selection_info['name']
|
||||
|
||||
logger.info(f"Loading {model_type} model: {model_name}")
|
||||
|
||||
try:
|
||||
# Try to load best checkpoint first if available
|
||||
if selection_info['info'].get('checkpoint_count', 0) > 0:
|
||||
checkpoint_result = load_best_checkpoint(model_name, model_type)
|
||||
if checkpoint_result:
|
||||
checkpoint_path, checkpoint_data = checkpoint_result
|
||||
loaded_models[model_type] = {
|
||||
'model': None, # Would need proper model class instantiation
|
||||
'checkpoint_data': checkpoint_data,
|
||||
'source': 'checkpoint',
|
||||
'path': checkpoint_path,
|
||||
'performance_score': checkpoint_data.get('performance_score', 0)
|
||||
}
|
||||
logger.info(f"Loaded {model_type} from checkpoint: {checkpoint_path}")
|
||||
continue
|
||||
|
||||
# Fall back to regular model loading
|
||||
model = load_model(model_name, model_type)
|
||||
if model:
|
||||
loaded_models[model_type] = {
|
||||
'model': model,
|
||||
'source': 'latest',
|
||||
'path': selection_info['info'].get('latest_path'),
|
||||
'performance_score': None
|
||||
}
|
||||
logger.info(f"Loaded {model_type} from latest: {model_name}")
|
||||
else:
|
||||
logger.error(f"Failed to load {model_type} model: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {model_type} model {model_name}: {e}")
|
||||
|
||||
return loaded_models
|
||||
|
||||
def get_startup_report(self, selected_models: Dict[str, Dict[str, Any]],
|
||||
loaded_models: Dict[str, Any]) -> str:
|
||||
"""Generate a startup report"""
|
||||
report_lines = [
|
||||
"=" * 60,
|
||||
"MODEL STARTUP SELECTION REPORT",
|
||||
"=" * 60,
|
||||
f"Selection Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
""
|
||||
]
|
||||
|
||||
if selected_models:
|
||||
report_lines.append("SELECTED MODELS:")
|
||||
for model_type, selection_info in selected_models.items():
|
||||
report_lines.append(f" {model_type.upper()}: {selection_info['name']}")
|
||||
report_lines.append(f" - Score: {selection_info.get('score', 0):.3f}")
|
||||
report_lines.append(f" - Reason: {selection_info.get('selection_reason', 'unknown')}")
|
||||
report_lines.append(f" - Verified: {selection_info.get('verified', False)}")
|
||||
report_lines.append(f" - Last Saved: {selection_info['info'].get('last_saved', 'unknown')}")
|
||||
report_lines.append("")
|
||||
else:
|
||||
report_lines.append("NO MODELS SELECTED")
|
||||
report_lines.append("")
|
||||
|
||||
if loaded_models:
|
||||
report_lines.append("LOADED MODELS:")
|
||||
for model_type, model_info in loaded_models.items():
|
||||
source = model_info.get('source', 'unknown')
|
||||
report_lines.append(f" {model_type.upper()}: Loaded from {source}")
|
||||
if 'performance_score' in model_info and model_info['performance_score'] is not None:
|
||||
report_lines.append(f" - Performance Score: {model_info['performance_score']:.3f}")
|
||||
report_lines.append("")
|
||||
else:
|
||||
report_lines.append("NO MODELS LOADED")
|
||||
report_lines.append("")
|
||||
|
||||
# Add summary statistics
|
||||
total_models = len(self.registry.list_models())
|
||||
selected_count = len(selected_models)
|
||||
loaded_count = len(loaded_models)
|
||||
|
||||
report_lines.extend([
|
||||
"SUMMARY STATISTICS:",
|
||||
f" Total Available Models: {total_models}",
|
||||
f" Models Selected: {selected_count}",
|
||||
f" Models Loaded: {loaded_count}",
|
||||
"=" * 60
|
||||
])
|
||||
|
||||
return "\n".join(report_lines)
|
||||
|
||||
# Global instance
|
||||
_model_selector = None
|
||||
|
||||
def get_model_selector() -> ModelSelector:
|
||||
"""Get the global model selector instance"""
|
||||
global _model_selector
|
||||
if _model_selector is None:
|
||||
_model_selector = ModelSelector()
|
||||
return _model_selector
|
||||
|
||||
def select_and_load_best_models() -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Convenience function to select and load best models for startup.
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_models_info, loaded_models)
|
||||
"""
|
||||
selector = get_model_selector()
|
||||
|
||||
# Select best models
|
||||
selected_models = selector.select_best_models_for_startup()
|
||||
|
||||
# Load selected models
|
||||
loaded_models = selector.load_selected_models(selected_models)
|
||||
|
||||
# Generate and log report
|
||||
report = selector.get_startup_report(selected_models, loaded_models)
|
||||
logger.info("Model Startup Report:\n" + report)
|
||||
|
||||
return selected_models, loaded_models
|
Reference in New Issue
Block a user