Files
gogo2/core/model_output_manager.py
Dobromir Popov dbb918ea92 wip
2025-07-23 15:52:40 +03:00

395 lines
16 KiB
Python

"""
Model Output Manager
This module provides extensible model output storage and management for the multi-modal trading system.
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
"""
import logging
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from collections import deque, defaultdict
from threading import Lock
from pathlib import Path
from .data_models import ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class ModelOutputManager:
"""
Extensible model output storage and management system
Features:
- Standardized ModelOutput storage for all model types
- Cross-model feeding with hidden states
- Historical output tracking
- Metadata management
- Persistence and recovery
- Performance analytics
"""
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
"""
Initialize the model output manager
Args:
cache_dir: Directory for persistent storage
max_history: Maximum number of outputs to keep in memory per model
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_history = max_history
# In-memory storage
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
# Metadata tracking
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
# Thread safety
self.storage_lock = Lock()
# Supported model types
self.supported_model_types = {
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
'ensemble', 'hybrid', 'custom' # Extensible for future types
}
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
logger.info(f"Supported model types: {self.supported_model_types}")
def store_output(self, model_output: ModelOutput) -> bool:
"""
Store model output with full extensibility support
Args:
model_output: ModelOutput from any model type
Returns:
bool: True if stored successfully, False otherwise
"""
try:
with self.storage_lock:
symbol = model_output.symbol
model_name = model_output.model_name
model_type = model_output.model_type
# Validate model type (extensible)
if model_type not in self.supported_model_types:
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
self.supported_model_types.add(model_type)
# Store current output
self.current_outputs[symbol][model_name] = model_output
# Add to history
self.output_history[symbol][model_name].append(model_output)
# Store cross-model states if available
if model_output.hidden_states:
self.cross_model_states[symbol][model_name] = model_output.hidden_states
# Update model metadata
self._update_model_metadata(model_name, model_type, model_output.metadata)
# Update performance statistics
self._update_performance_stats(symbol, model_name, model_output)
# Persist to disk (async to avoid blocking)
self._persist_output_async(model_output)
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
return True
except Exception as e:
logger.error(f"Error storing model output: {e}")
return False
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
"""
Get the current (latest) output from a specific model
Args:
symbol: Trading symbol
model_name: Name of the model
Returns:
ModelOutput: Latest output from the model, or None if not available
"""
try:
return self.current_outputs.get(symbol, {}).get(model_name)
except Exception as e:
logger.error(f"Error getting current output for {model_name}: {e}")
return None
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all current outputs for a symbol (for cross-model feeding)
Args:
symbol: Trading symbol
Returns:
Dict[str, ModelOutput]: Dictionary of current outputs by model name
"""
try:
return dict(self.current_outputs.get(symbol, {}))
except Exception as e:
logger.error(f"Error getting all current outputs for {symbol}: {e}")
return {}
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
"""
Get historical outputs from a model
Args:
symbol: Trading symbol
model_name: Name of the model
count: Number of historical outputs to retrieve
Returns:
List[ModelOutput]: List of historical outputs (most recent first)
"""
try:
history = self.output_history.get(symbol, {}).get(model_name, deque())
return list(history)[-count:][::-1] # Most recent first
except Exception as e:
logger.error(f"Error getting output history for {model_name}: {e}")
return []
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
"""
Get hidden states from other models for cross-model feeding
Args:
symbol: Trading symbol
requesting_model: Name of the model requesting the states
Returns:
Dict[str, Dict[str, Any]]: Hidden states from other models
"""
try:
all_states = self.cross_model_states.get(symbol, {})
# Return states from all models except the requesting one
return {model_name: states for model_name, states in all_states.items()
if model_name != requesting_model}
except Exception as e:
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
return {}
def get_model_types_active(self, symbol: str) -> List[str]:
"""
Get list of active model types for a symbol
Args:
symbol: Trading symbol
Returns:
List[str]: List of active model types
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
return [output.model_type for output in current_outputs.values()]
except Exception as e:
logger.error(f"Error getting active model types for {symbol}: {e}")
return []
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
"""
Get consensus prediction from all active models
Args:
symbol: Trading symbol
confidence_threshold: Minimum confidence threshold for inclusion
Returns:
Dict containing consensus prediction or None
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
if not current_outputs:
return None
# Filter by confidence threshold
high_confidence_outputs = [
output for output in current_outputs.values()
if output.confidence >= confidence_threshold
]
if not high_confidence_outputs:
return None
# Calculate consensus
buy_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'BUY')
sell_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'SELL')
hold_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'HOLD')
total_votes = len(high_confidence_outputs)
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
# Determine consensus action
if buy_votes > sell_votes and buy_votes > hold_votes:
consensus_action = 'BUY'
elif sell_votes > buy_votes and sell_votes > hold_votes:
consensus_action = 'SELL'
else:
consensus_action = 'HOLD'
return {
'action': consensus_action,
'confidence': avg_confidence,
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
'total_models': total_votes,
'model_types': [output.model_type for output in high_confidence_outputs]
}
except Exception as e:
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
return None
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
"""Update metadata for a model"""
try:
if model_name not in self.model_metadata:
self.model_metadata[model_name] = {
'model_type': model_type,
'first_seen': datetime.now(),
'total_predictions': 0,
'custom_metadata': {}
}
self.model_metadata[model_name]['total_predictions'] += 1
self.model_metadata[model_name]['last_seen'] = datetime.now()
# Merge custom metadata
if metadata:
self.model_metadata[model_name]['custom_metadata'].update(metadata)
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
"""Update performance statistics for a model"""
try:
stats = self.performance_stats[symbol][model_name]
if 'prediction_count' not in stats:
stats['prediction_count'] = 0
stats['confidence_sum'] = 0.0
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
stats['first_prediction'] = model_output.timestamp
stats['prediction_count'] += 1
stats['confidence_sum'] += model_output.confidence
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
stats['last_prediction'] = model_output.timestamp
action = model_output.predictions.get('action', 'HOLD')
if action in stats['action_counts']:
stats['action_counts'][action] += 1
except Exception as e:
logger.error(f"Error updating performance stats: {e}")
def _persist_output_async(self, model_output: ModelOutput):
"""Persist model output to disk (simplified version)"""
try:
# Create filename based on model and timestamp
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
filepath = self.cache_dir / filename
# Convert to JSON-serializable format
output_dict = {
'model_type': model_output.model_type,
'model_name': model_output.model_name,
'symbol': model_output.symbol,
'timestamp': model_output.timestamp.isoformat(),
'confidence': model_output.confidence,
'predictions': model_output.predictions,
'metadata': model_output.metadata
}
# Save to file (in a real implementation, this would be async)
with open(filepath, 'w') as f:
json.dump(output_dict, f, indent=2)
except Exception as e:
logger.error(f"Error persisting model output: {e}")
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
"""
Get performance summary for all models for a symbol
Args:
symbol: Trading symbol
Returns:
Dict containing performance summary
"""
try:
summary = {
'symbol': symbol,
'active_models': len(self.current_outputs.get(symbol, {})),
'model_stats': {}
}
for model_name, stats in self.performance_stats.get(symbol, {}).items():
summary['model_stats'][model_name] = {
'predictions': stats.get('prediction_count', 0),
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
'action_distribution': stats.get('action_counts', {}),
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
}
return summary
except Exception as e:
logger.error(f"Error getting performance summary: {e}")
return {'symbol': symbol, 'error': str(e)}
def cleanup_old_outputs(self, max_age_hours: int = 24):
"""
Clean up old outputs to manage memory usage
Args:
max_age_hours: Maximum age of outputs to keep in hours
"""
try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
with self.storage_lock:
for symbol in self.output_history:
for model_name in self.output_history[symbol]:
history = self.output_history[symbol][model_name]
# Remove old outputs
while history and history[0].timestamp < cutoff_time:
history.popleft()
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
except Exception as e:
logger.error(f"Error cleaning up old outputs: {e}")
def add_custom_model_type(self, model_type: str):
"""
Add support for a new custom model type
Args:
model_type: Name of the new model type
"""
self.supported_model_types.add(model_type)
logger.info(f"Added support for custom model type: {model_type}")
def get_supported_model_types(self) -> List[str]:
"""Get list of all supported model types"""
return list(self.supported_model_types)