Files
gogo2/utils/inference_logger.py
2025-07-26 23:34:36 +03:00

234 lines
8.9 KiB
Python

"""
Inference Logger
Centralized logging system for model inferences with database storage
Eliminates scattered logging throughout the codebase
"""
import time
import hashlib
import logging
import psutil
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
import numpy as np
from .database_manager import get_database_manager, InferenceRecord
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
class InferenceLogger:
"""Centralized inference logging system"""
def __init__(self):
self.db_manager = get_database_manager()
self.text_logger = get_text_logger()
self._process = psutil.Process()
def log_inference(self,
model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Log a model inference with all relevant details
Args:
model_name: Name of the model making the prediction
symbol: Trading symbol
action: Predicted action (BUY/SELL/HOLD)
confidence: Confidence score (0.0 to 1.0)
probabilities: Action probabilities dict
input_features: Input features used for prediction
processing_time_ms: Time taken for inference in milliseconds
checkpoint_id: ID of the checkpoint used
metadata: Additional metadata
Returns:
bool: True if logged successfully
"""
try:
# Create feature hash for deduplication
feature_hash = self._hash_features(input_features)
# Get current memory usage
memory_usage_mb = self._get_memory_usage()
# Convert input features to numpy array if needed
features_array = None
if isinstance(input_features, np.ndarray):
features_array = input_features.astype(np.float32)
elif isinstance(input_features, (list, tuple)):
features_array = np.array(input_features, dtype=np.float32)
# Create inference record
record = InferenceRecord(
model_name=model_name,
timestamp=datetime.now(),
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features_hash=feature_hash,
processing_time_ms=processing_time_ms,
memory_usage_mb=memory_usage_mb,
input_features=features_array,
checkpoint_id=checkpoint_id,
metadata=metadata
)
# Log to database
db_success = self.db_manager.log_inference(record)
# Log to text file
text_success = self.text_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id
)
if db_success:
# Reduced logging - no more scattered logs at runtime
pass # Database logging successful, text file provides human-readable record
else:
logger.error(f"Failed to log inference for {model_name}")
return db_success and text_success
except Exception as e:
logger.error(f"Error logging inference: {e}")
return False
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
"""Create a hash of input features for deduplication"""
try:
if isinstance(features, np.ndarray):
# Hash numpy array
return hashlib.md5(features.tobytes()).hexdigest()[:16]
elif isinstance(features, (dict, list)):
# Hash dict or list by converting to string
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
else:
# Hash string representation
return hashlib.md5(str(features).encode()).hexdigest()[:16]
except Exception:
# Fallback to timestamp-based hash
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
def _get_memory_usage(self) -> float:
"""Get current memory usage in MB"""
try:
return self._process.memory_info().rss / (1024 * 1024)
except Exception:
return 0.0
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
"""Get inference statistics for a model"""
try:
# Get recent inferences
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Filter by time window
cutoff_time = datetime.now() - timedelta(hours=hours)
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Calculate statistics
total_inferences = len(recent_inferences)
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
# Action distribution
action_counts = {}
for record in recent_inferences:
action_counts[record.action] = action_counts.get(record.action, 0) + 1
# Symbol distribution
symbol_counts = {}
for record in recent_inferences:
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
return {
'total_inferences': total_inferences,
'avg_confidence': avg_confidence,
'avg_processing_time_ms': avg_processing_time,
'action_distribution': action_counts,
'symbol_distribution': symbol_counts,
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
}
except Exception as e:
logger.error(f"Error getting model stats: {e}")
return {}
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference logs"""
return self.db_manager.cleanup_old_records(days_to_keep)
# Global inference logger instance
_inference_logger_instance = None
def get_inference_logger() -> InferenceLogger:
"""Get the global inference logger instance"""
global _inference_logger_instance
if _inference_logger_instance is None:
_inference_logger_instance = InferenceLogger()
return _inference_logger_instance
def log_model_inference(model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to log model inference
This is the main function that should be called throughout the codebase
instead of scattered logger.info() calls
"""
inference_logger = get_inference_logger()
return inference_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features=input_features,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id,
metadata=metadata
)