gogo2/utils/model_utils.py
2025-05-24 10:32:00 +03:00

241 lines
8.5 KiB
Python

#!/usr/bin/env python
"""
Model utilities for robust saving and loading of PyTorch models
"""
import os
import logging
import torch
import shutil
import gc
import json
from typing import Any, Dict, Optional, Union
logger = logging.getLogger(__name__)
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
"""
Robust model saving with multiple fallback approaches
Args:
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
path: Path to save the model
include_optimizer: Whether to include optimizer state in the save
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare checkpoint data
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': getattr(model, 'epsilon', 0.0),
'state_size': getattr(model, 'state_size', None),
'action_size': getattr(model, 'action_size', None),
'hidden_size': getattr(model, 'hidden_size', None),
}
# Add optimizer state if requested and available
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
checkpoint['optimizer'] = model.optimizer.state_dict()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
torch.save(checkpoint_no_opt, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save parameters separately as JSON
params = {
'epsilon': float(getattr(model, 'epsilon', 0.0)),
'state_size': int(getattr(model, 'state_size', 0)),
'action_size': int(getattr(model, 'action_size', 0)),
'hidden_size': int(getattr(model, 'hidden_size', 0))
}
with open(f"{path}.params.json", "w") as f:
json.dump(params, f)
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
"""
Robust model loading with fallback approaches
Args:
model: The model object to load into
path: Path to load the model from
device: Device to load the model on
Returns:
bool: True if successful, False otherwise
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Try regular PyTorch load first
try:
logger.info(f"Loading model from {path}")
if os.path.exists(path):
checkpoint = torch.load(path, map_location=device)
# Load network states
if 'policy_net' in checkpoint:
model.policy_net.load_state_dict(checkpoint['policy_net'])
if 'target_net' in checkpoint:
model.target_net.load_state_dict(checkpoint['target_net'])
# Load other attributes
if 'epsilon' in checkpoint:
model.epsilon = checkpoint['epsilon']
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
try:
model.optimizer.load_state_dict(checkpoint['optimizer'])
except Exception as e:
logger.warning(f"Failed to load optimizer state: {e}")
logger.info("Successfully loaded model")
return True
except Exception as e:
logger.warning(f"Regular load failed: {e}")
# Try loading JIT saved components
try:
policy_path = f"{path}.policy.jit"
target_path = f"{path}.target.jit"
params_path = f"{path}.params.json"
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
logger.info(f"Loading JIT model components")
# Load JIT models (this is more complex and may need model reconstruction)
# For now, just log that we found JIT files
logger.info("Found JIT model files, but loading them requires special handling")
with open(params_path, 'r') as f:
params = json.load(f)
logger.info(f"Model parameters: {params}")
# Note: Actually loading JIT models would require recreating the model architecture
# This is a placeholder for future implementation
return False
except Exception as e:
logger.error(f"JIT load failed: {e}")
logger.error(f"All load attempts failed for {path}")
return False
def get_model_info(path: str) -> Dict[str, Any]:
"""
Get information about a saved model
Args:
path: Path to the model file
Returns:
dict: Model information
"""
info = {
'exists': False,
'size_bytes': 0,
'has_optimizer': False,
'parameters': {}
}
try:
if os.path.exists(path):
info['exists'] = True
info['size_bytes'] = os.path.getsize(path)
# Try to load and inspect
checkpoint = torch.load(path, map_location='cpu')
info['has_optimizer'] = 'optimizer' in checkpoint
# Extract parameter info
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
if key in checkpoint:
info['parameters'][key] = checkpoint[key]
except Exception as e:
logger.warning(f"Failed to get model info for {path}: {e}")
return info
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
"""
Test that a model can be saved and loaded correctly
Args:
model: Model to test
test_path: Path for test file
Returns:
bool: True if save/load cycle successful
"""
try:
# Save the model
if not robust_save(model, test_path):
return False
# Create a new model instance (this would need model creation logic)
# For now, just verify the file exists and has content
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
logger.info("Save/load cycle verification successful")
# Clean up test file
os.remove(test_path)
return True
else:
return False
except Exception as e:
logger.error(f"Save/load cycle verification failed: {e}")
return False