565 lines
25 KiB
Python
565 lines
25 KiB
Python
"""
|
|
Database Manager for Trading System
|
|
|
|
Manages SQLite database for:
|
|
1. Inference records logging
|
|
2. Checkpoint metadata storage
|
|
3. Model performance tracking
|
|
"""
|
|
|
|
import sqlite3
|
|
import json
|
|
import logging
|
|
import os
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, asdict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class InferenceRecord:
|
|
"""Structure for inference logging"""
|
|
model_name: str
|
|
timestamp: datetime
|
|
symbol: str
|
|
action: str
|
|
confidence: float
|
|
probabilities: Dict[str, float]
|
|
input_features_hash: str # Hash of input features for deduplication
|
|
processing_time_ms: float
|
|
memory_usage_mb: float
|
|
input_features: Optional[np.ndarray] = None # Full input features for training
|
|
checkpoint_id: Optional[str] = None
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
@dataclass
|
|
class CheckpointMetadata:
|
|
"""Structure for checkpoint metadata"""
|
|
checkpoint_id: str
|
|
model_name: str
|
|
model_type: str
|
|
timestamp: datetime
|
|
performance_metrics: Dict[str, float]
|
|
training_metadata: Dict[str, Any]
|
|
file_path: str
|
|
file_size_mb: float
|
|
is_active: bool = False # Currently loaded checkpoint
|
|
|
|
class DatabaseManager:
|
|
"""Manages SQLite database for trading system logging and metadata"""
|
|
|
|
def __init__(self, db_path: str = "data/trading_system.db"):
|
|
self.db_path = db_path
|
|
self._ensure_db_directory()
|
|
self._initialize_database()
|
|
|
|
def _ensure_db_directory(self):
|
|
"""Ensure database directory exists"""
|
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
|
|
|
def _initialize_database(self):
|
|
"""Initialize database tables"""
|
|
with self._get_connection() as conn:
|
|
# Inference records table
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS inference_records (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
model_name TEXT NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
action TEXT NOT NULL,
|
|
confidence REAL NOT NULL,
|
|
probabilities TEXT NOT NULL, -- JSON
|
|
input_features_hash TEXT NOT NULL,
|
|
input_features_blob BLOB, -- Store full input features for training
|
|
processing_time_ms REAL NOT NULL,
|
|
memory_usage_mb REAL NOT NULL,
|
|
checkpoint_id TEXT,
|
|
metadata TEXT, -- JSON
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Checkpoint metadata table
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
checkpoint_id TEXT UNIQUE NOT NULL,
|
|
model_name TEXT NOT NULL,
|
|
model_type TEXT NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
performance_metrics TEXT NOT NULL, -- JSON
|
|
training_metadata TEXT NOT NULL, -- JSON
|
|
file_path TEXT NOT NULL,
|
|
file_size_mb REAL NOT NULL,
|
|
is_active BOOLEAN DEFAULT FALSE,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Model performance tracking table
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS model_performance (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
model_name TEXT NOT NULL,
|
|
date TEXT NOT NULL,
|
|
total_predictions INTEGER DEFAULT 0,
|
|
correct_predictions INTEGER DEFAULT 0,
|
|
accuracy REAL DEFAULT 0.0,
|
|
avg_confidence REAL DEFAULT 0.0,
|
|
avg_processing_time_ms REAL DEFAULT 0.0,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(model_name, date)
|
|
)
|
|
""")
|
|
|
|
# Create indexes for better performance
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
|
|
|
|
logger.info(f"Database initialized at {self.db_path}")
|
|
|
|
# Run migrations to handle schema updates
|
|
self._run_migrations()
|
|
|
|
def _run_migrations(self):
|
|
"""Run database migrations to handle schema updates"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
# Check if input_features_blob column exists
|
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
|
|
if 'input_features_blob' not in columns:
|
|
logger.info("Adding input_features_blob column to inference_records table")
|
|
conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB")
|
|
conn.commit()
|
|
logger.info("Successfully added input_features_blob column")
|
|
else:
|
|
logger.debug("input_features_blob column already exists")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running database migrations: {e}")
|
|
# If migration fails, we can still continue without the blob column
|
|
|
|
@contextmanager
|
|
def _get_connection(self):
|
|
"""Get database connection with proper error handling"""
|
|
conn = None
|
|
try:
|
|
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
|
conn.row_factory = sqlite3.Row # Enable dict-like access
|
|
yield conn
|
|
except Exception as e:
|
|
if conn:
|
|
conn.rollback()
|
|
logger.error(f"Database error: {e}")
|
|
raise
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
def log_inference(self, record: InferenceRecord) -> bool:
|
|
"""Log an inference record"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
# Check if input_features_blob column exists
|
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
has_blob_column = 'input_features_blob' in columns
|
|
|
|
# Serialize input features if provided and column exists
|
|
input_features_blob = None
|
|
if record.input_features is not None and has_blob_column:
|
|
input_features_blob = record.input_features.tobytes()
|
|
|
|
if has_blob_column:
|
|
# Use full query with blob column
|
|
conn.execute("""
|
|
INSERT INTO inference_records (
|
|
model_name, timestamp, symbol, action, confidence,
|
|
probabilities, input_features_hash, input_features_blob,
|
|
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
record.model_name,
|
|
record.timestamp.isoformat(),
|
|
record.symbol,
|
|
record.action,
|
|
record.confidence,
|
|
json.dumps(record.probabilities),
|
|
record.input_features_hash,
|
|
input_features_blob,
|
|
record.processing_time_ms,
|
|
record.memory_usage_mb,
|
|
record.checkpoint_id,
|
|
json.dumps(record.metadata) if record.metadata else None
|
|
))
|
|
else:
|
|
# Fallback query without blob column
|
|
logger.warning("input_features_blob column missing, storing without full features")
|
|
conn.execute("""
|
|
INSERT INTO inference_records (
|
|
model_name, timestamp, symbol, action, confidence,
|
|
probabilities, input_features_hash,
|
|
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
record.model_name,
|
|
record.timestamp.isoformat(),
|
|
record.symbol,
|
|
record.action,
|
|
record.confidence,
|
|
json.dumps(record.probabilities),
|
|
record.input_features_hash,
|
|
record.processing_time_ms,
|
|
record.memory_usage_mb,
|
|
record.checkpoint_id,
|
|
json.dumps(record.metadata) if record.metadata else None
|
|
))
|
|
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to log inference record: {e}")
|
|
return False
|
|
|
|
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
|
|
"""Save checkpoint metadata"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
# First, set all other checkpoints for this model as inactive
|
|
conn.execute("""
|
|
UPDATE checkpoint_metadata
|
|
SET is_active = FALSE
|
|
WHERE model_name = ?
|
|
""", (metadata.model_name,))
|
|
|
|
# Insert or replace the new checkpoint metadata
|
|
conn.execute("""
|
|
INSERT OR REPLACE INTO checkpoint_metadata (
|
|
checkpoint_id, model_name, model_type, timestamp,
|
|
performance_metrics, training_metadata, file_path,
|
|
file_size_mb, is_active
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
metadata.checkpoint_id,
|
|
metadata.model_name,
|
|
metadata.model_type,
|
|
metadata.timestamp.isoformat(),
|
|
json.dumps(metadata.performance_metrics),
|
|
json.dumps(metadata.training_metadata),
|
|
metadata.file_path,
|
|
metadata.file_size_mb,
|
|
metadata.is_active
|
|
))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint metadata: {e}")
|
|
return False
|
|
|
|
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
|
|
"""Get checkpoint metadata without loading the actual model"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
if checkpoint_id:
|
|
# Get specific checkpoint
|
|
cursor = conn.execute("""
|
|
SELECT * FROM checkpoint_metadata
|
|
WHERE model_name = ? AND checkpoint_id = ?
|
|
""", (model_name, checkpoint_id))
|
|
else:
|
|
# Get active checkpoint
|
|
cursor = conn.execute("""
|
|
SELECT * FROM checkpoint_metadata
|
|
WHERE model_name = ? AND is_active = TRUE
|
|
ORDER BY timestamp DESC LIMIT 1
|
|
""", (model_name,))
|
|
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return CheckpointMetadata(
|
|
checkpoint_id=row['checkpoint_id'],
|
|
model_name=row['model_name'],
|
|
model_type=row['model_type'],
|
|
timestamp=datetime.fromisoformat(row['timestamp']),
|
|
performance_metrics=json.loads(row['performance_metrics']),
|
|
training_metadata=json.loads(row['training_metadata']),
|
|
file_path=row['file_path'],
|
|
file_size_mb=row['file_size_mb'],
|
|
is_active=bool(row['is_active'])
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to get checkpoint metadata: {e}")
|
|
return None
|
|
|
|
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
|
|
"""Get best checkpoint metadata based on performance metric"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
cursor = conn.execute("""
|
|
SELECT * FROM checkpoint_metadata
|
|
WHERE model_name = ?
|
|
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
|
|
LIMIT 1
|
|
""", (model_name, metric_name))
|
|
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return CheckpointMetadata(
|
|
checkpoint_id=row['checkpoint_id'],
|
|
model_name=row['model_name'],
|
|
model_type=row['model_type'],
|
|
timestamp=datetime.fromisoformat(row['timestamp']),
|
|
performance_metrics=json.loads(row['performance_metrics']),
|
|
training_metadata=json.loads(row['training_metadata']),
|
|
file_path=row['file_path'],
|
|
file_size_mb=row['file_size_mb'],
|
|
is_active=bool(row['is_active'])
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to get best checkpoint metadata: {e}")
|
|
return None
|
|
|
|
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
|
|
"""List all checkpoints for a model"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
cursor = conn.execute("""
|
|
SELECT * FROM checkpoint_metadata
|
|
WHERE model_name = ?
|
|
ORDER BY timestamp DESC
|
|
""", (model_name,))
|
|
|
|
checkpoints = []
|
|
for row in cursor.fetchall():
|
|
checkpoints.append(CheckpointMetadata(
|
|
checkpoint_id=row['checkpoint_id'],
|
|
model_name=row['model_name'],
|
|
model_type=row['model_type'],
|
|
timestamp=datetime.fromisoformat(row['timestamp']),
|
|
performance_metrics=json.loads(row['performance_metrics']),
|
|
training_metadata=json.loads(row['training_metadata']),
|
|
file_path=row['file_path'],
|
|
file_size_mb=row['file_size_mb'],
|
|
is_active=bool(row['is_active'])
|
|
))
|
|
return checkpoints
|
|
except Exception as e:
|
|
logger.error(f"Failed to list checkpoints: {e}")
|
|
return []
|
|
|
|
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
|
|
"""Set a checkpoint as active for a model"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
# First, set all checkpoints for this model as inactive
|
|
conn.execute("""
|
|
UPDATE checkpoint_metadata
|
|
SET is_active = FALSE
|
|
WHERE model_name = ?
|
|
""", (model_name,))
|
|
|
|
# Set the specified checkpoint as active
|
|
cursor = conn.execute("""
|
|
UPDATE checkpoint_metadata
|
|
SET is_active = TRUE
|
|
WHERE model_name = ? AND checkpoint_id = ?
|
|
""", (model_name, checkpoint_id))
|
|
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
except Exception as e:
|
|
logger.error(f"Failed to set active checkpoint: {e}")
|
|
return False
|
|
|
|
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
|
|
"""Get recent inference records for a model"""
|
|
try:
|
|
with self._get_connection() as conn:
|
|
cursor = conn.execute("""
|
|
SELECT * FROM inference_records
|
|
WHERE model_name = ?
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""", (model_name, limit))
|
|
|
|
records = []
|
|
for row in cursor.fetchall():
|
|
# Deserialize input features if available
|
|
input_features = None
|
|
# Check if the column exists in the row (handles missing column gracefully)
|
|
if 'input_features_blob' in row.keys() and row['input_features_blob']:
|
|
try:
|
|
# Reconstruct numpy array from bytes
|
|
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to deserialize input features: {e}")
|
|
|
|
records.append(InferenceRecord(
|
|
model_name=row['model_name'],
|
|
timestamp=datetime.fromisoformat(row['timestamp']),
|
|
symbol=row['symbol'],
|
|
action=row['action'],
|
|
confidence=row['confidence'],
|
|
probabilities=json.loads(row['probabilities']),
|
|
input_features_hash=row['input_features_hash'],
|
|
processing_time_ms=row['processing_time_ms'],
|
|
memory_usage_mb=row['memory_usage_mb'],
|
|
input_features=input_features,
|
|
checkpoint_id=row['checkpoint_id'],
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
|
))
|
|
return records
|
|
except Exception as e:
|
|
logger.error(f"Failed to get recent inferences: {e}")
|
|
return []
|
|
|
|
def update_model_performance(self, model_name: str, date: str,
|
|
total_predictions: int, correct_predictions: int,
|
|
avg_confidence: float, avg_processing_time: float) -> bool:
|
|
"""Update daily model performance statistics"""
|
|
try:
|
|
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
with self._get_connection() as conn:
|
|
conn.execute("""
|
|
INSERT OR REPLACE INTO model_performance (
|
|
model_name, date, total_predictions, correct_predictions,
|
|
accuracy, avg_confidence, avg_processing_time_ms
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
model_name, date, total_predictions, correct_predictions,
|
|
accuracy, avg_confidence, avg_processing_time
|
|
))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to update model performance: {e}")
|
|
return False
|
|
|
|
def get_inference_records_for_training(self, model_name: str,
|
|
symbol: str = None,
|
|
hours_back: int = 24,
|
|
limit: int = 1000) -> List[InferenceRecord]:
|
|
"""
|
|
Get inference records with input features for training feedback
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
symbol: Optional symbol filter
|
|
hours_back: How many hours back to look
|
|
limit: Maximum number of records
|
|
|
|
Returns:
|
|
List of InferenceRecord with input_features populated
|
|
"""
|
|
try:
|
|
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
|
|
|
with self._get_connection() as conn:
|
|
# Check if input_features_blob column exists before querying
|
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
has_blob_column = 'input_features_blob' in columns
|
|
|
|
if not has_blob_column:
|
|
logger.warning("input_features_blob column not found, returning empty list")
|
|
return []
|
|
|
|
if symbol:
|
|
cursor = conn.execute("""
|
|
SELECT * FROM inference_records
|
|
WHERE model_name = ? AND symbol = ? AND timestamp >= ?
|
|
AND input_features_blob IS NOT NULL
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""", (model_name, symbol, cutoff_time.isoformat(), limit))
|
|
else:
|
|
cursor = conn.execute("""
|
|
SELECT * FROM inference_records
|
|
WHERE model_name = ? AND timestamp >= ?
|
|
AND input_features_blob IS NOT NULL
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""", (model_name, cutoff_time.isoformat(), limit))
|
|
|
|
records = []
|
|
for row in cursor.fetchall():
|
|
# Deserialize input features
|
|
input_features = None
|
|
if row['input_features_blob']:
|
|
try:
|
|
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to deserialize input features: {e}")
|
|
continue # Skip records with corrupted features
|
|
|
|
records.append(InferenceRecord(
|
|
model_name=row['model_name'],
|
|
timestamp=datetime.fromisoformat(row['timestamp']),
|
|
symbol=row['symbol'],
|
|
action=row['action'],
|
|
confidence=row['confidence'],
|
|
probabilities=json.loads(row['probabilities']),
|
|
input_features_hash=row['input_features_hash'],
|
|
processing_time_ms=row['processing_time_ms'],
|
|
memory_usage_mb=row['memory_usage_mb'],
|
|
input_features=input_features,
|
|
checkpoint_id=row['checkpoint_id'],
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
|
))
|
|
|
|
return records
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get inference records for training: {e}")
|
|
return []
|
|
|
|
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
|
"""Clean up old inference records"""
|
|
try:
|
|
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
|
|
|
with self._get_connection() as conn:
|
|
cursor = conn.execute("""
|
|
DELETE FROM inference_records
|
|
WHERE timestamp < ?
|
|
""", (cutoff_date.isoformat(),))
|
|
|
|
deleted_count = cursor.rowcount
|
|
conn.commit()
|
|
|
|
if deleted_count > 0:
|
|
logger.info(f"Cleaned up {deleted_count} old inference records")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup old records: {e}")
|
|
return False
|
|
|
|
# Global database manager instance
|
|
_db_manager_instance = None
|
|
|
|
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
|
|
"""Get the global database manager instance"""
|
|
global _db_manager_instance
|
|
|
|
if _db_manager_instance is None:
|
|
_db_manager_instance = DatabaseManager(db_path)
|
|
|
|
return _db_manager_instance
|
|
|
|
def reset_database_manager():
|
|
"""Reset the database manager instance to force re-initialization"""
|
|
global _db_manager_instance
|
|
_db_manager_instance = None
|
|
logger.info("Database manager instance reset - will re-initialize on next access") |