247 lines
11 KiB
Python
247 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model Checkpoint Saver
|
|
|
|
Utility to ensure all models can save checkpoints properly.
|
|
This will make them show as LOADED instead of FRESH.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelCheckpointSaver:
|
|
"""Utility to save checkpoints for all models to fix FRESH status"""
|
|
|
|
def __init__(self, orchestrator):
|
|
self.orchestrator = orchestrator
|
|
|
|
def save_all_model_checkpoints(self, force: bool = True) -> Dict[str, bool]:
|
|
"""Save checkpoints for all initialized models"""
|
|
results = {}
|
|
|
|
# Save DQN Agent
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
results['dqn_agent'] = self._save_dqn_checkpoint(force)
|
|
|
|
# Save CNN Model
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
results['enhanced_cnn'] = self._save_cnn_checkpoint(force)
|
|
|
|
# Save Extrema Trainer
|
|
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
|
results['extrema_trainer'] = self._save_extrema_checkpoint(force)
|
|
|
|
# COB RL model removed - see COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
|
|
# Will recreate when COB data quality is improved
|
|
|
|
# Save Transformer
|
|
if hasattr(self.orchestrator, 'transformer_trainer') and self.orchestrator.transformer_trainer:
|
|
results['transformer'] = self._save_transformer_checkpoint(force)
|
|
|
|
# Save Decision Model
|
|
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
|
results['decision'] = self._save_decision_checkpoint(force)
|
|
|
|
return results
|
|
|
|
def _save_dqn_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save DQN agent checkpoint"""
|
|
try:
|
|
if hasattr(self.orchestrator.rl_agent, 'save_checkpoint'):
|
|
success = self.orchestrator.rl_agent.save_checkpoint(force_save=force)
|
|
if success:
|
|
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['dqn']['checkpoint_filename'] = f"dqn_agent_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
logger.info("DQN checkpoint saved successfully")
|
|
return True
|
|
|
|
# Fallback: use improved model saver
|
|
from improved_model_saver import get_improved_model_saver
|
|
saver = get_improved_model_saver()
|
|
success = saver.save_model_safely(
|
|
self.orchestrator.rl_agent,
|
|
"dqn_agent",
|
|
"dqn",
|
|
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
|
)
|
|
if success:
|
|
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['dqn']['checkpoint_filename'] = "dqn_agent_latest"
|
|
logger.info("DQN checkpoint saved using fallback method")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save DQN checkpoint: {e}")
|
|
return False
|
|
|
|
def _save_cnn_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save CNN model checkpoint"""
|
|
try:
|
|
if hasattr(self.orchestrator.cnn_model, 'save_checkpoint'):
|
|
success = self.orchestrator.cnn_model.save_checkpoint(force_save=force)
|
|
if success:
|
|
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['cnn']['checkpoint_filename'] = f"enhanced_cnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
logger.info("CNN checkpoint saved successfully")
|
|
return True
|
|
|
|
# Fallback: use improved model saver
|
|
from improved_model_saver import get_improved_model_saver
|
|
saver = get_improved_model_saver()
|
|
success = saver.save_model_safely(
|
|
self.orchestrator.cnn_model,
|
|
"enhanced_cnn",
|
|
"cnn",
|
|
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
|
)
|
|
if success:
|
|
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['cnn']['checkpoint_filename'] = "enhanced_cnn_latest"
|
|
logger.info("CNN checkpoint saved using fallback method")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save CNN checkpoint: {e}")
|
|
return False
|
|
|
|
def _save_extrema_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save Extrema Trainer checkpoint"""
|
|
try:
|
|
if hasattr(self.orchestrator.extrema_trainer, 'save_checkpoint'):
|
|
self.orchestrator.extrema_trainer.save_checkpoint(force_save=force)
|
|
self.orchestrator.model_states['extrema_trainer']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['extrema_trainer']['checkpoint_filename'] = f"extrema_trainer_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
logger.info("Extrema Trainer checkpoint saved successfully")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save Extrema Trainer checkpoint: {e}")
|
|
return False
|
|
|
|
def _save_cob_rl_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save COB RL agent checkpoint"""
|
|
try:
|
|
# COB RL may have a different saving mechanism
|
|
from improved_model_saver import get_improved_model_saver
|
|
saver = get_improved_model_saver()
|
|
success = saver.save_model_safely(
|
|
self.orchestrator.cob_rl_agent,
|
|
"cob_rl",
|
|
"cob_rl",
|
|
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
|
)
|
|
if success:
|
|
self.orchestrator.model_states['cob_rl']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['cob_rl']['checkpoint_filename'] = "cob_rl_latest"
|
|
logger.info("COB RL checkpoint saved successfully")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save COB RL checkpoint: {e}")
|
|
return False
|
|
|
|
def _save_transformer_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save Transformer model checkpoint"""
|
|
try:
|
|
if hasattr(self.orchestrator.transformer_trainer, 'save_model'):
|
|
# Create a checkpoint file path
|
|
checkpoint_dir = Path("models/saved/transformer")
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
checkpoint_path = checkpoint_dir / f"transformer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
|
|
|
self.orchestrator.transformer_trainer.save_model(str(checkpoint_path))
|
|
self.orchestrator.model_states['transformer']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['transformer']['checkpoint_filename'] = checkpoint_path.name
|
|
logger.info("Transformer checkpoint saved successfully")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save Transformer checkpoint: {e}")
|
|
return False
|
|
|
|
def _save_decision_checkpoint(self, force: bool = True) -> bool:
|
|
"""Save Decision model checkpoint"""
|
|
try:
|
|
from improved_model_saver import get_improved_model_saver
|
|
saver = get_improved_model_saver()
|
|
success = saver.save_model_safely(
|
|
self.orchestrator.decision_model,
|
|
"decision",
|
|
"decision",
|
|
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
|
)
|
|
if success:
|
|
self.orchestrator.model_states['decision']['checkpoint_loaded'] = True
|
|
self.orchestrator.model_states['decision']['checkpoint_filename'] = "decision_latest"
|
|
logger.info("Decision model checkpoint saved successfully")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save Decision model checkpoint: {e}")
|
|
return False
|
|
|
|
def update_model_status_to_loaded(self, model_name: str):
|
|
"""Manually update a model's status to LOADED"""
|
|
if model_name in self.orchestrator.model_states:
|
|
self.orchestrator.model_states[model_name]['checkpoint_loaded'] = True
|
|
if not self.orchestrator.model_states[model_name].get('checkpoint_filename'):
|
|
self.orchestrator.model_states[model_name]['checkpoint_filename'] = f"{model_name}_manual_loaded"
|
|
logger.info(f"Updated {model_name} status to LOADED")
|
|
|
|
def force_all_models_to_loaded(self):
|
|
"""Force all existing models to show as LOADED"""
|
|
models_updated = []
|
|
|
|
for model_name in self.orchestrator.model_states.keys():
|
|
# Check if model actually exists
|
|
model_exists = False
|
|
|
|
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
model_exists = True
|
|
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
model_exists = True
|
|
elif model_name == 'extrema_trainer' and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
|
model_exists = True
|
|
# COB RL model removed - focusing on COB data quality first
|
|
elif model_name == 'transformer' and hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
|
model_exists = True
|
|
elif model_name == 'decision' and hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
|
model_exists = True
|
|
|
|
if model_exists:
|
|
self.update_model_status_to_loaded(model_name)
|
|
models_updated.append(model_name)
|
|
|
|
logger.info(f"Force-updated {len(models_updated)} models to LOADED status: {models_updated}")
|
|
return models_updated
|
|
|
|
|
|
def save_all_checkpoints_now(orchestrator):
|
|
"""Convenience function to save all checkpoints"""
|
|
saver = ModelCheckpointSaver(orchestrator)
|
|
results = saver.save_all_model_checkpoints(force=True)
|
|
|
|
print("Checkpoint saving results:")
|
|
for model_name, success in results.items():
|
|
status = "✅ SUCCESS" if success else "❌ FAILED"
|
|
print(f" {model_name}: {status}")
|
|
|
|
return results
|