model checkpoint manager

This commit is contained in:
Dobromir Popov
2025-09-08 13:31:11 +03:00
parent 060fdd28b4
commit c9fba56622
6 changed files with 838 additions and 142 deletions

View File

@@ -4710,53 +4710,85 @@ class CleanTradingDashboard:
stored_models = []
# Use unified model registry for saving
from utils.model_registry import save_model
# 1. Store DQN model
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
if hasattr(self.orchestrator.rl_agent, 'save'):
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
stored_models.append(('DQN', save_path))
logger.info(f"Stored DQN model: {save_path}")
success = save_model(
model=self.orchestrator.rl_agent.policy_net, # Save policy network
model_name='dqn_agent_session',
model_type='dqn',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('DQN', 'models/dqn/saved/dqn_agent_session_latest.pt'))
logger.info("Stored DQN model via unified registry")
else:
logger.warning("Failed to store DQN model via unified registry")
except Exception as e:
logger.warning(f"Failed to store DQN model: {e}")
# 2. Store CNN model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try:
if hasattr(self.orchestrator.cnn_model, 'save'):
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
stored_models.append(('CNN', save_path))
logger.info(f"Stored CNN model: {save_path}")
success = save_model(
model=self.orchestrator.cnn_model,
model_name='cnn_model_session',
model_type='cnn',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('CNN', 'models/cnn/saved/cnn_model_session_latest.pt'))
logger.info("Stored CNN model via unified registry")
else:
logger.warning("Failed to store CNN model via unified registry")
except Exception as e:
logger.warning(f"Failed to store CNN model: {e}")
# 3. Store Transformer model
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
try:
if hasattr(self.orchestrator.primary_transformer, 'save'):
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
stored_models.append(('Transformer', save_path))
logger.info(f"Stored Transformer model: {save_path}")
success = save_model(
model=self.orchestrator.primary_transformer,
model_name='transformer_model_session',
model_type='transformer',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('Transformer', 'models/transformer/saved/transformer_model_session_latest.pt'))
logger.info("Stored Transformer model via unified registry")
else:
logger.warning("Failed to store Transformer model via unified registry")
except Exception as e:
logger.warning(f"Failed to store Transformer model: {e}")
# 4. Store COB RL model
# 4. Store COB RL model (if exists)
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
try:
# COB RL model might have different save method
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
stored_models.append(('COB RL', save_path))
logger.info(f"Stored COB RL model: {save_path}")
except Exception as e:
logger.warning(f"Failed to store COB RL model: {e}")
# 5. Store Decision Fusion model
# 5. Store Decision model
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
try:
if hasattr(self.orchestrator.decision_model, 'save'):
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
stored_models.append(('Decision Fusion', save_path))
logger.info(f"Stored Decision Fusion model: {save_path}")
success = save_model(
model=self.orchestrator.decision_model,
model_name='decision_fusion_session',
model_type='hybrid',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('Decision Fusion', 'models/hybrid/saved/decision_fusion_session_latest.pt'))
logger.info("Stored Decision Fusion model via unified registry")
else:
logger.warning("Failed to store Decision Fusion model via unified registry")
except Exception as e:
logger.warning(f"Failed to store Decision Fusion model: {e}")
@@ -6706,13 +6738,39 @@ class CleanTradingDashboard:
except Exception as e:
logger.error(f"Error saving transformer checkpoint: {e}")
# Fallback to direct save
# Use unified registry for checkpoint
try:
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
transformer_trainer.save_model(checkpoint_path)
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
except Exception as fallback_error:
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
from utils.model_registry import save_checkpoint as registry_save_checkpoint
checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data
success = registry_save_checkpoint(
model=checkpoint_data,
model_name='transformer',
model_type='transformer',
performance_score=training_metrics['accuracy'],
metadata={
'training_samples': len(training_samples),
'loss': training_metrics['total_loss'],
'accuracy': training_metrics['accuracy'],
'checkpoint_source': 'dashboard_training'
}
)
if success:
logger.info("TRANSFORMER: Checkpoint saved via unified registry")
else:
logger.warning("TRANSFORMER: Failed to save checkpoint via unified registry")
except Exception as registry_error:
logger.warning(f"Unified registry save failed: {registry_error}")
# Fallback to direct save
try:
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
transformer_trainer.save_model(checkpoint_path)
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
except Exception as fallback_error:
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")