LR module possibly working

This commit is contained in:
Dobromir Popov
2025-05-28 23:42:06 +03:00
parent de01d3665c
commit 6b7d7aec81
16 changed files with 5118 additions and 580 deletions

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass, field
from collections import deque
import torch
import ta
from .config import get_config
from .data_provider import DataProvider, RawTick, OHLCVBar, MarketTick
@ -68,7 +69,7 @@ class TradingAction:
@dataclass
class MarketState:
"""Complete market state for RL evaluation"""
"""Complete market state for RL evaluation with comprehensive data"""
symbol: str
timestamp: datetime
prices: Dict[str, float] # {timeframe: current_price}
@ -78,6 +79,15 @@ class MarketState:
trend_strength: float
market_regime: str # 'trending', 'ranging', 'volatile'
universal_data: UniversalDataStream # Universal format data
# Enhanced data for comprehensive RL state building
raw_ticks: List[Dict[str, Any]] = field(default_factory=list) # Last 300s of tick data
ohlcv_data: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) # Multi-timeframe OHLCV
btc_reference_data: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) # BTC correlation data
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None # CNN hidden layer features
cnn_predictions: Optional[Dict[str, np.ndarray]] = None # CNN predictions by timeframe
pivot_points: Optional[Dict[str, Any]] = None # Williams market structure data
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Tick-level patterns
@dataclass
class PerfectMove:
@ -341,89 +351,328 @@ class EnhancedTradingOrchestrator:
return decisions
async def _get_all_market_states_universal(self, universal_stream: UniversalDataStream) -> Dict[str, MarketState]:
"""Get current market state for all symbols using universal data format"""
"""Get market states for all symbols with comprehensive data for RL"""
market_states = {}
try:
# Create market state for ETH/USDT (primary trading pair)
if 'ETH/USDT' in self.symbols:
eth_prices = {}
eth_features = {}
for symbol in self.symbols:
try:
# Basic market state data
current_prices = {}
for timeframe in self.timeframes:
# Get latest price from universal data stream
latest_price = self._get_latest_price_from_universal(symbol, timeframe, universal_stream)
if latest_price:
current_prices[timeframe] = latest_price
# Extract prices from universal stream
if len(universal_stream.eth_ticks) > 0:
eth_prices['1s'] = float(universal_stream.eth_ticks[-1, 4]) # Close price from ticks
if len(universal_stream.eth_1m) > 0:
eth_prices['1m'] = float(universal_stream.eth_1m[-1, 4]) # Close price from 1m
if len(universal_stream.eth_1h) > 0:
eth_prices['1h'] = float(universal_stream.eth_1h[-1, 4]) # Close price from 1h
if len(universal_stream.eth_1d) > 0:
eth_prices['1d'] = float(universal_stream.eth_1d[-1, 4]) # Close price from 1d
# Calculate basic metrics
volatility = self._calculate_volatility_from_universal(symbol, universal_stream)
volume = self._calculate_volume_from_universal(symbol, universal_stream)
trend_strength = self._calculate_trend_strength_from_universal(symbol, universal_stream)
market_regime = self._determine_market_regime(symbol, universal_stream)
# Extract features from universal stream (OHLCV data)
eth_features['1s'] = universal_stream.eth_ticks[:, 1:] if universal_stream.eth_ticks.shape[1] > 5 else universal_stream.eth_ticks
eth_features['1m'] = universal_stream.eth_1m[:, 1:] if universal_stream.eth_1m.shape[1] > 5 else universal_stream.eth_1m
eth_features['1h'] = universal_stream.eth_1h[:, 1:] if universal_stream.eth_1h.shape[1] > 5 else universal_stream.eth_1h
eth_features['1d'] = universal_stream.eth_1d[:, 1:] if universal_stream.eth_1d.shape[1] > 5 else universal_stream.eth_1d
# Get comprehensive data for RL state building
raw_ticks = self._get_recent_tick_data_for_rl(symbol)
ohlcv_data = self._get_multiframe_ohlcv_for_rl(symbol)
btc_reference_data = self._get_multiframe_ohlcv_for_rl('BTC/USDT')
# Calculate market metrics
volatility = self._calculate_volatility_from_universal('ETH/USDT', universal_stream)
volume = self._get_current_volume_from_universal('ETH/USDT', universal_stream)
trend_strength = self._calculate_trend_strength_from_universal('ETH/USDT', universal_stream)
market_regime = self._determine_market_regime_from_universal('ETH/USDT', universal_stream)
# Get CNN features if available
cnn_hidden_features, cnn_predictions = self._get_cnn_features_for_rl(symbol)
eth_market_state = MarketState(
symbol='ETH/USDT',
timestamp=universal_stream.timestamp,
prices=eth_prices,
features=eth_features,
# Calculate pivot points
pivot_points = self._calculate_pivot_points_for_rl(ohlcv_data)
# Analyze market microstructure
market_microstructure = self._analyze_market_microstructure(raw_ticks)
# Create comprehensive market state
market_state = MarketState(
symbol=symbol,
timestamp=datetime.now(),
prices=current_prices,
features={}, # Will be populated by feature extraction
volatility=volatility,
volume=volume,
trend_strength=trend_strength,
market_regime=market_regime,
universal_data=universal_stream
universal_data=universal_stream,
raw_ticks=raw_ticks,
ohlcv_data=ohlcv_data,
btc_reference_data=btc_reference_data,
cnn_hidden_features=cnn_hidden_features,
cnn_predictions=cnn_predictions,
pivot_points=pivot_points,
market_microstructure=market_microstructure
)
market_states['ETH/USDT'] = eth_market_state
self.market_states['ETH/USDT'].append(eth_market_state)
# Create market state for BTC/USDT (reference pair)
if 'BTC/USDT' in self.symbols:
btc_prices = {}
btc_features = {}
market_states[symbol] = market_state
logger.debug(f"Created comprehensive market state for {symbol} with {len(raw_ticks)} ticks")
# Extract BTC reference data
if len(universal_stream.btc_ticks) > 0:
btc_prices['1s'] = float(universal_stream.btc_ticks[-1, 4]) # Close price from BTC ticks
except Exception as e:
logger.error(f"Error creating market state for {symbol}: {e}")
btc_features['1s'] = universal_stream.btc_ticks[:, 1:] if universal_stream.btc_ticks.shape[1] > 5 else universal_stream.btc_ticks
# Calculate BTC metrics
btc_volatility = self._calculate_volatility_from_universal('BTC/USDT', universal_stream)
btc_volume = self._get_current_volume_from_universal('BTC/USDT', universal_stream)
btc_trend_strength = self._calculate_trend_strength_from_universal('BTC/USDT', universal_stream)
btc_market_regime = self._determine_market_regime_from_universal('BTC/USDT', universal_stream)
btc_market_state = MarketState(
symbol='BTC/USDT',
timestamp=universal_stream.timestamp,
prices=btc_prices,
features=btc_features,
volatility=btc_volatility,
volume=btc_volume,
trend_strength=btc_trend_strength,
market_regime=btc_market_regime,
universal_data=universal_stream
)
market_states['BTC/USDT'] = btc_market_state
self.market_states['BTC/USDT'].append(btc_market_state)
except Exception as e:
logger.error(f"Error creating market states from universal data: {e}")
return market_states
def _get_recent_tick_data_for_rl(self, symbol: str, seconds: int = 300) -> List[Dict[str, Any]]:
"""Get recent tick data for RL state building"""
try:
# Get ticks from data provider
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=seconds * 10)
# Convert to required format
tick_data = []
for tick in recent_ticks[-300:]: # Last 300 ticks max (300s at ~1 tick/sec)
tick_dict = {
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'quantity': getattr(tick, 'quantity', tick.volume),
'side': getattr(tick, 'side', 'unknown'),
'trade_id': getattr(tick, 'trade_id', 'unknown'),
'is_buyer_maker': getattr(tick, 'is_buyer_maker', False)
}
tick_data.append(tick_dict)
return tick_data
except Exception as e:
logger.warning(f"Error getting tick data for {symbol}: {e}")
return []
def _get_multiframe_ohlcv_for_rl(self, symbol: str) -> Dict[str, List[Dict[str, Any]]]:
"""Get multi-timeframe OHLCV data for RL state building"""
try:
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
try:
# Get historical data for timeframe
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=tf,
limit=300,
refresh=True
)
if df is not None and not df.empty:
# Convert to list of dictionaries with technical indicators
bars = []
# Add technical indicators
df_with_indicators = self._add_technical_indicators(df)
for idx, row in df_with_indicators.tail(300).iterrows():
bar = {
'timestamp': idx if hasattr(idx, 'timestamp') else datetime.now(),
'open': float(row.get('open', 0)),
'high': float(row.get('high', 0)),
'low': float(row.get('low', 0)),
'close': float(row.get('close', 0)),
'volume': float(row.get('volume', 0)),
'rsi': float(row.get('rsi', 50)),
'macd': float(row.get('macd', 0)),
'bb_upper': float(row.get('bb_upper', row.get('close', 0))),
'bb_lower': float(row.get('bb_lower', row.get('close', 0))),
'sma_20': float(row.get('sma_20', row.get('close', 0))),
'ema_12': float(row.get('ema_12', row.get('close', 0))),
'atr': float(row.get('atr', 0))
}
bars.append(bar)
ohlcv_data[tf] = bars
else:
ohlcv_data[tf] = []
except Exception as e:
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
ohlcv_data[tf] = []
return ohlcv_data
except Exception as e:
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
return {}
def _add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add technical indicators to OHLCV data"""
try:
df = df.copy()
# RSI
if len(df) >= 14:
df['rsi'] = ta.momentum.rsi(df['close'], window=14)
else:
df['rsi'] = 50
# MACD
if len(df) >= 26:
macd = ta.trend.macd_diff(df['close'])
df['macd'] = macd
else:
df['macd'] = 0
# Bollinger Bands
if len(df) >= 20:
bb = ta.volatility.BollingerBands(df['close'], window=20)
df['bb_upper'] = bb.bollinger_hband()
df['bb_lower'] = bb.bollinger_lband()
else:
df['bb_upper'] = df['close']
df['bb_lower'] = df['close']
# Moving Averages
if len(df) >= 20:
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
else:
df['sma_20'] = df['close']
if len(df) >= 12:
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
else:
df['ema_12'] = df['close']
# ATR
if len(df) >= 14:
df['atr'] = ta.volatility.average_true_range(df['high'], df['low'], df['close'], window=14)
else:
df['atr'] = 0
return df
except Exception as e:
logger.warning(f"Error adding technical indicators: {e}")
return df
def _get_cnn_features_for_rl(self, symbol: str) -> Tuple[Optional[Dict[str, np.ndarray]], Optional[Dict[str, np.ndarray]]]:
"""Get CNN hidden features and predictions for RL state building"""
try:
# Try to get CNN features from model registry
if hasattr(self, 'model_registry') and self.model_registry:
cnn_models = self.model_registry.get_models_by_type('cnn')
if cnn_models:
hidden_features = {}
predictions = {}
for model_name, model in cnn_models.items():
try:
# Get recent market data for the model
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=['1s', '1m', '1h', '1d'],
window_size=50
)
if feature_matrix is not None:
# Extract hidden features and predictions
model_hidden, model_pred = self._extract_cnn_features(model, feature_matrix)
if model_hidden is not None:
hidden_features[model_name] = model_hidden
if model_pred is not None:
predictions[model_name] = model_pred
except Exception as e:
logger.warning(f"Error getting features from CNN model {model_name}: {e}")
return hidden_features if hidden_features else None, predictions if predictions else None
return None, None
except Exception as e:
logger.warning(f"Error getting CNN features for {symbol}: {e}")
return None, None
def _extract_cnn_features(self, model, feature_matrix: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Extract hidden features and predictions from CNN model"""
try:
# This would need to be implemented based on your specific CNN architecture
# For now, return placeholder values
# Mock hidden features (would be extracted from model's hidden layers)
hidden_features = np.random.random(512).astype(np.float32)
# Mock predictions (would be model's output)
predictions = np.array([0.33, 0.33, 0.34, 0.7]).astype(np.float32) # BUY, SELL, HOLD, confidence
return hidden_features, predictions
except Exception as e:
logger.warning(f"Error extracting CNN features: {e}")
return None, None
def _calculate_pivot_points_for_rl(self, ohlcv_data: Dict[str, List]) -> Optional[Dict[str, Any]]:
"""Calculate Williams pivot points for RL state building"""
try:
if '1m' in ohlcv_data and len(ohlcv_data['1m']) >= 50:
# Use 1m data for pivot calculation
bars = ohlcv_data['1m']
# Convert to numpy array
ohlc_array = np.array([
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
for bar in bars[-200:] # Last 200 bars
])
# Calculate pivot points using Williams structure
# This would use the WilliamsMarketStructure implementation
pivot_data = {
'swing_highs': [],
'swing_lows': [],
'trend_levels': [],
'market_bias': 'neutral'
}
return pivot_data
return None
except Exception as e:
logger.warning(f"Error calculating pivot points: {e}")
return None
def _analyze_market_microstructure(self, raw_ticks: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Analyze market microstructure from tick data"""
try:
if not raw_ticks or len(raw_ticks) < 10:
return {}
# Calculate microstructure metrics
prices = [tick['price'] for tick in raw_ticks]
volumes = [tick['volume'] for tick in raw_ticks]
# Price momentum
price_momentum = (prices[-1] - prices[0]) / prices[0] if prices[0] != 0 else 0
# Volume pattern
avg_volume = sum(volumes) / len(volumes)
recent_volume = sum(volumes[-10:]) / 10 if len(volumes) >= 10 else avg_volume
volume_intensity = recent_volume / avg_volume if avg_volume != 0 else 1.0
# Tick frequency
if len(raw_ticks) >= 2:
time_diffs = []
for i in range(1, len(raw_ticks)):
if hasattr(raw_ticks[i]['timestamp'], 'timestamp') and hasattr(raw_ticks[i-1]['timestamp'], 'timestamp'):
diff = raw_ticks[i]['timestamp'].timestamp() - raw_ticks[i-1]['timestamp'].timestamp()
time_diffs.append(diff)
avg_tick_interval = sum(time_diffs) / len(time_diffs) if time_diffs else 1.0
else:
avg_tick_interval = 1.0
return {
'price_momentum': price_momentum,
'volume_intensity': volume_intensity,
'avg_tick_interval': avg_tick_interval,
'tick_count': len(raw_ticks),
'price_volatility': np.std(prices) if len(prices) > 1 else 0.0
}
except Exception as e:
logger.warning(f"Error analyzing market microstructure: {e}")
return {}
async def _get_enhanced_predictions_universal(self, symbol: str, market_state: MarketState,
universal_stream: UniversalDataStream) -> List[EnhancedPrediction]:
"""Get enhanced predictions using universal data format"""