362 lines
14 KiB
Python
362 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Improved Model Saver
|
|
|
|
A comprehensive model saving utility that handles various model types
|
|
and ensures reliable checkpointing with validation.
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import os
|
|
import json
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional, Union
|
|
import shutil
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ImprovedModelSaver:
|
|
"""Enhanced model saving with validation and backup strategies"""
|
|
|
|
def __init__(self, base_dir: str = "models/saved"):
|
|
self.base_dir = Path(base_dir)
|
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def save_model_safely(self,
|
|
model: Any,
|
|
model_name: str,
|
|
model_type: str = "unknown",
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Save a model with multiple fallback strategies
|
|
|
|
Args:
|
|
model: The model to save
|
|
model_name: Name identifier for the model
|
|
model_type: Type of model (dqn, cnn, rl, etc.)
|
|
metadata: Additional metadata to save
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
model_dir = self.base_dir / model_name
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create backup file names
|
|
main_path = model_dir / f"{model_name}_latest.pt"
|
|
backup_path = model_dir / f"{model_name}_{timestamp}.pt"
|
|
|
|
try:
|
|
# Strategy 1: Try to save using robust_save if available
|
|
if hasattr(model, '__dict__') and hasattr(torch, 'save'):
|
|
success = self._save_pytorch_model(model, main_path, backup_path)
|
|
if success:
|
|
self._save_metadata(model_dir, model_name, model_type, metadata)
|
|
logger.info(f"Successfully saved {model_name} using PyTorch save")
|
|
return True
|
|
|
|
# Strategy 2: Try state_dict saving for PyTorch models
|
|
if hasattr(model, 'state_dict'):
|
|
success = self._save_state_dict(model, main_path, backup_path)
|
|
if success:
|
|
self._save_metadata(model_dir, model_name, model_type, metadata)
|
|
logger.info(f"Successfully saved {model_name} using state_dict")
|
|
return True
|
|
|
|
# Strategy 3: Try component-based saving for complex models
|
|
if hasattr(model, 'policy_net') or hasattr(model, 'target_net'):
|
|
success = self._save_rl_agent_components(model, model_dir, model_name)
|
|
if success:
|
|
self._save_metadata(model_dir, model_name, model_type, metadata)
|
|
logger.info(f"Successfully saved {model_name} using component-based saving")
|
|
return True
|
|
|
|
# Strategy 4: Fallback - try pickle
|
|
success = self._save_with_pickle(model, main_path, backup_path)
|
|
if success:
|
|
self._save_metadata(model_dir, model_name, model_type, metadata)
|
|
logger.info(f"Successfully saved {model_name} using pickle fallback")
|
|
return True
|
|
|
|
logger.error(f"All save strategies failed for {model_name}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Critical error saving {model_name}: {e}")
|
|
return False
|
|
|
|
def _save_pytorch_model(self, model, main_path: Path, backup_path: Path) -> bool:
|
|
"""Save using standard PyTorch torch.save"""
|
|
try:
|
|
# Create checkpoint data
|
|
if hasattr(model, 'state_dict'):
|
|
checkpoint = {
|
|
'model_state_dict': model.state_dict(),
|
|
'model_class': model.__class__.__name__,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
# Add additional attributes
|
|
for attr in ['epsilon', 'total_steps', 'current_reward', 'optimizer']:
|
|
if hasattr(model, attr):
|
|
try:
|
|
value = getattr(model, attr)
|
|
if attr == 'optimizer' and value is not None:
|
|
checkpoint['optimizer_state_dict'] = value.state_dict()
|
|
else:
|
|
checkpoint[attr] = value
|
|
except Exception:
|
|
pass # Skip problematic attributes
|
|
else:
|
|
checkpoint = {
|
|
'model': model,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
# Save to backup location first
|
|
torch.save(checkpoint, backup_path)
|
|
|
|
# Verify backup was saved correctly
|
|
torch.load(backup_path, map_location='cpu')
|
|
|
|
# Copy to main location
|
|
shutil.copy2(backup_path, main_path)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"PyTorch save failed: {e}")
|
|
return False
|
|
|
|
def _save_state_dict(self, model, main_path: Path, backup_path: Path) -> bool:
|
|
"""Save using state_dict only"""
|
|
try:
|
|
state_dict = model.state_dict()
|
|
|
|
checkpoint = {
|
|
'state_dict': state_dict,
|
|
'model_class': model.__class__.__name__,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
torch.save(checkpoint, backup_path)
|
|
torch.load(backup_path, map_location='cpu') # Verify
|
|
shutil.copy2(backup_path, main_path)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"State dict save failed: {e}")
|
|
return False
|
|
|
|
def _save_rl_agent_components(self, model, model_dir: Path, model_name: str) -> bool:
|
|
"""Save RL agent components separately"""
|
|
try:
|
|
components_saved = 0
|
|
|
|
# Save policy network
|
|
if hasattr(model, 'policy_net') and model.policy_net is not None:
|
|
policy_path = model_dir / f"{model_name}_policy.pt"
|
|
torch.save(model.policy_net.state_dict(), policy_path)
|
|
components_saved += 1
|
|
|
|
# Save target network
|
|
if hasattr(model, 'target_net') and model.target_net is not None:
|
|
target_path = model_dir / f"{model_name}_target.pt"
|
|
torch.save(model.target_net.state_dict(), target_path)
|
|
components_saved += 1
|
|
|
|
# Save agent state
|
|
agent_state = {}
|
|
for attr in ['epsilon', 'total_steps', 'current_reward', 'memory']:
|
|
if hasattr(model, attr):
|
|
try:
|
|
value = getattr(model, attr)
|
|
if attr == 'memory' and hasattr(value, '__len__'):
|
|
# Don't save large replay buffers
|
|
agent_state[attr + '_size'] = len(value)
|
|
else:
|
|
agent_state[attr] = value
|
|
except Exception:
|
|
pass
|
|
|
|
if agent_state:
|
|
state_path = model_dir / f"{model_name}_agent_state.pt"
|
|
torch.save(agent_state, state_path)
|
|
components_saved += 1
|
|
|
|
return components_saved > 0
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Component-based save failed: {e}")
|
|
return False
|
|
|
|
def _save_with_pickle(self, model, main_path: Path, backup_path: Path) -> bool:
|
|
"""Fallback: save using pickle"""
|
|
try:
|
|
import pickle
|
|
|
|
with open(backup_path.with_suffix('.pkl'), 'wb') as f:
|
|
pickle.dump(model, f)
|
|
|
|
# Verify
|
|
with open(backup_path.with_suffix('.pkl'), 'rb') as f:
|
|
pickle.load(f)
|
|
|
|
shutil.copy2(backup_path.with_suffix('.pkl'), main_path.with_suffix('.pkl'))
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Pickle save failed: {e}")
|
|
return False
|
|
|
|
def _save_metadata(self, model_dir: Path, model_name: str, model_type: str, metadata: Optional[Dict[str, Any]]):
|
|
"""Save model metadata"""
|
|
try:
|
|
meta_data = {
|
|
'model_name': model_name,
|
|
'model_type': model_type,
|
|
'saved_at': datetime.now().isoformat(),
|
|
'save_method': 'improved_model_saver'
|
|
}
|
|
|
|
if metadata:
|
|
meta_data.update(metadata)
|
|
|
|
meta_path = model_dir / f"{model_name}_metadata.json"
|
|
with open(meta_path, 'w') as f:
|
|
json.dump(meta_data, f, indent=2, default=str)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to save metadata: {e}")
|
|
|
|
def load_model_safely(self, model_name: str, model_class=None):
|
|
"""
|
|
Load a model with multiple strategies
|
|
|
|
Args:
|
|
model_name: Name of the model to load
|
|
model_class: Class to instantiate if needed
|
|
|
|
Returns:
|
|
Loaded model or None
|
|
"""
|
|
model_dir = self.base_dir / model_name
|
|
|
|
if not model_dir.exists():
|
|
logger.warning(f"Model directory not found: {model_dir}")
|
|
return None
|
|
|
|
# Try different loading strategies
|
|
loaders = [
|
|
self._load_pytorch_checkpoint,
|
|
self._load_state_dict_only,
|
|
self._load_rl_components,
|
|
self._load_pickle_fallback
|
|
]
|
|
|
|
for loader in loaders:
|
|
try:
|
|
result = loader(model_dir, model_name, model_class)
|
|
if result is not None:
|
|
logger.info(f"Successfully loaded {model_name} using {loader.__name__}")
|
|
return result
|
|
except Exception as e:
|
|
logger.debug(f"{loader.__name__} failed: {e}")
|
|
continue
|
|
|
|
logger.error(f"All load strategies failed for {model_name}")
|
|
return None
|
|
|
|
def _load_pytorch_checkpoint(self, model_dir: Path, model_name: str, model_class):
|
|
"""Load PyTorch checkpoint"""
|
|
main_path = model_dir / f"{model_name}_latest.pt"
|
|
|
|
if main_path.exists():
|
|
checkpoint = torch.load(main_path, map_location='cpu')
|
|
|
|
if model_class and 'model_state_dict' in checkpoint:
|
|
model = model_class()
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
# Restore other attributes
|
|
for key, value in checkpoint.items():
|
|
if key not in ['model_state_dict', 'optimizer_state_dict', 'timestamp', 'model_class']:
|
|
if hasattr(model, key):
|
|
setattr(model, key, value)
|
|
|
|
return model
|
|
|
|
return checkpoint.get('model', checkpoint)
|
|
|
|
return None
|
|
|
|
def _load_state_dict_only(self, model_dir: Path, model_name: str, model_class):
|
|
"""Load state dict only"""
|
|
main_path = model_dir / f"{model_name}_latest.pt"
|
|
|
|
if main_path.exists() and model_class:
|
|
checkpoint = torch.load(main_path, map_location='cpu')
|
|
|
|
if 'state_dict' in checkpoint:
|
|
model = model_class()
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
return model
|
|
|
|
return None
|
|
|
|
def _load_rl_components(self, model_dir: Path, model_name: str, model_class):
|
|
"""Load RL agent from components"""
|
|
policy_path = model_dir / f"{model_name}_policy.pt"
|
|
target_path = model_dir / f"{model_name}_target.pt"
|
|
state_path = model_dir / f"{model_name}_agent_state.pt"
|
|
|
|
if policy_path.exists() and model_class:
|
|
model = model_class()
|
|
|
|
# Load policy network
|
|
if hasattr(model, 'policy_net'):
|
|
model.policy_net.load_state_dict(torch.load(policy_path, map_location='cpu'))
|
|
|
|
# Load target network
|
|
if target_path.exists() and hasattr(model, 'target_net'):
|
|
model.target_net.load_state_dict(torch.load(target_path, map_location='cpu'))
|
|
|
|
# Load agent state
|
|
if state_path.exists():
|
|
agent_state = torch.load(state_path, map_location='cpu')
|
|
for key, value in agent_state.items():
|
|
if hasattr(model, key):
|
|
setattr(model, key, value)
|
|
|
|
return model
|
|
|
|
return None
|
|
|
|
def _load_pickle_fallback(self, model_dir: Path, model_name: str, model_class):
|
|
"""Load from pickle"""
|
|
pickle_path = model_dir / f"{model_name}_latest.pkl"
|
|
|
|
if pickle_path.exists():
|
|
import pickle
|
|
with open(pickle_path, 'rb') as f:
|
|
return pickle.load(f)
|
|
|
|
return None
|
|
|
|
|
|
# Global instance for easy access
|
|
_improved_model_saver = None
|
|
|
|
def get_improved_model_saver() -> ImprovedModelSaver:
|
|
"""Get or create the global improved model saver instance"""
|
|
global _improved_model_saver
|
|
if _improved_model_saver is None:
|
|
_improved_model_saver = ImprovedModelSaver()
|
|
return _improved_model_saver
|