model checkpoint manager
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -774,42 +775,107 @@ class CNNModelTrainer:
|
||||
# Return realistic loss values based on random baseline performance
|
||||
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
|
||||
|
||||
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata"""
|
||||
save_dict = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'training_history': self.training_history,
|
||||
'model_config': {
|
||||
'input_size': self.model.input_size,
|
||||
'feature_dim': self.model.feature_dim,
|
||||
'output_size': self.model.output_size,
|
||||
'base_channels': self.model.base_channels
|
||||
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
|
||||
# Prepare model data
|
||||
model_data = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'training_history': self.training_history,
|
||||
'model_config': {
|
||||
'input_size': self.model.input_size,
|
||||
'feature_dim': self.model.feature_dim,
|
||||
'output_size': self.model.output_size,
|
||||
'base_channels': self.model.base_channels
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metadata:
|
||||
save_dict['metadata'] = metadata
|
||||
|
||||
torch.save(save_dict, filepath)
|
||||
logger.info(f"Enhanced CNN model saved to {filepath}")
|
||||
|
||||
if metadata:
|
||||
model_data['metadata'] = metadata
|
||||
|
||||
# Use unified registry if no filepath specified
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
# Extract model name from filepath or use default
|
||||
model_name = "enhanced_cnn"
|
||||
if filepath:
|
||||
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||
|
||||
success = save_model(
|
||||
model=self.model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
metadata={'full_checkpoint': model_data}
|
||||
)
|
||||
if success:
|
||||
logger.info(f"Enhanced CNN model saved to unified registry: {model_name}")
|
||||
return success
|
||||
else:
|
||||
# Legacy direct file save
|
||||
torch.save(model_data, filepath)
|
||||
logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save CNN model: {e}")
|
||||
return False
|
||||
|
||||
def load_model(self, filepath: str) -> Dict:
|
||||
"""Load model from file"""
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from {filepath}")
|
||||
return checkpoint.get('metadata', {})
|
||||
def load_model(self, filepath: str = None) -> Dict:
|
||||
"""Load model from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
|
||||
# Use unified registry if no filepath or if it's a models/ path
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
model_name = "enhanced_cnn"
|
||||
if filepath:
|
||||
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||
|
||||
model = load_model(model_name, 'cnn')
|
||||
if model is None:
|
||||
logger.warning(f"Could not load model {model_name} from unified registry")
|
||||
return {}
|
||||
|
||||
# Load full checkpoint data from metadata
|
||||
registry = get_model_registry()
|
||||
if model_name in registry.metadata['models']:
|
||||
model_data = registry.metadata['models'][model_name]
|
||||
if 'full_checkpoint' in model_data:
|
||||
checkpoint = model_data['full_checkpoint']
|
||||
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}")
|
||||
return checkpoint.get('metadata', {})
|
||||
|
||||
return {}
|
||||
|
||||
else:
|
||||
# Legacy direct file load
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)")
|
||||
return checkpoint.get('metadata', {})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load CNN model: {e}")
|
||||
return {}
|
||||
|
||||
def create_enhanced_cnn_model(input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
|
||||
@@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1329,54 +1330,140 @@ class DQNAgent:
|
||||
|
||||
return False # No improvement
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward
|
||||
}
|
||||
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model and agent state"""
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
def save(self, path: str = None):
|
||||
"""Save model and agent state using unified registry"""
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
from utils.model_registry import save_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
model_name = "dqn_agent"
|
||||
if path:
|
||||
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||
|
||||
# Prepare full agent state
|
||||
agent_state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward,
|
||||
'policy_net_state': self.policy_net.state_dict(),
|
||||
'target_net_state': self.target_net.state_dict()
|
||||
}
|
||||
|
||||
success = save_model(
|
||||
model=self.policy_net, # Save policy net as main model
|
||||
model_name=model_name,
|
||||
model_type='dqn',
|
||||
metadata={'full_agent_state': agent_state}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"DQN agent saved to unified registry: {model_name}")
|
||||
return
|
||||
|
||||
else:
|
||||
# Legacy direct file save
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward
|
||||
}
|
||||
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save DQN agent: {e}")
|
||||
|
||||
def load(self, path: str = None):
|
||||
"""Load model and agent state from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
model_name = "dqn_agent"
|
||||
if path:
|
||||
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||
|
||||
model = load_model(model_name, 'dqn')
|
||||
if model is None:
|
||||
logger.warning(f"Could not load DQN agent {model_name} from unified registry")
|
||||
return
|
||||
|
||||
# Load full agent state from metadata
|
||||
registry = get_model_registry()
|
||||
if model_name in registry.metadata['models']:
|
||||
model_data = registry.metadata['models'][model_name]
|
||||
if 'full_agent_state' in model_data:
|
||||
agent_state = model_data['full_agent_state']
|
||||
|
||||
# Restore agent state
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
# Load network states
|
||||
if 'policy_net_state' in agent_state:
|
||||
self.policy_net.load_state_dict(agent_state['policy_net_state'])
|
||||
if 'target_net_state' in agent_state:
|
||||
self.target_net.load_state_dict(agent_state['target_net_state'])
|
||||
|
||||
logger.info(f"DQN agent loaded from unified registry: {model_name}")
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
else:
|
||||
# Legacy direct file load
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load DQN agent: {e}")
|
||||
|
||||
def get_position_info(self):
|
||||
"""Get current position information"""
|
||||
|
||||
Reference in New Issue
Block a user