Files
gogo2/test_model_audit.py
Dobromir Popov 32d54f0604 model selector
2025-09-08 14:53:46 +03:00

345 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Model Loading/Saving Audit Test
This script tests the model registry and saving/loading mechanisms
to identify any issues and provide recommendations.
"""
import os
import sys
import logging
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SimpleTestModel(nn.Module):
"""Simple neural network for testing"""
def __init__(self, input_size=10, hidden_size=32, output_size=2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
def forward(self, x):
return self.net(x)
def test_model_registry():
"""Test the model registry functionality"""
logger.info("=== MODEL REGISTRY AUDIT ===")
registry = get_model_registry()
logger.info(f"Registry base directory: {registry.base_dir}")
logger.info(f"Registry metadata file: {registry.metadata_file}")
# Check existing models
existing_models = registry.list_models()
logger.info(f"Existing models: {existing_models}")
# Test model creation and saving
logger.info("Creating test model...")
test_model = SimpleTestModel()
# Generate some fake training data
test_input = torch.randn(32, 10)
test_output = test_model(test_input)
logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}")
# Test saving with different methods
logger.info("Testing model saving...")
# Test 1: Save with unified registry
success = save_model(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
metadata={
"test_type": "registry_audit",
"created_at": datetime.now().isoformat(),
"input_shape": list(test_input.shape),
"output_shape": list(test_output.shape)
}
)
if success:
logger.info("✅ Model saved successfully with unified registry")
else:
logger.error("❌ Failed to save model with unified registry")
# Test 2: Load model back
logger.info("Testing model loading...")
loaded_model = load_model("audit_test_model", "cnn")
if loaded_model is not None:
logger.info("✅ Model loaded successfully")
# Test if loaded model has proper structure
if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict):
state_dict = loaded_model.state_dict()
logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}")
# Check if we can create a new instance and load the state
fresh_model = SimpleTestModel()
try:
fresh_model.load_state_dict(state_dict)
test_output_loaded = fresh_model(test_input)
logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}")
# Compare outputs (should be identical)
if torch.allclose(test_output, test_output_loaded, atol=1e-6):
logger.info("✅ Loaded model produces identical outputs")
else:
logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)")
except Exception as e:
logger.warning(f"Could not test loaded model: {e}")
else:
logger.warning("Loaded model does not have proper structure")
else:
logger.error("❌ Failed to load model")
# Test 3: Save checkpoint
logger.info("Testing checkpoint saving...")
checkpoint_success = save_checkpoint(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
performance_score=0.85,
metadata={
"checkpoint_test": True,
"performance_metric": "accuracy",
"epoch": 1
}
)
if checkpoint_success:
logger.info("✅ Checkpoint saved successfully")
else:
logger.error("❌ Failed to save checkpoint")
# Check registry metadata after operations
logger.info("Checking registry metadata after operations...")
updated_models = registry.list_models()
logger.info(f"Updated models: {updated_models}")
# Check file system
logger.info("Checking file system...")
models_dir = Path("models")
if models_dir.exists():
logger.info(f"Models directory contents:")
for item in models_dir.rglob("*"):
if item.is_file():
logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)")
return {
"registry_save_success": success,
"registry_load_success": loaded_model is not None,
"checkpoint_success": checkpoint_success,
"existing_models": existing_models,
"updated_models": updated_models
}
def audit_model_metadata():
"""Audit the model metadata structure"""
logger.info("=== MODEL METADATA AUDIT ===")
registry = get_model_registry()
# Check metadata structure
metadata = registry.metadata
logger.info(f"Metadata keys: {list(metadata.keys())}")
if 'models' in metadata:
models = metadata['models']
logger.info(f"Number of registered models: {len(models)}")
for model_name, model_data in models.items():
logger.info(f"Model '{model_name}':")
logger.info(f" - Type: {model_data.get('type', 'unknown')}")
logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}")
logger.info(f" - Save count: {model_data.get('save_count', 0)}")
logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}")
logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}")
if 'last_updated' in metadata:
logger.info(f"Last metadata update: {metadata['last_updated']}")
return metadata
def analyze_model_files():
"""Analyze the model files on disk"""
logger.info("=== MODEL FILES ANALYSIS ===")
models_dir = Path("models")
if not models_dir.exists():
logger.error("Models directory does not exist")
return {}
analysis = {
'total_files': 0,
'total_size': 0,
'by_type': {},
'by_model': {},
'orphaned_files': [],
'missing_files': []
}
# Analyze all .pt files
for pt_file in models_dir.rglob("*.pt"):
analysis['total_files'] += 1
analysis['total_size'] += pt_file.stat().st_size
# Categorize by type
parts = pt_file.parts
model_type = "unknown"
if "cnn" in parts:
model_type = "cnn"
elif "dqn" in parts:
model_type = "dqn"
elif "transformer" in parts:
model_type = "transformer"
elif "hybrid" in parts:
model_type = "hybrid"
if model_type not in analysis['by_type']:
analysis['by_type'][model_type] = []
analysis['by_type'][model_type].append(str(pt_file))
# Try to extract model name
filename = pt_file.name
if "_latest" in filename:
model_name = filename.replace("_latest.pt", "")
elif "_" in filename:
# Extract timestamp-based names
parts = filename.split("_")
if len(parts) >= 2:
model_name = "_".join(parts[:-1]) # Everything except timestamp
else:
model_name = filename.replace(".pt", "")
else:
model_name = filename.replace(".pt", "")
if model_name not in analysis['by_model']:
analysis['by_model'][model_name] = []
analysis['by_model'][model_name].append(str(pt_file))
logger.info(f"Total model files: {analysis['total_files']}")
logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB")
logger.info("Files by type:")
for model_type, files in analysis['by_type'].items():
logger.info(f" {model_type}: {len(files)} files")
logger.info("Files by model:")
for model_name, files in analysis['by_model'].items():
logger.info(f" {model_name}: {len(files)} files")
return analysis
def recommend_best_model_selection():
"""Provide recommendations for best model selection at startup"""
logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===")
registry = get_model_registry()
models = registry.list_models()
recommendations = {
'startup_strategy': 'hybrid',
'fallback_models': [],
'performance_criteria': [],
'metadata_requirements': []
}
if models:
logger.info("Available models for selection:")
# Analyze each model type
for model_name, model_info in models.items():
model_type = model_info.get('type', 'unknown')
logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}")
# Check if checkpoints exist
if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0:
logger.info(f" - Has {model_info['checkpoint_count']} checkpoints")
recommendations['fallback_models'].append(model_name)
# Recommendations
logger.info("RECOMMENDATIONS:")
logger.info("1. Startup Strategy:")
logger.info(" - Try to load latest model for each type")
logger.info(" - Fall back to checkpoints if latest model fails")
logger.info(" - Use fallback to basic/default model if all else fails")
logger.info("2. Performance-based Selection:")
logger.info(" - For models with checkpoints, select highest performance_score")
logger.info(" - Track model age and prefer recently trained models")
logger.info(" - Implement model validation on startup")
logger.info("3. Metadata Requirements:")
logger.info(" - Store performance metrics in metadata")
logger.info(" - Track training data quality and size")
logger.info(" - Include model validation results")
else:
logger.info("No models registered - system will need initial training")
logger.info("RECOMMENDATION: Implement default model initialization")
return recommendations
def main():
"""Main audit function"""
logger.info("Starting Model Loading/Saving Audit")
logger.info("=" * 60)
try:
# Test model registry
registry_results = test_model_registry()
logger.info("-" * 40)
# Audit metadata
metadata = audit_model_metadata()
logger.info("-" * 40)
# Analyze files
file_analysis = analyze_model_files()
logger.info("-" * 40)
# Recommendations
recommendations = recommend_best_model_selection()
logger.info("-" * 40)
# Summary
logger.info("=== AUDIT SUMMARY ===")
logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}")
logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}")
logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}")
logger.info(f"Total model files: {file_analysis.get('total_files', 0)}")
logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}")
logger.info("Audit completed successfully!")
except Exception as e:
logger.error(f"Audit failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()