sqlite for checkpoints, cleanup

This commit is contained in:
Dobromir Popov
2025-07-25 22:34:13 +03:00
parent 130a52fb9b
commit dd9f4b63ba
42 changed files with 2017 additions and 1485 deletions

295
utils/cache_manager.py Normal file
View File

@ -0,0 +1,295 @@
"""
Cache Manager for Trading System
Utilities for managing and cleaning up cache files, including:
- Parquet file validation and repair
- Cache cleanup and maintenance
- Cache health monitoring
"""
import os
import logging
import pandas as pd
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class CacheManager:
"""Manages cache files for the trading system"""
def __init__(self, cache_dirs: List[str] = None):
"""
Initialize cache manager
Args:
cache_dirs: List of cache directories to manage
"""
self.cache_dirs = cache_dirs or [
"data/cache",
"data/monthly_cache",
"data/pivot_cache"
]
# Ensure cache directories exist
for cache_dir in self.cache_dirs:
Path(cache_dir).mkdir(parents=True, exist_ok=True)
def validate_parquet_file(self, file_path: Path) -> Tuple[bool, Optional[str]]:
"""
Validate a Parquet file
Args:
file_path: Path to the Parquet file
Returns:
Tuple of (is_valid, error_message)
"""
try:
if not file_path.exists():
return False, "File does not exist"
if file_path.stat().st_size == 0:
return False, "File is empty"
# Try to read the file
df = pd.read_parquet(file_path)
if df.empty:
return False, "File contains no data"
# Check for required columns (basic validation)
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return False, f"Missing required columns: {missing_columns}"
return True, None
except Exception as e:
error_str = str(e).lower()
corrupted_indicators = [
"parquet magic bytes not found",
"corrupted",
"couldn't deserialize thrift",
"don't know what type",
"invalid parquet file",
"unexpected end of file",
"invalid metadata"
]
if any(indicator in error_str for indicator in corrupted_indicators):
return False, f"Corrupted Parquet file: {e}"
else:
return False, f"Validation error: {e}"
def scan_cache_health(self) -> Dict[str, Dict]:
"""
Scan all cache directories for file health
Returns:
Dictionary with cache health information
"""
health_report = {}
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
dir_report = {
'total_files': 0,
'valid_files': 0,
'corrupted_files': 0,
'empty_files': 0,
'total_size_mb': 0.0,
'corrupted_files_list': [],
'old_files': []
}
# Scan all Parquet files
for file_path in cache_path.glob("*.parquet"):
dir_report['total_files'] += 1
file_size_mb = file_path.stat().st_size / (1024 * 1024)
dir_report['total_size_mb'] += file_size_mb
# Check file age
file_age = datetime.now() - datetime.fromtimestamp(file_path.stat().st_mtime)
if file_age > timedelta(days=7): # Files older than 7 days
dir_report['old_files'].append({
'file': str(file_path),
'age_days': file_age.days,
'size_mb': file_size_mb
})
# Validate file
is_valid, error_msg = self.validate_parquet_file(file_path)
if is_valid:
dir_report['valid_files'] += 1
else:
if "empty" in error_msg.lower():
dir_report['empty_files'] += 1
else:
dir_report['corrupted_files'] += 1
dir_report['corrupted_files_list'].append({
'file': str(file_path),
'error': error_msg,
'size_mb': file_size_mb
})
health_report[cache_dir] = dir_report
return health_report
def cleanup_corrupted_files(self, dry_run: bool = True) -> Dict[str, List[str]]:
"""
Clean up corrupted cache files
Args:
dry_run: If True, only report what would be deleted
Returns:
Dictionary of deleted files by directory
"""
deleted_files = {}
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
deleted_files[cache_dir] = []
for file_path in cache_path.glob("*.parquet"):
is_valid, error_msg = self.validate_parquet_file(file_path)
if not is_valid:
if dry_run:
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} ({error_msg})")
logger.info(f"Would delete corrupted file: {file_path} ({error_msg})")
else:
try:
file_path.unlink()
deleted_files[cache_dir].append(f"DELETED: {file_path} ({error_msg})")
logger.info(f"Deleted corrupted file: {file_path}")
except Exception as e:
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
logger.error(f"Failed to delete corrupted file {file_path}: {e}")
return deleted_files
def cleanup_old_files(self, days_to_keep: int = 7, dry_run: bool = True) -> Dict[str, List[str]]:
"""
Clean up old cache files
Args:
days_to_keep: Number of days to keep files
dry_run: If True, only report what would be deleted
Returns:
Dictionary of deleted files by directory
"""
deleted_files = {}
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
deleted_files[cache_dir] = []
for file_path in cache_path.glob("*.parquet"):
file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
if file_mtime < cutoff_date:
age_days = (datetime.now() - file_mtime).days
if dry_run:
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} (age: {age_days} days)")
logger.info(f"Would delete old file: {file_path} (age: {age_days} days)")
else:
try:
file_path.unlink()
deleted_files[cache_dir].append(f"DELETED: {file_path} (age: {age_days} days)")
logger.info(f"Deleted old file: {file_path}")
except Exception as e:
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
logger.error(f"Failed to delete old file {file_path}: {e}")
return deleted_files
def get_cache_summary(self) -> Dict[str, any]:
"""Get a summary of cache usage"""
health_report = self.scan_cache_health()
total_files = sum(report['total_files'] for report in health_report.values())
total_valid = sum(report['valid_files'] for report in health_report.values())
total_corrupted = sum(report['corrupted_files'] for report in health_report.values())
total_size_mb = sum(report['total_size_mb'] for report in health_report.values())
return {
'total_files': total_files,
'valid_files': total_valid,
'corrupted_files': total_corrupted,
'health_percentage': (total_valid / total_files * 100) if total_files > 0 else 0,
'total_size_mb': total_size_mb,
'directories': health_report
}
def emergency_cache_reset(self, confirm: bool = False) -> bool:
"""
Emergency cache reset - deletes all cache files
Args:
confirm: Must be True to actually delete files
Returns:
True if reset was performed
"""
if not confirm:
logger.warning("Emergency cache reset called but not confirmed")
return False
deleted_count = 0
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
for file_path in cache_path.glob("*"):
try:
if file_path.is_file():
file_path.unlink()
deleted_count += 1
except Exception as e:
logger.error(f"Failed to delete {file_path}: {e}")
logger.warning(f"Emergency cache reset completed: deleted {deleted_count} files")
return True
# Global cache manager instance
_cache_manager_instance = None
def get_cache_manager() -> CacheManager:
"""Get the global cache manager instance"""
global _cache_manager_instance
if _cache_manager_instance is None:
_cache_manager_instance = CacheManager()
return _cache_manager_instance
def cleanup_corrupted_cache(dry_run: bool = True) -> Dict[str, List[str]]:
"""Convenience function to clean up corrupted cache files"""
cache_manager = get_cache_manager()
return cache_manager.cleanup_corrupted_files(dry_run=dry_run)
def get_cache_health() -> Dict[str, any]:
"""Convenience function to get cache health summary"""
cache_manager = get_cache_manager()
return cache_manager.get_cache_summary()

View File

@ -5,6 +5,7 @@ This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
- Database-backed metadata storage for efficient access
"""
import os
@ -13,9 +14,13 @@ import glob
import logging
import shutil
import torch
import hashlib
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .database_manager import get_database_manager, CheckpointMetadata
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
# Global checkpoint manager instance
@ -46,7 +51,7 @@ def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_check
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata
Save a checkpoint with metadata to both filesystem and database
Args:
model: The model to save
@ -64,57 +69,90 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now()
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
checkpoint_id = f"{model_name}_{timestamp_str}"
# Save model
torch_path = f"{checkpoint_path}.pt"
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch_path = f"{checkpoint_path}.pt"
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp
'timestamp': timestamp_str,
'checkpoint_id': checkpoint_id
}, torch_path)
# Create metadata
checkpoint_metadata = {
# Calculate file size
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
# Save metadata to database
db_manager = get_database_manager()
checkpoint_metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
timestamp=timestamp,
performance_metrics=performance_metrics,
training_metadata=training_metadata or {},
file_path=torch_path,
file_size_mb=file_size_mb,
is_active=True # New checkpoint is active by default
)
# Save to database
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
# Log checkpoint save event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="SAVED",
checkpoint_id=checkpoint_id,
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
)
logger.info(f"Checkpoint saved: {checkpoint_id}")
else:
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
# Also save legacy JSON metadata for backward compatibility
legacy_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp,
'timestamp': timestamp_str,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': f"{model_name}_{timestamp}"
'checkpoint_id': checkpoint_id,
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
'created_at': timestamp_str
}
# Add performance score for sorting
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
checkpoint_metadata['created_at'] = timestamp
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
json.dump(legacy_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object
class CheckpointMetadata:
# Return metadata as an object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add database fields
self.checkpoint_id = checkpoint_id
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
return CheckpointMetadata(checkpoint_metadata)
return CheckpointMetadataObj(legacy_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
@ -122,7 +160,7 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Load the best checkpoint based on performance metrics using database metadata
Args:
model_name: Name of the model
@ -132,29 +170,77 @@ def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoi
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
# First try to get from database (fast metadata access)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
if not checkpoint_path:
if not checkpoint_metadata:
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
pass # Silent fallback
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert legacy metadata to object
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Add loss for compatibility
self.loss = metrics.get('loss', self.performance_score)
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
# Check if checkpoint file exists
if not os.path.exists(checkpoint_metadata.file_path):
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
return None
# Convert metadata to object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Log checkpoint load event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="LOADED",
checkpoint_id=checkpoint_metadata.checkpoint_id,
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
)
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
# Convert database metadata to object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, db_metadata: CheckpointMetadata):
self.checkpoint_id = db_metadata.checkpoint_id
self.model_name = db_metadata.model_name
self.model_type = db_metadata.model_type
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
self.performance_metrics = db_metadata.performance_metrics
self.training_metadata = db_metadata.training_metadata
self.file_path = db_metadata.file_path
self.file_size_mb = db_metadata.file_size_mb
self.is_active = db_metadata.is_active
# Backward compatibility fields
self.metrics = db_metadata.performance_metrics
self.metadata = db_metadata.training_metadata
self.created_at = self.timestamp
self.performance_score = db_metadata.performance_metrics.get('accuracy',
db_metadata.performance_metrics.get('reward', 0.0))
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
@ -254,7 +340,7 @@ class CheckpointManager:
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
logger.info(f"No checkpoints found for {model_name}")
# No more scattered "No checkpoints found" logs - handled by database system
return "", {}
# Load metadata for each checkpoint
@ -278,7 +364,7 @@ class CheckpointManager:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
logger.info(f"No valid checkpoints found for {model_name}")
# No more scattered logs - handled by database system
return "", {}
# Sort by metric (highest first)

408
utils/database_manager.py Normal file
View File

@ -0,0 +1,408 @@
"""
Database Manager for Trading System
Manages SQLite database for:
1. Inference records logging
2. Checkpoint metadata storage
3. Model performance tracking
"""
import sqlite3
import json
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from contextlib import contextmanager
from dataclasses import dataclass, asdict
logger = logging.getLogger(__name__)
@dataclass
class InferenceRecord:
"""Structure for inference logging"""
model_name: str
timestamp: datetime
symbol: str
action: str
confidence: float
probabilities: Dict[str, float]
input_features_hash: str # Hash of input features for deduplication
processing_time_ms: float
memory_usage_mb: float
checkpoint_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class CheckpointMetadata:
"""Structure for checkpoint metadata"""
checkpoint_id: str
model_name: str
model_type: str
timestamp: datetime
performance_metrics: Dict[str, float]
training_metadata: Dict[str, Any]
file_path: str
file_size_mb: float
is_active: bool = False # Currently loaded checkpoint
class DatabaseManager:
"""Manages SQLite database for trading system logging and metadata"""
def __init__(self, db_path: str = "data/trading_system.db"):
self.db_path = db_path
self._ensure_db_directory()
self._initialize_database()
def _ensure_db_directory(self):
"""Ensure database directory exists"""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
def _initialize_database(self):
"""Initialize database tables"""
with self._get_connection() as conn:
# Inference records table
conn.execute("""
CREATE TABLE IF NOT EXISTS inference_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
timestamp TEXT NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
confidence REAL NOT NULL,
probabilities TEXT NOT NULL, -- JSON
input_features_hash TEXT NOT NULL,
processing_time_ms REAL NOT NULL,
memory_usage_mb REAL NOT NULL,
checkpoint_id TEXT,
metadata TEXT, -- JSON
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Checkpoint metadata table
conn.execute("""
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
checkpoint_id TEXT UNIQUE NOT NULL,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
timestamp TEXT NOT NULL,
performance_metrics TEXT NOT NULL, -- JSON
training_metadata TEXT NOT NULL, -- JSON
file_path TEXT NOT NULL,
file_size_mb REAL NOT NULL,
is_active BOOLEAN DEFAULT FALSE,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Model performance tracking table
conn.execute("""
CREATE TABLE IF NOT EXISTS model_performance (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
date TEXT NOT NULL,
total_predictions INTEGER DEFAULT 0,
correct_predictions INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
avg_confidence REAL DEFAULT 0.0,
avg_processing_time_ms REAL DEFAULT 0.0,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
UNIQUE(model_name, date)
)
""")
# Create indexes for better performance
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
logger.info(f"Database initialized at {self.db_path}")
@contextmanager
def _get_connection(self):
"""Get database connection with proper error handling"""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=30.0)
conn.row_factory = sqlite3.Row # Enable dict-like access
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database error: {e}")
raise
finally:
if conn:
conn.close()
def log_inference(self, record: InferenceRecord) -> bool:
"""Log an inference record"""
try:
with self._get_connection() as conn:
conn.execute("""
INSERT INTO inference_records (
model_name, timestamp, symbol, action, confidence,
probabilities, input_features_hash, processing_time_ms,
memory_usage_mb, checkpoint_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.model_name,
record.timestamp.isoformat(),
record.symbol,
record.action,
record.confidence,
json.dumps(record.probabilities),
record.input_features_hash,
record.processing_time_ms,
record.memory_usage_mb,
record.checkpoint_id,
json.dumps(record.metadata) if record.metadata else None
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to log inference record: {e}")
return False
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
"""Save checkpoint metadata"""
try:
with self._get_connection() as conn:
# First, set all other checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (metadata.model_name,))
# Insert or replace the new checkpoint metadata
conn.execute("""
INSERT OR REPLACE INTO checkpoint_metadata (
checkpoint_id, model_name, model_type, timestamp,
performance_metrics, training_metadata, file_path,
file_size_mb, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
metadata.checkpoint_id,
metadata.model_name,
metadata.model_type,
metadata.timestamp.isoformat(),
json.dumps(metadata.performance_metrics),
json.dumps(metadata.training_metadata),
metadata.file_path,
metadata.file_size_mb,
metadata.is_active
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to save checkpoint metadata: {e}")
return False
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
"""Get checkpoint metadata without loading the actual model"""
try:
with self._get_connection() as conn:
if checkpoint_id:
# Get specific checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
else:
# Get active checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND is_active = TRUE
ORDER BY timestamp DESC LIMIT 1
""", (model_name,))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get checkpoint metadata: {e}")
return None
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
"""Get best checkpoint metadata based on performance metric"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
LIMIT 1
""", (model_name, metric_name))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get best checkpoint metadata: {e}")
return None
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
"""List all checkpoints for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY timestamp DESC
""", (model_name,))
checkpoints = []
for row in cursor.fetchall():
checkpoints.append(CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
))
return checkpoints
except Exception as e:
logger.error(f"Failed to list checkpoints: {e}")
return []
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
"""Set a checkpoint as active for a model"""
try:
with self._get_connection() as conn:
# First, set all checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (model_name,))
# Set the specified checkpoint as active
cursor = conn.execute("""
UPDATE checkpoint_metadata
SET is_active = TRUE
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
conn.commit()
return cursor.rowcount > 0
except Exception as e:
logger.error(f"Failed to set active checkpoint: {e}")
return False
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
"""Get recent inference records for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ?
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, limit))
records = []
for row in cursor.fetchall():
records.append(InferenceRecord(
model_name=row['model_name'],
timestamp=datetime.fromisoformat(row['timestamp']),
symbol=row['symbol'],
action=row['action'],
confidence=row['confidence'],
probabilities=json.loads(row['probabilities']),
input_features_hash=row['input_features_hash'],
processing_time_ms=row['processing_time_ms'],
memory_usage_mb=row['memory_usage_mb'],
checkpoint_id=row['checkpoint_id'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
))
return records
except Exception as e:
logger.error(f"Failed to get recent inferences: {e}")
return []
def update_model_performance(self, model_name: str, date: str,
total_predictions: int, correct_predictions: int,
avg_confidence: float, avg_processing_time: float) -> bool:
"""Update daily model performance statistics"""
try:
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
with self._get_connection() as conn:
conn.execute("""
INSERT OR REPLACE INTO model_performance (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time_ms
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to update model performance: {e}")
return False
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference records"""
try:
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
with self._get_connection() as conn:
cursor = conn.execute("""
DELETE FROM inference_records
WHERE timestamp < ?
""", (cutoff_date.isoformat(),))
deleted_count = cursor.rowcount
conn.commit()
if deleted_count > 0:
logger.info(f"Cleaned up {deleted_count} old inference records")
return True
except Exception as e:
logger.error(f"Failed to cleanup old records: {e}")
return False
# Global database manager instance
_db_manager_instance = None
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
"""Get the global database manager instance"""
global _db_manager_instance
if _db_manager_instance is None:
_db_manager_instance = DatabaseManager(db_path)
return _db_manager_instance

226
utils/inference_logger.py Normal file
View File

@ -0,0 +1,226 @@
"""
Inference Logger
Centralized logging system for model inferences with database storage
Eliminates scattered logging throughout the codebase
"""
import time
import hashlib
import logging
import psutil
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
import numpy as np
from .database_manager import get_database_manager, InferenceRecord
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
class InferenceLogger:
"""Centralized inference logging system"""
def __init__(self):
self.db_manager = get_database_manager()
self.text_logger = get_text_logger()
self._process = psutil.Process()
def log_inference(self,
model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Log a model inference with all relevant details
Args:
model_name: Name of the model making the prediction
symbol: Trading symbol
action: Predicted action (BUY/SELL/HOLD)
confidence: Confidence score (0.0 to 1.0)
probabilities: Action probabilities dict
input_features: Input features used for prediction
processing_time_ms: Time taken for inference in milliseconds
checkpoint_id: ID of the checkpoint used
metadata: Additional metadata
Returns:
bool: True if logged successfully
"""
try:
# Create feature hash for deduplication
feature_hash = self._hash_features(input_features)
# Get current memory usage
memory_usage_mb = self._get_memory_usage()
# Create inference record
record = InferenceRecord(
model_name=model_name,
timestamp=datetime.now(),
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features_hash=feature_hash,
processing_time_ms=processing_time_ms,
memory_usage_mb=memory_usage_mb,
checkpoint_id=checkpoint_id,
metadata=metadata
)
# Log to database
db_success = self.db_manager.log_inference(record)
# Log to text file
text_success = self.text_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id
)
if db_success:
# Reduced logging - no more scattered logs at runtime
pass # Database logging successful, text file provides human-readable record
else:
logger.error(f"Failed to log inference for {model_name}")
return db_success and text_success
except Exception as e:
logger.error(f"Error logging inference: {e}")
return False
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
"""Create a hash of input features for deduplication"""
try:
if isinstance(features, np.ndarray):
# Hash numpy array
return hashlib.md5(features.tobytes()).hexdigest()[:16]
elif isinstance(features, (dict, list)):
# Hash dict or list by converting to string
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
else:
# Hash string representation
return hashlib.md5(str(features).encode()).hexdigest()[:16]
except Exception:
# Fallback to timestamp-based hash
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
def _get_memory_usage(self) -> float:
"""Get current memory usage in MB"""
try:
return self._process.memory_info().rss / (1024 * 1024)
except Exception:
return 0.0
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
"""Get inference statistics for a model"""
try:
# Get recent inferences
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Filter by time window
cutoff_time = datetime.now() - timedelta(hours=hours)
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Calculate statistics
total_inferences = len(recent_inferences)
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
# Action distribution
action_counts = {}
for record in recent_inferences:
action_counts[record.action] = action_counts.get(record.action, 0) + 1
# Symbol distribution
symbol_counts = {}
for record in recent_inferences:
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
return {
'total_inferences': total_inferences,
'avg_confidence': avg_confidence,
'avg_processing_time_ms': avg_processing_time,
'action_distribution': action_counts,
'symbol_distribution': symbol_counts,
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
}
except Exception as e:
logger.error(f"Error getting model stats: {e}")
return {}
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference logs"""
return self.db_manager.cleanup_old_records(days_to_keep)
# Global inference logger instance
_inference_logger_instance = None
def get_inference_logger() -> InferenceLogger:
"""Get the global inference logger instance"""
global _inference_logger_instance
if _inference_logger_instance is None:
_inference_logger_instance = InferenceLogger()
return _inference_logger_instance
def log_model_inference(model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to log model inference
This is the main function that should be called throughout the codebase
instead of scattered logger.info() calls
"""
inference_logger = get_inference_logger()
return inference_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features=input_features,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id,
metadata=metadata
)

156
utils/text_logger.py Normal file
View File

@ -0,0 +1,156 @@
"""
Text File Logger for Trading System
Simple text file logging for tracking inference records and system events
Provides human-readable logs alongside database storage
"""
import os
import logging
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class TextLogger:
"""Simple text file logger for trading system events"""
def __init__(self, log_dir: str = "logs"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
# Create separate log files for different types of events
self.inference_log = self.log_dir / "inference_records.txt"
self.checkpoint_log = self.log_dir / "checkpoint_events.txt"
self.system_log = self.log_dir / "system_events.txt"
def log_inference(self, model_name: str, symbol: str, action: str,
confidence: float, processing_time_ms: float,
checkpoint_id: str = None) -> bool:
"""Log inference record to text file"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
checkpoint_info = f" [checkpoint: {checkpoint_id}]" if checkpoint_id else ""
log_entry = (
f"{timestamp} | {model_name:15} | {symbol:10} | "
f"{action:4} | conf={confidence:.3f} | "
f"time={processing_time_ms:6.1f}ms{checkpoint_info}\n"
)
with open(self.inference_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log inference to text file: {e}")
return False
def log_checkpoint_event(self, model_name: str, event_type: str,
checkpoint_id: str, details: str = "") -> bool:
"""Log checkpoint events (save, load, etc.)"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
details_str = f" - {details}" if details else ""
log_entry = (
f"{timestamp} | {model_name:15} | {event_type:10} | "
f"{checkpoint_id}{details_str}\n"
)
with open(self.checkpoint_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log checkpoint event to text file: {e}")
return False
def log_system_event(self, event_type: str, message: str,
component: str = "system") -> bool:
"""Log general system events"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = (
f"{timestamp} | {component:15} | {event_type:10} | {message}\n"
)
with open(self.system_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log system event to text file: {e}")
return False
def get_recent_inferences(self, lines: int = 50) -> str:
"""Get recent inference records from text file"""
try:
if not self.inference_log.exists():
return "No inference records found"
with open(self.inference_log, 'r', encoding='utf-8') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return ''.join(recent_lines)
except Exception as e:
logger.error(f"Failed to read inference log: {e}")
return f"Error reading log: {e}"
def get_recent_checkpoint_events(self, lines: int = 20) -> str:
"""Get recent checkpoint events from text file"""
try:
if not self.checkpoint_log.exists():
return "No checkpoint events found"
with open(self.checkpoint_log, 'r', encoding='utf-8') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return ''.join(recent_lines)
except Exception as e:
logger.error(f"Failed to read checkpoint log: {e}")
return f"Error reading log: {e}"
def cleanup_old_logs(self, max_lines: int = 10000) -> bool:
"""Keep only the most recent log entries"""
try:
for log_file in [self.inference_log, self.checkpoint_log, self.system_log]:
if log_file.exists():
with open(log_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) > max_lines:
# Keep only the most recent lines
recent_lines = lines[-max_lines:]
with open(log_file, 'w', encoding='utf-8') as f:
f.writelines(recent_lines)
logger.info(f"Cleaned up {log_file.name}: kept {len(recent_lines)} lines")
return True
except Exception as e:
logger.error(f"Failed to cleanup logs: {e}")
return False
# Global text logger instance
_text_logger_instance = None
def get_text_logger(log_dir: str = "logs") -> TextLogger:
"""Get the global text logger instance"""
global _text_logger_instance
if _text_logger_instance is None:
_text_logger_instance = TextLogger(log_dir)
return _text_logger_instance