226 lines
8.5 KiB
Python
226 lines
8.5 KiB
Python
"""
|
|
Inference Logger
|
|
|
|
Centralized logging system for model inferences with database storage
|
|
Eliminates scattered logging throughout the codebase
|
|
"""
|
|
|
|
import time
|
|
import hashlib
|
|
import logging
|
|
import psutil
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
|
|
from .database_manager import get_database_manager, InferenceRecord
|
|
from .text_logger import get_text_logger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class InferenceLogger:
|
|
"""Centralized inference logging system"""
|
|
|
|
def __init__(self):
|
|
self.db_manager = get_database_manager()
|
|
self.text_logger = get_text_logger()
|
|
self._process = psutil.Process()
|
|
|
|
def log_inference(self,
|
|
model_name: str,
|
|
symbol: str,
|
|
action: str,
|
|
confidence: float,
|
|
probabilities: Dict[str, float],
|
|
input_features: Union[np.ndarray, Dict, List],
|
|
processing_time_ms: float,
|
|
checkpoint_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Log a model inference with all relevant details
|
|
|
|
Args:
|
|
model_name: Name of the model making the prediction
|
|
symbol: Trading symbol
|
|
action: Predicted action (BUY/SELL/HOLD)
|
|
confidence: Confidence score (0.0 to 1.0)
|
|
probabilities: Action probabilities dict
|
|
input_features: Input features used for prediction
|
|
processing_time_ms: Time taken for inference in milliseconds
|
|
checkpoint_id: ID of the checkpoint used
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
bool: True if logged successfully
|
|
"""
|
|
try:
|
|
# Create feature hash for deduplication
|
|
feature_hash = self._hash_features(input_features)
|
|
|
|
# Get current memory usage
|
|
memory_usage_mb = self._get_memory_usage()
|
|
|
|
# Create inference record
|
|
record = InferenceRecord(
|
|
model_name=model_name,
|
|
timestamp=datetime.now(),
|
|
symbol=symbol,
|
|
action=action,
|
|
confidence=confidence,
|
|
probabilities=probabilities,
|
|
input_features_hash=feature_hash,
|
|
processing_time_ms=processing_time_ms,
|
|
memory_usage_mb=memory_usage_mb,
|
|
checkpoint_id=checkpoint_id,
|
|
metadata=metadata
|
|
)
|
|
|
|
# Log to database
|
|
db_success = self.db_manager.log_inference(record)
|
|
|
|
# Log to text file
|
|
text_success = self.text_logger.log_inference(
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
action=action,
|
|
confidence=confidence,
|
|
processing_time_ms=processing_time_ms,
|
|
checkpoint_id=checkpoint_id
|
|
)
|
|
|
|
if db_success:
|
|
# Reduced logging - no more scattered logs at runtime
|
|
pass # Database logging successful, text file provides human-readable record
|
|
else:
|
|
logger.error(f"Failed to log inference for {model_name}")
|
|
|
|
return db_success and text_success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error logging inference: {e}")
|
|
return False
|
|
|
|
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
|
|
"""Create a hash of input features for deduplication"""
|
|
try:
|
|
if isinstance(features, np.ndarray):
|
|
# Hash numpy array
|
|
return hashlib.md5(features.tobytes()).hexdigest()[:16]
|
|
elif isinstance(features, (dict, list)):
|
|
# Hash dict or list by converting to string
|
|
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
|
|
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
|
|
else:
|
|
# Hash string representation
|
|
return hashlib.md5(str(features).encode()).hexdigest()[:16]
|
|
except Exception:
|
|
# Fallback to timestamp-based hash
|
|
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
|
|
|
|
def _get_memory_usage(self) -> float:
|
|
"""Get current memory usage in MB"""
|
|
try:
|
|
return self._process.memory_info().rss / (1024 * 1024)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
|
"""Get inference statistics for a model"""
|
|
try:
|
|
# Get recent inferences
|
|
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
|
|
|
|
if not recent_inferences:
|
|
return {
|
|
'total_inferences': 0,
|
|
'avg_confidence': 0.0,
|
|
'avg_processing_time_ms': 0.0,
|
|
'action_distribution': {},
|
|
'symbol_distribution': {}
|
|
}
|
|
|
|
# Filter by time window
|
|
cutoff_time = datetime.now() - timedelta(hours=hours)
|
|
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
|
|
|
|
if not recent_inferences:
|
|
return {
|
|
'total_inferences': 0,
|
|
'avg_confidence': 0.0,
|
|
'avg_processing_time_ms': 0.0,
|
|
'action_distribution': {},
|
|
'symbol_distribution': {}
|
|
}
|
|
|
|
# Calculate statistics
|
|
total_inferences = len(recent_inferences)
|
|
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
|
|
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
|
|
|
|
# Action distribution
|
|
action_counts = {}
|
|
for record in recent_inferences:
|
|
action_counts[record.action] = action_counts.get(record.action, 0) + 1
|
|
|
|
# Symbol distribution
|
|
symbol_counts = {}
|
|
for record in recent_inferences:
|
|
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
|
|
|
|
return {
|
|
'total_inferences': total_inferences,
|
|
'avg_confidence': avg_confidence,
|
|
'avg_processing_time_ms': avg_processing_time,
|
|
'action_distribution': action_counts,
|
|
'symbol_distribution': symbol_counts,
|
|
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting model stats: {e}")
|
|
return {}
|
|
|
|
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
|
|
"""Clean up old inference logs"""
|
|
return self.db_manager.cleanup_old_records(days_to_keep)
|
|
|
|
# Global inference logger instance
|
|
_inference_logger_instance = None
|
|
|
|
def get_inference_logger() -> InferenceLogger:
|
|
"""Get the global inference logger instance"""
|
|
global _inference_logger_instance
|
|
|
|
if _inference_logger_instance is None:
|
|
_inference_logger_instance = InferenceLogger()
|
|
|
|
return _inference_logger_instance
|
|
|
|
def log_model_inference(model_name: str,
|
|
symbol: str,
|
|
action: str,
|
|
confidence: float,
|
|
probabilities: Dict[str, float],
|
|
input_features: Union[np.ndarray, Dict, List],
|
|
processing_time_ms: float,
|
|
checkpoint_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""
|
|
Convenience function to log model inference
|
|
|
|
This is the main function that should be called throughout the codebase
|
|
instead of scattered logger.info() calls
|
|
"""
|
|
inference_logger = get_inference_logger()
|
|
return inference_logger.log_inference(
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
action=action,
|
|
confidence=confidence,
|
|
probabilities=probabilities,
|
|
input_features=input_features,
|
|
processing_time_ms=processing_time_ms,
|
|
checkpoint_id=checkpoint_id,
|
|
metadata=metadata
|
|
) |