LR module possibly working
This commit is contained in:
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import torch
|
||||
import ta
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider, RawTick, OHLCVBar, MarketTick
|
||||
@ -68,7 +69,7 @@ class TradingAction:
|
||||
|
||||
@dataclass
|
||||
class MarketState:
|
||||
"""Complete market state for RL evaluation"""
|
||||
"""Complete market state for RL evaluation with comprehensive data"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
prices: Dict[str, float] # {timeframe: current_price}
|
||||
@ -78,6 +79,15 @@ class MarketState:
|
||||
trend_strength: float
|
||||
market_regime: str # 'trending', 'ranging', 'volatile'
|
||||
universal_data: UniversalDataStream # Universal format data
|
||||
|
||||
# Enhanced data for comprehensive RL state building
|
||||
raw_ticks: List[Dict[str, Any]] = field(default_factory=list) # Last 300s of tick data
|
||||
ohlcv_data: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) # Multi-timeframe OHLCV
|
||||
btc_reference_data: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) # BTC correlation data
|
||||
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None # CNN hidden layer features
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None # CNN predictions by timeframe
|
||||
pivot_points: Optional[Dict[str, Any]] = None # Williams market structure data
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Tick-level patterns
|
||||
|
||||
@dataclass
|
||||
class PerfectMove:
|
||||
@ -341,89 +351,328 @@ class EnhancedTradingOrchestrator:
|
||||
return decisions
|
||||
|
||||
async def _get_all_market_states_universal(self, universal_stream: UniversalDataStream) -> Dict[str, MarketState]:
|
||||
"""Get current market state for all symbols using universal data format"""
|
||||
"""Get market states for all symbols with comprehensive data for RL"""
|
||||
market_states = {}
|
||||
|
||||
try:
|
||||
# Create market state for ETH/USDT (primary trading pair)
|
||||
if 'ETH/USDT' in self.symbols:
|
||||
eth_prices = {}
|
||||
eth_features = {}
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Basic market state data
|
||||
current_prices = {}
|
||||
for timeframe in self.timeframes:
|
||||
# Get latest price from universal data stream
|
||||
latest_price = self._get_latest_price_from_universal(symbol, timeframe, universal_stream)
|
||||
if latest_price:
|
||||
current_prices[timeframe] = latest_price
|
||||
|
||||
# Extract prices from universal stream
|
||||
if len(universal_stream.eth_ticks) > 0:
|
||||
eth_prices['1s'] = float(universal_stream.eth_ticks[-1, 4]) # Close price from ticks
|
||||
if len(universal_stream.eth_1m) > 0:
|
||||
eth_prices['1m'] = float(universal_stream.eth_1m[-1, 4]) # Close price from 1m
|
||||
if len(universal_stream.eth_1h) > 0:
|
||||
eth_prices['1h'] = float(universal_stream.eth_1h[-1, 4]) # Close price from 1h
|
||||
if len(universal_stream.eth_1d) > 0:
|
||||
eth_prices['1d'] = float(universal_stream.eth_1d[-1, 4]) # Close price from 1d
|
||||
# Calculate basic metrics
|
||||
volatility = self._calculate_volatility_from_universal(symbol, universal_stream)
|
||||
volume = self._calculate_volume_from_universal(symbol, universal_stream)
|
||||
trend_strength = self._calculate_trend_strength_from_universal(symbol, universal_stream)
|
||||
market_regime = self._determine_market_regime(symbol, universal_stream)
|
||||
|
||||
# Extract features from universal stream (OHLCV data)
|
||||
eth_features['1s'] = universal_stream.eth_ticks[:, 1:] if universal_stream.eth_ticks.shape[1] > 5 else universal_stream.eth_ticks
|
||||
eth_features['1m'] = universal_stream.eth_1m[:, 1:] if universal_stream.eth_1m.shape[1] > 5 else universal_stream.eth_1m
|
||||
eth_features['1h'] = universal_stream.eth_1h[:, 1:] if universal_stream.eth_1h.shape[1] > 5 else universal_stream.eth_1h
|
||||
eth_features['1d'] = universal_stream.eth_1d[:, 1:] if universal_stream.eth_1d.shape[1] > 5 else universal_stream.eth_1d
|
||||
# Get comprehensive data for RL state building
|
||||
raw_ticks = self._get_recent_tick_data_for_rl(symbol)
|
||||
ohlcv_data = self._get_multiframe_ohlcv_for_rl(symbol)
|
||||
btc_reference_data = self._get_multiframe_ohlcv_for_rl('BTC/USDT')
|
||||
|
||||
# Calculate market metrics
|
||||
volatility = self._calculate_volatility_from_universal('ETH/USDT', universal_stream)
|
||||
volume = self._get_current_volume_from_universal('ETH/USDT', universal_stream)
|
||||
trend_strength = self._calculate_trend_strength_from_universal('ETH/USDT', universal_stream)
|
||||
market_regime = self._determine_market_regime_from_universal('ETH/USDT', universal_stream)
|
||||
# Get CNN features if available
|
||||
cnn_hidden_features, cnn_predictions = self._get_cnn_features_for_rl(symbol)
|
||||
|
||||
eth_market_state = MarketState(
|
||||
symbol='ETH/USDT',
|
||||
timestamp=universal_stream.timestamp,
|
||||
prices=eth_prices,
|
||||
features=eth_features,
|
||||
# Calculate pivot points
|
||||
pivot_points = self._calculate_pivot_points_for_rl(ohlcv_data)
|
||||
|
||||
# Analyze market microstructure
|
||||
market_microstructure = self._analyze_market_microstructure(raw_ticks)
|
||||
|
||||
# Create comprehensive market state
|
||||
market_state = MarketState(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
prices=current_prices,
|
||||
features={}, # Will be populated by feature extraction
|
||||
volatility=volatility,
|
||||
volume=volume,
|
||||
trend_strength=trend_strength,
|
||||
market_regime=market_regime,
|
||||
universal_data=universal_stream
|
||||
universal_data=universal_stream,
|
||||
raw_ticks=raw_ticks,
|
||||
ohlcv_data=ohlcv_data,
|
||||
btc_reference_data=btc_reference_data,
|
||||
cnn_hidden_features=cnn_hidden_features,
|
||||
cnn_predictions=cnn_predictions,
|
||||
pivot_points=pivot_points,
|
||||
market_microstructure=market_microstructure
|
||||
)
|
||||
|
||||
market_states['ETH/USDT'] = eth_market_state
|
||||
self.market_states['ETH/USDT'].append(eth_market_state)
|
||||
|
||||
# Create market state for BTC/USDT (reference pair)
|
||||
if 'BTC/USDT' in self.symbols:
|
||||
btc_prices = {}
|
||||
btc_features = {}
|
||||
market_states[symbol] = market_state
|
||||
logger.debug(f"Created comprehensive market state for {symbol} with {len(raw_ticks)} ticks")
|
||||
|
||||
# Extract BTC reference data
|
||||
if len(universal_stream.btc_ticks) > 0:
|
||||
btc_prices['1s'] = float(universal_stream.btc_ticks[-1, 4]) # Close price from BTC ticks
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating market state for {symbol}: {e}")
|
||||
|
||||
btc_features['1s'] = universal_stream.btc_ticks[:, 1:] if universal_stream.btc_ticks.shape[1] > 5 else universal_stream.btc_ticks
|
||||
|
||||
# Calculate BTC metrics
|
||||
btc_volatility = self._calculate_volatility_from_universal('BTC/USDT', universal_stream)
|
||||
btc_volume = self._get_current_volume_from_universal('BTC/USDT', universal_stream)
|
||||
btc_trend_strength = self._calculate_trend_strength_from_universal('BTC/USDT', universal_stream)
|
||||
btc_market_regime = self._determine_market_regime_from_universal('BTC/USDT', universal_stream)
|
||||
|
||||
btc_market_state = MarketState(
|
||||
symbol='BTC/USDT',
|
||||
timestamp=universal_stream.timestamp,
|
||||
prices=btc_prices,
|
||||
features=btc_features,
|
||||
volatility=btc_volatility,
|
||||
volume=btc_volume,
|
||||
trend_strength=btc_trend_strength,
|
||||
market_regime=btc_market_regime,
|
||||
universal_data=universal_stream
|
||||
)
|
||||
|
||||
market_states['BTC/USDT'] = btc_market_state
|
||||
self.market_states['BTC/USDT'].append(btc_market_state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating market states from universal data: {e}")
|
||||
|
||||
return market_states
|
||||
|
||||
def _get_recent_tick_data_for_rl(self, symbol: str, seconds: int = 300) -> List[Dict[str, Any]]:
|
||||
"""Get recent tick data for RL state building"""
|
||||
try:
|
||||
# Get ticks from data provider
|
||||
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=seconds * 10)
|
||||
|
||||
# Convert to required format
|
||||
tick_data = []
|
||||
for tick in recent_ticks[-300:]: # Last 300 ticks max (300s at ~1 tick/sec)
|
||||
tick_dict = {
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': getattr(tick, 'quantity', tick.volume),
|
||||
'side': getattr(tick, 'side', 'unknown'),
|
||||
'trade_id': getattr(tick, 'trade_id', 'unknown'),
|
||||
'is_buyer_maker': getattr(tick, 'is_buyer_maker', False)
|
||||
}
|
||||
tick_data.append(tick_dict)
|
||||
|
||||
return tick_data
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _get_multiframe_ohlcv_for_rl(self, symbol: str) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Get multi-timeframe OHLCV data for RL state building"""
|
||||
try:
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
# Get historical data for timeframe
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
limit=300,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to list of dictionaries with technical indicators
|
||||
bars = []
|
||||
|
||||
# Add technical indicators
|
||||
df_with_indicators = self._add_technical_indicators(df)
|
||||
|
||||
for idx, row in df_with_indicators.tail(300).iterrows():
|
||||
bar = {
|
||||
'timestamp': idx if hasattr(idx, 'timestamp') else datetime.now(),
|
||||
'open': float(row.get('open', 0)),
|
||||
'high': float(row.get('high', 0)),
|
||||
'low': float(row.get('low', 0)),
|
||||
'close': float(row.get('close', 0)),
|
||||
'volume': float(row.get('volume', 0)),
|
||||
'rsi': float(row.get('rsi', 50)),
|
||||
'macd': float(row.get('macd', 0)),
|
||||
'bb_upper': float(row.get('bb_upper', row.get('close', 0))),
|
||||
'bb_lower': float(row.get('bb_lower', row.get('close', 0))),
|
||||
'sma_20': float(row.get('sma_20', row.get('close', 0))),
|
||||
'ema_12': float(row.get('ema_12', row.get('close', 0))),
|
||||
'atr': float(row.get('atr', 0))
|
||||
}
|
||||
bars.append(bar)
|
||||
|
||||
ohlcv_data[tf] = bars
|
||||
else:
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add technical indicators to OHLCV data"""
|
||||
try:
|
||||
df = df.copy()
|
||||
|
||||
# RSI
|
||||
if len(df) >= 14:
|
||||
df['rsi'] = ta.momentum.rsi(df['close'], window=14)
|
||||
else:
|
||||
df['rsi'] = 50
|
||||
|
||||
# MACD
|
||||
if len(df) >= 26:
|
||||
macd = ta.trend.macd_diff(df['close'])
|
||||
df['macd'] = macd
|
||||
else:
|
||||
df['macd'] = 0
|
||||
|
||||
# Bollinger Bands
|
||||
if len(df) >= 20:
|
||||
bb = ta.volatility.BollingerBands(df['close'], window=20)
|
||||
df['bb_upper'] = bb.bollinger_hband()
|
||||
df['bb_lower'] = bb.bollinger_lband()
|
||||
else:
|
||||
df['bb_upper'] = df['close']
|
||||
df['bb_lower'] = df['close']
|
||||
|
||||
# Moving Averages
|
||||
if len(df) >= 20:
|
||||
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
|
||||
else:
|
||||
df['sma_20'] = df['close']
|
||||
|
||||
if len(df) >= 12:
|
||||
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
|
||||
else:
|
||||
df['ema_12'] = df['close']
|
||||
|
||||
# ATR
|
||||
if len(df) >= 14:
|
||||
df['atr'] = ta.volatility.average_true_range(df['high'], df['low'], df['close'], window=14)
|
||||
else:
|
||||
df['atr'] = 0
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding technical indicators: {e}")
|
||||
return df
|
||||
|
||||
def _get_cnn_features_for_rl(self, symbol: str) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[Dict[str, np.ndarray]]]:
|
||||
"""Get CNN hidden features and predictions for RL state building"""
|
||||
try:
|
||||
# Try to get CNN features from model registry
|
||||
if hasattr(self, 'model_registry') and self.model_registry:
|
||||
cnn_models = self.model_registry.get_models_by_type('cnn')
|
||||
|
||||
if cnn_models:
|
||||
hidden_features = {}
|
||||
predictions = {}
|
||||
|
||||
for model_name, model in cnn_models.items():
|
||||
try:
|
||||
# Get recent market data for the model
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=['1s', '1m', '1h', '1d'],
|
||||
window_size=50
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Extract hidden features and predictions
|
||||
model_hidden, model_pred = self._extract_cnn_features(model, feature_matrix)
|
||||
if model_hidden is not None:
|
||||
hidden_features[model_name] = model_hidden
|
||||
if model_pred is not None:
|
||||
predictions[model_name] = model_pred
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting features from CNN model {model_name}: {e}")
|
||||
|
||||
return hidden_features if hidden_features else None, predictions if predictions else None
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting CNN features for {symbol}: {e}")
|
||||
return None, None
|
||||
|
||||
def _extract_cnn_features(self, model, feature_matrix: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Extract hidden features and predictions from CNN model"""
|
||||
try:
|
||||
# This would need to be implemented based on your specific CNN architecture
|
||||
# For now, return placeholder values
|
||||
|
||||
# Mock hidden features (would be extracted from model's hidden layers)
|
||||
hidden_features = np.random.random(512).astype(np.float32)
|
||||
|
||||
# Mock predictions (would be model's output)
|
||||
predictions = np.array([0.33, 0.33, 0.34, 0.7]).astype(np.float32) # BUY, SELL, HOLD, confidence
|
||||
|
||||
return hidden_features, predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting CNN features: {e}")
|
||||
return None, None
|
||||
|
||||
def _calculate_pivot_points_for_rl(self, ohlcv_data: Dict[str, List]) -> Optional[Dict[str, Any]]:
|
||||
"""Calculate Williams pivot points for RL state building"""
|
||||
try:
|
||||
if '1m' in ohlcv_data and len(ohlcv_data['1m']) >= 50:
|
||||
# Use 1m data for pivot calculation
|
||||
bars = ohlcv_data['1m']
|
||||
|
||||
# Convert to numpy array
|
||||
ohlc_array = np.array([
|
||||
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
|
||||
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
|
||||
for bar in bars[-200:] # Last 200 bars
|
||||
])
|
||||
|
||||
# Calculate pivot points using Williams structure
|
||||
# This would use the WilliamsMarketStructure implementation
|
||||
pivot_data = {
|
||||
'swing_highs': [],
|
||||
'swing_lows': [],
|
||||
'trend_levels': [],
|
||||
'market_bias': 'neutral'
|
||||
}
|
||||
|
||||
return pivot_data
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating pivot points: {e}")
|
||||
return None
|
||||
|
||||
def _analyze_market_microstructure(self, raw_ticks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Analyze market microstructure from tick data"""
|
||||
try:
|
||||
if not raw_ticks or len(raw_ticks) < 10:
|
||||
return {}
|
||||
|
||||
# Calculate microstructure metrics
|
||||
prices = [tick['price'] for tick in raw_ticks]
|
||||
volumes = [tick['volume'] for tick in raw_ticks]
|
||||
|
||||
# Price momentum
|
||||
price_momentum = (prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0
|
||||
|
||||
# Volume pattern
|
||||
avg_volume = sum(volumes) / len(volumes)
|
||||
recent_volume = sum(volumes[-10:]) / 10 if len(volumes) >= 10 else avg_volume
|
||||
volume_intensity = recent_volume / avg_volume if avg_volume != 0 else 1.0
|
||||
|
||||
# Tick frequency
|
||||
if len(raw_ticks) >= 2:
|
||||
time_diffs = []
|
||||
for i in range(1, len(raw_ticks)):
|
||||
if hasattr(raw_ticks[i]['timestamp'], 'timestamp') and hasattr(raw_ticks[i-1]['timestamp'], 'timestamp'):
|
||||
diff = raw_ticks[i]['timestamp'].timestamp() - raw_ticks[i-1]['timestamp'].timestamp()
|
||||
time_diffs.append(diff)
|
||||
|
||||
avg_tick_interval = sum(time_diffs) / len(time_diffs) if time_diffs else 1.0
|
||||
else:
|
||||
avg_tick_interval = 1.0
|
||||
|
||||
return {
|
||||
'price_momentum': price_momentum,
|
||||
'volume_intensity': volume_intensity,
|
||||
'avg_tick_interval': avg_tick_interval,
|
||||
'tick_count': len(raw_ticks),
|
||||
'price_volatility': np.std(prices) if len(prices) > 1 else 0.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error analyzing market microstructure: {e}")
|
||||
return {}
|
||||
|
||||
async def _get_enhanced_predictions_universal(self, symbol: str, market_state: MarketState,
|
||||
universal_stream: UniversalDataStream) -> List[EnhancedPrediction]:
|
||||
"""Get enhanced predictions using universal data format"""
|
||||
|
627
core/unified_data_stream.py
Normal file
627
core/unified_data_stream.py
Normal file
@ -0,0 +1,627 @@
|
||||
"""
|
||||
Unified Data Stream Architecture for Dashboard and Enhanced RL Training
|
||||
|
||||
This module provides a centralized data streaming architecture that:
|
||||
1. Serves real-time data to the dashboard UI
|
||||
2. Feeds the enhanced RL training pipeline with comprehensive data
|
||||
3. Maintains data consistency across all consumers
|
||||
4. Provides efficient data distribution without duplication
|
||||
5. Supports multiple data consumers with different requirements
|
||||
|
||||
Key Features:
|
||||
- Single source of truth for all market data
|
||||
- Real-time tick processing and aggregation
|
||||
- Multi-timeframe OHLCV generation
|
||||
- CNN feature extraction and caching
|
||||
- RL state building with comprehensive data
|
||||
- Dashboard-ready formatted data
|
||||
- Training data collection and buffering
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
from threading import Thread, Lock
|
||||
import json
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider, MarketTick
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from .enhanced_orchestrator import MarketState, TradingAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class StreamConsumer:
|
||||
"""Data stream consumer configuration"""
|
||||
consumer_id: str
|
||||
consumer_name: str
|
||||
callback: Callable[[Dict[str, Any]], None]
|
||||
data_types: List[str] # ['ticks', 'ohlcv', 'training_data', 'ui_data']
|
||||
active: bool = True
|
||||
last_update: datetime = field(default_factory=datetime.now)
|
||||
update_count: int = 0
|
||||
|
||||
@dataclass
|
||||
class TrainingDataPacket:
|
||||
"""Training data packet for RL pipeline"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
tick_cache: List[Dict[str, Any]]
|
||||
one_second_bars: List[Dict[str, Any]]
|
||||
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
|
||||
cnn_features: Optional[Dict[str, np.ndarray]]
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]]
|
||||
market_state: Optional[MarketState]
|
||||
universal_stream: Optional[UniversalDataStream]
|
||||
|
||||
@dataclass
|
||||
class UIDataPacket:
|
||||
"""UI data packet for dashboard"""
|
||||
timestamp: datetime
|
||||
current_prices: Dict[str, float]
|
||||
tick_cache_size: int
|
||||
one_second_bars_count: int
|
||||
streaming_status: str
|
||||
training_data_available: bool
|
||||
model_training_status: Dict[str, Any]
|
||||
orchestrator_status: Dict[str, Any]
|
||||
|
||||
class UnifiedDataStream:
|
||||
"""
|
||||
Unified data stream manager for dashboard and training pipeline integration
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider, orchestrator=None):
|
||||
"""Initialize unified data stream"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize universal data adapter
|
||||
self.universal_adapter = UniversalDataAdapter(data_provider)
|
||||
|
||||
# Data consumers registry
|
||||
self.consumers: Dict[str, StreamConsumer] = {}
|
||||
self.consumer_lock = Lock()
|
||||
|
||||
# Data buffers for different consumers
|
||||
self.tick_cache = deque(maxlen=5000) # Raw tick cache
|
||||
self.one_second_bars = deque(maxlen=1000) # 1s OHLCV bars
|
||||
self.training_data_buffer = deque(maxlen=100) # Training data packets
|
||||
self.ui_data_buffer = deque(maxlen=50) # UI data packets
|
||||
|
||||
# Multi-timeframe data storage
|
||||
self.multi_timeframe_data = {
|
||||
'ETH/USDT': {
|
||||
'1s': deque(maxlen=300),
|
||||
'1m': deque(maxlen=300),
|
||||
'1h': deque(maxlen=300),
|
||||
'1d': deque(maxlen=300)
|
||||
},
|
||||
'BTC/USDT': {
|
||||
'1s': deque(maxlen=300),
|
||||
'1m': deque(maxlen=300),
|
||||
'1h': deque(maxlen=300),
|
||||
'1d': deque(maxlen=300)
|
||||
}
|
||||
}
|
||||
|
||||
# CNN features cache
|
||||
self.cnn_features_cache = {}
|
||||
self.cnn_predictions_cache = {}
|
||||
|
||||
# Stream status
|
||||
self.streaming = False
|
||||
self.stream_thread = None
|
||||
|
||||
# Performance tracking
|
||||
self.stream_stats = {
|
||||
'total_ticks_processed': 0,
|
||||
'total_packets_sent': 0,
|
||||
'consumers_served': 0,
|
||||
'last_tick_time': None,
|
||||
'processing_errors': 0,
|
||||
'data_quality_score': 1.0
|
||||
}
|
||||
|
||||
# Data validation
|
||||
self.last_prices = {}
|
||||
self.price_change_threshold = 0.1 # 10% change threshold
|
||||
|
||||
logger.info("Unified Data Stream initialized")
|
||||
logger.info(f"Symbols: {self.config.symbols}")
|
||||
logger.info(f"Timeframes: {self.config.timeframes}")
|
||||
|
||||
def register_consumer(self, consumer_name: str, callback: Callable[[Dict[str, Any]], None],
|
||||
data_types: List[str]) -> str:
|
||||
"""Register a data consumer"""
|
||||
consumer_id = f"{consumer_name}_{int(time.time())}"
|
||||
|
||||
with self.consumer_lock:
|
||||
consumer = StreamConsumer(
|
||||
consumer_id=consumer_id,
|
||||
consumer_name=consumer_name,
|
||||
callback=callback,
|
||||
data_types=data_types
|
||||
)
|
||||
self.consumers[consumer_id] = consumer
|
||||
|
||||
logger.info(f"Registered consumer: {consumer_name} ({consumer_id})")
|
||||
logger.info(f"Data types: {data_types}")
|
||||
|
||||
return consumer_id
|
||||
|
||||
def unregister_consumer(self, consumer_id: str):
|
||||
"""Unregister a data consumer"""
|
||||
with self.consumer_lock:
|
||||
if consumer_id in self.consumers:
|
||||
consumer = self.consumers.pop(consumer_id)
|
||||
logger.info(f"Unregistered consumer: {consumer.consumer_name} ({consumer_id})")
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start unified data streaming"""
|
||||
if self.streaming:
|
||||
logger.warning("Data streaming already active")
|
||||
return
|
||||
|
||||
self.streaming = True
|
||||
|
||||
# Subscribe to data provider ticks
|
||||
self.data_provider.subscribe_to_ticks(
|
||||
callback=self._handle_tick,
|
||||
symbols=self.config.symbols,
|
||||
subscriber_name="UnifiedDataStream"
|
||||
)
|
||||
|
||||
# Start background processing
|
||||
self.stream_thread = Thread(target=self._stream_processor, daemon=True)
|
||||
self.stream_thread.start()
|
||||
|
||||
logger.info("Unified data streaming started")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop unified data streaming"""
|
||||
self.streaming = False
|
||||
|
||||
if self.stream_thread:
|
||||
self.stream_thread.join(timeout=5)
|
||||
|
||||
logger.info("Unified data streaming stopped")
|
||||
|
||||
def _handle_tick(self, tick: MarketTick):
|
||||
"""Handle incoming tick data"""
|
||||
try:
|
||||
# Validate tick data
|
||||
if not self._validate_tick(tick):
|
||||
return
|
||||
|
||||
# Add to tick cache
|
||||
tick_data = {
|
||||
'symbol': tick.symbol,
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': tick.quantity,
|
||||
'side': tick.side
|
||||
}
|
||||
|
||||
self.tick_cache.append(tick_data)
|
||||
|
||||
# Update current prices
|
||||
self.last_prices[tick.symbol] = tick.price
|
||||
|
||||
# Generate 1s bars if needed
|
||||
self._update_one_second_bars(tick_data)
|
||||
|
||||
# Update multi-timeframe data
|
||||
self._update_multi_timeframe_data(tick_data)
|
||||
|
||||
# Update statistics
|
||||
self.stream_stats['total_ticks_processed'] += 1
|
||||
self.stream_stats['last_tick_time'] = tick.timestamp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling tick: {e}")
|
||||
self.stream_stats['processing_errors'] += 1
|
||||
|
||||
def _validate_tick(self, tick: MarketTick) -> bool:
|
||||
"""Validate tick data quality"""
|
||||
try:
|
||||
# Check for valid price
|
||||
if tick.price <= 0:
|
||||
return False
|
||||
|
||||
# Check for reasonable price change
|
||||
if tick.symbol in self.last_prices:
|
||||
last_price = self.last_prices[tick.symbol]
|
||||
if last_price > 0:
|
||||
price_change = abs(tick.price - last_price) / last_price
|
||||
if price_change > self.price_change_threshold:
|
||||
logger.warning(f"Large price change detected for {tick.symbol}: {price_change:.2%}")
|
||||
return False
|
||||
|
||||
# Check timestamp
|
||||
if tick.timestamp > datetime.now() + timedelta(seconds=10):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating tick: {e}")
|
||||
return False
|
||||
|
||||
def _update_one_second_bars(self, tick_data: Dict[str, Any]):
|
||||
"""Update 1-second OHLCV bars"""
|
||||
try:
|
||||
symbol = tick_data['symbol']
|
||||
price = tick_data['price']
|
||||
volume = tick_data['volume']
|
||||
timestamp = tick_data['timestamp']
|
||||
|
||||
# Round timestamp to nearest second
|
||||
bar_timestamp = timestamp.replace(microsecond=0)
|
||||
|
||||
# Check if we need a new bar
|
||||
if (not self.one_second_bars or
|
||||
self.one_second_bars[-1]['timestamp'] != bar_timestamp or
|
||||
self.one_second_bars[-1]['symbol'] != symbol):
|
||||
|
||||
# Create new 1s bar
|
||||
bar_data = {
|
||||
'symbol': symbol,
|
||||
'timestamp': bar_timestamp,
|
||||
'open': price,
|
||||
'high': price,
|
||||
'low': price,
|
||||
'close': price,
|
||||
'volume': volume
|
||||
}
|
||||
self.one_second_bars.append(bar_data)
|
||||
else:
|
||||
# Update existing bar
|
||||
bar = self.one_second_bars[-1]
|
||||
bar['high'] = max(bar['high'], price)
|
||||
bar['low'] = min(bar['low'], price)
|
||||
bar['close'] = price
|
||||
bar['volume'] += volume
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating 1s bars: {e}")
|
||||
|
||||
def _update_multi_timeframe_data(self, tick_data: Dict[str, Any]):
|
||||
"""Update multi-timeframe OHLCV data"""
|
||||
try:
|
||||
symbol = tick_data['symbol']
|
||||
if symbol not in self.multi_timeframe_data:
|
||||
return
|
||||
|
||||
# Update each timeframe
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
self._update_timeframe_bar(symbol, timeframe, tick_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating multi-timeframe data: {e}")
|
||||
|
||||
def _update_timeframe_bar(self, symbol: str, timeframe: str, tick_data: Dict[str, Any]):
|
||||
"""Update specific timeframe bar"""
|
||||
try:
|
||||
price = tick_data['price']
|
||||
volume = tick_data['volume']
|
||||
timestamp = tick_data['timestamp']
|
||||
|
||||
# Calculate bar timestamp based on timeframe
|
||||
if timeframe == '1s':
|
||||
bar_timestamp = timestamp.replace(microsecond=0)
|
||||
elif timeframe == '1m':
|
||||
bar_timestamp = timestamp.replace(second=0, microsecond=0)
|
||||
elif timeframe == '1h':
|
||||
bar_timestamp = timestamp.replace(minute=0, second=0, microsecond=0)
|
||||
elif timeframe == '1d':
|
||||
bar_timestamp = timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
else:
|
||||
return
|
||||
|
||||
timeframe_buffer = self.multi_timeframe_data[symbol][timeframe]
|
||||
|
||||
# Check if we need a new bar
|
||||
if (not timeframe_buffer or
|
||||
timeframe_buffer[-1]['timestamp'] != bar_timestamp):
|
||||
|
||||
# Create new bar
|
||||
bar_data = {
|
||||
'timestamp': bar_timestamp,
|
||||
'open': price,
|
||||
'high': price,
|
||||
'low': price,
|
||||
'close': price,
|
||||
'volume': volume
|
||||
}
|
||||
timeframe_buffer.append(bar_data)
|
||||
else:
|
||||
# Update existing bar
|
||||
bar = timeframe_buffer[-1]
|
||||
bar['high'] = max(bar['high'], price)
|
||||
bar['low'] = min(bar['low'], price)
|
||||
bar['close'] = price
|
||||
bar['volume'] += volume
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating {timeframe} bar for {symbol}: {e}")
|
||||
|
||||
def _stream_processor(self):
|
||||
"""Background stream processor"""
|
||||
logger.info("Stream processor started")
|
||||
|
||||
while self.streaming:
|
||||
try:
|
||||
# Process training data packets
|
||||
self._process_training_data()
|
||||
|
||||
# Process UI data packets
|
||||
self._process_ui_data()
|
||||
|
||||
# Update CNN features if orchestrator available
|
||||
if self.orchestrator:
|
||||
self._update_cnn_features()
|
||||
|
||||
# Distribute data to consumers
|
||||
self._distribute_data()
|
||||
|
||||
# Sleep briefly
|
||||
time.sleep(0.1) # 100ms processing cycle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream processor: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
logger.info("Stream processor stopped")
|
||||
|
||||
def _process_training_data(self):
|
||||
"""Process and package training data"""
|
||||
try:
|
||||
if len(self.tick_cache) < 10: # Need minimum data
|
||||
return
|
||||
|
||||
# Create training data packet
|
||||
training_packet = TrainingDataPacket(
|
||||
timestamp=datetime.now(),
|
||||
symbol='ETH/USDT', # Primary symbol
|
||||
tick_cache=list(self.tick_cache)[-300:], # Last 300 ticks
|
||||
one_second_bars=list(self.one_second_bars)[-300:], # Last 300 1s bars
|
||||
multi_timeframe_data=self._get_multi_timeframe_snapshot(),
|
||||
cnn_features=self.cnn_features_cache.copy(),
|
||||
cnn_predictions=self.cnn_predictions_cache.copy(),
|
||||
market_state=self._build_market_state(),
|
||||
universal_stream=self._get_universal_stream()
|
||||
)
|
||||
|
||||
self.training_data_buffer.append(training_packet)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing training data: {e}")
|
||||
|
||||
def _process_ui_data(self):
|
||||
"""Process and package UI data"""
|
||||
try:
|
||||
# Create UI data packet
|
||||
ui_packet = UIDataPacket(
|
||||
timestamp=datetime.now(),
|
||||
current_prices=self.last_prices.copy(),
|
||||
tick_cache_size=len(self.tick_cache),
|
||||
one_second_bars_count=len(self.one_second_bars),
|
||||
streaming_status='LIVE' if self.streaming else 'STOPPED',
|
||||
training_data_available=len(self.training_data_buffer) > 0,
|
||||
model_training_status=self._get_model_training_status(),
|
||||
orchestrator_status=self._get_orchestrator_status()
|
||||
)
|
||||
|
||||
self.ui_data_buffer.append(ui_packet)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing UI data: {e}")
|
||||
|
||||
def _update_cnn_features(self):
|
||||
"""Update CNN features cache"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
# Get CNN features from orchestrator
|
||||
for symbol in self.config.symbols:
|
||||
if hasattr(self.orchestrator, '_get_cnn_features_for_rl'):
|
||||
hidden_features, predictions = self.orchestrator._get_cnn_features_for_rl(symbol)
|
||||
|
||||
if hidden_features:
|
||||
self.cnn_features_cache[symbol] = hidden_features
|
||||
|
||||
if predictions:
|
||||
self.cnn_predictions_cache[symbol] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating CNN features: {e}")
|
||||
|
||||
def _distribute_data(self):
|
||||
"""Distribute data to registered consumers"""
|
||||
try:
|
||||
with self.consumer_lock:
|
||||
for consumer_id, consumer in self.consumers.items():
|
||||
if not consumer.active:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Prepare data based on consumer requirements
|
||||
data_packet = self._prepare_consumer_data(consumer)
|
||||
|
||||
if data_packet:
|
||||
# Send data to consumer
|
||||
consumer.callback(data_packet)
|
||||
consumer.update_count += 1
|
||||
consumer.last_update = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending data to consumer {consumer.consumer_name}: {e}")
|
||||
consumer.active = False
|
||||
|
||||
self.stream_stats['consumers_served'] = len([c for c in self.consumers.values() if c.active])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error distributing data: {e}")
|
||||
|
||||
def _prepare_consumer_data(self, consumer: StreamConsumer) -> Optional[Dict[str, Any]]:
|
||||
"""Prepare data packet for specific consumer"""
|
||||
try:
|
||||
data_packet = {
|
||||
'timestamp': datetime.now(),
|
||||
'consumer_id': consumer.consumer_id,
|
||||
'consumer_name': consumer.consumer_name
|
||||
}
|
||||
|
||||
# Add requested data types
|
||||
if 'ticks' in consumer.data_types:
|
||||
data_packet['ticks'] = list(self.tick_cache)[-100:] # Last 100 ticks
|
||||
|
||||
if 'ohlcv' in consumer.data_types:
|
||||
data_packet['one_second_bars'] = list(self.one_second_bars)[-100:]
|
||||
data_packet['multi_timeframe'] = self._get_multi_timeframe_snapshot()
|
||||
|
||||
if 'training_data' in consumer.data_types:
|
||||
if self.training_data_buffer:
|
||||
data_packet['training_data'] = self.training_data_buffer[-1]
|
||||
|
||||
if 'ui_data' in consumer.data_types:
|
||||
if self.ui_data_buffer:
|
||||
data_packet['ui_data'] = self.ui_data_buffer[-1]
|
||||
|
||||
return data_packet
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing data for consumer {consumer.consumer_name}: {e}")
|
||||
return None
|
||||
|
||||
def _get_multi_timeframe_snapshot(self) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
|
||||
"""Get snapshot of multi-timeframe data"""
|
||||
snapshot = {}
|
||||
for symbol, timeframes in self.multi_timeframe_data.items():
|
||||
snapshot[symbol] = {}
|
||||
for timeframe, data in timeframes.items():
|
||||
snapshot[symbol][timeframe] = list(data)
|
||||
return snapshot
|
||||
|
||||
def _build_market_state(self) -> Optional[MarketState]:
|
||||
"""Build market state for training"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return None
|
||||
|
||||
# Get universal stream
|
||||
universal_stream = self._get_universal_stream()
|
||||
if not universal_stream:
|
||||
return None
|
||||
|
||||
# Build market state using orchestrator
|
||||
symbol = 'ETH/USDT'
|
||||
current_price = self.last_prices.get(symbol, 0.0)
|
||||
|
||||
market_state = MarketState(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
prices={'current': current_price},
|
||||
features={},
|
||||
volatility=0.0,
|
||||
volume=0.0,
|
||||
trend_strength=0.0,
|
||||
market_regime='unknown',
|
||||
universal_data=universal_stream,
|
||||
raw_ticks=list(self.tick_cache)[-300:],
|
||||
ohlcv_data=self._get_multi_timeframe_snapshot(),
|
||||
btc_reference_data=self._get_btc_reference_data(),
|
||||
cnn_hidden_features=self.cnn_features_cache.copy(),
|
||||
cnn_predictions=self.cnn_predictions_cache.copy()
|
||||
)
|
||||
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building market state: {e}")
|
||||
return None
|
||||
|
||||
def _get_universal_stream(self) -> Optional[UniversalDataStream]:
|
||||
"""Get universal data stream"""
|
||||
try:
|
||||
if self.universal_adapter:
|
||||
return self.universal_adapter.get_universal_stream()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal stream: {e}")
|
||||
return None
|
||||
|
||||
def _get_btc_reference_data(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Get BTC reference data"""
|
||||
btc_data = {}
|
||||
if 'BTC/USDT' in self.multi_timeframe_data:
|
||||
for timeframe, data in self.multi_timeframe_data['BTC/USDT'].items():
|
||||
btc_data[timeframe] = list(data)
|
||||
return btc_data
|
||||
|
||||
def _get_model_training_status(self) -> Dict[str, Any]:
|
||||
"""Get model training status"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'get_performance_metrics'):
|
||||
return self.orchestrator.get_performance_metrics()
|
||||
|
||||
return {
|
||||
'cnn_status': 'TRAINING',
|
||||
'rl_status': 'TRAINING',
|
||||
'data_available': len(self.training_data_buffer) > 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model training status: {e}")
|
||||
return {}
|
||||
|
||||
def _get_orchestrator_status(self) -> Dict[str, Any]:
|
||||
"""Get orchestrator status"""
|
||||
try:
|
||||
if self.orchestrator:
|
||||
return {
|
||||
'active': True,
|
||||
'symbols': self.config.symbols,
|
||||
'streaming': self.streaming,
|
||||
'tick_processor_active': hasattr(self.orchestrator, 'tick_processor')
|
||||
}
|
||||
|
||||
return {'active': False}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting orchestrator status: {e}")
|
||||
return {'active': False}
|
||||
|
||||
def get_stream_stats(self) -> Dict[str, Any]:
|
||||
"""Get stream statistics"""
|
||||
stats = self.stream_stats.copy()
|
||||
stats.update({
|
||||
'tick_cache_size': len(self.tick_cache),
|
||||
'one_second_bars_count': len(self.one_second_bars),
|
||||
'training_data_packets': len(self.training_data_buffer),
|
||||
'ui_data_packets': len(self.ui_data_buffer),
|
||||
'active_consumers': len([c for c in self.consumers.values() if c.active]),
|
||||
'total_consumers': len(self.consumers)
|
||||
})
|
||||
return stats
|
||||
|
||||
def get_latest_training_data(self) -> Optional[TrainingDataPacket]:
|
||||
"""Get latest training data packet"""
|
||||
if self.training_data_buffer:
|
||||
return self.training_data_buffer[-1]
|
||||
return None
|
||||
|
||||
def get_latest_ui_data(self) -> Optional[UIDataPacket]:
|
||||
"""Get latest UI data packet"""
|
||||
if self.ui_data_buffer:
|
||||
return self.ui_data_buffer[-1]
|
||||
return None
|
Reference in New Issue
Block a user