Merge commit 'd49a473ed6f4aef55bfdd47d6370e53582be6b7b' into cleanup
This commit is contained in:
413
core/unified_model_data_interface.py
Normal file
413
core/unified_model_data_interface.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Unified Model Data Interface
|
||||
|
||||
CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED
|
||||
This module provides a standardized data interface for all models.
|
||||
NEVER use mock/fake/synthetic data or placeholder values.
|
||||
If data is unavailable: return None, log errors, raise exceptions.
|
||||
|
||||
This interface ensures:
|
||||
- Consistent data format across all models
|
||||
- Proper feature engineering and normalization
|
||||
- Real-time data streaming to models
|
||||
- No data dumps or unnecessary file I/O
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelInputData:
|
||||
"""Standardized input data structure for all models"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
|
||||
# Time series data (preserves structure)
|
||||
candles_1m: Optional[np.ndarray] = None # Shape: (window_size, 5) [OHLCV]
|
||||
candles_1s: Optional[np.ndarray] = None # Shape: (window_size, 5) [OHLCV]
|
||||
candles_5m: Optional[np.ndarray] = None # Shape: (window_size, 5) [OHLCV]
|
||||
|
||||
# Technical indicators (flattened for models that need it)
|
||||
technical_indicators: Optional[np.ndarray] = None # Shape: (n_indicators,)
|
||||
|
||||
# Market microstructure (for COB models)
|
||||
order_book_features: Optional[np.ndarray] = None # Shape: (n_features,)
|
||||
|
||||
# Market context
|
||||
volume_profile: Optional[np.ndarray] = None # Shape: (n_levels,)
|
||||
volatility_regime: float = 0.0
|
||||
trend_strength: float = 0.0
|
||||
|
||||
# Metadata
|
||||
data_quality_score: float = 1.0
|
||||
feature_count: int = 0
|
||||
|
||||
class UnifiedModelDataInterface:
|
||||
"""
|
||||
Unified interface for preparing data for all model types
|
||||
|
||||
Features:
|
||||
- Standardized data format across models
|
||||
- Preserves time series structure for CNN/Transformer
|
||||
- Flattened features for DQN/Generic models
|
||||
- Real-time data streaming
|
||||
- No unnecessary file I/O or dumps
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider, config):
|
||||
self.data_provider = data_provider
|
||||
self.config = config
|
||||
|
||||
# Model-specific requirements
|
||||
self.model_requirements = {
|
||||
'cnn': {
|
||||
'input_shape': (60, 5), # 60 candles, 5 features (OHLCV)
|
||||
'requires_sequence': True,
|
||||
'normalization': 'pivot_based'
|
||||
},
|
||||
'dqn': {
|
||||
'input_shape': (100,), # 100-dim state vector
|
||||
'requires_sequence': False,
|
||||
'normalization': 'min_max'
|
||||
},
|
||||
'cob_rl': {
|
||||
'input_shape': (50,), # 50-dim order book features
|
||||
'requires_sequence': False,
|
||||
'normalization': 'z_score'
|
||||
},
|
||||
'transformer': {
|
||||
'input_shape': (150, 5), # 150 candles, 5 features
|
||||
'requires_sequence': True,
|
||||
'normalization': 'pivot_based'
|
||||
},
|
||||
'generic': {
|
||||
'input_shape': (200,), # 200-dim feature vector
|
||||
'requires_sequence': False,
|
||||
'normalization': 'min_max'
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Unified Model Data Interface initialized")
|
||||
|
||||
def prepare_model_input(self, symbol: str, model_type: str,
|
||||
window_size: int = None) -> Optional[ModelInputData]:
|
||||
"""
|
||||
Prepare standardized input data for any model type
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
model_type: Type of model ('cnn', 'dqn', 'cob_rl', 'transformer', 'generic')
|
||||
window_size: Number of candles to include (model-specific default if None)
|
||||
|
||||
Returns:
|
||||
ModelInputData: Standardized input data structure
|
||||
"""
|
||||
try:
|
||||
if model_type not in self.model_requirements:
|
||||
logger.error(f"Unknown model type: {model_type}")
|
||||
return None
|
||||
|
||||
requirements = self.model_requirements[model_type]
|
||||
if window_size is None:
|
||||
window_size = requirements['input_shape'][0]
|
||||
|
||||
# Get current market data
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
logger.warning(f"No current price available for {symbol}")
|
||||
return None
|
||||
|
||||
# Prepare base data structure
|
||||
model_input = ModelInputData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
current_price=current_price
|
||||
)
|
||||
|
||||
# Get time series data based on model requirements
|
||||
if requirements['requires_sequence']:
|
||||
model_input = self._add_sequence_data(model_input, symbol, window_size)
|
||||
else:
|
||||
model_input = self._add_feature_vector(model_input, symbol, window_size)
|
||||
|
||||
# Add model-specific features
|
||||
if model_type == 'cob_rl':
|
||||
model_input = self._add_order_book_features(model_input, symbol)
|
||||
|
||||
# Apply normalization
|
||||
model_input = self._apply_normalization(model_input, model_type, requirements)
|
||||
|
||||
# Calculate data quality score
|
||||
model_input.data_quality_score = self._calculate_data_quality(model_input)
|
||||
model_input.feature_count = self._count_features(model_input)
|
||||
|
||||
logger.debug(f"Prepared {model_type} input for {symbol}: {model_input.feature_count} features, quality: {model_input.data_quality_score:.2f}")
|
||||
return model_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing {model_type} input for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _add_sequence_data(self, model_input: ModelInputData, symbol: str, window_size: int) -> ModelInputData:
|
||||
"""Add time series sequence data for CNN/Transformer models"""
|
||||
try:
|
||||
# Get 1m candles
|
||||
candles_1m = self.data_provider.get_historical_data_for_inference(symbol, '1m', limit=window_size + 50)
|
||||
if candles_1m is not None and not candles_1m.empty:
|
||||
# Extract OHLCV sequence
|
||||
ohlcv_data = candles_1m[['open', 'high', 'low', 'close', 'volume']].values
|
||||
model_input.candles_1m = ohlcv_data[-window_size:] # Last window_size candles
|
||||
|
||||
# Get 1s candles for high-frequency models
|
||||
candles_1s = self.data_provider.get_historical_data_for_inference(symbol, '1s', limit=window_size + 50)
|
||||
if candles_1s is not None and not candles_1s.empty:
|
||||
ohlcv_data = candles_1s[['open', 'high', 'low', 'close', 'volume']].values
|
||||
model_input.candles_1s = ohlcv_data[-window_size:]
|
||||
|
||||
# Get 5m candles for longer-term context
|
||||
candles_5m = self.data_provider.get_historical_data_for_inference(symbol, '5m', limit=window_size + 50)
|
||||
if candles_5m is not None and not candles_5m.empty:
|
||||
ohlcv_data = candles_5m[['open', 'high', 'low', 'close', 'volume']].values
|
||||
model_input.candles_5m = ohlcv_data[-window_size:]
|
||||
|
||||
return model_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding sequence data for {symbol}: {e}")
|
||||
return model_input
|
||||
|
||||
def _add_feature_vector(self, model_input: ModelInputData, symbol: str, window_size: int) -> ModelInputData:
|
||||
"""Add flattened feature vector for DQN/Generic models"""
|
||||
try:
|
||||
# Get feature matrix
|
||||
feature_matrix = self.data_provider.get_feature_matrix(symbol, window_size=window_size)
|
||||
if feature_matrix is not None:
|
||||
# Flatten and limit to expected size
|
||||
flattened = feature_matrix.flatten()
|
||||
target_size = 200 # Default for generic models
|
||||
|
||||
if len(flattened) > target_size:
|
||||
flattened = flattened[:target_size]
|
||||
elif len(flattened) < target_size:
|
||||
# Pad with zeros
|
||||
padded = np.zeros(target_size)
|
||||
padded[:len(flattened)] = flattened
|
||||
flattened = padded
|
||||
|
||||
model_input.technical_indicators = flattened
|
||||
|
||||
return model_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding feature vector for {symbol}: {e}")
|
||||
return model_input
|
||||
|
||||
def _add_order_book_features(self, model_input: ModelInputData, symbol: str) -> ModelInputData:
|
||||
"""Add order book features for COB models"""
|
||||
try:
|
||||
# Get COB state from data provider
|
||||
cob_state = self.data_provider.get_cob_state(symbol)
|
||||
if cob_state is not None:
|
||||
model_input.order_book_features = cob_state
|
||||
|
||||
return model_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding order book features for {symbol}: {e}")
|
||||
return model_input
|
||||
|
||||
def _apply_normalization(self, model_input: ModelInputData, model_type: str, requirements: Dict) -> ModelInputData:
|
||||
"""Apply model-specific normalization"""
|
||||
try:
|
||||
norm_type = requirements['normalization']
|
||||
|
||||
if norm_type == 'pivot_based':
|
||||
# Use pivot-based normalization for price data
|
||||
if model_input.candles_1m is not None:
|
||||
model_input.candles_1m = self._normalize_with_pivot_bounds(model_input.candles_1m, model_input.symbol)
|
||||
if model_input.candles_1s is not None:
|
||||
model_input.candles_1s = self._normalize_with_pivot_bounds(model_input.candles_1s, model_input.symbol)
|
||||
if model_input.candles_5m is not None:
|
||||
model_input.candles_5m = self._normalize_with_pivot_bounds(model_input.candles_5m, model_input.symbol)
|
||||
|
||||
elif norm_type == 'min_max':
|
||||
# Min-max normalization for feature vectors
|
||||
if model_input.technical_indicators is not None:
|
||||
model_input.technical_indicators = self._min_max_normalize(model_input.technical_indicators)
|
||||
|
||||
elif norm_type == 'z_score':
|
||||
# Z-score normalization for order book features
|
||||
if model_input.order_book_features is not None:
|
||||
model_input.order_book_features = self._z_score_normalize(model_input.order_book_features)
|
||||
|
||||
return model_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying normalization for {model_type}: {e}")
|
||||
return model_input
|
||||
|
||||
def _normalize_with_pivot_bounds(self, candles: np.ndarray, symbol: str) -> np.ndarray:
|
||||
"""Normalize candles using pivot bounds"""
|
||||
try:
|
||||
if symbol not in self.data_provider.pivot_bounds:
|
||||
# Fallback to simple normalization
|
||||
return self._min_max_normalize(candles)
|
||||
|
||||
bounds = self.data_provider.pivot_bounds[symbol]
|
||||
price_range = bounds.get_price_range()
|
||||
volume_range = bounds.volume_max - bounds.volume_min
|
||||
|
||||
normalized = candles.copy()
|
||||
|
||||
# Normalize price columns (0-3: OHLC)
|
||||
for i in range(4):
|
||||
normalized[:, i] = (candles[:, i] - bounds.price_min) / price_range
|
||||
|
||||
# Normalize volume column (4)
|
||||
if volume_range > 0:
|
||||
normalized[:, 4] = (candles[:, 4] - bounds.volume_min) / volume_range
|
||||
else:
|
||||
normalized[:, 4] = 0.5 # Default to middle if no volume range
|
||||
|
||||
return normalized
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing with pivot bounds: {e}")
|
||||
return self._min_max_normalize(candles)
|
||||
|
||||
def _min_max_normalize(self, data: np.ndarray) -> np.ndarray:
|
||||
"""Min-max normalization"""
|
||||
try:
|
||||
data_min = np.min(data)
|
||||
data_max = np.max(data)
|
||||
|
||||
if data_max - data_min == 0:
|
||||
return np.zeros_like(data)
|
||||
|
||||
return (data - data_min) / (data_max - data_min)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in min-max normalization: {e}")
|
||||
return data
|
||||
|
||||
def _z_score_normalize(self, data: np.ndarray) -> np.ndarray:
|
||||
"""Z-score normalization"""
|
||||
try:
|
||||
mean = np.mean(data)
|
||||
std = np.std(data)
|
||||
|
||||
if std == 0:
|
||||
return np.zeros_like(data)
|
||||
|
||||
return (data - mean) / std
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in z-score normalization: {e}")
|
||||
return data
|
||||
|
||||
def _calculate_data_quality(self, model_input: ModelInputData) -> float:
|
||||
"""Calculate data quality score (0.0 to 1.0)"""
|
||||
try:
|
||||
score = 1.0
|
||||
|
||||
# Check for missing data
|
||||
if model_input.candles_1m is None:
|
||||
score -= 0.3
|
||||
if model_input.technical_indicators is None:
|
||||
score -= 0.2
|
||||
if model_input.order_book_features is None:
|
||||
score -= 0.1
|
||||
|
||||
# Check for NaN values
|
||||
if model_input.candles_1m is not None and np.isnan(model_input.candles_1m).any():
|
||||
score -= 0.2
|
||||
if model_input.technical_indicators is not None and np.isnan(model_input.technical_indicators).any():
|
||||
score -= 0.2
|
||||
|
||||
# Check for zero variance (indicating stale data)
|
||||
if model_input.candles_1m is not None:
|
||||
price_variance = np.var(model_input.candles_1m[:, 3]) # Close price variance
|
||||
if price_variance < 1e-8:
|
||||
score -= 0.3
|
||||
|
||||
return max(0.0, score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating data quality: {e}")
|
||||
return 0.0
|
||||
|
||||
def _count_features(self, model_input: ModelInputData) -> int:
|
||||
"""Count total number of features"""
|
||||
try:
|
||||
count = 0
|
||||
|
||||
if model_input.candles_1m is not None:
|
||||
count += model_input.candles_1m.size
|
||||
if model_input.candles_1s is not None:
|
||||
count += model_input.candles_1s.size
|
||||
if model_input.candles_5m is not None:
|
||||
count += model_input.candles_5m.size
|
||||
if model_input.technical_indicators is not None:
|
||||
count += model_input.technical_indicators.size
|
||||
if model_input.order_book_features is not None:
|
||||
count += model_input.order_book_features.size
|
||||
if model_input.volume_profile is not None:
|
||||
count += model_input.volume_profile.size
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting features: {e}")
|
||||
return 0
|
||||
|
||||
def get_model_specific_input(self, model_input: ModelInputData, model_type: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract model-specific input from standardized ModelInputData
|
||||
|
||||
Args:
|
||||
model_input: Standardized input data
|
||||
model_type: Type of model
|
||||
|
||||
Returns:
|
||||
np.ndarray: Model-specific input data
|
||||
"""
|
||||
try:
|
||||
if model_type == 'cnn':
|
||||
if model_input.candles_1m is not None:
|
||||
return model_input.candles_1m
|
||||
return None
|
||||
|
||||
elif model_type == 'dqn':
|
||||
if model_input.technical_indicators is not None:
|
||||
return model_input.technical_indicators
|
||||
return None
|
||||
|
||||
elif model_type == 'cob_rl':
|
||||
if model_input.order_book_features is not None:
|
||||
return model_input.order_book_features
|
||||
return None
|
||||
|
||||
elif model_type == 'transformer':
|
||||
if model_input.candles_1m is not None:
|
||||
return model_input.candles_1m
|
||||
return None
|
||||
|
||||
elif model_type == 'generic':
|
||||
if model_input.technical_indicators is not None:
|
||||
return model_input.technical_indicators
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.error(f"Unknown model type: {model_type}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model-specific input for {model_type}: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user