wip
This commit is contained in:
232
core/data_models.py
Normal file
232
core/data_models.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Standardized Data Models for Multi-Modal Trading System
|
||||
|
||||
This module defines the standardized data structures used across all models:
|
||||
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
|
||||
- ModelOutput: Extensible output format supporting all model types
|
||||
- COBData: Cumulative Order Book data structure
|
||||
- Enhanced data structures for cross-model feeding and extensibility
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class OHLCVBar:
|
||||
"""OHLCV bar data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
timeframe: str
|
||||
indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class PivotPoint:
|
||||
"""Pivot point data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
type: str # 'high' or 'low'
|
||||
level: int # Pivot level (1, 2, 3, etc.)
|
||||
confidence: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""Extensible model output format supporting all model types"""
|
||||
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
||||
model_name: str # Specific model identifier
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
confidence: float
|
||||
predictions: Dict[str, Any] # Model-specific predictions
|
||||
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Cumulative Order Book data for price buckets"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
bucket_size: float # $1 for ETH, $10 for BTC
|
||||
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
|
||||
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
|
||||
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
|
||||
order_flow_metrics: Dict[str, float] # Various order flow indicators
|
||||
|
||||
# Moving averages of COB imbalance for ±5 buckets
|
||||
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
|
||||
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
|
||||
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
|
||||
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
|
||||
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
"""
|
||||
Unified base data input for all models
|
||||
|
||||
Standardized format ensures all models receive identical input structure:
|
||||
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
||||
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
||||
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
||||
"""
|
||||
symbol: str # Primary symbol (ETH/USDT)
|
||||
timestamp: datetime
|
||||
|
||||
# Multi-timeframe OHLCV data for primary symbol (ETH)
|
||||
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
|
||||
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
|
||||
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
|
||||
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
|
||||
|
||||
# Reference symbol (BTC) 1s data
|
||||
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
|
||||
|
||||
# COB data for 1s timeframe (±20 buckets around current price)
|
||||
cob_data: Optional[COBData] = None
|
||||
|
||||
# Technical indicators
|
||||
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Pivot points from Williams Market Structure
|
||||
pivot_points: List[PivotPoint] = field(default_factory=list)
|
||||
|
||||
# Last predictions from all models (for cross-model feeding)
|
||||
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
|
||||
|
||||
# Market microstructure data
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert BaseDataInput to standardized feature vector for models
|
||||
|
||||
Returns:
|
||||
np.ndarray: Standardized feature vector combining all data sources
|
||||
"""
|
||||
features = []
|
||||
|
||||
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
|
||||
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
||||
for bar in ohlcv_list[-300:]: # Ensure exactly 300 frames
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
|
||||
# BTC OHLCV features (300 frames x 5 features = 1500 features)
|
||||
for bar in self.btc_ohlcv_1s[-300:]: # Ensure exactly 300 frames
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
|
||||
# COB features (±20 buckets x multiple metrics ≈ 800 features)
|
||||
if self.cob_data:
|
||||
# Price bucket features
|
||||
for price in sorted(self.cob_data.price_buckets.keys()):
|
||||
bucket_data = self.cob_data.price_buckets[price]
|
||||
features.extend([
|
||||
bucket_data.get('bid_volume', 0.0),
|
||||
bucket_data.get('ask_volume', 0.0),
|
||||
bucket_data.get('total_volume', 0.0),
|
||||
bucket_data.get('imbalance', 0.0)
|
||||
])
|
||||
|
||||
# Moving averages of imbalance for ±5 buckets (5 buckets x 4 MAs x 2 sides = 40 features)
|
||||
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance,
|
||||
self.cob_data.ma_15s_imbalance, self.cob_data.ma_60s_imbalance]:
|
||||
for price in sorted(list(ma_dict.keys())[:5]): # ±5 buckets
|
||||
features.append(ma_dict[price])
|
||||
|
||||
# Technical indicators (variable, pad to 100 features)
|
||||
indicator_values = list(self.technical_indicators.values())
|
||||
features.extend(indicator_values[:100]) # Take first 100 indicators
|
||||
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad if needed
|
||||
|
||||
# Last predictions from other models (variable, pad to 50 features)
|
||||
prediction_features = []
|
||||
for model_output in self.last_predictions.values():
|
||||
prediction_features.extend([
|
||||
model_output.confidence,
|
||||
model_output.predictions.get('buy_probability', 0.0),
|
||||
model_output.predictions.get('sell_probability', 0.0),
|
||||
model_output.predictions.get('hold_probability', 0.0),
|
||||
model_output.predictions.get('expected_reward', 0.0)
|
||||
])
|
||||
features.extend(prediction_features[:50]) # Take first 50 prediction features
|
||||
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad if needed
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate that the BaseDataInput contains required data
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
# Check that we have required OHLCV data
|
||||
if len(self.ohlcv_1s) < 100: # At least 100 frames
|
||||
return False
|
||||
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
|
||||
return False
|
||||
|
||||
# Check that timestamps are reasonable
|
||||
if not self.timestamp:
|
||||
return False
|
||||
|
||||
# Check symbol format
|
||||
if not self.symbol or '/' not in self.symbol:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Trading action output from models"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float
|
||||
source: str # 'rl', 'cnn', 'orchestrator'
|
||||
price: Optional[float] = None
|
||||
quantity: Optional[float] = None
|
||||
reason: Optional[str] = None
|
||||
|
||||
def create_model_output(model_type: str, model_name: str, symbol: str,
|
||||
action: str, confidence: float,
|
||||
hidden_states: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
|
||||
"""
|
||||
Helper function to create standardized ModelOutput
|
||||
|
||||
Args:
|
||||
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
|
||||
model_name: Specific model identifier
|
||||
symbol: Trading symbol
|
||||
action: Trading action ('BUY', 'SELL', 'HOLD')
|
||||
confidence: Confidence score (0.0 to 1.0)
|
||||
hidden_states: Optional hidden states for cross-model feeding
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
predictions = {
|
||||
'action': action,
|
||||
'buy_probability': confidence if action == 'BUY' else 0.0,
|
||||
'sell_probability': confidence if action == 'SELL' else 0.0,
|
||||
'hold_probability': confidence if action == 'HOLD' else 0.0,
|
||||
}
|
||||
|
||||
return ModelOutput(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=hidden_states or {},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
395
core/model_output_manager.py
Normal file
395
core/model_output_manager.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Model Output Manager
|
||||
|
||||
This module provides extensible model output storage and management for the multi-modal trading system.
|
||||
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from collections import deque, defaultdict
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
|
||||
from .data_models import ModelOutput, create_model_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelOutputManager:
|
||||
"""
|
||||
Extensible model output storage and management system
|
||||
|
||||
Features:
|
||||
- Standardized ModelOutput storage for all model types
|
||||
- Cross-model feeding with hidden states
|
||||
- Historical output tracking
|
||||
- Metadata management
|
||||
- Persistence and recovery
|
||||
- Performance analytics
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
|
||||
"""
|
||||
Initialize the model output manager
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for persistent storage
|
||||
max_history: Maximum number of outputs to keep in memory per model
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.max_history = max_history
|
||||
|
||||
# In-memory storage
|
||||
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
|
||||
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
|
||||
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
|
||||
|
||||
# Metadata tracking
|
||||
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
|
||||
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
|
||||
|
||||
# Thread safety
|
||||
self.storage_lock = Lock()
|
||||
|
||||
# Supported model types
|
||||
self.supported_model_types = {
|
||||
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
|
||||
'ensemble', 'hybrid', 'custom' # Extensible for future types
|
||||
}
|
||||
|
||||
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
|
||||
logger.info(f"Supported model types: {self.supported_model_types}")
|
||||
|
||||
def store_output(self, model_output: ModelOutput) -> bool:
|
||||
"""
|
||||
Store model output with full extensibility support
|
||||
|
||||
Args:
|
||||
model_output: ModelOutput from any model type
|
||||
|
||||
Returns:
|
||||
bool: True if stored successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.storage_lock:
|
||||
symbol = model_output.symbol
|
||||
model_name = model_output.model_name
|
||||
model_type = model_output.model_type
|
||||
|
||||
# Validate model type (extensible)
|
||||
if model_type not in self.supported_model_types:
|
||||
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
|
||||
self.supported_model_types.add(model_type)
|
||||
|
||||
# Store current output
|
||||
self.current_outputs[symbol][model_name] = model_output
|
||||
|
||||
# Add to history
|
||||
self.output_history[symbol][model_name].append(model_output)
|
||||
|
||||
# Store cross-model states if available
|
||||
if model_output.hidden_states:
|
||||
self.cross_model_states[symbol][model_name] = model_output.hidden_states
|
||||
|
||||
# Update model metadata
|
||||
self._update_model_metadata(model_name, model_type, model_output.metadata)
|
||||
|
||||
# Update performance statistics
|
||||
self._update_performance_stats(symbol, model_name, model_output)
|
||||
|
||||
# Persist to disk (async to avoid blocking)
|
||||
self._persist_output_async(model_output)
|
||||
|
||||
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing model output: {e}")
|
||||
return False
|
||||
|
||||
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Get the current (latest) output from a specific model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
ModelOutput: Latest output from the model, or None if not available
|
||||
"""
|
||||
try:
|
||||
return self.current_outputs.get(symbol, {}).get(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current output for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
||||
"""
|
||||
Get all current outputs for a symbol (for cross-model feeding)
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelOutput]: Dictionary of current outputs by model name
|
||||
"""
|
||||
try:
|
||||
return dict(self.current_outputs.get(symbol, {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all current outputs for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
|
||||
"""
|
||||
Get historical outputs from a model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
count: Number of historical outputs to retrieve
|
||||
|
||||
Returns:
|
||||
List[ModelOutput]: List of historical outputs (most recent first)
|
||||
"""
|
||||
try:
|
||||
history = self.output_history.get(symbol, {}).get(model_name, deque())
|
||||
return list(history)[-count:][::-1] # Most recent first
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting output history for {model_name}: {e}")
|
||||
return []
|
||||
|
||||
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get hidden states from other models for cross-model feeding
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
requesting_model: Name of the model requesting the states
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, Any]]: Hidden states from other models
|
||||
"""
|
||||
try:
|
||||
all_states = self.cross_model_states.get(symbol, {})
|
||||
# Return states from all models except the requesting one
|
||||
return {model_name: states for model_name, states in all_states.items()
|
||||
if model_name != requesting_model}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
|
||||
return {}
|
||||
|
||||
def get_model_types_active(self, symbol: str) -> List[str]:
|
||||
"""
|
||||
Get list of active model types for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
List[str]: List of active model types
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
return [output.model_type for output in current_outputs.values()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active model types for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get consensus prediction from all active models
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
confidence_threshold: Minimum confidence threshold for inclusion
|
||||
|
||||
Returns:
|
||||
Dict containing consensus prediction or None
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
if not current_outputs:
|
||||
return None
|
||||
|
||||
# Filter by confidence threshold
|
||||
high_confidence_outputs = [
|
||||
output for output in current_outputs.values()
|
||||
if output.confidence >= confidence_threshold
|
||||
]
|
||||
|
||||
if not high_confidence_outputs:
|
||||
return None
|
||||
|
||||
# Calculate consensus
|
||||
buy_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'BUY')
|
||||
sell_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'SELL')
|
||||
hold_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'HOLD')
|
||||
|
||||
total_votes = len(high_confidence_outputs)
|
||||
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
|
||||
|
||||
# Determine consensus action
|
||||
if buy_votes > sell_votes and buy_votes > hold_votes:
|
||||
consensus_action = 'BUY'
|
||||
elif sell_votes > buy_votes and sell_votes > hold_votes:
|
||||
consensus_action = 'SELL'
|
||||
else:
|
||||
consensus_action = 'HOLD'
|
||||
|
||||
return {
|
||||
'action': consensus_action,
|
||||
'confidence': avg_confidence,
|
||||
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
|
||||
'total_models': total_votes,
|
||||
'model_types': [output.model_type for output in high_confidence_outputs]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
|
||||
"""Update metadata for a model"""
|
||||
try:
|
||||
if model_name not in self.model_metadata:
|
||||
self.model_metadata[model_name] = {
|
||||
'model_type': model_type,
|
||||
'first_seen': datetime.now(),
|
||||
'total_predictions': 0,
|
||||
'custom_metadata': {}
|
||||
}
|
||||
|
||||
self.model_metadata[model_name]['total_predictions'] += 1
|
||||
self.model_metadata[model_name]['last_seen'] = datetime.now()
|
||||
|
||||
# Merge custom metadata
|
||||
if metadata:
|
||||
self.model_metadata[model_name]['custom_metadata'].update(metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating model metadata: {e}")
|
||||
|
||||
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
|
||||
"""Update performance statistics for a model"""
|
||||
try:
|
||||
stats = self.performance_stats[symbol][model_name]
|
||||
|
||||
if 'prediction_count' not in stats:
|
||||
stats['prediction_count'] = 0
|
||||
stats['confidence_sum'] = 0.0
|
||||
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
stats['first_prediction'] = model_output.timestamp
|
||||
|
||||
stats['prediction_count'] += 1
|
||||
stats['confidence_sum'] += model_output.confidence
|
||||
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
|
||||
stats['last_prediction'] = model_output.timestamp
|
||||
|
||||
action = model_output.predictions.get('action', 'HOLD')
|
||||
if action in stats['action_counts']:
|
||||
stats['action_counts'][action] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating performance stats: {e}")
|
||||
|
||||
def _persist_output_async(self, model_output: ModelOutput):
|
||||
"""Persist model output to disk (simplified version)"""
|
||||
try:
|
||||
# Create filename based on model and timestamp
|
||||
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
|
||||
filepath = self.cache_dir / filename
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
output_dict = {
|
||||
'model_type': model_output.model_type,
|
||||
'model_name': model_output.model_name,
|
||||
'symbol': model_output.symbol,
|
||||
'timestamp': model_output.timestamp.isoformat(),
|
||||
'confidence': model_output.confidence,
|
||||
'predictions': model_output.predictions,
|
||||
'metadata': model_output.metadata
|
||||
}
|
||||
|
||||
# Save to file (in a real implementation, this would be async)
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(output_dict, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error persisting model output: {e}")
|
||||
|
||||
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get performance summary for all models for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict containing performance summary
|
||||
"""
|
||||
try:
|
||||
summary = {
|
||||
'symbol': symbol,
|
||||
'active_models': len(self.current_outputs.get(symbol, {})),
|
||||
'model_stats': {}
|
||||
}
|
||||
|
||||
for model_name, stats in self.performance_stats.get(symbol, {}).items():
|
||||
summary['model_stats'][model_name] = {
|
||||
'predictions': stats.get('prediction_count', 0),
|
||||
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
|
||||
'action_distribution': stats.get('action_counts', {}),
|
||||
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance summary: {e}")
|
||||
return {'symbol': symbol, 'error': str(e)}
|
||||
|
||||
def cleanup_old_outputs(self, max_age_hours: int = 24):
|
||||
"""
|
||||
Clean up old outputs to manage memory usage
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of outputs to keep in hours
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
with self.storage_lock:
|
||||
for symbol in self.output_history:
|
||||
for model_name in self.output_history[symbol]:
|
||||
history = self.output_history[symbol][model_name]
|
||||
# Remove old outputs
|
||||
while history and history[0].timestamp < cutoff_time:
|
||||
history.popleft()
|
||||
|
||||
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old outputs: {e}")
|
||||
|
||||
def add_custom_model_type(self, model_type: str):
|
||||
"""
|
||||
Add support for a new custom model type
|
||||
|
||||
Args:
|
||||
model_type: Name of the new model type
|
||||
"""
|
||||
self.supported_model_types.add(model_type)
|
||||
logger.info(f"Added support for custom model type: {model_type}")
|
||||
|
||||
def get_supported_model_types(self) -> List[str]:
|
||||
"""Get list of all supported model types"""
|
||||
return list(self.supported_model_types)
|
||||
453
core/standardized_data_provider.py
Normal file
453
core/standardized_data_provider.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Standardized Data Provider Extension
|
||||
|
||||
This module extends the existing DataProvider with standardized BaseDataInput functionality
|
||||
for all models in the multi-modal trading system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from collections import deque
|
||||
from threading import Lock
|
||||
|
||||
from .data_provider import DataProvider
|
||||
from .data_models import BaseDataInput, OHLCVBar, COBData, ModelOutput, PivotPoint
|
||||
from .multi_exchange_cob_provider import MultiExchangeCOBProvider
|
||||
from .model_output_manager import ModelOutputManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StandardizedDataProvider(DataProvider):
|
||||
"""
|
||||
Extended DataProvider with standardized BaseDataInput support
|
||||
|
||||
Provides unified data format for all models:
|
||||
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
||||
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
||||
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
|
||||
"""Initialize the standardized data provider"""
|
||||
super().__init__(symbols, timeframes)
|
||||
|
||||
# Standardized data storage
|
||||
self.base_data_cache: Dict[str, BaseDataInput] = {} # {symbol: BaseDataInput}
|
||||
self.cob_data_cache: Dict[str, COBData] = {} # {symbol: COBData}
|
||||
|
||||
# Model output management with extensible storage
|
||||
self.model_output_manager = ModelOutputManager(
|
||||
cache_dir=str(self.cache_dir / "model_outputs"),
|
||||
max_history=1000
|
||||
)
|
||||
|
||||
# COB moving averages calculation
|
||||
self.cob_imbalance_history: Dict[str, deque] = {} # {symbol: deque of (timestamp, imbalance_data)}
|
||||
self.ma_calculation_lock = Lock()
|
||||
|
||||
# Initialize caches for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.base_data_cache[symbol] = None
|
||||
self.cob_data_cache[symbol] = None
|
||||
self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data
|
||||
|
||||
# COB provider integration
|
||||
self.cob_provider: Optional[MultiExchangeCOBProvider] = None
|
||||
self._initialize_cob_provider()
|
||||
|
||||
logger.info("StandardizedDataProvider initialized with BaseDataInput support")
|
||||
|
||||
def _initialize_cob_provider(self):
|
||||
"""Initialize COB provider for order book data"""
|
||||
try:
|
||||
from .multi_exchange_cob_provider import MultiExchangeCOBProvider, ExchangeConfig, ExchangeType
|
||||
|
||||
# Configure exchanges (focusing on Binance for now)
|
||||
exchange_configs = {
|
||||
'binance': ExchangeConfig(
|
||||
exchange_type=ExchangeType.BINANCE,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
websocket_url="wss://stream.binance.com:9443/ws/",
|
||||
symbols_mapping={symbol: symbol.replace('/', '').lower() for symbol in self.symbols}
|
||||
)
|
||||
}
|
||||
|
||||
self.cob_provider = MultiExchangeCOBProvider(self.symbols, exchange_configs)
|
||||
logger.info("COB provider initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize COB provider: {e}")
|
||||
self.cob_provider = None
|
||||
|
||||
def get_base_data_input(self, symbol: str, timestamp: Optional[datetime] = None) -> Optional[BaseDataInput]:
|
||||
"""
|
||||
Get standardized BaseDataInput for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timestamp: Optional timestamp, defaults to current time
|
||||
|
||||
Returns:
|
||||
BaseDataInput: Standardized input data for models, or None if insufficient data
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
try:
|
||||
# Get OHLCV data for all timeframes
|
||||
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
|
||||
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
|
||||
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
|
||||
|
||||
# Get BTC reference data
|
||||
btc_symbol = 'BTC/USDT'
|
||||
btc_ohlcv_1s = self._get_ohlcv_bars(btc_symbol, '1s', 300)
|
||||
|
||||
# Check if we have sufficient data
|
||||
if not all([ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
|
||||
logger.warning(f"Insufficient OHLCV data for {symbol}")
|
||||
return None
|
||||
|
||||
if any(len(data) < 100 for data in [ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
|
||||
logger.warning(f"Insufficient data frames for {symbol}")
|
||||
return None
|
||||
|
||||
# Get COB data
|
||||
cob_data = self._get_cob_data(symbol, timestamp)
|
||||
|
||||
# Get technical indicators
|
||||
technical_indicators = self._get_technical_indicators(symbol)
|
||||
|
||||
# Get pivot points
|
||||
pivot_points = self._get_pivot_points(symbol)
|
||||
|
||||
# Get last predictions from all models
|
||||
last_predictions = self.model_output_manager.get_all_current_outputs(symbol)
|
||||
|
||||
# Create BaseDataInput
|
||||
base_input = BaseDataInput(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
ohlcv_1s=ohlcv_1s,
|
||||
ohlcv_1m=ohlcv_1m,
|
||||
ohlcv_1h=ohlcv_1h,
|
||||
ohlcv_1d=ohlcv_1d,
|
||||
btc_ohlcv_1s=btc_ohlcv_1s,
|
||||
cob_data=cob_data,
|
||||
technical_indicators=technical_indicators,
|
||||
pivot_points=pivot_points,
|
||||
last_predictions=last_predictions
|
||||
)
|
||||
|
||||
# Validate the input
|
||||
if not base_input.validate():
|
||||
logger.warning(f"BaseDataInput validation failed for {symbol}")
|
||||
return None
|
||||
|
||||
# Cache the result
|
||||
self.base_data_cache[symbol] = base_input
|
||||
|
||||
return base_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating BaseDataInput for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List[OHLCVBar]:
|
||||
"""
|
||||
Get OHLCV bars for a symbol and timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe ('1s', '1m', '1h', '1d')
|
||||
count: Number of bars to retrieve
|
||||
|
||||
Returns:
|
||||
List[OHLCVBar]: List of OHLCV bars
|
||||
"""
|
||||
try:
|
||||
# Get historical data from parent class
|
||||
df = self.get_historical_data(symbol, timeframe, count)
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
|
||||
# Convert DataFrame to OHLCVBar objects
|
||||
bars = []
|
||||
for _, row in df.tail(count).iterrows():
|
||||
bar = OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=row.name if hasattr(row, 'name') else datetime.now(),
|
||||
open=float(row['open']),
|
||||
high=float(row['high']),
|
||||
low=float(row['low']),
|
||||
close=float(row['close']),
|
||||
volume=float(row['volume']),
|
||||
timeframe=timeframe,
|
||||
indicators={}
|
||||
)
|
||||
|
||||
# Add technical indicators if available
|
||||
for col in df.columns:
|
||||
if col not in ['open', 'high', 'low', 'close', 'volume']:
|
||||
bar.indicators[col] = float(row[col]) if not np.isnan(row[col]) else 0.0
|
||||
|
||||
bars.append(bar)
|
||||
|
||||
return bars
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
|
||||
return []
|
||||
|
||||
def _get_cob_data(self, symbol: str, timestamp: datetime) -> Optional[COBData]:
|
||||
"""
|
||||
Get COB data for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Current timestamp
|
||||
|
||||
Returns:
|
||||
COBData: COB data with price buckets and moving averages
|
||||
"""
|
||||
try:
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
|
||||
# Get current price
|
||||
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
|
||||
if current_price <= 0:
|
||||
return None
|
||||
|
||||
# Determine bucket size based on symbol
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0 # $1 for ETH, $10 for BTC
|
||||
|
||||
# Calculate price range (±20 buckets)
|
||||
price_range = 20 * bucket_size
|
||||
min_price = current_price - price_range
|
||||
max_price = current_price + price_range
|
||||
|
||||
# Create price buckets
|
||||
price_buckets = {}
|
||||
bid_ask_imbalance = {}
|
||||
volume_weighted_prices = {}
|
||||
|
||||
# Generate mock COB data for now (will be replaced with real COB provider data)
|
||||
for i in range(-20, 21):
|
||||
price = current_price + (i * bucket_size)
|
||||
if price > 0:
|
||||
# Mock data - replace with real COB provider data
|
||||
bid_volume = max(0, 1000 - abs(i) * 50) # More volume near current price
|
||||
ask_volume = max(0, 1000 - abs(i) * 50)
|
||||
total_volume = bid_volume + ask_volume
|
||||
imbalance = (bid_volume - ask_volume) / max(total_volume, 1)
|
||||
|
||||
price_buckets[price] = {
|
||||
'bid_volume': bid_volume,
|
||||
'ask_volume': ask_volume,
|
||||
'total_volume': total_volume,
|
||||
'imbalance': imbalance
|
||||
}
|
||||
bid_ask_imbalance[price] = imbalance
|
||||
volume_weighted_prices[price] = price # Simplified VWAP
|
||||
|
||||
# Calculate moving averages of imbalance for ±5 buckets
|
||||
ma_data = self._calculate_cob_moving_averages(symbol, bid_ask_imbalance, timestamp)
|
||||
|
||||
cob_data = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
current_price=current_price,
|
||||
bucket_size=bucket_size,
|
||||
price_buckets=price_buckets,
|
||||
bid_ask_imbalance=bid_ask_imbalance,
|
||||
volume_weighted_prices=volume_weighted_prices,
|
||||
order_flow_metrics={},
|
||||
ma_1s_imbalance=ma_data.get('1s', {}),
|
||||
ma_5s_imbalance=ma_data.get('5s', {}),
|
||||
ma_15s_imbalance=ma_data.get('15s', {}),
|
||||
ma_60s_imbalance=ma_data.get('60s', {})
|
||||
)
|
||||
|
||||
# Cache the COB data
|
||||
self.cob_data_cache[symbol] = cob_data
|
||||
|
||||
return cob_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_cob_moving_averages(self, symbol: str, bid_ask_imbalance: Dict[float, float],
|
||||
timestamp: datetime) -> Dict[str, Dict[float, float]]:
|
||||
"""
|
||||
Calculate moving averages of COB imbalance for ±5 buckets
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
bid_ask_imbalance: Current bid/ask imbalance data
|
||||
timestamp: Current timestamp
|
||||
|
||||
Returns:
|
||||
Dict containing MA data for different timeframes
|
||||
"""
|
||||
try:
|
||||
with self.ma_calculation_lock:
|
||||
# Add current imbalance data to history
|
||||
self.cob_imbalance_history[symbol].append((timestamp, bid_ask_imbalance))
|
||||
|
||||
# Calculate MAs for different timeframes
|
||||
ma_results = {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
|
||||
|
||||
# Get current price for ±5 bucket calculation
|
||||
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
|
||||
if current_price <= 0:
|
||||
return ma_results
|
||||
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
||||
|
||||
# Calculate MAs for ±5 buckets around current price
|
||||
for i in range(-5, 6):
|
||||
price = current_price + (i * bucket_size)
|
||||
if price <= 0:
|
||||
continue
|
||||
|
||||
# Get historical imbalance data for this price bucket
|
||||
history = self.cob_imbalance_history[symbol]
|
||||
|
||||
# Calculate different MA periods
|
||||
for period, period_name in [(1, '1s'), (5, '5s'), (15, '15s'), (60, '60s')]:
|
||||
recent_data = []
|
||||
cutoff_time = timestamp - timedelta(seconds=period)
|
||||
|
||||
for hist_timestamp, hist_imbalance in history:
|
||||
if hist_timestamp >= cutoff_time and price in hist_imbalance:
|
||||
recent_data.append(hist_imbalance[price])
|
||||
|
||||
# Calculate moving average
|
||||
if recent_data:
|
||||
ma_results[period_name][price] = sum(recent_data) / len(recent_data)
|
||||
else:
|
||||
ma_results[period_name][price] = 0.0
|
||||
|
||||
return ma_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating COB moving averages for {symbol}: {e}")
|
||||
return {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get technical indicators for a symbol"""
|
||||
try:
|
||||
# Get latest OHLCV data
|
||||
df = self.get_historical_data(symbol, '1h', 100) # Use 1h for indicators
|
||||
if df is None or df.empty:
|
||||
return {}
|
||||
|
||||
indicators = {}
|
||||
|
||||
# Add basic indicators if available in the dataframe
|
||||
latest_row = df.iloc[-1]
|
||||
for col in df.columns:
|
||||
if col not in ['open', 'high', 'low', 'close', 'volume']:
|
||||
indicators[col] = float(latest_row[col]) if not np.isnan(latest_row[col]) else 0.0
|
||||
|
||||
return indicators
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting technical indicators for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _get_pivot_points(self, symbol: str) -> List[PivotPoint]:
|
||||
"""Get pivot points for a symbol"""
|
||||
try:
|
||||
pivot_points = []
|
||||
|
||||
# Get pivot points from Williams Market Structure if available
|
||||
if symbol in self.williams_structure:
|
||||
williams = self.williams_structure[symbol]
|
||||
# This would need to be implemented based on the actual Williams structure
|
||||
# For now, return empty list
|
||||
pass
|
||||
|
||||
return pivot_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot points for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def store_model_output(self, model_output: ModelOutput):
|
||||
"""
|
||||
Store model output for cross-model feeding using ModelOutputManager
|
||||
|
||||
Args:
|
||||
model_output: ModelOutput from any model
|
||||
"""
|
||||
try:
|
||||
success = self.model_output_manager.store_output(model_output)
|
||||
if success:
|
||||
logger.debug(f"Stored model output from {model_output.model_name} for {model_output.symbol}")
|
||||
else:
|
||||
logger.warning(f"Failed to store model output from {model_output.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing model output: {e}")
|
||||
|
||||
def get_model_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
||||
"""
|
||||
Get all model outputs for a symbol using ModelOutputManager
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelOutput]: Dictionary of model outputs by model name
|
||||
"""
|
||||
return self.model_output_manager.get_all_current_outputs(symbol)
|
||||
|
||||
def get_model_output_manager(self) -> ModelOutputManager:
|
||||
"""
|
||||
Get the model output manager for advanced operations
|
||||
|
||||
Returns:
|
||||
ModelOutputManager: The model output manager instance
|
||||
"""
|
||||
return self.model_output_manager
|
||||
|
||||
def start_real_time_processing(self):
|
||||
"""Start real-time processing for standardized data"""
|
||||
try:
|
||||
# Start parent class real-time processing
|
||||
if hasattr(super(), 'start_real_time_processing'):
|
||||
super().start_real_time_processing()
|
||||
|
||||
# Start COB provider if available
|
||||
if self.cob_provider:
|
||||
import asyncio
|
||||
asyncio.create_task(self.cob_provider.start_streaming())
|
||||
|
||||
logger.info("Started real-time processing for standardized data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
def stop_real_time_processing(self):
|
||||
"""Stop real-time processing"""
|
||||
try:
|
||||
# Stop COB provider if available
|
||||
if self.cob_provider:
|
||||
import asyncio
|
||||
asyncio.create_task(self.cob_provider.stop_streaming())
|
||||
|
||||
# Stop parent class processing
|
||||
if hasattr(super(), 'stop_real_time_processing'):
|
||||
super().stop_real_time_processing()
|
||||
|
||||
logger.info("Stopped real-time processing for standardized data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
||||
Reference in New Issue
Block a user