massive clenup
This commit is contained in:
241
utils/model_utils.py
Normal file
241
utils/model_utils.py
Normal file
@ -0,0 +1,241 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Model utilities for robust saving and loading of PyTorch models
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import shutil
|
||||
import gc
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
|
||||
"""
|
||||
Robust model saving with multiple fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
|
||||
path: Path to save the model
|
||||
include_optimizer: Whether to include optimizer state in the save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
# Backup path in case the main save fails
|
||||
backup_path = f"{path}.backup"
|
||||
|
||||
# Clean up GPU memory before saving
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint = {
|
||||
'policy_net': model.policy_net.state_dict(),
|
||||
'target_net': model.target_net.state_dict(),
|
||||
'epsilon': getattr(model, 'epsilon', 0.0),
|
||||
'state_size': getattr(model, 'state_size', None),
|
||||
'action_size': getattr(model, 'action_size', None),
|
||||
'hidden_size': getattr(model, 'hidden_size', None),
|
||||
}
|
||||
|
||||
# Add optimizer state if requested and available
|
||||
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
checkpoint['optimizer'] = model.optimizer.state_dict()
|
||||
|
||||
# Attempt 1: Try with default settings in a separate file first
|
||||
try:
|
||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||
torch.save(checkpoint, backup_path)
|
||||
logger.info(f"Successfully saved to {backup_path}")
|
||||
|
||||
# If backup worked, copy to the actual path
|
||||
if os.path.exists(backup_path):
|
||||
shutil.copy(backup_path, path)
|
||||
logger.info(f"Copied backup to {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"First save attempt failed: {e}")
|
||||
|
||||
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||
torch.save(checkpoint, path, pickle_protocol=2)
|
||||
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Second save attempt failed: {e}")
|
||||
|
||||
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
|
||||
torch.save(checkpoint_no_opt, path)
|
||||
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Third save attempt failed: {e}")
|
||||
|
||||
# Attempt 4: Try with torch.jit.save instead
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||
# Save policy network using jit
|
||||
scripted_policy = torch.jit.script(model.policy_net)
|
||||
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||
|
||||
# Save target network using jit
|
||||
scripted_target = torch.jit.script(model.target_net)
|
||||
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||
|
||||
# Save parameters separately as JSON
|
||||
params = {
|
||||
'epsilon': float(getattr(model, 'epsilon', 0.0)),
|
||||
'state_size': int(getattr(model, 'state_size', 0)),
|
||||
'action_size': int(getattr(model, 'action_size', 0)),
|
||||
'hidden_size': int(getattr(model, 'hidden_size', 0))
|
||||
}
|
||||
with open(f"{path}.params.json", "w") as f:
|
||||
json.dump(params, f)
|
||||
|
||||
logger.info(f"Successfully saved model components with jit.save")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"All save attempts failed: {e}")
|
||||
return False
|
||||
|
||||
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
|
||||
"""
|
||||
Robust model loading with fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to load into
|
||||
path: Path to load the model from
|
||||
device: Device to load the model on
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Try regular PyTorch load first
|
||||
try:
|
||||
logger.info(f"Loading model from {path}")
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path, map_location=device)
|
||||
|
||||
# Load network states
|
||||
if 'policy_net' in checkpoint:
|
||||
model.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
if 'target_net' in checkpoint:
|
||||
model.target_net.load_state_dict(checkpoint['target_net'])
|
||||
|
||||
# Load other attributes
|
||||
if 'epsilon' in checkpoint:
|
||||
model.epsilon = checkpoint['epsilon']
|
||||
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
try:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
logger.info("Successfully loaded model")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Regular load failed: {e}")
|
||||
|
||||
# Try loading JIT saved components
|
||||
try:
|
||||
policy_path = f"{path}.policy.jit"
|
||||
target_path = f"{path}.target.jit"
|
||||
params_path = f"{path}.params.json"
|
||||
|
||||
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
|
||||
logger.info(f"Loading JIT model components")
|
||||
|
||||
# Load JIT models (this is more complex and may need model reconstruction)
|
||||
# For now, just log that we found JIT files
|
||||
logger.info("Found JIT model files, but loading them requires special handling")
|
||||
with open(params_path, 'r') as f:
|
||||
params = json.load(f)
|
||||
logger.info(f"Model parameters: {params}")
|
||||
|
||||
# Note: Actually loading JIT models would require recreating the model architecture
|
||||
# This is a placeholder for future implementation
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"JIT load failed: {e}")
|
||||
|
||||
logger.error(f"All load attempts failed for {path}")
|
||||
return False
|
||||
|
||||
def get_model_info(path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a saved model
|
||||
|
||||
Args:
|
||||
path: Path to the model file
|
||||
|
||||
Returns:
|
||||
dict: Model information
|
||||
"""
|
||||
info = {
|
||||
'exists': False,
|
||||
'size_bytes': 0,
|
||||
'has_optimizer': False,
|
||||
'parameters': {}
|
||||
}
|
||||
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
info['exists'] = True
|
||||
info['size_bytes'] = os.path.getsize(path)
|
||||
|
||||
# Try to load and inspect
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
info['has_optimizer'] = 'optimizer' in checkpoint
|
||||
|
||||
# Extract parameter info
|
||||
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
|
||||
if key in checkpoint:
|
||||
info['parameters'][key] = checkpoint[key]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model info for {path}: {e}")
|
||||
|
||||
return info
|
||||
|
||||
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
|
||||
"""
|
||||
Test that a model can be saved and loaded correctly
|
||||
|
||||
Args:
|
||||
model: Model to test
|
||||
test_path: Path for test file
|
||||
|
||||
Returns:
|
||||
bool: True if save/load cycle successful
|
||||
"""
|
||||
try:
|
||||
# Save the model
|
||||
if not robust_save(model, test_path):
|
||||
return False
|
||||
|
||||
# Create a new model instance (this would need model creation logic)
|
||||
# For now, just verify the file exists and has content
|
||||
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
|
||||
logger.info("Save/load cycle verification successful")
|
||||
# Clean up test file
|
||||
os.remove(test_path)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Save/load cycle verification failed: {e}")
|
||||
return False
|
Reference in New Issue
Block a user