241 lines
8.5 KiB
Python
241 lines
8.5 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Model utilities for robust saving and loading of PyTorch models
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import torch
|
|
import shutil
|
|
import gc
|
|
import json
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
|
|
"""
|
|
Robust model saving with multiple fallback approaches
|
|
|
|
Args:
|
|
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
|
|
path: Path to save the model
|
|
include_optimizer: Whether to include optimizer state in the save
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
|
|
# Backup path in case the main save fails
|
|
backup_path = f"{path}.backup"
|
|
|
|
# Clean up GPU memory before saving
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Prepare checkpoint data
|
|
checkpoint = {
|
|
'policy_net': model.policy_net.state_dict(),
|
|
'target_net': model.target_net.state_dict(),
|
|
'epsilon': getattr(model, 'epsilon', 0.0),
|
|
'state_size': getattr(model, 'state_size', None),
|
|
'action_size': getattr(model, 'action_size', None),
|
|
'hidden_size': getattr(model, 'hidden_size', None),
|
|
}
|
|
|
|
# Add optimizer state if requested and available
|
|
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
|
|
checkpoint['optimizer'] = model.optimizer.state_dict()
|
|
|
|
# Attempt 1: Try with default settings in a separate file first
|
|
try:
|
|
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
|
torch.save(checkpoint, backup_path)
|
|
logger.info(f"Successfully saved to {backup_path}")
|
|
|
|
# If backup worked, copy to the actual path
|
|
if os.path.exists(backup_path):
|
|
shutil.copy(backup_path, path)
|
|
logger.info(f"Copied backup to {path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"First save attempt failed: {e}")
|
|
|
|
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
|
torch.save(checkpoint, path, pickle_protocol=2)
|
|
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Second save attempt failed: {e}")
|
|
|
|
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
|
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
|
|
torch.save(checkpoint_no_opt, path)
|
|
logger.info(f"Successfully saved to {path} without optimizer state")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Third save attempt failed: {e}")
|
|
|
|
# Attempt 4: Try with torch.jit.save instead
|
|
try:
|
|
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
|
# Save policy network using jit
|
|
scripted_policy = torch.jit.script(model.policy_net)
|
|
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
|
|
|
# Save target network using jit
|
|
scripted_target = torch.jit.script(model.target_net)
|
|
torch.jit.save(scripted_target, f"{path}.target.jit")
|
|
|
|
# Save parameters separately as JSON
|
|
params = {
|
|
'epsilon': float(getattr(model, 'epsilon', 0.0)),
|
|
'state_size': int(getattr(model, 'state_size', 0)),
|
|
'action_size': int(getattr(model, 'action_size', 0)),
|
|
'hidden_size': int(getattr(model, 'hidden_size', 0))
|
|
}
|
|
with open(f"{path}.params.json", "w") as f:
|
|
json.dump(params, f)
|
|
|
|
logger.info(f"Successfully saved model components with jit.save")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"All save attempts failed: {e}")
|
|
return False
|
|
|
|
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
|
|
"""
|
|
Robust model loading with fallback approaches
|
|
|
|
Args:
|
|
model: The model object to load into
|
|
path: Path to load the model from
|
|
device: Device to load the model on
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
if device is None:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Try regular PyTorch load first
|
|
try:
|
|
logger.info(f"Loading model from {path}")
|
|
if os.path.exists(path):
|
|
checkpoint = torch.load(path, map_location=device)
|
|
|
|
# Load network states
|
|
if 'policy_net' in checkpoint:
|
|
model.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
if 'target_net' in checkpoint:
|
|
model.target_net.load_state_dict(checkpoint['target_net'])
|
|
|
|
# Load other attributes
|
|
if 'epsilon' in checkpoint:
|
|
model.epsilon = checkpoint['epsilon']
|
|
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
|
|
try:
|
|
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load optimizer state: {e}")
|
|
|
|
logger.info("Successfully loaded model")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Regular load failed: {e}")
|
|
|
|
# Try loading JIT saved components
|
|
try:
|
|
policy_path = f"{path}.policy.jit"
|
|
target_path = f"{path}.target.jit"
|
|
params_path = f"{path}.params.json"
|
|
|
|
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
|
|
logger.info(f"Loading JIT model components")
|
|
|
|
# Load JIT models (this is more complex and may need model reconstruction)
|
|
# For now, just log that we found JIT files
|
|
logger.info("Found JIT model files, but loading them requires special handling")
|
|
with open(params_path, 'r') as f:
|
|
params = json.load(f)
|
|
logger.info(f"Model parameters: {params}")
|
|
|
|
# Note: Actually loading JIT models would require recreating the model architecture
|
|
# This is a placeholder for future implementation
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"JIT load failed: {e}")
|
|
|
|
logger.error(f"All load attempts failed for {path}")
|
|
return False
|
|
|
|
def get_model_info(path: str) -> Dict[str, Any]:
|
|
"""
|
|
Get information about a saved model
|
|
|
|
Args:
|
|
path: Path to the model file
|
|
|
|
Returns:
|
|
dict: Model information
|
|
"""
|
|
info = {
|
|
'exists': False,
|
|
'size_bytes': 0,
|
|
'has_optimizer': False,
|
|
'parameters': {}
|
|
}
|
|
|
|
try:
|
|
if os.path.exists(path):
|
|
info['exists'] = True
|
|
info['size_bytes'] = os.path.getsize(path)
|
|
|
|
# Try to load and inspect
|
|
checkpoint = torch.load(path, map_location='cpu')
|
|
info['has_optimizer'] = 'optimizer' in checkpoint
|
|
|
|
# Extract parameter info
|
|
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
|
|
if key in checkpoint:
|
|
info['parameters'][key] = checkpoint[key]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get model info for {path}: {e}")
|
|
|
|
return info
|
|
|
|
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
|
|
"""
|
|
Test that a model can be saved and loaded correctly
|
|
|
|
Args:
|
|
model: Model to test
|
|
test_path: Path for test file
|
|
|
|
Returns:
|
|
bool: True if save/load cycle successful
|
|
"""
|
|
try:
|
|
# Save the model
|
|
if not robust_save(model, test_path):
|
|
return False
|
|
|
|
# Create a new model instance (this would need model creation logic)
|
|
# For now, just verify the file exists and has content
|
|
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
|
|
logger.info("Save/load cycle verification successful")
|
|
# Clean up test file
|
|
os.remove(test_path)
|
|
return True
|
|
else:
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Save/load cycle verification failed: {e}")
|
|
return False |