model checkpoint manager
This commit is contained in:
@@ -4710,53 +4710,85 @@ class CleanTradingDashboard:
|
||||
|
||||
stored_models = []
|
||||
|
||||
# Use unified model registry for saving
|
||||
from utils.model_registry import save_model
|
||||
|
||||
# 1. Store DQN model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
try:
|
||||
if hasattr(self.orchestrator.rl_agent, 'save'):
|
||||
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
|
||||
stored_models.append(('DQN', save_path))
|
||||
logger.info(f"Stored DQN model: {save_path}")
|
||||
success = save_model(
|
||||
model=self.orchestrator.rl_agent.policy_net, # Save policy network
|
||||
model_name='dqn_agent_session',
|
||||
model_type='dqn',
|
||||
metadata={'session_save': True, 'dashboard_save': True}
|
||||
)
|
||||
if success:
|
||||
stored_models.append(('DQN', 'models/dqn/saved/dqn_agent_session_latest.pt'))
|
||||
logger.info("Stored DQN model via unified registry")
|
||||
else:
|
||||
logger.warning("Failed to store DQN model via unified registry")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store DQN model: {e}")
|
||||
|
||||
|
||||
# 2. Store CNN model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
try:
|
||||
if hasattr(self.orchestrator.cnn_model, 'save'):
|
||||
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
|
||||
stored_models.append(('CNN', save_path))
|
||||
logger.info(f"Stored CNN model: {save_path}")
|
||||
success = save_model(
|
||||
model=self.orchestrator.cnn_model,
|
||||
model_name='cnn_model_session',
|
||||
model_type='cnn',
|
||||
metadata={'session_save': True, 'dashboard_save': True}
|
||||
)
|
||||
if success:
|
||||
stored_models.append(('CNN', 'models/cnn/saved/cnn_model_session_latest.pt'))
|
||||
logger.info("Stored CNN model via unified registry")
|
||||
else:
|
||||
logger.warning("Failed to store CNN model via unified registry")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store CNN model: {e}")
|
||||
|
||||
|
||||
# 3. Store Transformer model
|
||||
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||
try:
|
||||
if hasattr(self.orchestrator.primary_transformer, 'save'):
|
||||
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
|
||||
stored_models.append(('Transformer', save_path))
|
||||
logger.info(f"Stored Transformer model: {save_path}")
|
||||
success = save_model(
|
||||
model=self.orchestrator.primary_transformer,
|
||||
model_name='transformer_model_session',
|
||||
model_type='transformer',
|
||||
metadata={'session_save': True, 'dashboard_save': True}
|
||||
)
|
||||
if success:
|
||||
stored_models.append(('Transformer', 'models/transformer/saved/transformer_model_session_latest.pt'))
|
||||
logger.info("Stored Transformer model via unified registry")
|
||||
else:
|
||||
logger.warning("Failed to store Transformer model via unified registry")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store Transformer model: {e}")
|
||||
|
||||
# 4. Store COB RL model
|
||||
|
||||
# 4. Store COB RL model (if exists)
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
try:
|
||||
# COB RL model might have different save method
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
|
||||
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
|
||||
stored_models.append(('COB RL', save_path))
|
||||
logger.info(f"Stored COB RL model: {save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store COB RL model: {e}")
|
||||
|
||||
# 5. Store Decision Fusion model
|
||||
|
||||
# 5. Store Decision model
|
||||
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||
try:
|
||||
if hasattr(self.orchestrator.decision_model, 'save'):
|
||||
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
|
||||
stored_models.append(('Decision Fusion', save_path))
|
||||
logger.info(f"Stored Decision Fusion model: {save_path}")
|
||||
success = save_model(
|
||||
model=self.orchestrator.decision_model,
|
||||
model_name='decision_fusion_session',
|
||||
model_type='hybrid',
|
||||
metadata={'session_save': True, 'dashboard_save': True}
|
||||
)
|
||||
if success:
|
||||
stored_models.append(('Decision Fusion', 'models/hybrid/saved/decision_fusion_session_latest.pt'))
|
||||
logger.info("Stored Decision Fusion model via unified registry")
|
||||
else:
|
||||
logger.warning("Failed to store Decision Fusion model via unified registry")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store Decision Fusion model: {e}")
|
||||
|
||||
@@ -6706,13 +6738,39 @@ class CleanTradingDashboard:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transformer checkpoint: {e}")
|
||||
# Fallback to direct save
|
||||
# Use unified registry for checkpoint
|
||||
try:
|
||||
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
transformer_trainer.save_model(checkpoint_path)
|
||||
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
|
||||
from utils.model_registry import save_checkpoint as registry_save_checkpoint
|
||||
|
||||
checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data
|
||||
|
||||
success = registry_save_checkpoint(
|
||||
model=checkpoint_data,
|
||||
model_name='transformer',
|
||||
model_type='transformer',
|
||||
performance_score=training_metrics['accuracy'],
|
||||
metadata={
|
||||
'training_samples': len(training_samples),
|
||||
'loss': training_metrics['total_loss'],
|
||||
'accuracy': training_metrics['accuracy'],
|
||||
'checkpoint_source': 'dashboard_training'
|
||||
}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("TRANSFORMER: Checkpoint saved via unified registry")
|
||||
else:
|
||||
logger.warning("TRANSFORMER: Failed to save checkpoint via unified registry")
|
||||
|
||||
except Exception as registry_error:
|
||||
logger.warning(f"Unified registry save failed: {registry_error}")
|
||||
# Fallback to direct save
|
||||
try:
|
||||
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
transformer_trainer.save_model(checkpoint_path)
|
||||
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
|
||||
|
||||
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")
|
||||
|
||||
|
Reference in New Issue
Block a user