stability

This commit is contained in:
Dobromir Popov
2025-07-28 12:10:52 +03:00
parent 9219b78241
commit fb72c93743
8 changed files with 207 additions and 53 deletions

View File

@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = True):
self.enable_wandb = enable_wandb
self.checkpoint_manager = get_checkpoint_manager()
@ -55,9 +56,13 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(cnn_model.state_dict(), model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model,
model_name=model_name,
model_path=model_path,
model_type='cnn',
performance_metrics=performance_metrics,
training_metadata=training_metadata
@ -114,9 +119,13 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent,
model_name=model_name,
model_path=model_path,
model_type='rl',
performance_metrics=performance_metrics,
training_metadata=training_metadata