training wip
This commit is contained in:
@@ -1802,604 +1802,177 @@ class DataProvider:
|
||||
logger.debug(f"Applied pivot-based normalization for {symbol}")
|
||||
|
||||
else:
|
||||
# Fallback to traditional normalization when pivot bounds not available
|
||||
logger.debug("Using traditional normalization (no pivot bounds available)")
|
||||
|
||||
for col in df_norm.columns:
|
||||
if col in ['open', 'high', 'low', 'close', 'sma_10', 'sma_20', 'sma_50',
|
||||
'ema_12', 'ema_26', 'ema_50', 'bb_upper', 'bb_lower', 'bb_middle',
|
||||
'keltner_upper', 'keltner_lower', 'keltner_middle', 'psar', 'vwap']:
|
||||
# Price-based indicators: normalize by close price
|
||||
if 'close' in df_norm.columns:
|
||||
base_price = df_norm['close'].iloc[-1] # Use latest close as reference
|
||||
if base_price > 0:
|
||||
df_norm[col] = df_norm[col] / base_price
|
||||
|
||||
elif col == 'volume':
|
||||
# Volume: normalize by its own rolling mean
|
||||
volume_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
|
||||
if volume_mean > 0:
|
||||
df_norm[col] = df_norm[col] / volume_mean
|
||||
|
||||
# Normalize indicators that have standard ranges (regardless of pivot bounds)
|
||||
for col in df_norm.columns:
|
||||
if col in ['rsi_14', 'rsi_7', 'rsi_21']:
|
||||
# RSI: already 0-100, normalize to 0-1
|
||||
df_norm[col] = df_norm[col] / 100.0
|
||||
|
||||
elif col in ['stoch_k', 'stoch_d']:
|
||||
# Stochastic: already 0-100, normalize to 0-1
|
||||
df_norm[col] = df_norm[col] / 100.0
|
||||
|
||||
elif col == 'williams_r':
|
||||
# Williams %R: -100 to 0, normalize to 0-1
|
||||
df_norm[col] = (df_norm[col] + 100) / 100.0
|
||||
|
||||
elif col in ['macd', 'macd_signal', 'macd_histogram']:
|
||||
# MACD: normalize by ATR or close price
|
||||
if 'atr' in df_norm.columns and df_norm['atr'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['atr'].iloc[-1]
|
||||
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
|
||||
|
||||
elif col in ['bb_width', 'bb_percent', 'price_position', 'trend_strength',
|
||||
'momentum_composite', 'volatility_regime', 'pivot_price_position',
|
||||
'pivot_support_distance', 'pivot_resistance_distance']:
|
||||
# Already normalized indicators: ensure 0-1 range
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1)
|
||||
|
||||
elif col in ['atr', 'true_range']:
|
||||
# Volatility indicators: normalize by close price or pivot range
|
||||
if symbol and symbol in self.pivot_bounds:
|
||||
bounds = self.pivot_bounds[symbol]
|
||||
df_norm[col] = df_norm[col] / bounds.get_price_range()
|
||||
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
|
||||
|
||||
elif col not in ['timestamp', 'near_pivot_support', 'near_pivot_resistance']:
|
||||
# Other indicators: z-score normalization
|
||||
col_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
|
||||
col_std = df_norm[col].rolling(window=min(20, len(df_norm))).std().iloc[-1]
|
||||
if col_std > 0:
|
||||
df_norm[col] = (df_norm[col] - col_mean) / col_std
|
||||
else:
|
||||
df_norm[col] = 0
|
||||
|
||||
# Replace inf/-inf with 0
|
||||
df_norm = df_norm.replace([np.inf, -np.inf], 0)
|
||||
# Use symbol-grouped normalization with consistent ranges
|
||||
df_norm = self._apply_symbol_grouped_normalization(df_norm, symbol)
|
||||
|
||||
# Fill any remaining NaN values
|
||||
df_norm = df_norm.fillna(0.0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing features for {symbol}: {e}")
|
||||
return df.fillna(0.0) if df is not None else None
|
||||
|
||||
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
|
||||
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Get symbol-specific price ranges for consistent normalization
|
||||
symbol_price_ranges = {
|
||||
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
|
||||
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
|
||||
}
|
||||
|
||||
if symbol in symbol_price_ranges:
|
||||
price_range = symbol_price_ranges[symbol]
|
||||
range_size = price_range['max'] - price_range['min']
|
||||
|
||||
# Normalize price columns to [0, 1] range specific to symbol
|
||||
price_cols = ['open', 'high', 'low', 'close']
|
||||
for col in price_cols:
|
||||
if col in df_norm.columns:
|
||||
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
|
||||
|
||||
# Normalize volume to [0, 1] using log scale
|
||||
if 'volume' in df_norm.columns:
|
||||
df_norm['volume'] = np.log1p(df_norm['volume'])
|
||||
vol_max = df_norm['volume'].max()
|
||||
if vol_max > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / vol_max
|
||||
|
||||
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
|
||||
|
||||
# Fill any NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing features: {e}")
|
||||
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
def get_multi_symbol_feature_matrix(self, symbols: List[str] = None,
|
||||
timeframes: List[str] = None,
|
||||
window_size: int = 20) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get feature matrix for multiple symbols and timeframes
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (n_symbols, n_timeframes, window_size, n_features)
|
||||
"""
|
||||
|
||||
def get_historical_data_for_inference(self, symbol: str, timeframe: str, limit: int = 300) -> Optional[pd.DataFrame]:
|
||||
"""Get normalized historical data specifically for model inference"""
|
||||
try:
|
||||
if symbols is None:
|
||||
symbols = self.symbols
|
||||
if timeframes is None:
|
||||
timeframes = self.timeframes
|
||||
# Get raw historical data
|
||||
raw_df = self.get_historical_data(symbol, timeframe, limit)
|
||||
|
||||
symbol_matrices = []
|
||||
if raw_df is None or raw_df.empty:
|
||||
return None
|
||||
|
||||
for symbol in symbols:
|
||||
symbol_matrix = self.get_feature_matrix(symbol, timeframes, window_size)
|
||||
if symbol_matrix is not None:
|
||||
symbol_matrices.append(symbol_matrix)
|
||||
# Apply normalization for inference
|
||||
normalized_df = self._normalize_features(raw_df, symbol)
|
||||
|
||||
logger.debug(f"Retrieved normalized historical data for inference: {symbol} {timeframe} ({len(normalized_df)} records)")
|
||||
return normalized_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting normalized historical data for inference: {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def get_multi_symbol_features_for_inference(self, symbols_timeframes: List[Tuple[str, str]], limit: int = 300) -> Dict[str, Dict[str, pd.DataFrame]]:
|
||||
"""Get normalized multi-symbol feature matrices for model inference"""
|
||||
try:
|
||||
logger.info("Preparing normalized multi-symbol features for model inference...")
|
||||
|
||||
symbol_features = {}
|
||||
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
if symbol not in symbol_features:
|
||||
symbol_features[symbol] = {}
|
||||
|
||||
# Get normalized data for inference
|
||||
normalized_df = self.get_historical_data_for_inference(symbol, timeframe, limit)
|
||||
|
||||
if normalized_df is not None and not normalized_df.empty:
|
||||
symbol_features[symbol][timeframe] = normalized_df
|
||||
logger.debug(f"Prepared normalized features for {symbol} {timeframe}: {len(normalized_df)} records")
|
||||
else:
|
||||
logger.warning(f"Could not create feature matrix for {symbol}")
|
||||
logger.warning(f"No normalized data available for {symbol} {timeframe}")
|
||||
symbol_features[symbol][timeframe] = None
|
||||
|
||||
if symbol_matrices:
|
||||
# Stack all symbol matrices
|
||||
multi_symbol_matrix = np.stack(symbol_matrices, axis=0)
|
||||
logger.info(f"Created multi-symbol feature matrix: {multi_symbol_matrix.shape}")
|
||||
return multi_symbol_matrix
|
||||
|
||||
return None
|
||||
return symbol_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating multi-symbol feature matrix: {e}")
|
||||
return None
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""Get health status of the data provider"""
|
||||
status = {
|
||||
'streaming': self.is_streaming,
|
||||
'symbols': len(self.symbols),
|
||||
'timeframes': len(self.timeframes),
|
||||
'current_prices': len(self.current_prices),
|
||||
'websocket_tasks': len(self.websocket_tasks),
|
||||
'historical_data_loaded': {}
|
||||
}
|
||||
|
||||
# Check historical data availability
|
||||
for symbol in self.symbols:
|
||||
status['historical_data_loaded'][symbol] = {}
|
||||
for tf in self.timeframes:
|
||||
has_data = (symbol in self.historical_data and
|
||||
tf in self.historical_data[symbol] and
|
||||
not self.historical_data[symbol][tf].empty)
|
||||
status['historical_data_loaded'][symbol][tf] = has_data
|
||||
|
||||
return status
|
||||
|
||||
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
||||
symbols: List[str] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""Subscribe to real-time tick data updates"""
|
||||
subscriber_id = str(uuid.uuid4())[:8]
|
||||
subscriber_name = subscriber_name or f"subscriber_{subscriber_id}"
|
||||
|
||||
# Convert symbols to Binance format
|
||||
if symbols:
|
||||
binance_symbols = [s.replace('/', '').upper() for s in symbols]
|
||||
else:
|
||||
binance_symbols = [s.replace('/', '').upper() for s in self.symbols]
|
||||
|
||||
subscriber = DataSubscriber(
|
||||
subscriber_id=subscriber_id,
|
||||
callback=callback,
|
||||
symbols=binance_symbols,
|
||||
subscriber_name=subscriber_name
|
||||
)
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers[subscriber_id] = subscriber
|
||||
|
||||
logger.info(f"New tick subscriber registered: {subscriber_name} ({subscriber_id}) for symbols: {binance_symbols}")
|
||||
|
||||
# Send recent tick data to new subscriber
|
||||
self._send_recent_ticks_to_subscriber(subscriber)
|
||||
|
||||
return subscriber_id
|
||||
|
||||
def unsubscribe_from_ticks(self, subscriber_id: str):
|
||||
"""Unsubscribe from tick data updates"""
|
||||
with self.subscriber_lock:
|
||||
if subscriber_id in self.subscribers:
|
||||
subscriber_name = self.subscribers[subscriber_id].subscriber_name
|
||||
self.subscribers[subscriber_id].active = False
|
||||
del self.subscribers[subscriber_id]
|
||||
logger.info(f"Subscriber {subscriber_name} ({subscriber_id}) unsubscribed")
|
||||
|
||||
def _send_recent_ticks_to_subscriber(self, subscriber: DataSubscriber):
|
||||
"""Send recent tick data to a new subscriber"""
|
||||
logger.error(f"Error preparing multi-symbol features for inference: {e}")
|
||||
return {}
|
||||
|
||||
def get_cnn_features_for_inference(self, symbol: str, timeframe: str = '1m', window_size: int = 60) -> Optional[np.ndarray]:
|
||||
"""Get normalized CNN features for a specific symbol and timeframe"""
|
||||
try:
|
||||
for symbol in subscriber.symbols:
|
||||
if symbol in self.tick_buffers:
|
||||
# Send last 50 ticks to get subscriber up to speed
|
||||
recent_ticks = list(self.tick_buffers[symbol])[-50:]
|
||||
for tick in recent_ticks:
|
||||
try:
|
||||
subscriber.callback(tick)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending recent tick to subscriber {subscriber.subscriber_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending recent ticks: {e}")
|
||||
|
||||
def _distribute_tick(self, tick: MarketTick):
|
||||
"""Distribute tick to all relevant subscribers"""
|
||||
distributed_count = 0
|
||||
|
||||
with self.subscriber_lock:
|
||||
subscribers_to_remove = []
|
||||
# Get normalized data
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
|
||||
|
||||
for subscriber_id, subscriber in self.subscribers.items():
|
||||
if not subscriber.active:
|
||||
subscribers_to_remove.append(subscriber_id)
|
||||
continue
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Extract recent window for CNN
|
||||
recent_data = df.tail(window_size)
|
||||
|
||||
# Extract OHLCV features
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
logger.debug(f"Extracted CNN features for {symbol} {timeframe}: {features.shape}")
|
||||
return features.flatten()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN features for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def get_dqn_state_for_inference(self, symbols_timeframes: List[Tuple[str, str]], target_size: int = 100) -> Optional[np.ndarray]:
|
||||
"""Get normalized DQN state vector combining multiple symbols and timeframes"""
|
||||
try:
|
||||
state_components = []
|
||||
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=50)
|
||||
|
||||
if tick.symbol in subscriber.symbols:
|
||||
try:
|
||||
# Call subscriber callback in a thread to avoid blocking
|
||||
def call_callback():
|
||||
try:
|
||||
subscriber.callback(tick)
|
||||
subscriber.tick_count += 1
|
||||
subscriber.last_update = datetime.now()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in subscriber {subscriber_id} callback: {e}")
|
||||
subscriber.active = False
|
||||
|
||||
# Use thread to avoid blocking the main data processing
|
||||
Thread(target=call_callback, daemon=True).start()
|
||||
distributed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error distributing tick to subscriber {subscriber_id}: {e}")
|
||||
subscriber.active = False
|
||||
if df is not None and not df.empty:
|
||||
# Extract key features for state
|
||||
latest = df.iloc[-1]
|
||||
state_features = [
|
||||
latest['close'], # Current price (normalized)
|
||||
latest['volume'], # Current volume (normalized)
|
||||
df['close'].pct_change().iloc[-1] if len(df) > 1 else 0, # Price change
|
||||
]
|
||||
state_components.extend(state_features)
|
||||
|
||||
# Remove inactive subscribers
|
||||
for subscriber_id in subscribers_to_remove:
|
||||
if subscriber_id in self.subscribers:
|
||||
del self.subscribers[subscriber_id]
|
||||
|
||||
self.distribution_stats['total_ticks_distributed'] += distributed_count
|
||||
|
||||
def _validate_tick_data(self, symbol: str, price: float, volume: float) -> bool:
|
||||
"""Validate incoming tick data for quality"""
|
||||
try:
|
||||
# Basic validation
|
||||
if price <= 0 or volume < 0:
|
||||
return False
|
||||
|
||||
# Price change validation
|
||||
last_price = self.last_prices.get(symbol, 0)
|
||||
if last_price > 0:
|
||||
price_change_pct = abs(price - last_price) / last_price
|
||||
if price_change_pct > self.price_change_threshold:
|
||||
logger.warning(f"Large price change for {symbol}: {price_change_pct:.2%}")
|
||||
# Don't reject, just warn - could be legitimate
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating tick data: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_ticks(self, symbol: str, count: int = 100) -> List[MarketTick]:
|
||||
"""Get recent ticks for a symbol"""
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
if binance_symbol in self.tick_buffers:
|
||||
return list(self.tick_buffers[binance_symbol])[-count:]
|
||||
return []
|
||||
|
||||
def subscribe_to_raw_ticks(self, callback: Callable[[RawTick], None]) -> str:
|
||||
"""Subscribe to raw tick data with timing information"""
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
self.raw_tick_callbacks.append(callback)
|
||||
logger.info(f"Raw tick subscriber added: {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
def subscribe_to_ohlcv_bars(self, callback: Callable[[OHLCVBar], None]) -> str:
|
||||
"""Subscribe to 1s OHLCV bars calculated from ticks"""
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
self.ohlcv_bar_callbacks.append(callback)
|
||||
logger.info(f"OHLCV bar subscriber added: {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
def get_raw_tick_features(self, symbol: str, window_size: int = 50) -> Optional[np.ndarray]:
|
||||
"""Get raw tick features for model consumption"""
|
||||
return self.tick_aggregator.get_tick_features_for_model(symbol, window_size)
|
||||
|
||||
def get_ohlcv_features(self, symbol: str, window_size: int = 60) -> Optional[np.ndarray]:
|
||||
"""Get 1s OHLCV features for model consumption"""
|
||||
return self.tick_aggregator.get_ohlcv_features_for_model(symbol, window_size)
|
||||
|
||||
def get_detected_patterns(self, symbol: str, count: int = 50) -> List:
|
||||
"""Get recently detected tick patterns"""
|
||||
return self.tick_aggregator.get_detected_patterns(symbol, count)
|
||||
|
||||
def get_tick_aggregator_stats(self) -> Dict[str, Any]:
|
||||
"""Get tick aggregator statistics"""
|
||||
return self.tick_aggregator.get_statistics()
|
||||
|
||||
def get_subscriber_stats(self) -> Dict[str, Any]:
|
||||
"""Get subscriber and distribution statistics"""
|
||||
with self.subscriber_lock:
|
||||
active_subscribers = len([s for s in self.subscribers.values() if s.active])
|
||||
subscriber_stats = {
|
||||
sid: {
|
||||
'name': s.subscriber_name,
|
||||
'active': s.active,
|
||||
'symbols': s.symbols,
|
||||
'tick_count': s.tick_count,
|
||||
'last_update': s.last_update.isoformat() if s.last_update else None
|
||||
}
|
||||
for sid, s in self.subscribers.items()
|
||||
}
|
||||
|
||||
# Get tick aggregator stats
|
||||
aggregator_stats = self.get_tick_aggregator_stats()
|
||||
|
||||
return {
|
||||
'active_subscribers': active_subscribers,
|
||||
'total_subscribers': len(self.subscribers),
|
||||
'raw_tick_callbacks': len(self.raw_tick_callbacks),
|
||||
'ohlcv_bar_callbacks': len(self.ohlcv_bar_callbacks),
|
||||
'subscriber_details': subscriber_stats,
|
||||
'distribution_stats': self.distribution_stats.copy(),
|
||||
'buffer_sizes': {symbol: len(buffer) for symbol, buffer in self.tick_buffers.items()},
|
||||
'tick_aggregator': aggregator_stats
|
||||
}
|
||||
|
||||
def update_bom_cache(self, symbol: str, bom_features: List[float], cob_integration=None):
|
||||
"""
|
||||
Update BOM cache with latest features for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
bom_features: List of BOM features (should be 120 features)
|
||||
cob_integration: Optional COB integration instance for real BOM data
|
||||
"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Ensure we have exactly 120 features
|
||||
if len(bom_features) != self.bom_feature_count:
|
||||
if len(bom_features) > self.bom_feature_count:
|
||||
bom_features = bom_features[:self.bom_feature_count]
|
||||
if state_components:
|
||||
# Pad or truncate to expected DQN state size
|
||||
if len(state_components) < target_size:
|
||||
state_components.extend([0] * (target_size - len(state_components)))
|
||||
else:
|
||||
bom_features.extend([0.0] * (self.bom_feature_count - len(bom_features)))
|
||||
|
||||
# Convert to numpy array for efficient storage
|
||||
bom_array = np.array(bom_features, dtype=np.float32)
|
||||
|
||||
# Add timestamp and features to cache
|
||||
with self.data_lock:
|
||||
self.bom_data_cache[symbol].append((current_time, bom_array))
|
||||
|
||||
logger.debug(f"Updated BOM cache for {symbol}: {len(self.bom_data_cache[symbol])} timestamps cached")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating BOM cache for {symbol}: {e}")
|
||||
|
||||
def get_bom_matrix_for_cnn(self, symbol: str, sequence_length: int = 50) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get BOM matrix for CNN input from cached 1s data
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
sequence_length: Required sequence length (default 50)
|
||||
|
||||
Returns:
|
||||
np.ndarray: BOM matrix of shape (sequence_length, 120) or None if insufficient data
|
||||
"""
|
||||
try:
|
||||
with self.data_lock:
|
||||
if symbol not in self.bom_data_cache or len(self.bom_data_cache[symbol]) == 0:
|
||||
logger.warning(f"No BOM data cached for {symbol}")
|
||||
return None
|
||||
state_components = state_components[:target_size]
|
||||
|
||||
# Get recent data
|
||||
cached_data = list(self.bom_data_cache[symbol])
|
||||
|
||||
if len(cached_data) < sequence_length:
|
||||
logger.warning(f"Insufficient BOM data for {symbol}: {len(cached_data)} < {sequence_length}")
|
||||
# Pad with zeros if we don't have enough data
|
||||
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
|
||||
|
||||
# Fill available data at the end
|
||||
for i, (timestamp, features) in enumerate(cached_data):
|
||||
if i < sequence_length:
|
||||
bom_matrix[sequence_length - len(cached_data) + i] = features
|
||||
|
||||
return bom_matrix
|
||||
|
||||
# Take the most recent sequence_length samples
|
||||
recent_data = cached_data[-sequence_length:]
|
||||
|
||||
# Create matrix
|
||||
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
|
||||
for i, (timestamp, features) in enumerate(recent_data):
|
||||
bom_matrix[i] = features
|
||||
|
||||
logger.debug(f"Retrieved BOM matrix for {symbol}: shape={bom_matrix.shape}")
|
||||
return bom_matrix
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting BOM matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_real_bom_features(self, symbol: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Get REAL BOM features from actual market data ONLY
|
||||
|
||||
NO SYNTHETIC DATA - Returns None if real data is not available
|
||||
"""
|
||||
try:
|
||||
# Try to get real COB data from integration
|
||||
if hasattr(self, 'cob_integration') and self.cob_integration:
|
||||
return self._extract_real_bom_features(symbol, self.cob_integration)
|
||||
state_vector = np.array(state_components, dtype=np.float32)
|
||||
logger.debug(f"Created DQN state vector: {len(state_vector)} dimensions")
|
||||
return state_vector
|
||||
|
||||
# No real data available - return None instead of synthetic
|
||||
logger.warning(f"No real BOM data available for {symbol} - waiting for real market data")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting real BOM features for {symbol}: {e}")
|
||||
logger.error(f"Error creating DQN state for inference: {e}")
|
||||
return None
|
||||
|
||||
def start_bom_cache_updates(self, cob_integration=None):
|
||||
"""
|
||||
Start background updates of BOM cache every second
|
||||
|
||||
Args:
|
||||
cob_integration: Optional COB integration instance for real data
|
||||
"""
|
||||
|
||||
def get_transformer_sequences_for_inference(self, symbols_timeframes: List[Tuple[str, str]], seq_length: int = 150) -> List[np.ndarray]:
|
||||
"""Get normalized sequences for transformer inference"""
|
||||
try:
|
||||
def update_loop():
|
||||
while self.is_streaming:
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
if cob_integration:
|
||||
# Try to get real BOM features from COB integration
|
||||
try:
|
||||
bom_features = self._extract_real_bom_features(symbol, cob_integration)
|
||||
if bom_features:
|
||||
self.update_bom_cache(symbol, bom_features, cob_integration)
|
||||
else:
|
||||
# NO SYNTHETIC FALLBACK - Wait for real data
|
||||
logger.warning(f"No real BOM features available for {symbol} - waiting for real data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting real BOM features for {symbol}: {e}")
|
||||
logger.warning(f"Waiting for real data instead of using synthetic")
|
||||
else:
|
||||
# NO SYNTHETIC FEATURES - Wait for real COB integration
|
||||
logger.warning(f"No COB integration available for {symbol} - waiting for real data")
|
||||
|
||||
time.sleep(1.0) # Update every second
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in BOM cache update loop: {e}")
|
||||
time.sleep(5.0) # Wait longer on error
|
||||
sequences = []
|
||||
|
||||
# Start background thread
|
||||
bom_thread = Thread(target=update_loop, daemon=True)
|
||||
bom_thread.start()
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Use last seq_length points as sequence
|
||||
sequence = df.tail(seq_length)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
logger.debug(f"Created transformer sequence for {symbol} {timeframe}: {sequence.shape}")
|
||||
|
||||
logger.info("Started BOM cache updates (1s resolution)")
|
||||
return sequences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting BOM cache updates: {e}")
|
||||
|
||||
def _extract_real_bom_features(self, symbol: str, cob_integration) -> Optional[List[float]]:
|
||||
"""Extract real BOM features from COB integration"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Get consolidated order book
|
||||
if hasattr(cob_integration, 'get_consolidated_orderbook'):
|
||||
cob_snapshot = cob_integration.get_consolidated_orderbook(symbol)
|
||||
if cob_snapshot:
|
||||
# Extract order book features (40 features)
|
||||
features.extend(self._extract_orderbook_features(cob_snapshot))
|
||||
else:
|
||||
features.extend([0.0] * 40)
|
||||
else:
|
||||
features.extend([0.0] * 40)
|
||||
|
||||
# Get volume profile features (30 features)
|
||||
if hasattr(cob_integration, 'get_session_volume_profile'):
|
||||
volume_profile = cob_integration.get_session_volume_profile(symbol)
|
||||
if volume_profile:
|
||||
features.extend(self._extract_volume_profile_features(volume_profile))
|
||||
else:
|
||||
features.extend([0.0] * 30)
|
||||
else:
|
||||
features.extend([0.0] * 30)
|
||||
|
||||
# Add flow and microstructure features (50 features)
|
||||
features.extend(self._extract_flow_microstructure_features(symbol, cob_integration))
|
||||
|
||||
# Ensure exactly 120 features
|
||||
if len(features) > 120:
|
||||
features = features[:120]
|
||||
elif len(features) < 120:
|
||||
features.extend([0.0] * (120 - len(features)))
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting real BOM features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_orderbook_features(self, cob_snapshot) -> List[float]:
|
||||
"""Extract order book features from COB snapshot"""
|
||||
features = []
|
||||
|
||||
try:
|
||||
# Top 10 bid levels
|
||||
for i in range(10):
|
||||
if i < len(cob_snapshot.consolidated_bids):
|
||||
level = cob_snapshot.consolidated_bids[i]
|
||||
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
|
||||
volume_normalized = level.total_volume_usd / 1000000
|
||||
features.extend([price_offset, volume_normalized])
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Top 10 ask levels
|
||||
for i in range(10):
|
||||
if i < len(cob_snapshot.consolidated_asks):
|
||||
level = cob_snapshot.consolidated_asks[i]
|
||||
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
|
||||
volume_normalized = level.total_volume_usd / 1000000
|
||||
features.extend([price_offset, volume_normalized])
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting order book features: {e}")
|
||||
features = [0.0] * 40
|
||||
|
||||
return features[:40]
|
||||
|
||||
def _extract_volume_profile_features(self, volume_profile) -> List[float]:
|
||||
"""Extract volume profile features"""
|
||||
features = []
|
||||
|
||||
try:
|
||||
if 'data' in volume_profile:
|
||||
svp_data = volume_profile['data']
|
||||
top_levels = sorted(svp_data, key=lambda x: x.get('total_volume', 0), reverse=True)[:10]
|
||||
|
||||
for level in top_levels:
|
||||
buy_percent = level.get('buy_percent', 50.0) / 100.0
|
||||
sell_percent = level.get('sell_percent', 50.0) / 100.0
|
||||
total_volume = level.get('total_volume', 0.0) / 1000000
|
||||
features.extend([buy_percent, sell_percent, total_volume])
|
||||
|
||||
# Pad to 30 features
|
||||
while len(features) < 30:
|
||||
features.extend([0.5, 0.5, 0.0])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting volume profile features: {e}")
|
||||
features = [0.0] * 30
|
||||
|
||||
return features[:30]
|
||||
|
||||
def _extract_flow_microstructure_features(self, symbol: str, cob_integration) -> List[float]:
|
||||
"""Extract flow and microstructure features"""
|
||||
try:
|
||||
# For now, return synthetic features since full implementation would be complex
|
||||
# NO SYNTHETIC DATA - Return None if no real microstructure data
|
||||
logger.warning(f"No real microstructure data available for {symbol}")
|
||||
return None
|
||||
except:
|
||||
return [0.0] * 50
|
||||
|
||||
def _handle_rate_limit(self, url: str):
|
||||
"""Handle rate limiting with exponential backoff"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we need to wait
|
||||
if url in self.last_request_time:
|
||||
time_since_last = current_time - self.last_request_time[url]
|
||||
if time_since_last < self.request_interval:
|
||||
sleep_time = self.request_interval - time_since_last
|
||||
logger.info(f"Rate limiting: sleeping {sleep_time:.2f}s")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_request_time[url] = time.time()
|
||||
|
||||
def _make_request_with_retry(self, url: str, params: dict = None):
|
||||
"""Make HTTP request with retry logic for 451 errors"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
self._handle_rate_limit(url)
|
||||
response = requests.get(url, params=params, timeout=30)
|
||||
|
||||
if response.status_code == 451:
|
||||
logger.warning(f"Rate limit hit (451), attempt {attempt + 1}/{self.max_retries}")
|
||||
if attempt < self.max_retries - 1:
|
||||
sleep_time = self.retry_delay * (2 ** attempt) # Exponential backoff
|
||||
logger.info(f"Waiting {sleep_time}s before retry...")
|
||||
time.sleep(sleep_time)
|
||||
continue
|
||||
else:
|
||||
logger.error("Max retries reached, using cached data")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request failed (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(5 * (attempt + 1))
|
||||
|
||||
return None
|
||||
logger.error(f"Error creating transformer sequences for inference: {e}")
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user