model checkpoint manager
This commit is contained in:
@@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1329,54 +1330,140 @@ class DQNAgent:
|
||||
|
||||
return False # No improvement
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward
|
||||
}
|
||||
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model and agent state"""
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
def save(self, path: str = None):
|
||||
"""Save model and agent state using unified registry"""
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
from utils.model_registry import save_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
model_name = "dqn_agent"
|
||||
if path:
|
||||
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||
|
||||
# Prepare full agent state
|
||||
agent_state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward,
|
||||
'policy_net_state': self.policy_net.state_dict(),
|
||||
'target_net_state': self.target_net.state_dict()
|
||||
}
|
||||
|
||||
success = save_model(
|
||||
model=self.policy_net, # Save policy net as main model
|
||||
model_name=model_name,
|
||||
model_type='dqn',
|
||||
metadata={'full_agent_state': agent_state}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"DQN agent saved to unified registry: {model_name}")
|
||||
return
|
||||
|
||||
else:
|
||||
# Legacy direct file save
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'best_reward': self.best_reward,
|
||||
'avg_reward': self.avg_reward
|
||||
}
|
||||
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save DQN agent: {e}")
|
||||
|
||||
def load(self, path: str = None):
|
||||
"""Load model and agent state from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
model_name = "dqn_agent"
|
||||
if path:
|
||||
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||
|
||||
model = load_model(model_name, 'dqn')
|
||||
if model is None:
|
||||
logger.warning(f"Could not load DQN agent {model_name} from unified registry")
|
||||
return
|
||||
|
||||
# Load full agent state from metadata
|
||||
registry = get_model_registry()
|
||||
if model_name in registry.metadata['models']:
|
||||
model_data = registry.metadata['models'][model_name]
|
||||
if 'full_agent_state' in model_data:
|
||||
agent_state = model_data['full_agent_state']
|
||||
|
||||
# Restore agent state
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
# Load network states
|
||||
if 'policy_net_state' in agent_state:
|
||||
self.policy_net.load_state_dict(agent_state['policy_net_state'])
|
||||
if 'target_net_state' in agent_state:
|
||||
self.target_net.load_state_dict(agent_state['target_net_state'])
|
||||
|
||||
logger.info(f"DQN agent loaded from unified registry: {model_name}")
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
else:
|
||||
# Legacy direct file load
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||
|
||||
# Load additional metrics if they exist
|
||||
if 'best_reward' in agent_state:
|
||||
self.best_reward = agent_state['best_reward']
|
||||
if 'avg_reward' in agent_state:
|
||||
self.avg_reward = agent_state['avg_reward']
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load DQN agent: {e}")
|
||||
|
||||
def get_position_info(self):
|
||||
"""Get current position information"""
|
||||
|
||||
Reference in New Issue
Block a user