sqlite for checkpoints, cleanup
This commit is contained in:
295
utils/cache_manager.py
Normal file
295
utils/cache_manager.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Cache Manager for Trading System
|
||||
|
||||
Utilities for managing and cleaning up cache files, including:
|
||||
- Parquet file validation and repair
|
||||
- Cache cleanup and maintenance
|
||||
- Cache health monitoring
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheManager:
|
||||
"""Manages cache files for the trading system"""
|
||||
|
||||
def __init__(self, cache_dirs: List[str] = None):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
cache_dirs: List of cache directories to manage
|
||||
"""
|
||||
self.cache_dirs = cache_dirs or [
|
||||
"data/cache",
|
||||
"data/monthly_cache",
|
||||
"data/pivot_cache"
|
||||
]
|
||||
|
||||
# Ensure cache directories exist
|
||||
for cache_dir in self.cache_dirs:
|
||||
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def validate_parquet_file(self, file_path: Path) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a Parquet file
|
||||
|
||||
Args:
|
||||
file_path: Path to the Parquet file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
if not file_path.exists():
|
||||
return False, "File does not exist"
|
||||
|
||||
if file_path.stat().st_size == 0:
|
||||
return False, "File is empty"
|
||||
|
||||
# Try to read the file
|
||||
df = pd.read_parquet(file_path)
|
||||
|
||||
if df.empty:
|
||||
return False, "File contains no data"
|
||||
|
||||
# Check for required columns (basic validation)
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
|
||||
if missing_columns:
|
||||
return False, f"Missing required columns: {missing_columns}"
|
||||
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
return False, f"Corrupted Parquet file: {e}"
|
||||
else:
|
||||
return False, f"Validation error: {e}"
|
||||
|
||||
def scan_cache_health(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Scan all cache directories for file health
|
||||
|
||||
Returns:
|
||||
Dictionary with cache health information
|
||||
"""
|
||||
health_report = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
dir_report = {
|
||||
'total_files': 0,
|
||||
'valid_files': 0,
|
||||
'corrupted_files': 0,
|
||||
'empty_files': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'corrupted_files_list': [],
|
||||
'old_files': []
|
||||
}
|
||||
|
||||
# Scan all Parquet files
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
dir_report['total_files'] += 1
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
dir_report['total_size_mb'] += file_size_mb
|
||||
|
||||
# Check file age
|
||||
file_age = datetime.now() - datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_age > timedelta(days=7): # Files older than 7 days
|
||||
dir_report['old_files'].append({
|
||||
'file': str(file_path),
|
||||
'age_days': file_age.days,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
# Validate file
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if is_valid:
|
||||
dir_report['valid_files'] += 1
|
||||
else:
|
||||
if "empty" in error_msg.lower():
|
||||
dir_report['empty_files'] += 1
|
||||
else:
|
||||
dir_report['corrupted_files'] += 1
|
||||
dir_report['corrupted_files_list'].append({
|
||||
'file': str(file_path),
|
||||
'error': error_msg,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
health_report[cache_dir] = dir_report
|
||||
|
||||
return health_report
|
||||
|
||||
def cleanup_corrupted_files(self, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up corrupted cache files
|
||||
|
||||
Args:
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if not is_valid:
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} ({error_msg})")
|
||||
logger.info(f"Would delete corrupted file: {file_path} ({error_msg})")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} ({error_msg})")
|
||||
logger.info(f"Deleted corrupted file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete corrupted file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def cleanup_old_files(self, days_to_keep: int = 7, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up old cache files
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days to keep files
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
if file_mtime < cutoff_date:
|
||||
age_days = (datetime.now() - file_mtime).days
|
||||
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Would delete old file: {file_path} (age: {age_days} days)")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Deleted old file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete old file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def get_cache_summary(self) -> Dict[str, any]:
|
||||
"""Get a summary of cache usage"""
|
||||
health_report = self.scan_cache_health()
|
||||
|
||||
total_files = sum(report['total_files'] for report in health_report.values())
|
||||
total_valid = sum(report['valid_files'] for report in health_report.values())
|
||||
total_corrupted = sum(report['corrupted_files'] for report in health_report.values())
|
||||
total_size_mb = sum(report['total_size_mb'] for report in health_report.values())
|
||||
|
||||
return {
|
||||
'total_files': total_files,
|
||||
'valid_files': total_valid,
|
||||
'corrupted_files': total_corrupted,
|
||||
'health_percentage': (total_valid / total_files * 100) if total_files > 0 else 0,
|
||||
'total_size_mb': total_size_mb,
|
||||
'directories': health_report
|
||||
}
|
||||
|
||||
def emergency_cache_reset(self, confirm: bool = False) -> bool:
|
||||
"""
|
||||
Emergency cache reset - deletes all cache files
|
||||
|
||||
Args:
|
||||
confirm: Must be True to actually delete files
|
||||
|
||||
Returns:
|
||||
True if reset was performed
|
||||
"""
|
||||
if not confirm:
|
||||
logger.warning("Emergency cache reset called but not confirmed")
|
||||
return False
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
for file_path in cache_path.glob("*"):
|
||||
try:
|
||||
if file_path.is_file():
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
logger.warning(f"Emergency cache reset completed: deleted {deleted_count} files")
|
||||
return True
|
||||
|
||||
# Global cache manager instance
|
||||
_cache_manager_instance = None
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
"""Get the global cache manager instance"""
|
||||
global _cache_manager_instance
|
||||
|
||||
if _cache_manager_instance is None:
|
||||
_cache_manager_instance = CacheManager()
|
||||
|
||||
return _cache_manager_instance
|
||||
|
||||
def cleanup_corrupted_cache(dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""Convenience function to clean up corrupted cache files"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.cleanup_corrupted_files(dry_run=dry_run)
|
||||
|
||||
def get_cache_health() -> Dict[str, any]:
|
||||
"""Convenience function to get cache health summary"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.get_cache_summary()
|
@ -5,6 +5,7 @@ This module provides functionality for managing model checkpoints, including:
|
||||
- Saving checkpoints with metadata
|
||||
- Loading the best checkpoint based on performance metrics
|
||||
- Cleaning up old or underperforming checkpoints
|
||||
- Database-backed metadata storage for efficient access
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -13,9 +14,13 @@ import glob
|
||||
import logging
|
||||
import shutil
|
||||
import torch
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .database_manager import get_database_manager, CheckpointMetadata
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpoint manager instance
|
||||
@ -46,7 +51,7 @@ def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_check
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
Save a checkpoint with metadata to both filesystem and database
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
@ -64,57 +69,90 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
timestamp = datetime.now()
|
||||
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint path
|
||||
model_dir = os.path.join(checkpoint_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
|
||||
checkpoint_id = f"{model_name}_{timestamp_str}"
|
||||
|
||||
# Save model
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
if hasattr(model, 'save'):
|
||||
# Use model's save method if available
|
||||
model.save(checkpoint_path)
|
||||
else:
|
||||
# Otherwise, save state_dict
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp
|
||||
'timestamp': timestamp_str,
|
||||
'checkpoint_id': checkpoint_id
|
||||
}, torch_path)
|
||||
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
# Calculate file size
|
||||
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
|
||||
|
||||
# Save metadata to database
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
timestamp=timestamp,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata or {},
|
||||
file_path=torch_path,
|
||||
file_size_mb=file_size_mb,
|
||||
is_active=True # New checkpoint is active by default
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
|
||||
# Log checkpoint save event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="SAVED",
|
||||
checkpoint_id=checkpoint_id,
|
||||
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
|
||||
)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
|
||||
|
||||
# Also save legacy JSON metadata for backward compatibility
|
||||
legacy_metadata = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp,
|
||||
'timestamp': timestamp_str,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'checkpoint_id': f"{model_name}_{timestamp}"
|
||||
'checkpoint_id': checkpoint_id,
|
||||
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
|
||||
'created_at': timestamp_str
|
||||
}
|
||||
|
||||
# Add performance score for sorting
|
||||
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
|
||||
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
|
||||
checkpoint_metadata['created_at'] = timestamp
|
||||
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
json.dump(legacy_metadata, f, indent=2)
|
||||
|
||||
# Get checkpoint manager and clean up old checkpoints
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_manager._cleanup_checkpoints(model_name)
|
||||
|
||||
# Return metadata as an object
|
||||
class CheckpointMetadata:
|
||||
# Return metadata as an object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
# Add database fields
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
|
||||
|
||||
return CheckpointMetadata(checkpoint_metadata)
|
||||
return CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
@ -122,7 +160,7 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
|
||||
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
Load the best checkpoint based on performance metrics using database metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
@ -132,29 +170,77 @@ def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoi
|
||||
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
||||
"""
|
||||
try:
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
# First try to get from database (fast metadata access)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
|
||||
|
||||
if not checkpoint_path:
|
||||
if not checkpoint_metadata:
|
||||
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
|
||||
pass # Silent fallback
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Convert legacy metadata to object
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
|
||||
# Add loss for compatibility
|
||||
self.loss = metrics.get('loss', self.performance_score)
|
||||
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
# Check if checkpoint file exists
|
||||
if not os.path.exists(checkpoint_metadata.file_path):
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
|
||||
return None
|
||||
|
||||
# Convert metadata to object
|
||||
class CheckpointMetadata:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
# Log checkpoint load event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="LOADED",
|
||||
checkpoint_id=checkpoint_metadata.checkpoint_id,
|
||||
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
|
||||
)
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
|
||||
# Convert database metadata to object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, db_metadata: CheckpointMetadata):
|
||||
self.checkpoint_id = db_metadata.checkpoint_id
|
||||
self.model_name = db_metadata.model_name
|
||||
self.model_type = db_metadata.model_type
|
||||
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
self.performance_metrics = db_metadata.performance_metrics
|
||||
self.training_metadata = db_metadata.training_metadata
|
||||
self.file_path = db_metadata.file_path
|
||||
self.file_size_mb = db_metadata.file_size_mb
|
||||
self.is_active = db_metadata.is_active
|
||||
|
||||
# Backward compatibility fields
|
||||
self.metrics = db_metadata.performance_metrics
|
||||
self.metadata = db_metadata.training_metadata
|
||||
self.created_at = self.timestamp
|
||||
self.performance_score = db_metadata.performance_metrics.get('accuracy',
|
||||
db_metadata.performance_metrics.get('reward', 0.0))
|
||||
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
|
||||
|
||||
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
@ -254,7 +340,7 @@ class CheckpointManager:
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
logger.info(f"No checkpoints found for {model_name}")
|
||||
# No more scattered "No checkpoints found" logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
@ -278,7 +364,7 @@ class CheckpointManager:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
logger.info(f"No valid checkpoints found for {model_name}")
|
||||
# No more scattered logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Sort by metric (highest first)
|
||||
|
408
utils/database_manager.py
Normal file
408
utils/database_manager.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""
|
||||
Database Manager for Trading System
|
||||
|
||||
Manages SQLite database for:
|
||||
1. Inference records logging
|
||||
2. Checkpoint metadata storage
|
||||
3. Model performance tracking
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceRecord:
|
||||
"""Structure for inference logging"""
|
||||
model_name: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str
|
||||
confidence: float
|
||||
probabilities: Dict[str, float]
|
||||
input_features_hash: str # Hash of input features for deduplication
|
||||
processing_time_ms: float
|
||||
memory_usage_mb: float
|
||||
checkpoint_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Structure for checkpoint metadata"""
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
timestamp: datetime
|
||||
performance_metrics: Dict[str, float]
|
||||
training_metadata: Dict[str, Any]
|
||||
file_path: str
|
||||
file_size_mb: float
|
||||
is_active: bool = False # Currently loaded checkpoint
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages SQLite database for trading system logging and metadata"""
|
||||
|
||||
def __init__(self, db_path: str = "data/trading_system.db"):
|
||||
self.db_path = db_path
|
||||
self._ensure_db_directory()
|
||||
self._initialize_database()
|
||||
|
||||
def _ensure_db_directory(self):
|
||||
"""Ensure database directory exists"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize database tables"""
|
||||
with self._get_connection() as conn:
|
||||
# Inference records table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT NOT NULL, -- JSON
|
||||
input_features_hash TEXT NOT NULL,
|
||||
processing_time_ms REAL NOT NULL,
|
||||
memory_usage_mb REAL NOT NULL,
|
||||
checkpoint_id TEXT,
|
||||
metadata TEXT, -- JSON
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Checkpoint metadata table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
checkpoint_id TEXT UNIQUE NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
performance_metrics TEXT NOT NULL, -- JSON
|
||||
training_metadata TEXT NOT NULL, -- JSON
|
||||
file_path TEXT NOT NULL,
|
||||
file_size_mb REAL NOT NULL,
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Model performance tracking table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
accuracy REAL DEFAULT 0.0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
avg_processing_time_ms REAL DEFAULT 0.0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(model_name, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for better performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
|
||||
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Get database connection with proper error handling"""
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row # Enable dict-like access
|
||||
yield conn
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def log_inference(self, record: InferenceRecord) -> bool:
|
||||
"""Log an inference record"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash, processing_time_ms,
|
||||
memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
record.symbol,
|
||||
record.action,
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
json.dumps(record.metadata) if record.metadata else None
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference record: {e}")
|
||||
return False
|
||||
|
||||
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
|
||||
"""Save checkpoint metadata"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all other checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (metadata.model_name,))
|
||||
|
||||
# Insert or replace the new checkpoint metadata
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO checkpoint_metadata (
|
||||
checkpoint_id, model_name, model_type, timestamp,
|
||||
performance_metrics, training_metadata, file_path,
|
||||
file_size_mb, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
metadata.checkpoint_id,
|
||||
metadata.model_name,
|
||||
metadata.model_type,
|
||||
metadata.timestamp.isoformat(),
|
||||
json.dumps(metadata.performance_metrics),
|
||||
json.dumps(metadata.training_metadata),
|
||||
metadata.file_path,
|
||||
metadata.file_size_mb,
|
||||
metadata.is_active
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint metadata: {e}")
|
||||
return False
|
||||
|
||||
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
|
||||
"""Get checkpoint metadata without loading the actual model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
if checkpoint_id:
|
||||
# Get specific checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
else:
|
||||
# Get active checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND is_active = TRUE
|
||||
ORDER BY timestamp DESC LIMIT 1
|
||||
""", (model_name,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
|
||||
"""Get best checkpoint metadata based on performance metric"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
|
||||
LIMIT 1
|
||||
""", (model_name, metric_name))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get best checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
|
||||
"""List all checkpoints for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
""", (model_name,))
|
||||
|
||||
checkpoints = []
|
||||
for row in cursor.fetchall():
|
||||
checkpoints.append(CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
))
|
||||
return checkpoints
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
|
||||
"""Set a checkpoint as active for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
# Set the specified checkpoint as active
|
||||
cursor = conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = TRUE
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set active checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
|
||||
"""Get recent inference records for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent inferences: {e}")
|
||||
return []
|
||||
|
||||
def update_model_performance(self, model_name: str, date: str,
|
||||
total_predictions: int, correct_predictions: int,
|
||||
avg_confidence: float, avg_processing_time: float) -> bool:
|
||||
"""Update daily model performance statistics"""
|
||||
try:
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update model performance: {e}")
|
||||
return False
|
||||
|
||||
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference records"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
DELETE FROM inference_records
|
||||
WHERE timestamp < ?
|
||||
""", (cutoff_date.isoformat(),))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Cleaned up {deleted_count} old inference records")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old records: {e}")
|
||||
return False
|
||||
|
||||
# Global database manager instance
|
||||
_db_manager_instance = None
|
||||
|
||||
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
|
||||
"""Get the global database manager instance"""
|
||||
global _db_manager_instance
|
||||
|
||||
if _db_manager_instance is None:
|
||||
_db_manager_instance = DatabaseManager(db_path)
|
||||
|
||||
return _db_manager_instance
|
226
utils/inference_logger.py
Normal file
226
utils/inference_logger.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""
|
||||
Inference Logger
|
||||
|
||||
Centralized logging system for model inferences with database storage
|
||||
Eliminates scattered logging throughout the codebase
|
||||
"""
|
||||
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
import psutil
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from .database_manager import get_database_manager, InferenceRecord
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InferenceLogger:
|
||||
"""Centralized inference logging system"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_manager = get_database_manager()
|
||||
self.text_logger = get_text_logger()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def log_inference(self,
|
||||
model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Log a model inference with all relevant details
|
||||
|
||||
Args:
|
||||
model_name: Name of the model making the prediction
|
||||
symbol: Trading symbol
|
||||
action: Predicted action (BUY/SELL/HOLD)
|
||||
confidence: Confidence score (0.0 to 1.0)
|
||||
probabilities: Action probabilities dict
|
||||
input_features: Input features used for prediction
|
||||
processing_time_ms: Time taken for inference in milliseconds
|
||||
checkpoint_id: ID of the checkpoint used
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
bool: True if logged successfully
|
||||
"""
|
||||
try:
|
||||
# Create feature hash for deduplication
|
||||
feature_hash = self._hash_features(input_features)
|
||||
|
||||
# Get current memory usage
|
||||
memory_usage_mb = self._get_memory_usage()
|
||||
|
||||
# Create inference record
|
||||
record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features_hash=feature_hash,
|
||||
processing_time_ms=processing_time_ms,
|
||||
memory_usage_mb=memory_usage_mb,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Log to database
|
||||
db_success = self.db_manager.log_inference(record)
|
||||
|
||||
# Log to text file
|
||||
text_success = self.text_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id
|
||||
)
|
||||
|
||||
if db_success:
|
||||
# Reduced logging - no more scattered logs at runtime
|
||||
pass # Database logging successful, text file provides human-readable record
|
||||
else:
|
||||
logger.error(f"Failed to log inference for {model_name}")
|
||||
|
||||
return db_success and text_success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging inference: {e}")
|
||||
return False
|
||||
|
||||
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
|
||||
"""Create a hash of input features for deduplication"""
|
||||
try:
|
||||
if isinstance(features, np.ndarray):
|
||||
# Hash numpy array
|
||||
return hashlib.md5(features.tobytes()).hexdigest()[:16]
|
||||
elif isinstance(features, (dict, list)):
|
||||
# Hash dict or list by converting to string
|
||||
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
|
||||
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
|
||||
else:
|
||||
# Hash string representation
|
||||
return hashlib.md5(str(features).encode()).hexdigest()[:16]
|
||||
except Exception:
|
||||
# Fallback to timestamp-based hash
|
||||
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
|
||||
|
||||
def _get_memory_usage(self) -> float:
|
||||
"""Get current memory usage in MB"""
|
||||
try:
|
||||
return self._process.memory_info().rss / (1024 * 1024)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get inference statistics for a model"""
|
||||
try:
|
||||
# Get recent inferences
|
||||
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Filter by time window
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours)
|
||||
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Calculate statistics
|
||||
total_inferences = len(recent_inferences)
|
||||
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
|
||||
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
|
||||
|
||||
# Action distribution
|
||||
action_counts = {}
|
||||
for record in recent_inferences:
|
||||
action_counts[record.action] = action_counts.get(record.action, 0) + 1
|
||||
|
||||
# Symbol distribution
|
||||
symbol_counts = {}
|
||||
for record in recent_inferences:
|
||||
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
|
||||
|
||||
return {
|
||||
'total_inferences': total_inferences,
|
||||
'avg_confidence': avg_confidence,
|
||||
'avg_processing_time_ms': avg_processing_time,
|
||||
'action_distribution': action_counts,
|
||||
'symbol_distribution': symbol_counts,
|
||||
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model stats: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference logs"""
|
||||
return self.db_manager.cleanup_old_records(days_to_keep)
|
||||
|
||||
# Global inference logger instance
|
||||
_inference_logger_instance = None
|
||||
|
||||
def get_inference_logger() -> InferenceLogger:
|
||||
"""Get the global inference logger instance"""
|
||||
global _inference_logger_instance
|
||||
|
||||
if _inference_logger_instance is None:
|
||||
_inference_logger_instance = InferenceLogger()
|
||||
|
||||
return _inference_logger_instance
|
||||
|
||||
def log_model_inference(model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Convenience function to log model inference
|
||||
|
||||
This is the main function that should be called throughout the codebase
|
||||
instead of scattered logger.info() calls
|
||||
"""
|
||||
inference_logger = get_inference_logger()
|
||||
return inference_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features=input_features,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
156
utils/text_logger.py
Normal file
156
utils/text_logger.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Text File Logger for Trading System
|
||||
|
||||
Simple text file logging for tracking inference records and system events
|
||||
Provides human-readable logs alongside database storage
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TextLogger:
|
||||
"""Simple text file logger for trading system events"""
|
||||
|
||||
def __init__(self, log_dir: str = "logs"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create separate log files for different types of events
|
||||
self.inference_log = self.log_dir / "inference_records.txt"
|
||||
self.checkpoint_log = self.log_dir / "checkpoint_events.txt"
|
||||
self.system_log = self.log_dir / "system_events.txt"
|
||||
|
||||
def log_inference(self, model_name: str, symbol: str, action: str,
|
||||
confidence: float, processing_time_ms: float,
|
||||
checkpoint_id: str = None) -> bool:
|
||||
"""Log inference record to text file"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
checkpoint_info = f" [checkpoint: {checkpoint_id}]" if checkpoint_id else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {symbol:10} | "
|
||||
f"{action:4} | conf={confidence:.3f} | "
|
||||
f"time={processing_time_ms:6.1f}ms{checkpoint_info}\n"
|
||||
)
|
||||
|
||||
with open(self.inference_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_checkpoint_event(self, model_name: str, event_type: str,
|
||||
checkpoint_id: str, details: str = "") -> bool:
|
||||
"""Log checkpoint events (save, load, etc.)"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
details_str = f" - {details}" if details else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {event_type:10} | "
|
||||
f"{checkpoint_id}{details_str}\n"
|
||||
)
|
||||
|
||||
with open(self.checkpoint_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log checkpoint event to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_system_event(self, event_type: str, message: str,
|
||||
component: str = "system") -> bool:
|
||||
"""Log general system events"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {component:15} | {event_type:10} | {message}\n"
|
||||
)
|
||||
|
||||
with open(self.system_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log system event to text file: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, lines: int = 50) -> str:
|
||||
"""Get recent inference records from text file"""
|
||||
try:
|
||||
if not self.inference_log.exists():
|
||||
return "No inference records found"
|
||||
|
||||
with open(self.inference_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read inference log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def get_recent_checkpoint_events(self, lines: int = 20) -> str:
|
||||
"""Get recent checkpoint events from text file"""
|
||||
try:
|
||||
if not self.checkpoint_log.exists():
|
||||
return "No checkpoint events found"
|
||||
|
||||
with open(self.checkpoint_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read checkpoint log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def cleanup_old_logs(self, max_lines: int = 10000) -> bool:
|
||||
"""Keep only the most recent log entries"""
|
||||
try:
|
||||
for log_file in [self.inference_log, self.checkpoint_log, self.system_log]:
|
||||
if log_file.exists():
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(lines) > max_lines:
|
||||
# Keep only the most recent lines
|
||||
recent_lines = lines[-max_lines:]
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(recent_lines)
|
||||
|
||||
logger.info(f"Cleaned up {log_file.name}: kept {len(recent_lines)} lines")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup logs: {e}")
|
||||
return False
|
||||
|
||||
# Global text logger instance
|
||||
_text_logger_instance = None
|
||||
|
||||
def get_text_logger(log_dir: str = "logs") -> TextLogger:
|
||||
"""Get the global text logger instance"""
|
||||
global _text_logger_instance
|
||||
|
||||
if _text_logger_instance is None:
|
||||
_text_logger_instance = TextLogger(log_dir)
|
||||
|
||||
return _text_logger_instance
|
Reference in New Issue
Block a user