Files
gogo2/improved_model_saver.py
2025-09-02 16:05:44 +03:00

362 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Improved Model Saver
A comprehensive model saving utility that handles various model types
and ensures reliable checkpointing with validation.
"""
import logging
import torch
import os
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional, Union
import shutil
logger = logging.getLogger(__name__)
class ImprovedModelSaver:
"""Enhanced model saving with validation and backup strategies"""
def __init__(self, base_dir: str = "models/saved"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
def save_model_safely(self,
model: Any,
model_name: str,
model_type: str = "unknown",
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model with multiple fallback strategies
Args:
model: The model to save
model_name: Name identifier for the model
model_type: Type of model (dqn, cnn, rl, etc.)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_dir = self.base_dir / model_name
model_dir.mkdir(parents=True, exist_ok=True)
# Create backup file names
main_path = model_dir / f"{model_name}_latest.pt"
backup_path = model_dir / f"{model_name}_{timestamp}.pt"
try:
# Strategy 1: Try to save using robust_save if available
if hasattr(model, '__dict__') and hasattr(torch, 'save'):
success = self._save_pytorch_model(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using PyTorch save")
return True
# Strategy 2: Try state_dict saving for PyTorch models
if hasattr(model, 'state_dict'):
success = self._save_state_dict(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using state_dict")
return True
# Strategy 3: Try component-based saving for complex models
if hasattr(model, 'policy_net') or hasattr(model, 'target_net'):
success = self._save_rl_agent_components(model, model_dir, model_name)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using component-based saving")
return True
# Strategy 4: Fallback - try pickle
success = self._save_with_pickle(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using pickle fallback")
return True
logger.error(f"All save strategies failed for {model_name}")
return False
except Exception as e:
logger.error(f"Critical error saving {model_name}: {e}")
return False
def _save_pytorch_model(self, model, main_path: Path, backup_path: Path) -> bool:
"""Save using standard PyTorch torch.save"""
try:
# Create checkpoint data
if hasattr(model, 'state_dict'):
checkpoint = {
'model_state_dict': model.state_dict(),
'model_class': model.__class__.__name__,
'timestamp': datetime.now().isoformat()
}
# Add additional attributes
for attr in ['epsilon', 'total_steps', 'current_reward', 'optimizer']:
if hasattr(model, attr):
try:
value = getattr(model, attr)
if attr == 'optimizer' and value is not None:
checkpoint['optimizer_state_dict'] = value.state_dict()
else:
checkpoint[attr] = value
except Exception:
pass # Skip problematic attributes
else:
checkpoint = {
'model': model,
'timestamp': datetime.now().isoformat()
}
# Save to backup location first
torch.save(checkpoint, backup_path)
# Verify backup was saved correctly
torch.load(backup_path, map_location='cpu')
# Copy to main location
shutil.copy2(backup_path, main_path)
return True
except Exception as e:
logger.warning(f"PyTorch save failed: {e}")
return False
def _save_state_dict(self, model, main_path: Path, backup_path: Path) -> bool:
"""Save using state_dict only"""
try:
state_dict = model.state_dict()
checkpoint = {
'state_dict': state_dict,
'model_class': model.__class__.__name__,
'timestamp': datetime.now().isoformat()
}
torch.save(checkpoint, backup_path)
torch.load(backup_path, map_location='cpu') # Verify
shutil.copy2(backup_path, main_path)
return True
except Exception as e:
logger.warning(f"State dict save failed: {e}")
return False
def _save_rl_agent_components(self, model, model_dir: Path, model_name: str) -> bool:
"""Save RL agent components separately"""
try:
components_saved = 0
# Save policy network
if hasattr(model, 'policy_net') and model.policy_net is not None:
policy_path = model_dir / f"{model_name}_policy.pt"
torch.save(model.policy_net.state_dict(), policy_path)
components_saved += 1
# Save target network
if hasattr(model, 'target_net') and model.target_net is not None:
target_path = model_dir / f"{model_name}_target.pt"
torch.save(model.target_net.state_dict(), target_path)
components_saved += 1
# Save agent state
agent_state = {}
for attr in ['epsilon', 'total_steps', 'current_reward', 'memory']:
if hasattr(model, attr):
try:
value = getattr(model, attr)
if attr == 'memory' and hasattr(value, '__len__'):
# Don't save large replay buffers
agent_state[attr + '_size'] = len(value)
else:
agent_state[attr] = value
except Exception:
pass
if agent_state:
state_path = model_dir / f"{model_name}_agent_state.pt"
torch.save(agent_state, state_path)
components_saved += 1
return components_saved > 0
except Exception as e:
logger.warning(f"Component-based save failed: {e}")
return False
def _save_with_pickle(self, model, main_path: Path, backup_path: Path) -> bool:
"""Fallback: save using pickle"""
try:
import pickle
with open(backup_path.with_suffix('.pkl'), 'wb') as f:
pickle.dump(model, f)
# Verify
with open(backup_path.with_suffix('.pkl'), 'rb') as f:
pickle.load(f)
shutil.copy2(backup_path.with_suffix('.pkl'), main_path.with_suffix('.pkl'))
return True
except Exception as e:
logger.warning(f"Pickle save failed: {e}")
return False
def _save_metadata(self, model_dir: Path, model_name: str, model_type: str, metadata: Optional[Dict[str, Any]]):
"""Save model metadata"""
try:
meta_data = {
'model_name': model_name,
'model_type': model_type,
'saved_at': datetime.now().isoformat(),
'save_method': 'improved_model_saver'
}
if metadata:
meta_data.update(metadata)
meta_path = model_dir / f"{model_name}_metadata.json"
with open(meta_path, 'w') as f:
json.dump(meta_data, f, indent=2, default=str)
except Exception as e:
logger.warning(f"Failed to save metadata: {e}")
def load_model_safely(self, model_name: str, model_class=None):
"""
Load a model with multiple strategies
Args:
model_name: Name of the model to load
model_class: Class to instantiate if needed
Returns:
Loaded model or None
"""
model_dir = self.base_dir / model_name
if not model_dir.exists():
logger.warning(f"Model directory not found: {model_dir}")
return None
# Try different loading strategies
loaders = [
self._load_pytorch_checkpoint,
self._load_state_dict_only,
self._load_rl_components,
self._load_pickle_fallback
]
for loader in loaders:
try:
result = loader(model_dir, model_name, model_class)
if result is not None:
logger.info(f"Successfully loaded {model_name} using {loader.__name__}")
return result
except Exception as e:
logger.debug(f"{loader.__name__} failed: {e}")
continue
logger.error(f"All load strategies failed for {model_name}")
return None
def _load_pytorch_checkpoint(self, model_dir: Path, model_name: str, model_class):
"""Load PyTorch checkpoint"""
main_path = model_dir / f"{model_name}_latest.pt"
if main_path.exists():
checkpoint = torch.load(main_path, map_location='cpu')
if model_class and 'model_state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
# Restore other attributes
for key, value in checkpoint.items():
if key not in ['model_state_dict', 'optimizer_state_dict', 'timestamp', 'model_class']:
if hasattr(model, key):
setattr(model, key, value)
return model
return checkpoint.get('model', checkpoint)
return None
def _load_state_dict_only(self, model_dir: Path, model_name: str, model_class):
"""Load state dict only"""
main_path = model_dir / f"{model_name}_latest.pt"
if main_path.exists() and model_class:
checkpoint = torch.load(main_path, map_location='cpu')
if 'state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['state_dict'])
return model
return None
def _load_rl_components(self, model_dir: Path, model_name: str, model_class):
"""Load RL agent from components"""
policy_path = model_dir / f"{model_name}_policy.pt"
target_path = model_dir / f"{model_name}_target.pt"
state_path = model_dir / f"{model_name}_agent_state.pt"
if policy_path.exists() and model_class:
model = model_class()
# Load policy network
if hasattr(model, 'policy_net'):
model.policy_net.load_state_dict(torch.load(policy_path, map_location='cpu'))
# Load target network
if target_path.exists() and hasattr(model, 'target_net'):
model.target_net.load_state_dict(torch.load(target_path, map_location='cpu'))
# Load agent state
if state_path.exists():
agent_state = torch.load(state_path, map_location='cpu')
for key, value in agent_state.items():
if hasattr(model, key):
setattr(model, key, value)
return model
return None
def _load_pickle_fallback(self, model_dir: Path, model_name: str, model_class):
"""Load from pickle"""
pickle_path = model_dir / f"{model_name}_latest.pkl"
if pickle_path.exists():
import pickle
with open(pickle_path, 'rb') as f:
return pickle.load(f)
return None
# Global instance for easy access
_improved_model_saver = None
def get_improved_model_saver() -> ImprovedModelSaver:
"""Get or create the global improved model saver instance"""
global _improved_model_saver
if _improved_model_saver is None:
_improved_model_saver = ImprovedModelSaver()
return _improved_model_saver