382 lines
14 KiB
Python
382 lines
14 KiB
Python
"""
|
|
Model Output Manager
|
|
|
|
This module provides a centralized storage and management system for model outputs,
|
|
enabling cross-model feeding and evaluation.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from threading import Lock
|
|
|
|
from .data_models import ModelOutput
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelOutputManager:
|
|
"""
|
|
Centralized storage and management system for model outputs
|
|
|
|
This class:
|
|
1. Stores model outputs for all models
|
|
2. Provides access to current and historical outputs
|
|
3. Handles persistence of outputs to disk
|
|
4. Supports evaluation of model performance
|
|
"""
|
|
|
|
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
|
|
"""
|
|
Initialize the model output manager
|
|
|
|
Args:
|
|
cache_dir: Directory to store model outputs
|
|
max_history: Maximum number of historical outputs to keep per model
|
|
"""
|
|
self.cache_dir = cache_dir
|
|
self.max_history = max_history
|
|
self.outputs_lock = Lock()
|
|
|
|
# Current outputs for each model and symbol
|
|
# {symbol: {model_name: ModelOutput}}
|
|
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {}
|
|
|
|
# Historical outputs for each model and symbol
|
|
# {symbol: {model_name: List[ModelOutput]}}
|
|
self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {}
|
|
|
|
# Performance metrics for each model and symbol
|
|
# {symbol: {model_name: Dict[str, float]}}
|
|
self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
|
|
|
|
# Create cache directory if it doesn't exist
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}")
|
|
|
|
def store_output(self, model_output: ModelOutput) -> bool:
|
|
"""
|
|
Store a model output
|
|
|
|
Args:
|
|
model_output: Model output to store
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
symbol = model_output.symbol
|
|
model_name = model_output.model_name
|
|
|
|
with self.outputs_lock:
|
|
# Initialize dictionaries if they don't exist
|
|
if symbol not in self.current_outputs:
|
|
self.current_outputs[symbol] = {}
|
|
if symbol not in self.historical_outputs:
|
|
self.historical_outputs[symbol] = {}
|
|
if model_name not in self.historical_outputs[symbol]:
|
|
self.historical_outputs[symbol][model_name] = []
|
|
|
|
# Store current output
|
|
self.current_outputs[symbol][model_name] = model_output
|
|
|
|
# Add to historical outputs
|
|
self.historical_outputs[symbol][model_name].append(model_output)
|
|
|
|
# Limit historical outputs
|
|
if len(self.historical_outputs[symbol][model_name]) > self.max_history:
|
|
self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:]
|
|
|
|
# Persist output to disk
|
|
self._persist_output(model_output)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing model output: {e}")
|
|
return False
|
|
|
|
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
|
|
"""
|
|
Get the current output for a model and symbol
|
|
|
|
Args:
|
|
symbol: Symbol to get output for
|
|
model_name: Model name to get output for
|
|
|
|
Returns:
|
|
ModelOutput: Current output, or None if not available
|
|
"""
|
|
try:
|
|
with self.outputs_lock:
|
|
if symbol in self.current_outputs and model_name in self.current_outputs[symbol]:
|
|
return self.current_outputs[symbol][model_name]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting current output: {e}")
|
|
return None
|
|
|
|
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
|
"""
|
|
Get all current outputs for a symbol
|
|
|
|
Args:
|
|
symbol: Symbol to get outputs for
|
|
|
|
Returns:
|
|
Dict[str, ModelOutput]: Dictionary of model name to output
|
|
"""
|
|
try:
|
|
with self.outputs_lock:
|
|
if symbol in self.current_outputs:
|
|
return self.current_outputs[symbol].copy()
|
|
return {}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting all current outputs: {e}")
|
|
return {}
|
|
|
|
def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]:
|
|
"""
|
|
Get historical outputs for a model and symbol
|
|
|
|
Args:
|
|
symbol: Symbol to get outputs for
|
|
model_name: Model name to get outputs for
|
|
limit: Maximum number of outputs to return, None for all
|
|
|
|
Returns:
|
|
List[ModelOutput]: List of historical outputs
|
|
"""
|
|
try:
|
|
with self.outputs_lock:
|
|
if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]:
|
|
outputs = self.historical_outputs[symbol][model_name]
|
|
if limit is not None:
|
|
outputs = outputs[-limit:]
|
|
return outputs.copy()
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting historical outputs: {e}")
|
|
return []
|
|
|
|
def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]:
|
|
"""
|
|
Evaluate model performance based on historical outputs
|
|
|
|
Args:
|
|
symbol: Symbol to evaluate
|
|
model_name: Model name to evaluate
|
|
|
|
Returns:
|
|
Dict[str, float]: Performance metrics
|
|
"""
|
|
try:
|
|
# Get historical outputs
|
|
outputs = self.get_historical_outputs(symbol, model_name)
|
|
|
|
if not outputs:
|
|
return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0}
|
|
|
|
# Calculate metrics
|
|
total_outputs = len(outputs)
|
|
total_confidence = sum(output.confidence for output in outputs)
|
|
avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0
|
|
|
|
# For now, we don't have ground truth to calculate accuracy
|
|
# In the future, we can add this by comparing predictions to actual market movements
|
|
|
|
metrics = {
|
|
'confidence': avg_confidence,
|
|
'samples': total_outputs,
|
|
'last_update': datetime.now().isoformat()
|
|
}
|
|
|
|
# Store metrics
|
|
with self.outputs_lock:
|
|
if symbol not in self.performance_metrics:
|
|
self.performance_metrics[symbol] = {}
|
|
self.performance_metrics[symbol][model_name] = metrics
|
|
|
|
return metrics
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating model performance: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]:
|
|
"""
|
|
Get performance metrics for a model and symbol
|
|
|
|
Args:
|
|
symbol: Symbol to get metrics for
|
|
model_name: Model name to get metrics for
|
|
|
|
Returns:
|
|
Dict[str, float]: Performance metrics
|
|
"""
|
|
try:
|
|
with self.outputs_lock:
|
|
if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]:
|
|
return self.performance_metrics[symbol][model_name].copy()
|
|
|
|
# If no metrics are available, calculate them
|
|
return self.evaluate_model_performance(symbol, model_name)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting performance metrics: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _persist_output(self, model_output: ModelOutput) -> bool:
|
|
"""
|
|
Persist a model output to disk
|
|
|
|
Args:
|
|
model_output: Model output to persist
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_'))
|
|
os.makedirs(symbol_dir, exist_ok=True)
|
|
|
|
# Create filename with timestamp
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json"
|
|
filepath = os.path.join(self.cache_dir, filename)
|
|
|
|
# Convert ModelOutput to dictionary
|
|
output_dict = {
|
|
'model_type': model_output.model_type,
|
|
'model_name': model_output.model_name,
|
|
'symbol': model_output.symbol,
|
|
'timestamp': model_output.timestamp.isoformat(),
|
|
'confidence': model_output.confidence,
|
|
'predictions': model_output.predictions,
|
|
'metadata': model_output.metadata
|
|
}
|
|
|
|
# Don't store hidden states in file (too large)
|
|
|
|
# Write to file
|
|
with open(filepath, 'w') as f:
|
|
json.dump(output_dict, f, indent=2)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error persisting model output: {e}")
|
|
return False
|
|
|
|
def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int:
|
|
"""
|
|
Load model outputs from disk
|
|
|
|
Args:
|
|
symbol: Symbol to load outputs for, None for all
|
|
model_name: Model name to load outputs for, None for all
|
|
|
|
Returns:
|
|
int: Number of outputs loaded
|
|
"""
|
|
try:
|
|
# Find all output files
|
|
import glob
|
|
|
|
if symbol and model_name:
|
|
pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json")
|
|
elif symbol:
|
|
pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json")
|
|
elif model_name:
|
|
pattern = os.path.join(self.cache_dir, f"{model_name}_*.json")
|
|
else:
|
|
pattern = os.path.join(self.cache_dir, "*.json")
|
|
|
|
output_files = glob.glob(pattern)
|
|
|
|
if not output_files:
|
|
logger.info(f"No output files found for pattern: {pattern}")
|
|
return 0
|
|
|
|
# Load each file
|
|
loaded_count = 0
|
|
for filepath in output_files:
|
|
try:
|
|
with open(filepath, 'r') as f:
|
|
output_dict = json.load(f)
|
|
|
|
# Create ModelOutput
|
|
model_output = ModelOutput(
|
|
model_type=output_dict['model_type'],
|
|
model_name=output_dict['model_name'],
|
|
symbol=output_dict['symbol'],
|
|
timestamp=datetime.fromisoformat(output_dict['timestamp']),
|
|
confidence=output_dict['confidence'],
|
|
predictions=output_dict['predictions'],
|
|
hidden_states={}, # Don't load hidden states from disk
|
|
metadata=output_dict.get('metadata', {})
|
|
)
|
|
|
|
# Store output
|
|
self.store_output(model_output)
|
|
loaded_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading output file {filepath}: {e}")
|
|
|
|
logger.info(f"Loaded {loaded_count} model outputs from disk")
|
|
return loaded_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading outputs from disk: {e}")
|
|
return 0
|
|
|
|
def cleanup_old_outputs(self, max_age_days: int = 30) -> int:
|
|
"""
|
|
Clean up old output files
|
|
|
|
Args:
|
|
max_age_days: Maximum age of files to keep in days
|
|
|
|
Returns:
|
|
int: Number of files deleted
|
|
"""
|
|
try:
|
|
# Find all output files
|
|
import glob
|
|
output_files = glob.glob(os.path.join(self.cache_dir, "*.json"))
|
|
|
|
if not output_files:
|
|
return 0
|
|
|
|
# Calculate cutoff time
|
|
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
|
|
|
|
# Delete old files
|
|
deleted_count = 0
|
|
for filepath in output_files:
|
|
try:
|
|
# Get file modification time
|
|
mtime = os.path.getmtime(filepath)
|
|
|
|
# Delete if older than cutoff
|
|
if mtime < cutoff_time:
|
|
os.remove(filepath)
|
|
deleted_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file {filepath}: {e}")
|
|
|
|
logger.info(f"Deleted {deleted_count} old model output files")
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up old outputs: {e}")
|
|
return 0 |