checkpoint manager

This commit is contained in:
Dobromir Popov
2025-07-23 21:40:04 +03:00
parent bab39fa68f
commit 45a62443a0
9 changed files with 1587 additions and 709 deletions

View File

@ -1,34 +1,31 @@
"""
Model Output Manager
This module provides extensible model output storage and management for the multi-modal trading system.
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
This module provides a centralized storage and management system for model outputs,
enabling cross-model feeding and evaluation.
"""
import logging
import os
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from collections import deque, defaultdict
import logging
import time
from datetime import datetime
from typing import Dict, List, Optional, Any
from threading import Lock
from pathlib import Path
from .data_models import ModelOutput, create_model_output
from .data_models import ModelOutput
logger = logging.getLogger(__name__)
class ModelOutputManager:
"""
Extensible model output storage and management system
Centralized storage and management system for model outputs
Features:
- Standardized ModelOutput storage for all model types
- Cross-model feeding with hidden states
- Historical output tracking
- Metadata management
- Persistence and recovery
- Performance analytics
This class:
1. Stores model outputs for all models
2. Provides access to current and historical outputs
3. Handles persistence of outputs to disk
4. Supports evaluation of model performance
"""
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
@ -36,279 +33,226 @@ class ModelOutputManager:
Initialize the model output manager
Args:
cache_dir: Directory for persistent storage
max_history: Maximum number of outputs to keep in memory per model
cache_dir: Directory to store model outputs
max_history: Maximum number of historical outputs to keep per model
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir = cache_dir
self.max_history = max_history
self.outputs_lock = Lock()
# In-memory storage
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
# Current outputs for each model and symbol
# {symbol: {model_name: ModelOutput}}
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {}
# Metadata tracking
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
# Historical outputs for each model and symbol
# {symbol: {model_name: List[ModelOutput]}}
self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {}
# Thread safety
self.storage_lock = Lock()
# Performance metrics for each model and symbol
# {symbol: {model_name: Dict[str, float]}}
self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
# Supported model types
self.supported_model_types = {
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
'ensemble', 'hybrid', 'custom' # Extensible for future types
}
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
logger.info(f"Supported model types: {self.supported_model_types}")
logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}")
def store_output(self, model_output: ModelOutput) -> bool:
"""
Store model output with full extensibility support
Store a model output
Args:
model_output: ModelOutput from any model type
model_output: Model output to store
Returns:
bool: True if stored successfully, False otherwise
bool: True if successful, False otherwise
"""
try:
with self.storage_lock:
symbol = model_output.symbol
model_name = model_output.model_name
model_type = model_output.model_type
# Validate model type (extensible)
if model_type not in self.supported_model_types:
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
self.supported_model_types.add(model_type)
symbol = model_output.symbol
model_name = model_output.model_name
with self.outputs_lock:
# Initialize dictionaries if they don't exist
if symbol not in self.current_outputs:
self.current_outputs[symbol] = {}
if symbol not in self.historical_outputs:
self.historical_outputs[symbol] = {}
if model_name not in self.historical_outputs[symbol]:
self.historical_outputs[symbol][model_name] = []
# Store current output
self.current_outputs[symbol][model_name] = model_output
# Add to history
self.output_history[symbol][model_name].append(model_output)
# Store cross-model states if available
if model_output.hidden_states:
self.cross_model_states[symbol][model_name] = model_output.hidden_states
# Update model metadata
self._update_model_metadata(model_name, model_type, model_output.metadata)
# Update performance statistics
self._update_performance_stats(symbol, model_name, model_output)
# Persist to disk (async to avoid blocking)
self._persist_output_async(model_output)
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
return True
# Add to historical outputs
self.historical_outputs[symbol][model_name].append(model_output)
# Limit historical outputs
if len(self.historical_outputs[symbol][model_name]) > self.max_history:
self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:]
# Persist output to disk
self._persist_output(model_output)
return True
except Exception as e:
logger.error(f"Error storing model output: {e}")
return False
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
"""
Get the current (latest) output from a specific model
Get the current output for a model and symbol
Args:
symbol: Trading symbol
model_name: Name of the model
symbol: Symbol to get output for
model_name: Model name to get output for
Returns:
ModelOutput: Latest output from the model, or None if not available
ModelOutput: Current output, or None if not available
"""
try:
return self.current_outputs.get(symbol, {}).get(model_name)
with self.outputs_lock:
if symbol in self.current_outputs and model_name in self.current_outputs[symbol]:
return self.current_outputs[symbol][model_name]
return None
except Exception as e:
logger.error(f"Error getting current output for {model_name}: {e}")
logger.error(f"Error getting current output: {e}")
return None
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all current outputs for a symbol (for cross-model feeding)
Get all current outputs for a symbol
Args:
symbol: Trading symbol
symbol: Symbol to get outputs for
Returns:
Dict[str, ModelOutput]: Dictionary of current outputs by model name
Dict[str, ModelOutput]: Dictionary of model name to output
"""
try:
return dict(self.current_outputs.get(symbol, {}))
with self.outputs_lock:
if symbol in self.current_outputs:
return self.current_outputs[symbol].copy()
return {}
except Exception as e:
logger.error(f"Error getting all current outputs for {symbol}: {e}")
logger.error(f"Error getting all current outputs: {e}")
return {}
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]:
"""
Get historical outputs from a model
Get historical outputs for a model and symbol
Args:
symbol: Trading symbol
model_name: Name of the model
count: Number of historical outputs to retrieve
symbol: Symbol to get outputs for
model_name: Model name to get outputs for
limit: Maximum number of outputs to return, None for all
Returns:
List[ModelOutput]: List of historical outputs (most recent first)
List[ModelOutput]: List of historical outputs
"""
try:
history = self.output_history.get(symbol, {}).get(model_name, deque())
return list(history)[-count:][::-1] # Most recent first
with self.outputs_lock:
if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]:
outputs = self.historical_outputs[symbol][model_name]
if limit is not None:
outputs = outputs[-limit:]
return outputs.copy()
return []
except Exception as e:
logger.error(f"Error getting output history for {model_name}: {e}")
logger.error(f"Error getting historical outputs: {e}")
return []
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]:
"""
Get hidden states from other models for cross-model feeding
Evaluate model performance based on historical outputs
Args:
symbol: Trading symbol
requesting_model: Name of the model requesting the states
symbol: Symbol to evaluate
model_name: Model name to evaluate
Returns:
Dict[str, Dict[str, Any]]: Hidden states from other models
Dict[str, float]: Performance metrics
"""
try:
all_states = self.cross_model_states.get(symbol, {})
# Return states from all models except the requesting one
return {model_name: states for model_name, states in all_states.items()
if model_name != requesting_model}
except Exception as e:
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
return {}
def get_model_types_active(self, symbol: str) -> List[str]:
"""
Get list of active model types for a symbol
Args:
symbol: Trading symbol
Returns:
List[str]: List of active model types
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
return [output.model_type for output in current_outputs.values()]
except Exception as e:
logger.error(f"Error getting active model types for {symbol}: {e}")
return []
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
"""
Get consensus prediction from all active models
Args:
symbol: Trading symbol
confidence_threshold: Minimum confidence threshold for inclusion
Returns:
Dict containing consensus prediction or None
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
if not current_outputs:
return None
# Get historical outputs
outputs = self.get_historical_outputs(symbol, model_name)
# Filter by confidence threshold
high_confidence_outputs = [
output for output in current_outputs.values()
if output.confidence >= confidence_threshold
]
if not outputs:
return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0}
if not high_confidence_outputs:
return None
# Calculate metrics
total_outputs = len(outputs)
total_confidence = sum(output.confidence for output in outputs)
avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0
# Calculate consensus
buy_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'BUY')
sell_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'SELL')
hold_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'HOLD')
# For now, we don't have ground truth to calculate accuracy
# In the future, we can add this by comparing predictions to actual market movements
total_votes = len(high_confidence_outputs)
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
# Determine consensus action
if buy_votes > sell_votes and buy_votes > hold_votes:
consensus_action = 'BUY'
elif sell_votes > buy_votes and sell_votes > hold_votes:
consensus_action = 'SELL'
else:
consensus_action = 'HOLD'
return {
'action': consensus_action,
metrics = {
'confidence': avg_confidence,
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
'total_models': total_votes,
'model_types': [output.model_type for output in high_confidence_outputs]
'samples': total_outputs,
'last_update': datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
return None
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
"""Update metadata for a model"""
try:
if model_name not in self.model_metadata:
self.model_metadata[model_name] = {
'model_type': model_type,
'first_seen': datetime.now(),
'total_predictions': 0,
'custom_metadata': {}
}
# Store metrics
with self.outputs_lock:
if symbol not in self.performance_metrics:
self.performance_metrics[symbol] = {}
self.performance_metrics[symbol][model_name] = metrics
self.model_metadata[model_name]['total_predictions'] += 1
self.model_metadata[model_name]['last_seen'] = datetime.now()
# Merge custom metadata
if metadata:
self.model_metadata[model_name]['custom_metadata'].update(metadata)
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
"""Update performance statistics for a model"""
try:
stats = self.performance_stats[symbol][model_name]
if 'prediction_count' not in stats:
stats['prediction_count'] = 0
stats['confidence_sum'] = 0.0
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
stats['first_prediction'] = model_output.timestamp
stats['prediction_count'] += 1
stats['confidence_sum'] += model_output.confidence
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
stats['last_prediction'] = model_output.timestamp
action = model_output.predictions.get('action', 'HOLD')
if action in stats['action_counts']:
stats['action_counts'][action] += 1
return metrics
except Exception as e:
logger.error(f"Error updating performance stats: {e}")
logger.error(f"Error evaluating model performance: {e}")
return {'error': str(e)}
def _persist_output_async(self, model_output: ModelOutput):
"""Persist model output to disk (simplified version)"""
def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]:
"""
Get performance metrics for a model and symbol
Args:
symbol: Symbol to get metrics for
model_name: Model name to get metrics for
Returns:
Dict[str, float]: Performance metrics
"""
try:
# Create filename based on model and timestamp
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
filepath = self.cache_dir / filename
with self.outputs_lock:
if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]:
return self.performance_metrics[symbol][model_name].copy()
# Convert to JSON-serializable format
# If no metrics are available, calculate them
return self.evaluate_model_performance(symbol, model_name)
except Exception as e:
logger.error(f"Error getting performance metrics: {e}")
return {'error': str(e)}
def _persist_output(self, model_output: ModelOutput) -> bool:
"""
Persist a model output to disk
Args:
model_output: Model output to persist
Returns:
bool: True if successful, False otherwise
"""
try:
# Create directory if it doesn't exist
symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_'))
os.makedirs(symbol_dir, exist_ok=True)
# Create filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json"
filepath = os.path.join(self.cache_dir, filename)
# Convert ModelOutput to dictionary
output_dict = {
'model_type': model_output.model_type,
'model_name': model_output.model_name,
@ -319,77 +263,120 @@ class ModelOutputManager:
'metadata': model_output.metadata
}
# Save to file (in a real implementation, this would be async)
# Don't store hidden states in file (too large)
# Write to file
with open(filepath, 'w') as f:
json.dump(output_dict, f, indent=2)
return True
except Exception as e:
logger.error(f"Error persisting model output: {e}")
return False
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int:
"""
Get performance summary for all models for a symbol
Load model outputs from disk
Args:
symbol: Trading symbol
symbol: Symbol to load outputs for, None for all
model_name: Model name to load outputs for, None for all
Returns:
Dict containing performance summary
int: Number of outputs loaded
"""
try:
summary = {
'symbol': symbol,
'active_models': len(self.current_outputs.get(symbol, {})),
'model_stats': {}
}
# Find all output files
import glob
for model_name, stats in self.performance_stats.get(symbol, {}).items():
summary['model_stats'][model_name] = {
'predictions': stats.get('prediction_count', 0),
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
'action_distribution': stats.get('action_counts', {}),
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
}
if symbol and model_name:
pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json")
elif symbol:
pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json")
elif model_name:
pattern = os.path.join(self.cache_dir, f"{model_name}_*.json")
else:
pattern = os.path.join(self.cache_dir, "*.json")
return summary
output_files = glob.glob(pattern)
if not output_files:
logger.info(f"No output files found for pattern: {pattern}")
return 0
# Load each file
loaded_count = 0
for filepath in output_files:
try:
with open(filepath, 'r') as f:
output_dict = json.load(f)
# Create ModelOutput
model_output = ModelOutput(
model_type=output_dict['model_type'],
model_name=output_dict['model_name'],
symbol=output_dict['symbol'],
timestamp=datetime.fromisoformat(output_dict['timestamp']),
confidence=output_dict['confidence'],
predictions=output_dict['predictions'],
hidden_states={}, # Don't load hidden states from disk
metadata=output_dict.get('metadata', {})
)
# Store output
self.store_output(model_output)
loaded_count += 1
except Exception as e:
logger.error(f"Error loading output file {filepath}: {e}")
logger.info(f"Loaded {loaded_count} model outputs from disk")
return loaded_count
except Exception as e:
logger.error(f"Error getting performance summary: {e}")
return {'symbol': symbol, 'error': str(e)}
logger.error(f"Error loading outputs from disk: {e}")
return 0
def cleanup_old_outputs(self, max_age_hours: int = 24):
def cleanup_old_outputs(self, max_age_days: int = 30) -> int:
"""
Clean up old outputs to manage memory usage
Clean up old output files
Args:
max_age_hours: Maximum age of outputs to keep in hours
max_age_days: Maximum age of files to keep in days
Returns:
int: Number of files deleted
"""
try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
# Find all output files
import glob
output_files = glob.glob(os.path.join(self.cache_dir, "*.json"))
with self.storage_lock:
for symbol in self.output_history:
for model_name in self.output_history[symbol]:
history = self.output_history[symbol][model_name]
# Remove old outputs
while history and history[0].timestamp < cutoff_time:
history.popleft()
if not output_files:
return 0
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
# Calculate cutoff time
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
# Delete old files
deleted_count = 0
for filepath in output_files:
try:
# Get file modification time
mtime = os.path.getmtime(filepath)
# Delete if older than cutoff
if mtime < cutoff_time:
os.remove(filepath)
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting file {filepath}: {e}")
logger.info(f"Deleted {deleted_count} old model output files")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up old outputs: {e}")
def add_custom_model_type(self, model_type: str):
"""
Add support for a new custom model type
Args:
model_type: Name of the new model type
"""
self.supported_model_types.add(model_type)
logger.info(f"Added support for custom model type: {model_type}")
def get_supported_model_types(self) -> List[str]:
"""Get list of all supported model types"""
return list(self.supported_model_types)
return 0