Files
gogo2/COBY/storage/migrations.py
2025-08-04 17:12:26 +03:00

271 lines
9.3 KiB
Python

"""
Database migration system for schema updates.
"""
from typing import List, Dict, Any
from datetime import datetime
from ..utils.logging import get_logger
from ..utils.exceptions import StorageError
from .connection_pool import db_pool
logger = get_logger(__name__)
class Migration:
"""Base class for database migrations"""
def __init__(self, version: str, description: str):
self.version = version
self.description = description
async def up(self) -> None:
"""Apply the migration"""
raise NotImplementedError
async def down(self) -> None:
"""Rollback the migration"""
raise NotImplementedError
class MigrationManager:
"""Manages database schema migrations"""
def __init__(self):
self.migrations: List[Migration] = []
def register_migration(self, migration: Migration) -> None:
"""Register a migration"""
self.migrations.append(migration)
# Sort by version
self.migrations.sort(key=lambda m: m.version)
async def initialize_migration_table(self) -> None:
"""Create migration tracking table"""
query = """
CREATE TABLE IF NOT EXISTS market_data.schema_migrations (
version VARCHAR(50) PRIMARY KEY,
description TEXT NOT NULL,
applied_at TIMESTAMPTZ DEFAULT NOW()
);
"""
await db_pool.execute_command(query)
logger.debug("Migration table initialized")
async def get_applied_migrations(self) -> List[str]:
"""Get list of applied migration versions"""
try:
query = "SELECT version FROM market_data.schema_migrations ORDER BY version"
rows = await db_pool.execute_query(query)
return [row['version'] for row in rows]
except Exception:
# Table might not exist yet
return []
async def apply_migration(self, migration: Migration) -> bool:
"""Apply a single migration"""
try:
logger.info(f"Applying migration {migration.version}: {migration.description}")
async with db_pool.get_transaction() as conn:
# Apply the migration
await migration.up()
# Record the migration
await conn.execute(
"INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)",
migration.version,
migration.description
)
logger.info(f"Migration {migration.version} applied successfully")
return True
except Exception as e:
logger.error(f"Failed to apply migration {migration.version}: {e}")
return False
async def rollback_migration(self, migration: Migration) -> bool:
"""Rollback a single migration"""
try:
logger.info(f"Rolling back migration {migration.version}: {migration.description}")
async with db_pool.get_transaction() as conn:
# Rollback the migration
await migration.down()
# Remove the migration record
await conn.execute(
"DELETE FROM market_data.schema_migrations WHERE version = $1",
migration.version
)
logger.info(f"Migration {migration.version} rolled back successfully")
return True
except Exception as e:
logger.error(f"Failed to rollback migration {migration.version}: {e}")
return False
async def migrate_up(self, target_version: str = None) -> bool:
"""Apply all pending migrations up to target version"""
try:
await self.initialize_migration_table()
applied_migrations = await self.get_applied_migrations()
pending_migrations = [
m for m in self.migrations
if m.version not in applied_migrations
]
if target_version:
pending_migrations = [
m for m in pending_migrations
if m.version <= target_version
]
if not pending_migrations:
logger.info("No pending migrations to apply")
return True
logger.info(f"Applying {len(pending_migrations)} pending migrations")
for migration in pending_migrations:
if not await self.apply_migration(migration):
return False
logger.info("All migrations applied successfully")
return True
except Exception as e:
logger.error(f"Migration failed: {e}")
return False
async def migrate_down(self, target_version: str) -> bool:
"""Rollback migrations down to target version"""
try:
applied_migrations = await self.get_applied_migrations()
migrations_to_rollback = [
m for m in reversed(self.migrations)
if m.version in applied_migrations and m.version > target_version
]
if not migrations_to_rollback:
logger.info("No migrations to rollback")
return True
logger.info(f"Rolling back {len(migrations_to_rollback)} migrations")
for migration in migrations_to_rollback:
if not await self.rollback_migration(migration):
return False
logger.info("All migrations rolled back successfully")
return True
except Exception as e:
logger.error(f"Migration rollback failed: {e}")
return False
async def get_migration_status(self) -> Dict[str, Any]:
"""Get current migration status"""
try:
applied_migrations = await self.get_applied_migrations()
status = {
'total_migrations': len(self.migrations),
'applied_migrations': len(applied_migrations),
'pending_migrations': len(self.migrations) - len(applied_migrations),
'current_version': applied_migrations[-1] if applied_migrations else None,
'latest_version': self.migrations[-1].version if self.migrations else None,
'migrations': []
}
for migration in self.migrations:
status['migrations'].append({
'version': migration.version,
'description': migration.description,
'applied': migration.version in applied_migrations
})
return status
except Exception as e:
logger.error(f"Failed to get migration status: {e}")
return {}
# Example migrations
class InitialSchemaMigration(Migration):
"""Initial schema creation migration"""
def __init__(self):
super().__init__("001", "Create initial schema and tables")
async def up(self) -> None:
"""Create initial schema"""
from .schema import DatabaseSchema
queries = DatabaseSchema.get_all_creation_queries()
for query in queries:
await db_pool.execute_command(query)
async def down(self) -> None:
"""Drop initial schema"""
# Drop tables in reverse order
tables = [
'system_metrics',
'exchange_status',
'ohlcv_data',
'heatmap_data',
'trade_events',
'order_book_snapshots'
]
for table in tables:
await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE")
class AddIndexesMigration(Migration):
"""Add performance indexes migration"""
def __init__(self):
super().__init__("002", "Add performance indexes")
async def up(self) -> None:
"""Add indexes"""
from .schema import DatabaseSchema
queries = DatabaseSchema.get_index_creation_queries()
for query in queries:
await db_pool.execute_command(query)
async def down(self) -> None:
"""Drop indexes"""
indexes = [
'idx_order_book_symbol_exchange',
'idx_order_book_timestamp',
'idx_trade_events_symbol_exchange',
'idx_trade_events_timestamp',
'idx_trade_events_price',
'idx_heatmap_symbol_bucket',
'idx_heatmap_timestamp',
'idx_ohlcv_symbol_timeframe',
'idx_ohlcv_timestamp',
'idx_exchange_status_exchange',
'idx_exchange_status_timestamp',
'idx_system_metrics_name',
'idx_system_metrics_timestamp'
]
for index in indexes:
await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}")
# Global migration manager
migration_manager = MigrationManager()
# Register default migrations
migration_manager.register_migration(InitialSchemaMigration())
migration_manager.register_migration(AddIndexesMigration())