#!/usr/bin/env python3 """ Model Loading/Saving Audit Test This script tests the model registry and saving/loading mechanisms to identify any issues and provide recommendations. """ import os import sys import logging import torch import torch.nn as nn from datetime import datetime from pathlib import Path # Add project root to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class SimpleTestModel(nn.Module): """Simple neural network for testing""" def __init__(self, input_size=10, hidden_size=32, output_size=2): super().__init__() self.net = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, output_size) ) def forward(self, x): return self.net(x) def test_model_registry(): """Test the model registry functionality""" logger.info("=== MODEL REGISTRY AUDIT ===") registry = get_model_registry() logger.info(f"Registry base directory: {registry.base_dir}") logger.info(f"Registry metadata file: {registry.metadata_file}") # Check existing models existing_models = registry.list_models() logger.info(f"Existing models: {existing_models}") # Test model creation and saving logger.info("Creating test model...") test_model = SimpleTestModel() # Generate some fake training data test_input = torch.randn(32, 10) test_output = test_model(test_input) logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}") # Test saving with different methods logger.info("Testing model saving...") # Test 1: Save with unified registry success = save_model( model=test_model, model_name="audit_test_model", model_type="cnn", metadata={ "test_type": "registry_audit", "created_at": datetime.now().isoformat(), "input_shape": list(test_input.shape), "output_shape": list(test_output.shape) } ) if success: logger.info("✅ Model saved successfully with unified registry") else: logger.error("❌ Failed to save model with unified registry") # Test 2: Load model back logger.info("Testing model loading...") loaded_model = load_model("audit_test_model", "cnn") if loaded_model is not None: logger.info("✅ Model loaded successfully") # Test if loaded model has proper structure if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict): state_dict = loaded_model.state_dict() logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}") # Check if we can create a new instance and load the state fresh_model = SimpleTestModel() try: fresh_model.load_state_dict(state_dict) test_output_loaded = fresh_model(test_input) logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}") # Compare outputs (should be identical) if torch.allclose(test_output, test_output_loaded, atol=1e-6): logger.info("✅ Loaded model produces identical outputs") else: logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)") except Exception as e: logger.warning(f"Could not test loaded model: {e}") else: logger.warning("Loaded model does not have proper structure") else: logger.error("❌ Failed to load model") # Test 3: Save checkpoint logger.info("Testing checkpoint saving...") checkpoint_success = save_checkpoint( model=test_model, model_name="audit_test_model", model_type="cnn", performance_score=0.85, metadata={ "checkpoint_test": True, "performance_metric": "accuracy", "epoch": 1 } ) if checkpoint_success: logger.info("✅ Checkpoint saved successfully") else: logger.error("❌ Failed to save checkpoint") # Check registry metadata after operations logger.info("Checking registry metadata after operations...") updated_models = registry.list_models() logger.info(f"Updated models: {updated_models}") # Check file system logger.info("Checking file system...") models_dir = Path("models") if models_dir.exists(): logger.info(f"Models directory contents:") for item in models_dir.rglob("*"): if item.is_file(): logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)") return { "registry_save_success": success, "registry_load_success": loaded_model is not None, "checkpoint_success": checkpoint_success, "existing_models": existing_models, "updated_models": updated_models } def audit_model_metadata(): """Audit the model metadata structure""" logger.info("=== MODEL METADATA AUDIT ===") registry = get_model_registry() # Check metadata structure metadata = registry.metadata logger.info(f"Metadata keys: {list(metadata.keys())}") if 'models' in metadata: models = metadata['models'] logger.info(f"Number of registered models: {len(models)}") for model_name, model_data in models.items(): logger.info(f"Model '{model_name}':") logger.info(f" - Type: {model_data.get('type', 'unknown')}") logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}") logger.info(f" - Save count: {model_data.get('save_count', 0)}") logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}") logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}") if 'last_updated' in metadata: logger.info(f"Last metadata update: {metadata['last_updated']}") return metadata def analyze_model_files(): """Analyze the model files on disk""" logger.info("=== MODEL FILES ANALYSIS ===") models_dir = Path("models") if not models_dir.exists(): logger.error("Models directory does not exist") return {} analysis = { 'total_files': 0, 'total_size': 0, 'by_type': {}, 'by_model': {}, 'orphaned_files': [], 'missing_files': [] } # Analyze all .pt files for pt_file in models_dir.rglob("*.pt"): analysis['total_files'] += 1 analysis['total_size'] += pt_file.stat().st_size # Categorize by type parts = pt_file.parts model_type = "unknown" if "cnn" in parts: model_type = "cnn" elif "dqn" in parts: model_type = "dqn" elif "transformer" in parts: model_type = "transformer" elif "hybrid" in parts: model_type = "hybrid" if model_type not in analysis['by_type']: analysis['by_type'][model_type] = [] analysis['by_type'][model_type].append(str(pt_file)) # Try to extract model name filename = pt_file.name if "_latest" in filename: model_name = filename.replace("_latest.pt", "") elif "_" in filename: # Extract timestamp-based names parts = filename.split("_") if len(parts) >= 2: model_name = "_".join(parts[:-1]) # Everything except timestamp else: model_name = filename.replace(".pt", "") else: model_name = filename.replace(".pt", "") if model_name not in analysis['by_model']: analysis['by_model'][model_name] = [] analysis['by_model'][model_name].append(str(pt_file)) logger.info(f"Total model files: {analysis['total_files']}") logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB") logger.info("Files by type:") for model_type, files in analysis['by_type'].items(): logger.info(f" {model_type}: {len(files)} files") logger.info("Files by model:") for model_name, files in analysis['by_model'].items(): logger.info(f" {model_name}: {len(files)} files") return analysis def recommend_best_model_selection(): """Provide recommendations for best model selection at startup""" logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===") registry = get_model_registry() models = registry.list_models() recommendations = { 'startup_strategy': 'hybrid', 'fallback_models': [], 'performance_criteria': [], 'metadata_requirements': [] } if models: logger.info("Available models for selection:") # Analyze each model type for model_name, model_info in models.items(): model_type = model_info.get('type', 'unknown') logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}") # Check if checkpoints exist if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0: logger.info(f" - Has {model_info['checkpoint_count']} checkpoints") recommendations['fallback_models'].append(model_name) # Recommendations logger.info("RECOMMENDATIONS:") logger.info("1. Startup Strategy:") logger.info(" - Try to load latest model for each type") logger.info(" - Fall back to checkpoints if latest model fails") logger.info(" - Use fallback to basic/default model if all else fails") logger.info("2. Performance-based Selection:") logger.info(" - For models with checkpoints, select highest performance_score") logger.info(" - Track model age and prefer recently trained models") logger.info(" - Implement model validation on startup") logger.info("3. Metadata Requirements:") logger.info(" - Store performance metrics in metadata") logger.info(" - Track training data quality and size") logger.info(" - Include model validation results") else: logger.info("No models registered - system will need initial training") logger.info("RECOMMENDATION: Implement default model initialization") return recommendations def main(): """Main audit function""" logger.info("Starting Model Loading/Saving Audit") logger.info("=" * 60) try: # Test model registry registry_results = test_model_registry() logger.info("-" * 40) # Audit metadata metadata = audit_model_metadata() logger.info("-" * 40) # Analyze files file_analysis = analyze_model_files() logger.info("-" * 40) # Recommendations recommendations = recommend_best_model_selection() logger.info("-" * 40) # Summary logger.info("=== AUDIT SUMMARY ===") logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}") logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}") logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}") logger.info(f"Total model files: {file_analysis.get('total_files', 0)}") logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}") logger.info("Audit completed successfully!") except Exception as e: logger.error(f"Audit failed with error: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()