""" Database migration system for schema updates. """ from typing import List, Dict, Any from datetime import datetime from ..utils.logging import get_logger from ..utils.exceptions import StorageError from .connection_pool import db_pool logger = get_logger(__name__) class Migration: """Base class for database migrations""" def __init__(self, version: str, description: str): self.version = version self.description = description async def up(self) -> None: """Apply the migration""" raise NotImplementedError async def down(self) -> None: """Rollback the migration""" raise NotImplementedError class MigrationManager: """Manages database schema migrations""" def __init__(self): self.migrations: List[Migration] = [] def register_migration(self, migration: Migration) -> None: """Register a migration""" self.migrations.append(migration) # Sort by version self.migrations.sort(key=lambda m: m.version) async def initialize_migration_table(self) -> None: """Create migration tracking table""" query = """ CREATE TABLE IF NOT EXISTS market_data.schema_migrations ( version VARCHAR(50) PRIMARY KEY, description TEXT NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW() ); """ await db_pool.execute_command(query) logger.debug("Migration table initialized") async def get_applied_migrations(self) -> List[str]: """Get list of applied migration versions""" try: query = "SELECT version FROM market_data.schema_migrations ORDER BY version" rows = await db_pool.execute_query(query) return [row['version'] for row in rows] except Exception: # Table might not exist yet return [] async def apply_migration(self, migration: Migration) -> bool: """Apply a single migration""" try: logger.info(f"Applying migration {migration.version}: {migration.description}") async with db_pool.get_transaction() as conn: # Apply the migration await migration.up() # Record the migration await conn.execute( "INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)", migration.version, migration.description ) logger.info(f"Migration {migration.version} applied successfully") return True except Exception as e: logger.error(f"Failed to apply migration {migration.version}: {e}") return False async def rollback_migration(self, migration: Migration) -> bool: """Rollback a single migration""" try: logger.info(f"Rolling back migration {migration.version}: {migration.description}") async with db_pool.get_transaction() as conn: # Rollback the migration await migration.down() # Remove the migration record await conn.execute( "DELETE FROM market_data.schema_migrations WHERE version = $1", migration.version ) logger.info(f"Migration {migration.version} rolled back successfully") return True except Exception as e: logger.error(f"Failed to rollback migration {migration.version}: {e}") return False async def migrate_up(self, target_version: str = None) -> bool: """Apply all pending migrations up to target version""" try: await self.initialize_migration_table() applied_migrations = await self.get_applied_migrations() pending_migrations = [ m for m in self.migrations if m.version not in applied_migrations ] if target_version: pending_migrations = [ m for m in pending_migrations if m.version <= target_version ] if not pending_migrations: logger.info("No pending migrations to apply") return True logger.info(f"Applying {len(pending_migrations)} pending migrations") for migration in pending_migrations: if not await self.apply_migration(migration): return False logger.info("All migrations applied successfully") return True except Exception as e: logger.error(f"Migration failed: {e}") return False async def migrate_down(self, target_version: str) -> bool: """Rollback migrations down to target version""" try: applied_migrations = await self.get_applied_migrations() migrations_to_rollback = [ m for m in reversed(self.migrations) if m.version in applied_migrations and m.version > target_version ] if not migrations_to_rollback: logger.info("No migrations to rollback") return True logger.info(f"Rolling back {len(migrations_to_rollback)} migrations") for migration in migrations_to_rollback: if not await self.rollback_migration(migration): return False logger.info("All migrations rolled back successfully") return True except Exception as e: logger.error(f"Migration rollback failed: {e}") return False async def get_migration_status(self) -> Dict[str, Any]: """Get current migration status""" try: applied_migrations = await self.get_applied_migrations() status = { 'total_migrations': len(self.migrations), 'applied_migrations': len(applied_migrations), 'pending_migrations': len(self.migrations) - len(applied_migrations), 'current_version': applied_migrations[-1] if applied_migrations else None, 'latest_version': self.migrations[-1].version if self.migrations else None, 'migrations': [] } for migration in self.migrations: status['migrations'].append({ 'version': migration.version, 'description': migration.description, 'applied': migration.version in applied_migrations }) return status except Exception as e: logger.error(f"Failed to get migration status: {e}") return {} # Example migrations class InitialSchemaMigration(Migration): """Initial schema creation migration""" def __init__(self): super().__init__("001", "Create initial schema and tables") async def up(self) -> None: """Create initial schema""" from .schema import DatabaseSchema queries = DatabaseSchema.get_all_creation_queries() for query in queries: await db_pool.execute_command(query) async def down(self) -> None: """Drop initial schema""" # Drop tables in reverse order tables = [ 'system_metrics', 'exchange_status', 'ohlcv_data', 'heatmap_data', 'trade_events', 'order_book_snapshots' ] for table in tables: await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE") class AddIndexesMigration(Migration): """Add performance indexes migration""" def __init__(self): super().__init__("002", "Add performance indexes") async def up(self) -> None: """Add indexes""" from .schema import DatabaseSchema queries = DatabaseSchema.get_index_creation_queries() for query in queries: await db_pool.execute_command(query) async def down(self) -> None: """Drop indexes""" indexes = [ 'idx_order_book_symbol_exchange', 'idx_order_book_timestamp', 'idx_trade_events_symbol_exchange', 'idx_trade_events_timestamp', 'idx_trade_events_price', 'idx_heatmap_symbol_bucket', 'idx_heatmap_timestamp', 'idx_ohlcv_symbol_timeframe', 'idx_ohlcv_timestamp', 'idx_exchange_status_exchange', 'idx_exchange_status_timestamp', 'idx_system_metrics_name', 'idx_system_metrics_timestamp' ] for index in indexes: await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}") # Global migration manager migration_manager = MigrationManager() # Register default migrations migration_manager.register_migration(InitialSchemaMigration()) migration_manager.register_migration(AddIndexesMigration())