#!/usr/bin/env python3 """ Migrate Existing Models to Checkpoint System This script migrates existing model files to the new checkpoint system and creates proper database metadata entries. """ import os import shutil import logging from datetime import datetime from pathlib import Path from utils.database_manager import get_database_manager, CheckpointMetadata from utils.checkpoint_manager import save_checkpoint from utils.text_logger import get_text_logger logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def migrate_existing_models(): """Migrate existing models to checkpoint system""" print("=== Migrating Existing Models to Checkpoint System ===") db_manager = get_database_manager() text_logger = get_text_logger() # Define model migrations migrations = [ { 'model_name': 'enhanced_cnn', 'model_type': 'cnn', 'source_file': 'models/enhanced_cnn/ETH_USDT_cnn.pth', 'performance_metrics': {'loss': 0.0187, 'accuracy': 0.92}, 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True} }, { 'model_name': 'dqn_agent', 'model_type': 'rl', 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_policy.pth', 'performance_metrics': {'loss': 0.0234, 'reward': 145.2}, 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'policy'} }, { 'model_name': 'dqn_agent_target', 'model_type': 'rl', 'source_file': 'models/enhanced_rl/ETH_USDT_dqn_target.pth', 'performance_metrics': {'loss': 0.0234, 'reward': 145.2}, 'training_metadata': {'symbol': 'ETH/USDT', 'migrated': True, 'type': 'target'} } ] migrated_count = 0 for migration in migrations: source_path = Path(migration['source_file']) if not source_path.exists(): logger.warning(f"Source file not found: {source_path}") continue try: # Create checkpoint directory checkpoint_dir = Path("models/checkpoints") / migration['model_name'] checkpoint_dir.mkdir(parents=True, exist_ok=True) # Create checkpoint filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_id = f"{migration['model_name']}_{timestamp}" checkpoint_file = checkpoint_dir / f"{checkpoint_id}.pt" # Copy model file to checkpoint location shutil.copy2(source_path, checkpoint_file) logger.info(f"Copied {source_path} -> {checkpoint_file}") # Calculate file size file_size_mb = checkpoint_file.stat().st_size / (1024 * 1024) # Create checkpoint metadata metadata = CheckpointMetadata( checkpoint_id=checkpoint_id, model_name=migration['model_name'], model_type=migration['model_type'], timestamp=datetime.now(), performance_metrics=migration['performance_metrics'], training_metadata=migration['training_metadata'], file_path=str(checkpoint_file), file_size_mb=file_size_mb, is_active=True ) # Save to database if db_manager.save_checkpoint_metadata(metadata): logger.info(f"Saved checkpoint metadata: {checkpoint_id}") # Log to text file text_logger.log_checkpoint_event( model_name=migration['model_name'], event_type="MIGRATED", checkpoint_id=checkpoint_id, details=f"from {source_path}, size={file_size_mb:.1f}MB" ) migrated_count += 1 else: logger.error(f"Failed to save checkpoint metadata: {checkpoint_id}") except Exception as e: logger.error(f"Failed to migrate {migration['model_name']}: {e}") print(f"\nMigration completed: {migrated_count} models migrated") # Show current checkpoint status print("\n=== Current Checkpoint Status ===") for model_name in ['dqn_agent', 'enhanced_cnn', 'dqn_agent_target']: checkpoints = db_manager.list_checkpoints(model_name) if checkpoints: print(f"{model_name}: {len(checkpoints)} checkpoints") for checkpoint in checkpoints[:2]: # Show first 2 print(f" - {checkpoint.checkpoint_id} ({checkpoint.file_size_mb:.1f}MB)") else: print(f"{model_name}: No checkpoints") def verify_checkpoint_system(): """Verify the checkpoint system is working""" print("\n=== Verifying Checkpoint System ===") db_manager = get_database_manager() # Test loading checkpoints for model_name in ['dqn_agent', 'enhanced_cnn']: metadata = db_manager.get_best_checkpoint_metadata(model_name) if metadata: file_exists = Path(metadata.file_path).exists() print(f"{model_name}: โœ… Metadata found, File exists: {file_exists}") if file_exists: print(f" -> {metadata.checkpoint_id} ({metadata.file_size_mb:.1f}MB)") else: print(f" -> ERROR: File missing: {metadata.file_path}") else: print(f"{model_name}: โŒ No checkpoint metadata found") def create_test_checkpoint(): """Create a test checkpoint to verify saving works""" print("\n=== Testing Checkpoint Saving ===") try: import torch import torch.nn as nn # Create a simple test model class TestModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) test_model = TestModel() # Save using the checkpoint system from utils.checkpoint_manager import save_checkpoint result = save_checkpoint( model=test_model, model_name="test_model", model_type="test", performance_metrics={"loss": 0.1, "accuracy": 0.95}, training_metadata={"test": True, "created": datetime.now().isoformat()} ) if result: print(f"โœ… Test checkpoint saved successfully: {result.checkpoint_id}") # Verify it exists db_manager = get_database_manager() metadata = db_manager.get_best_checkpoint_metadata("test_model") if metadata and Path(metadata.file_path).exists(): print(f"โœ… Test checkpoint verified: {metadata.file_path}") # Clean up test checkpoint Path(metadata.file_path).unlink() print("๐Ÿงน Test checkpoint cleaned up") else: print("โŒ Test checkpoint verification failed") else: print("โŒ Test checkpoint saving failed") except Exception as e: print(f"โŒ Test checkpoint creation failed: {e}") def main(): """Main migration process""" migrate_existing_models() verify_checkpoint_system() create_test_checkpoint() print("\n=== Migration Complete ===") print("The checkpoint system should now work properly!") print("Existing models have been migrated and the system is ready for use.") if __name__ == "__main__": main()