sqlite for checkpoints, cleanup
This commit is contained in:
408
utils/database_manager.py
Normal file
408
utils/database_manager.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Database Manager for Trading System
|
||||
|
||||
Manages SQLite database for:
|
||||
1. Inference records logging
|
||||
2. Checkpoint metadata storage
|
||||
3. Model performance tracking
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceRecord:
|
||||
"""Structure for inference logging"""
|
||||
model_name: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str
|
||||
confidence: float
|
||||
probabilities: Dict[str, float]
|
||||
input_features_hash: str # Hash of input features for deduplication
|
||||
processing_time_ms: float
|
||||
memory_usage_mb: float
|
||||
checkpoint_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Structure for checkpoint metadata"""
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
timestamp: datetime
|
||||
performance_metrics: Dict[str, float]
|
||||
training_metadata: Dict[str, Any]
|
||||
file_path: str
|
||||
file_size_mb: float
|
||||
is_active: bool = False # Currently loaded checkpoint
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages SQLite database for trading system logging and metadata"""
|
||||
|
||||
def __init__(self, db_path: str = "data/trading_system.db"):
|
||||
self.db_path = db_path
|
||||
self._ensure_db_directory()
|
||||
self._initialize_database()
|
||||
|
||||
def _ensure_db_directory(self):
|
||||
"""Ensure database directory exists"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize database tables"""
|
||||
with self._get_connection() as conn:
|
||||
# Inference records table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT NOT NULL, -- JSON
|
||||
input_features_hash TEXT NOT NULL,
|
||||
processing_time_ms REAL NOT NULL,
|
||||
memory_usage_mb REAL NOT NULL,
|
||||
checkpoint_id TEXT,
|
||||
metadata TEXT, -- JSON
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Checkpoint metadata table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
checkpoint_id TEXT UNIQUE NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
performance_metrics TEXT NOT NULL, -- JSON
|
||||
training_metadata TEXT NOT NULL, -- JSON
|
||||
file_path TEXT NOT NULL,
|
||||
file_size_mb REAL NOT NULL,
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Model performance tracking table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
accuracy REAL DEFAULT 0.0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
avg_processing_time_ms REAL DEFAULT 0.0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(model_name, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for better performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
|
||||
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Get database connection with proper error handling"""
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row # Enable dict-like access
|
||||
yield conn
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def log_inference(self, record: InferenceRecord) -> bool:
|
||||
"""Log an inference record"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash, processing_time_ms,
|
||||
memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
record.symbol,
|
||||
record.action,
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
json.dumps(record.metadata) if record.metadata else None
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference record: {e}")
|
||||
return False
|
||||
|
||||
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
|
||||
"""Save checkpoint metadata"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all other checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (metadata.model_name,))
|
||||
|
||||
# Insert or replace the new checkpoint metadata
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO checkpoint_metadata (
|
||||
checkpoint_id, model_name, model_type, timestamp,
|
||||
performance_metrics, training_metadata, file_path,
|
||||
file_size_mb, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
metadata.checkpoint_id,
|
||||
metadata.model_name,
|
||||
metadata.model_type,
|
||||
metadata.timestamp.isoformat(),
|
||||
json.dumps(metadata.performance_metrics),
|
||||
json.dumps(metadata.training_metadata),
|
||||
metadata.file_path,
|
||||
metadata.file_size_mb,
|
||||
metadata.is_active
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint metadata: {e}")
|
||||
return False
|
||||
|
||||
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
|
||||
"""Get checkpoint metadata without loading the actual model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
if checkpoint_id:
|
||||
# Get specific checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
else:
|
||||
# Get active checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND is_active = TRUE
|
||||
ORDER BY timestamp DESC LIMIT 1
|
||||
""", (model_name,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
|
||||
"""Get best checkpoint metadata based on performance metric"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
|
||||
LIMIT 1
|
||||
""", (model_name, metric_name))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get best checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
|
||||
"""List all checkpoints for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
""", (model_name,))
|
||||
|
||||
checkpoints = []
|
||||
for row in cursor.fetchall():
|
||||
checkpoints.append(CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
))
|
||||
return checkpoints
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
|
||||
"""Set a checkpoint as active for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
# Set the specified checkpoint as active
|
||||
cursor = conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = TRUE
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set active checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
|
||||
"""Get recent inference records for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent inferences: {e}")
|
||||
return []
|
||||
|
||||
def update_model_performance(self, model_name: str, date: str,
|
||||
total_predictions: int, correct_predictions: int,
|
||||
avg_confidence: float, avg_processing_time: float) -> bool:
|
||||
"""Update daily model performance statistics"""
|
||||
try:
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update model performance: {e}")
|
||||
return False
|
||||
|
||||
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference records"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
DELETE FROM inference_records
|
||||
WHERE timestamp < ?
|
||||
""", (cutoff_date.isoformat(),))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Cleaned up {deleted_count} old inference records")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old records: {e}")
|
||||
return False
|
||||
|
||||
# Global database manager instance
|
||||
_db_manager_instance = None
|
||||
|
||||
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
|
||||
"""Get the global database manager instance"""
|
||||
global _db_manager_instance
|
||||
|
||||
if _db_manager_instance is None:
|
||||
_db_manager_instance = DatabaseManager(db_path)
|
||||
|
||||
return _db_manager_instance
|
||||
Reference in New Issue
Block a user