Files
gogo2/utils/database_manager.py
Dobromir Popov 13155197f8 inference works
2025-07-27 00:24:32 +03:00

565 lines
25 KiB
Python

"""
Database Manager for Trading System
Manages SQLite database for:
1. Inference records logging
2. Checkpoint metadata storage
3. Model performance tracking
"""
import sqlite3
import json
import logging
import os
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from contextlib import contextmanager
from dataclasses import dataclass, asdict
logger = logging.getLogger(__name__)
@dataclass
class InferenceRecord:
"""Structure for inference logging"""
model_name: str
timestamp: datetime
symbol: str
action: str
confidence: float
probabilities: Dict[str, float]
input_features_hash: str # Hash of input features for deduplication
processing_time_ms: float
memory_usage_mb: float
input_features: Optional[np.ndarray] = None # Full input features for training
checkpoint_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class CheckpointMetadata:
"""Structure for checkpoint metadata"""
checkpoint_id: str
model_name: str
model_type: str
timestamp: datetime
performance_metrics: Dict[str, float]
training_metadata: Dict[str, Any]
file_path: str
file_size_mb: float
is_active: bool = False # Currently loaded checkpoint
class DatabaseManager:
"""Manages SQLite database for trading system logging and metadata"""
def __init__(self, db_path: str = "data/trading_system.db"):
self.db_path = db_path
self._ensure_db_directory()
self._initialize_database()
def _ensure_db_directory(self):
"""Ensure database directory exists"""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
def _initialize_database(self):
"""Initialize database tables"""
with self._get_connection() as conn:
# Inference records table
conn.execute("""
CREATE TABLE IF NOT EXISTS inference_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
timestamp TEXT NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
confidence REAL NOT NULL,
probabilities TEXT NOT NULL, -- JSON
input_features_hash TEXT NOT NULL,
input_features_blob BLOB, -- Store full input features for training
processing_time_ms REAL NOT NULL,
memory_usage_mb REAL NOT NULL,
checkpoint_id TEXT,
metadata TEXT, -- JSON
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Checkpoint metadata table
conn.execute("""
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
checkpoint_id TEXT UNIQUE NOT NULL,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
timestamp TEXT NOT NULL,
performance_metrics TEXT NOT NULL, -- JSON
training_metadata TEXT NOT NULL, -- JSON
file_path TEXT NOT NULL,
file_size_mb REAL NOT NULL,
is_active BOOLEAN DEFAULT FALSE,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Model performance tracking table
conn.execute("""
CREATE TABLE IF NOT EXISTS model_performance (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
date TEXT NOT NULL,
total_predictions INTEGER DEFAULT 0,
correct_predictions INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
avg_confidence REAL DEFAULT 0.0,
avg_processing_time_ms REAL DEFAULT 0.0,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
UNIQUE(model_name, date)
)
""")
# Create indexes for better performance
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
logger.info(f"Database initialized at {self.db_path}")
# Run migrations to handle schema updates
self._run_migrations()
def _run_migrations(self):
"""Run database migrations to handle schema updates"""
try:
with self._get_connection() as conn:
# Check if input_features_blob column exists
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
if 'input_features_blob' not in columns:
logger.info("Adding input_features_blob column to inference_records table")
conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB")
conn.commit()
logger.info("Successfully added input_features_blob column")
else:
logger.debug("input_features_blob column already exists")
except Exception as e:
logger.error(f"Error running database migrations: {e}")
# If migration fails, we can still continue without the blob column
@contextmanager
def _get_connection(self):
"""Get database connection with proper error handling"""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=30.0)
conn.row_factory = sqlite3.Row # Enable dict-like access
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database error: {e}")
raise
finally:
if conn:
conn.close()
def log_inference(self, record: InferenceRecord) -> bool:
"""Log an inference record"""
try:
with self._get_connection() as conn:
# Check if input_features_blob column exists
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
has_blob_column = 'input_features_blob' in columns
# Serialize input features if provided and column exists
input_features_blob = None
if record.input_features is not None and has_blob_column:
input_features_blob = record.input_features.tobytes()
if has_blob_column:
# Use full query with blob column
conn.execute("""
INSERT INTO inference_records (
model_name, timestamp, symbol, action, confidence,
probabilities, input_features_hash, input_features_blob,
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.model_name,
record.timestamp.isoformat(),
record.symbol,
record.action,
record.confidence,
json.dumps(record.probabilities),
record.input_features_hash,
input_features_blob,
record.processing_time_ms,
record.memory_usage_mb,
record.checkpoint_id,
json.dumps(record.metadata) if record.metadata else None
))
else:
# Fallback query without blob column
logger.warning("input_features_blob column missing, storing without full features")
conn.execute("""
INSERT INTO inference_records (
model_name, timestamp, symbol, action, confidence,
probabilities, input_features_hash,
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.model_name,
record.timestamp.isoformat(),
record.symbol,
record.action,
record.confidence,
json.dumps(record.probabilities),
record.input_features_hash,
record.processing_time_ms,
record.memory_usage_mb,
record.checkpoint_id,
json.dumps(record.metadata) if record.metadata else None
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to log inference record: {e}")
return False
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
"""Save checkpoint metadata"""
try:
with self._get_connection() as conn:
# First, set all other checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (metadata.model_name,))
# Insert or replace the new checkpoint metadata
conn.execute("""
INSERT OR REPLACE INTO checkpoint_metadata (
checkpoint_id, model_name, model_type, timestamp,
performance_metrics, training_metadata, file_path,
file_size_mb, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
metadata.checkpoint_id,
metadata.model_name,
metadata.model_type,
metadata.timestamp.isoformat(),
json.dumps(metadata.performance_metrics),
json.dumps(metadata.training_metadata),
metadata.file_path,
metadata.file_size_mb,
metadata.is_active
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to save checkpoint metadata: {e}")
return False
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
"""Get checkpoint metadata without loading the actual model"""
try:
with self._get_connection() as conn:
if checkpoint_id:
# Get specific checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
else:
# Get active checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND is_active = TRUE
ORDER BY timestamp DESC LIMIT 1
""", (model_name,))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get checkpoint metadata: {e}")
return None
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
"""Get best checkpoint metadata based on performance metric"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
LIMIT 1
""", (model_name, metric_name))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get best checkpoint metadata: {e}")
return None
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
"""List all checkpoints for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY timestamp DESC
""", (model_name,))
checkpoints = []
for row in cursor.fetchall():
checkpoints.append(CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
))
return checkpoints
except Exception as e:
logger.error(f"Failed to list checkpoints: {e}")
return []
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
"""Set a checkpoint as active for a model"""
try:
with self._get_connection() as conn:
# First, set all checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (model_name,))
# Set the specified checkpoint as active
cursor = conn.execute("""
UPDATE checkpoint_metadata
SET is_active = TRUE
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
conn.commit()
return cursor.rowcount > 0
except Exception as e:
logger.error(f"Failed to set active checkpoint: {e}")
return False
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
"""Get recent inference records for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ?
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, limit))
records = []
for row in cursor.fetchall():
# Deserialize input features if available
input_features = None
# Check if the column exists in the row (handles missing column gracefully)
if 'input_features_blob' in row.keys() and row['input_features_blob']:
try:
# Reconstruct numpy array from bytes
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
except Exception as e:
logger.warning(f"Failed to deserialize input features: {e}")
records.append(InferenceRecord(
model_name=row['model_name'],
timestamp=datetime.fromisoformat(row['timestamp']),
symbol=row['symbol'],
action=row['action'],
confidence=row['confidence'],
probabilities=json.loads(row['probabilities']),
input_features_hash=row['input_features_hash'],
processing_time_ms=row['processing_time_ms'],
memory_usage_mb=row['memory_usage_mb'],
input_features=input_features,
checkpoint_id=row['checkpoint_id'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
))
return records
except Exception as e:
logger.error(f"Failed to get recent inferences: {e}")
return []
def update_model_performance(self, model_name: str, date: str,
total_predictions: int, correct_predictions: int,
avg_confidence: float, avg_processing_time: float) -> bool:
"""Update daily model performance statistics"""
try:
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
with self._get_connection() as conn:
conn.execute("""
INSERT OR REPLACE INTO model_performance (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time_ms
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to update model performance: {e}")
return False
def get_inference_records_for_training(self, model_name: str,
symbol: str = None,
hours_back: int = 24,
limit: int = 1000) -> List[InferenceRecord]:
"""
Get inference records with input features for training feedback
Args:
model_name: Name of the model
symbol: Optional symbol filter
hours_back: How many hours back to look
limit: Maximum number of records
Returns:
List of InferenceRecord with input_features populated
"""
try:
cutoff_time = datetime.now() - timedelta(hours=hours_back)
with self._get_connection() as conn:
# Check if input_features_blob column exists before querying
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
has_blob_column = 'input_features_blob' in columns
if not has_blob_column:
logger.warning("input_features_blob column not found, returning empty list")
return []
if symbol:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ? AND symbol = ? AND timestamp >= ?
AND input_features_blob IS NOT NULL
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, symbol, cutoff_time.isoformat(), limit))
else:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ? AND timestamp >= ?
AND input_features_blob IS NOT NULL
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, cutoff_time.isoformat(), limit))
records = []
for row in cursor.fetchall():
# Deserialize input features
input_features = None
if row['input_features_blob']:
try:
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
except Exception as e:
logger.warning(f"Failed to deserialize input features: {e}")
continue # Skip records with corrupted features
records.append(InferenceRecord(
model_name=row['model_name'],
timestamp=datetime.fromisoformat(row['timestamp']),
symbol=row['symbol'],
action=row['action'],
confidence=row['confidence'],
probabilities=json.loads(row['probabilities']),
input_features_hash=row['input_features_hash'],
processing_time_ms=row['processing_time_ms'],
memory_usage_mb=row['memory_usage_mb'],
input_features=input_features,
checkpoint_id=row['checkpoint_id'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
))
return records
except Exception as e:
logger.error(f"Failed to get inference records for training: {e}")
return []
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference records"""
try:
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
with self._get_connection() as conn:
cursor = conn.execute("""
DELETE FROM inference_records
WHERE timestamp < ?
""", (cutoff_date.isoformat(),))
deleted_count = cursor.rowcount
conn.commit()
if deleted_count > 0:
logger.info(f"Cleaned up {deleted_count} old inference records")
return True
except Exception as e:
logger.error(f"Failed to cleanup old records: {e}")
return False
# Global database manager instance
_db_manager_instance = None
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
"""Get the global database manager instance"""
global _db_manager_instance
if _db_manager_instance is None:
_db_manager_instance = DatabaseManager(db_path)
return _db_manager_instance
def reset_database_manager():
"""Reset the database manager instance to force re-initialization"""
global _db_manager_instance
_db_manager_instance = None
logger.info("Database manager instance reset - will re-initialize on next access")