stability
This commit is contained in:
@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
self.enable_wandb = enable_wandb
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
|
||||
@ -55,9 +56,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(cnn_model.state_dict(), model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
@ -114,9 +119,13 @@ class TrainingIntegration:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
|
Reference in New Issue
Block a user