Files
gogo2/model_checkpoint_saver.py
2025-09-02 16:16:01 +03:00

247 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Model Checkpoint Saver
Utility to ensure all models can save checkpoints properly.
This will make them show as LOADED instead of FRESH.
"""
import logging
import os
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class ModelCheckpointSaver:
"""Utility to save checkpoints for all models to fix FRESH status"""
def __init__(self, orchestrator):
self.orchestrator = orchestrator
def save_all_model_checkpoints(self, force: bool = True) -> Dict[str, bool]:
"""Save checkpoints for all initialized models"""
results = {}
# Save DQN Agent
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
results['dqn_agent'] = self._save_dqn_checkpoint(force)
# Save CNN Model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
results['enhanced_cnn'] = self._save_cnn_checkpoint(force)
# Save Extrema Trainer
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
results['extrema_trainer'] = self._save_extrema_checkpoint(force)
# COB RL model removed - see COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
# Will recreate when COB data quality is improved
# Save Transformer
if hasattr(self.orchestrator, 'transformer_trainer') and self.orchestrator.transformer_trainer:
results['transformer'] = self._save_transformer_checkpoint(force)
# Save Decision Model
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
results['decision'] = self._save_decision_checkpoint(force)
return results
def _save_dqn_checkpoint(self, force: bool = True) -> bool:
"""Save DQN agent checkpoint"""
try:
if hasattr(self.orchestrator.rl_agent, 'save_checkpoint'):
success = self.orchestrator.rl_agent.save_checkpoint(force_save=force)
if success:
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
self.orchestrator.model_states['dqn']['checkpoint_filename'] = f"dqn_agent_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("DQN checkpoint saved successfully")
return True
# Fallback: use improved model saver
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.rl_agent,
"dqn_agent",
"dqn",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
self.orchestrator.model_states['dqn']['checkpoint_filename'] = "dqn_agent_latest"
logger.info("DQN checkpoint saved using fallback method")
return True
return False
except Exception as e:
logger.error(f"Failed to save DQN checkpoint: {e}")
return False
def _save_cnn_checkpoint(self, force: bool = True) -> bool:
"""Save CNN model checkpoint"""
try:
if hasattr(self.orchestrator.cnn_model, 'save_checkpoint'):
success = self.orchestrator.cnn_model.save_checkpoint(force_save=force)
if success:
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
self.orchestrator.model_states['cnn']['checkpoint_filename'] = f"enhanced_cnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("CNN checkpoint saved successfully")
return True
# Fallback: use improved model saver
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.cnn_model,
"enhanced_cnn",
"cnn",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
self.orchestrator.model_states['cnn']['checkpoint_filename'] = "enhanced_cnn_latest"
logger.info("CNN checkpoint saved using fallback method")
return True
return False
except Exception as e:
logger.error(f"Failed to save CNN checkpoint: {e}")
return False
def _save_extrema_checkpoint(self, force: bool = True) -> bool:
"""Save Extrema Trainer checkpoint"""
try:
if hasattr(self.orchestrator.extrema_trainer, 'save_checkpoint'):
self.orchestrator.extrema_trainer.save_checkpoint(force_save=force)
self.orchestrator.model_states['extrema_trainer']['checkpoint_loaded'] = True
self.orchestrator.model_states['extrema_trainer']['checkpoint_filename'] = f"extrema_trainer_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("Extrema Trainer checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Extrema Trainer checkpoint: {e}")
return False
def _save_cob_rl_checkpoint(self, force: bool = True) -> bool:
"""Save COB RL agent checkpoint"""
try:
# COB RL may have a different saving mechanism
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.cob_rl_agent,
"cob_rl",
"cob_rl",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['cob_rl']['checkpoint_loaded'] = True
self.orchestrator.model_states['cob_rl']['checkpoint_filename'] = "cob_rl_latest"
logger.info("COB RL checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save COB RL checkpoint: {e}")
return False
def _save_transformer_checkpoint(self, force: bool = True) -> bool:
"""Save Transformer model checkpoint"""
try:
if hasattr(self.orchestrator.transformer_trainer, 'save_model'):
# Create a checkpoint file path
checkpoint_dir = Path("models/saved/transformer")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / f"transformer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
self.orchestrator.transformer_trainer.save_model(str(checkpoint_path))
self.orchestrator.model_states['transformer']['checkpoint_loaded'] = True
self.orchestrator.model_states['transformer']['checkpoint_filename'] = checkpoint_path.name
logger.info("Transformer checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Transformer checkpoint: {e}")
return False
def _save_decision_checkpoint(self, force: bool = True) -> bool:
"""Save Decision model checkpoint"""
try:
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.decision_model,
"decision",
"decision",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['decision']['checkpoint_loaded'] = True
self.orchestrator.model_states['decision']['checkpoint_filename'] = "decision_latest"
logger.info("Decision model checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Decision model checkpoint: {e}")
return False
def update_model_status_to_loaded(self, model_name: str):
"""Manually update a model's status to LOADED"""
if model_name in self.orchestrator.model_states:
self.orchestrator.model_states[model_name]['checkpoint_loaded'] = True
if not self.orchestrator.model_states[model_name].get('checkpoint_filename'):
self.orchestrator.model_states[model_name]['checkpoint_filename'] = f"{model_name}_manual_loaded"
logger.info(f"Updated {model_name} status to LOADED")
def force_all_models_to_loaded(self):
"""Force all existing models to show as LOADED"""
models_updated = []
for model_name in self.orchestrator.model_states.keys():
# Check if model actually exists
model_exists = False
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
model_exists = True
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
model_exists = True
elif model_name == 'extrema_trainer' and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
model_exists = True
# COB RL model removed - focusing on COB data quality first
elif model_name == 'transformer' and hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
model_exists = True
elif model_name == 'decision' and hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
model_exists = True
if model_exists:
self.update_model_status_to_loaded(model_name)
models_updated.append(model_name)
logger.info(f"Force-updated {len(models_updated)} models to LOADED status: {models_updated}")
return models_updated
def save_all_checkpoints_now(orchestrator):
"""Convenience function to save all checkpoints"""
saver = ModelCheckpointSaver(orchestrator)
results = saver.save_all_model_checkpoints(force=True)
print("Checkpoint saving results:")
for model_name, success in results.items():
status = "✅ SUCCESS" if success else "❌ FAILED"
print(f" {model_name}: {status}")
return results