395 lines
16 KiB
Python
395 lines
16 KiB
Python
"""
|
|
Model Output Manager
|
|
|
|
This module provides extensible model output storage and management for the multi-modal trading system.
|
|
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import pickle
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from collections import deque, defaultdict
|
|
from threading import Lock
|
|
from pathlib import Path
|
|
|
|
from .data_models import ModelOutput, create_model_output
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelOutputManager:
|
|
"""
|
|
Extensible model output storage and management system
|
|
|
|
Features:
|
|
- Standardized ModelOutput storage for all model types
|
|
- Cross-model feeding with hidden states
|
|
- Historical output tracking
|
|
- Metadata management
|
|
- Persistence and recovery
|
|
- Performance analytics
|
|
"""
|
|
|
|
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
|
|
"""
|
|
Initialize the model output manager
|
|
|
|
Args:
|
|
cache_dir: Directory for persistent storage
|
|
max_history: Maximum number of outputs to keep in memory per model
|
|
"""
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self.max_history = max_history
|
|
|
|
# In-memory storage
|
|
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
|
|
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
|
|
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
|
|
|
|
# Metadata tracking
|
|
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
|
|
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
|
|
|
|
# Thread safety
|
|
self.storage_lock = Lock()
|
|
|
|
# Supported model types
|
|
self.supported_model_types = {
|
|
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
|
|
'ensemble', 'hybrid', 'custom' # Extensible for future types
|
|
}
|
|
|
|
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
|
|
logger.info(f"Supported model types: {self.supported_model_types}")
|
|
|
|
def store_output(self, model_output: ModelOutput) -> bool:
|
|
"""
|
|
Store model output with full extensibility support
|
|
|
|
Args:
|
|
model_output: ModelOutput from any model type
|
|
|
|
Returns:
|
|
bool: True if stored successfully, False otherwise
|
|
"""
|
|
try:
|
|
with self.storage_lock:
|
|
symbol = model_output.symbol
|
|
model_name = model_output.model_name
|
|
model_type = model_output.model_type
|
|
|
|
# Validate model type (extensible)
|
|
if model_type not in self.supported_model_types:
|
|
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
|
|
self.supported_model_types.add(model_type)
|
|
|
|
# Store current output
|
|
self.current_outputs[symbol][model_name] = model_output
|
|
|
|
# Add to history
|
|
self.output_history[symbol][model_name].append(model_output)
|
|
|
|
# Store cross-model states if available
|
|
if model_output.hidden_states:
|
|
self.cross_model_states[symbol][model_name] = model_output.hidden_states
|
|
|
|
# Update model metadata
|
|
self._update_model_metadata(model_name, model_type, model_output.metadata)
|
|
|
|
# Update performance statistics
|
|
self._update_performance_stats(symbol, model_name, model_output)
|
|
|
|
# Persist to disk (async to avoid blocking)
|
|
self._persist_output_async(model_output)
|
|
|
|
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing model output: {e}")
|
|
return False
|
|
|
|
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
|
|
"""
|
|
Get the current (latest) output from a specific model
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
ModelOutput: Latest output from the model, or None if not available
|
|
"""
|
|
try:
|
|
return self.current_outputs.get(symbol, {}).get(model_name)
|
|
except Exception as e:
|
|
logger.error(f"Error getting current output for {model_name}: {e}")
|
|
return None
|
|
|
|
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
|
"""
|
|
Get all current outputs for a symbol (for cross-model feeding)
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
|
|
Returns:
|
|
Dict[str, ModelOutput]: Dictionary of current outputs by model name
|
|
"""
|
|
try:
|
|
return dict(self.current_outputs.get(symbol, {}))
|
|
except Exception as e:
|
|
logger.error(f"Error getting all current outputs for {symbol}: {e}")
|
|
return {}
|
|
|
|
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
|
|
"""
|
|
Get historical outputs from a model
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
model_name: Name of the model
|
|
count: Number of historical outputs to retrieve
|
|
|
|
Returns:
|
|
List[ModelOutput]: List of historical outputs (most recent first)
|
|
"""
|
|
try:
|
|
history = self.output_history.get(symbol, {}).get(model_name, deque())
|
|
return list(history)[-count:][::-1] # Most recent first
|
|
except Exception as e:
|
|
logger.error(f"Error getting output history for {model_name}: {e}")
|
|
return []
|
|
|
|
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Get hidden states from other models for cross-model feeding
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
requesting_model: Name of the model requesting the states
|
|
|
|
Returns:
|
|
Dict[str, Dict[str, Any]]: Hidden states from other models
|
|
"""
|
|
try:
|
|
all_states = self.cross_model_states.get(symbol, {})
|
|
# Return states from all models except the requesting one
|
|
return {model_name: states for model_name, states in all_states.items()
|
|
if model_name != requesting_model}
|
|
except Exception as e:
|
|
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
|
|
return {}
|
|
|
|
def get_model_types_active(self, symbol: str) -> List[str]:
|
|
"""
|
|
Get list of active model types for a symbol
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
|
|
Returns:
|
|
List[str]: List of active model types
|
|
"""
|
|
try:
|
|
current_outputs = self.current_outputs.get(symbol, {})
|
|
return [output.model_type for output in current_outputs.values()]
|
|
except Exception as e:
|
|
logger.error(f"Error getting active model types for {symbol}: {e}")
|
|
return []
|
|
|
|
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get consensus prediction from all active models
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
confidence_threshold: Minimum confidence threshold for inclusion
|
|
|
|
Returns:
|
|
Dict containing consensus prediction or None
|
|
"""
|
|
try:
|
|
current_outputs = self.current_outputs.get(symbol, {})
|
|
if not current_outputs:
|
|
return None
|
|
|
|
# Filter by confidence threshold
|
|
high_confidence_outputs = [
|
|
output for output in current_outputs.values()
|
|
if output.confidence >= confidence_threshold
|
|
]
|
|
|
|
if not high_confidence_outputs:
|
|
return None
|
|
|
|
# Calculate consensus
|
|
buy_votes = sum(1 for output in high_confidence_outputs
|
|
if output.predictions.get('action') == 'BUY')
|
|
sell_votes = sum(1 for output in high_confidence_outputs
|
|
if output.predictions.get('action') == 'SELL')
|
|
hold_votes = sum(1 for output in high_confidence_outputs
|
|
if output.predictions.get('action') == 'HOLD')
|
|
|
|
total_votes = len(high_confidence_outputs)
|
|
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
|
|
|
|
# Determine consensus action
|
|
if buy_votes > sell_votes and buy_votes > hold_votes:
|
|
consensus_action = 'BUY'
|
|
elif sell_votes > buy_votes and sell_votes > hold_votes:
|
|
consensus_action = 'SELL'
|
|
else:
|
|
consensus_action = 'HOLD'
|
|
|
|
return {
|
|
'action': consensus_action,
|
|
'confidence': avg_confidence,
|
|
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
|
|
'total_models': total_votes,
|
|
'model_types': [output.model_type for output in high_confidence_outputs]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
|
|
return None
|
|
|
|
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
|
|
"""Update metadata for a model"""
|
|
try:
|
|
if model_name not in self.model_metadata:
|
|
self.model_metadata[model_name] = {
|
|
'model_type': model_type,
|
|
'first_seen': datetime.now(),
|
|
'total_predictions': 0,
|
|
'custom_metadata': {}
|
|
}
|
|
|
|
self.model_metadata[model_name]['total_predictions'] += 1
|
|
self.model_metadata[model_name]['last_seen'] = datetime.now()
|
|
|
|
# Merge custom metadata
|
|
if metadata:
|
|
self.model_metadata[model_name]['custom_metadata'].update(metadata)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating model metadata: {e}")
|
|
|
|
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
|
|
"""Update performance statistics for a model"""
|
|
try:
|
|
stats = self.performance_stats[symbol][model_name]
|
|
|
|
if 'prediction_count' not in stats:
|
|
stats['prediction_count'] = 0
|
|
stats['confidence_sum'] = 0.0
|
|
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
|
stats['first_prediction'] = model_output.timestamp
|
|
|
|
stats['prediction_count'] += 1
|
|
stats['confidence_sum'] += model_output.confidence
|
|
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
|
|
stats['last_prediction'] = model_output.timestamp
|
|
|
|
action = model_output.predictions.get('action', 'HOLD')
|
|
if action in stats['action_counts']:
|
|
stats['action_counts'][action] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating performance stats: {e}")
|
|
|
|
def _persist_output_async(self, model_output: ModelOutput):
|
|
"""Persist model output to disk (simplified version)"""
|
|
try:
|
|
# Create filename based on model and timestamp
|
|
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
|
|
filepath = self.cache_dir / filename
|
|
|
|
# Convert to JSON-serializable format
|
|
output_dict = {
|
|
'model_type': model_output.model_type,
|
|
'model_name': model_output.model_name,
|
|
'symbol': model_output.symbol,
|
|
'timestamp': model_output.timestamp.isoformat(),
|
|
'confidence': model_output.confidence,
|
|
'predictions': model_output.predictions,
|
|
'metadata': model_output.metadata
|
|
}
|
|
|
|
# Save to file (in a real implementation, this would be async)
|
|
with open(filepath, 'w') as f:
|
|
json.dump(output_dict, f, indent=2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error persisting model output: {e}")
|
|
|
|
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
|
|
"""
|
|
Get performance summary for all models for a symbol
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
|
|
Returns:
|
|
Dict containing performance summary
|
|
"""
|
|
try:
|
|
summary = {
|
|
'symbol': symbol,
|
|
'active_models': len(self.current_outputs.get(symbol, {})),
|
|
'model_stats': {}
|
|
}
|
|
|
|
for model_name, stats in self.performance_stats.get(symbol, {}).items():
|
|
summary['model_stats'][model_name] = {
|
|
'predictions': stats.get('prediction_count', 0),
|
|
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
|
|
'action_distribution': stats.get('action_counts', {}),
|
|
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
|
|
}
|
|
|
|
return summary
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting performance summary: {e}")
|
|
return {'symbol': symbol, 'error': str(e)}
|
|
|
|
def cleanup_old_outputs(self, max_age_hours: int = 24):
|
|
"""
|
|
Clean up old outputs to manage memory usage
|
|
|
|
Args:
|
|
max_age_hours: Maximum age of outputs to keep in hours
|
|
"""
|
|
try:
|
|
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
|
|
|
with self.storage_lock:
|
|
for symbol in self.output_history:
|
|
for model_name in self.output_history[symbol]:
|
|
history = self.output_history[symbol][model_name]
|
|
# Remove old outputs
|
|
while history and history[0].timestamp < cutoff_time:
|
|
history.popleft()
|
|
|
|
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up old outputs: {e}")
|
|
|
|
def add_custom_model_type(self, model_type: str):
|
|
"""
|
|
Add support for a new custom model type
|
|
|
|
Args:
|
|
model_type: Name of the new model type
|
|
"""
|
|
self.supported_model_types.add(model_type)
|
|
logger.info(f"Added support for custom model type: {model_type}")
|
|
|
|
def get_supported_model_types(self) -> List[str]:
|
|
"""Get list of all supported model types"""
|
|
return list(self.supported_model_types) |