271 lines
9.3 KiB
Python
271 lines
9.3 KiB
Python
"""
|
|
Database migration system for schema updates.
|
|
"""
|
|
|
|
from typing import List, Dict, Any
|
|
from datetime import datetime
|
|
from ..utils.logging import get_logger
|
|
from ..utils.exceptions import StorageError
|
|
from .connection_pool import db_pool
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Migration:
|
|
"""Base class for database migrations"""
|
|
|
|
def __init__(self, version: str, description: str):
|
|
self.version = version
|
|
self.description = description
|
|
|
|
async def up(self) -> None:
|
|
"""Apply the migration"""
|
|
raise NotImplementedError
|
|
|
|
async def down(self) -> None:
|
|
"""Rollback the migration"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MigrationManager:
|
|
"""Manages database schema migrations"""
|
|
|
|
def __init__(self):
|
|
self.migrations: List[Migration] = []
|
|
|
|
def register_migration(self, migration: Migration) -> None:
|
|
"""Register a migration"""
|
|
self.migrations.append(migration)
|
|
# Sort by version
|
|
self.migrations.sort(key=lambda m: m.version)
|
|
|
|
async def initialize_migration_table(self) -> None:
|
|
"""Create migration tracking table"""
|
|
query = """
|
|
CREATE TABLE IF NOT EXISTS market_data.schema_migrations (
|
|
version VARCHAR(50) PRIMARY KEY,
|
|
description TEXT NOT NULL,
|
|
applied_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
"""
|
|
|
|
await db_pool.execute_command(query)
|
|
logger.debug("Migration table initialized")
|
|
|
|
async def get_applied_migrations(self) -> List[str]:
|
|
"""Get list of applied migration versions"""
|
|
try:
|
|
query = "SELECT version FROM market_data.schema_migrations ORDER BY version"
|
|
rows = await db_pool.execute_query(query)
|
|
return [row['version'] for row in rows]
|
|
except Exception:
|
|
# Table might not exist yet
|
|
return []
|
|
|
|
async def apply_migration(self, migration: Migration) -> bool:
|
|
"""Apply a single migration"""
|
|
try:
|
|
logger.info(f"Applying migration {migration.version}: {migration.description}")
|
|
|
|
async with db_pool.get_transaction() as conn:
|
|
# Apply the migration
|
|
await migration.up()
|
|
|
|
# Record the migration
|
|
await conn.execute(
|
|
"INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)",
|
|
migration.version,
|
|
migration.description
|
|
)
|
|
|
|
logger.info(f"Migration {migration.version} applied successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply migration {migration.version}: {e}")
|
|
return False
|
|
|
|
async def rollback_migration(self, migration: Migration) -> bool:
|
|
"""Rollback a single migration"""
|
|
try:
|
|
logger.info(f"Rolling back migration {migration.version}: {migration.description}")
|
|
|
|
async with db_pool.get_transaction() as conn:
|
|
# Rollback the migration
|
|
await migration.down()
|
|
|
|
# Remove the migration record
|
|
await conn.execute(
|
|
"DELETE FROM market_data.schema_migrations WHERE version = $1",
|
|
migration.version
|
|
)
|
|
|
|
logger.info(f"Migration {migration.version} rolled back successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to rollback migration {migration.version}: {e}")
|
|
return False
|
|
|
|
async def migrate_up(self, target_version: str = None) -> bool:
|
|
"""Apply all pending migrations up to target version"""
|
|
try:
|
|
await self.initialize_migration_table()
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
pending_migrations = [
|
|
m for m in self.migrations
|
|
if m.version not in applied_migrations
|
|
]
|
|
|
|
if target_version:
|
|
pending_migrations = [
|
|
m for m in pending_migrations
|
|
if m.version <= target_version
|
|
]
|
|
|
|
if not pending_migrations:
|
|
logger.info("No pending migrations to apply")
|
|
return True
|
|
|
|
logger.info(f"Applying {len(pending_migrations)} pending migrations")
|
|
|
|
for migration in pending_migrations:
|
|
if not await self.apply_migration(migration):
|
|
return False
|
|
|
|
logger.info("All migrations applied successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Migration failed: {e}")
|
|
return False
|
|
|
|
async def migrate_down(self, target_version: str) -> bool:
|
|
"""Rollback migrations down to target version"""
|
|
try:
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
migrations_to_rollback = [
|
|
m for m in reversed(self.migrations)
|
|
if m.version in applied_migrations and m.version > target_version
|
|
]
|
|
|
|
if not migrations_to_rollback:
|
|
logger.info("No migrations to rollback")
|
|
return True
|
|
|
|
logger.info(f"Rolling back {len(migrations_to_rollback)} migrations")
|
|
|
|
for migration in migrations_to_rollback:
|
|
if not await self.rollback_migration(migration):
|
|
return False
|
|
|
|
logger.info("All migrations rolled back successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Migration rollback failed: {e}")
|
|
return False
|
|
|
|
async def get_migration_status(self) -> Dict[str, Any]:
|
|
"""Get current migration status"""
|
|
try:
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
status = {
|
|
'total_migrations': len(self.migrations),
|
|
'applied_migrations': len(applied_migrations),
|
|
'pending_migrations': len(self.migrations) - len(applied_migrations),
|
|
'current_version': applied_migrations[-1] if applied_migrations else None,
|
|
'latest_version': self.migrations[-1].version if self.migrations else None,
|
|
'migrations': []
|
|
}
|
|
|
|
for migration in self.migrations:
|
|
status['migrations'].append({
|
|
'version': migration.version,
|
|
'description': migration.description,
|
|
'applied': migration.version in applied_migrations
|
|
})
|
|
|
|
return status
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get migration status: {e}")
|
|
return {}
|
|
|
|
|
|
# Example migrations
|
|
class InitialSchemaMigration(Migration):
|
|
"""Initial schema creation migration"""
|
|
|
|
def __init__(self):
|
|
super().__init__("001", "Create initial schema and tables")
|
|
|
|
async def up(self) -> None:
|
|
"""Create initial schema"""
|
|
from .schema import DatabaseSchema
|
|
|
|
queries = DatabaseSchema.get_all_creation_queries()
|
|
for query in queries:
|
|
await db_pool.execute_command(query)
|
|
|
|
async def down(self) -> None:
|
|
"""Drop initial schema"""
|
|
# Drop tables in reverse order
|
|
tables = [
|
|
'system_metrics',
|
|
'exchange_status',
|
|
'ohlcv_data',
|
|
'heatmap_data',
|
|
'trade_events',
|
|
'order_book_snapshots'
|
|
]
|
|
|
|
for table in tables:
|
|
await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE")
|
|
|
|
|
|
class AddIndexesMigration(Migration):
|
|
"""Add performance indexes migration"""
|
|
|
|
def __init__(self):
|
|
super().__init__("002", "Add performance indexes")
|
|
|
|
async def up(self) -> None:
|
|
"""Add indexes"""
|
|
from .schema import DatabaseSchema
|
|
|
|
queries = DatabaseSchema.get_index_creation_queries()
|
|
for query in queries:
|
|
await db_pool.execute_command(query)
|
|
|
|
async def down(self) -> None:
|
|
"""Drop indexes"""
|
|
indexes = [
|
|
'idx_order_book_symbol_exchange',
|
|
'idx_order_book_timestamp',
|
|
'idx_trade_events_symbol_exchange',
|
|
'idx_trade_events_timestamp',
|
|
'idx_trade_events_price',
|
|
'idx_heatmap_symbol_bucket',
|
|
'idx_heatmap_timestamp',
|
|
'idx_ohlcv_symbol_timeframe',
|
|
'idx_ohlcv_timestamp',
|
|
'idx_exchange_status_exchange',
|
|
'idx_exchange_status_timestamp',
|
|
'idx_system_metrics_name',
|
|
'idx_system_metrics_timestamp'
|
|
]
|
|
|
|
for index in indexes:
|
|
await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}")
|
|
|
|
|
|
# Global migration manager
|
|
migration_manager = MigrationManager()
|
|
|
|
# Register default migrations
|
|
migration_manager.register_migration(InitialSchemaMigration())
|
|
migration_manager.register_migration(AddIndexesMigration()) |