LR module possibly working
This commit is contained in:
708
training/enhanced_rl_state_builder.py
Normal file
708
training/enhanced_rl_state_builder.py
Normal file
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
Enhanced RL State Builder for Comprehensive Market Data Integration
|
||||
|
||||
This module implements the specification requirements for RL training with:
|
||||
- 300s of raw tick data for momentum detection
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
|
||||
- CNN hidden layer features integration
|
||||
- CNN predictions from all timeframes
|
||||
- Pivot point predictions using Williams market structure
|
||||
- Market regime analysis
|
||||
|
||||
State Vector Components:
|
||||
- ETH tick data: ~3000 features (300s * 10 features/tick)
|
||||
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
|
||||
- BTC reference: ~2400 features (300 bars * 8 features)
|
||||
- CNN features: ~512 features (hidden layer)
|
||||
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
|
||||
- Pivot points: ~250 features (Williams structure)
|
||||
- Market regime: ~20 features
|
||||
Total: ~8000+ features
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
try:
|
||||
import ta
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("TA-Lib not available, using pandas for technical indicators")
|
||||
ta = None
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.universal_data_adapter import UniversalDataStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
bid: float = 0.0
|
||||
ask: float = 0.0
|
||||
|
||||
@property
|
||||
def spread(self) -> float:
|
||||
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""OHLCV data structure"""
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
# Technical indicators (optional)
|
||||
rsi: Optional[float] = None
|
||||
macd: Optional[float] = None
|
||||
bb_upper: Optional[float] = None
|
||||
bb_lower: Optional[float] = None
|
||||
sma_20: Optional[float] = None
|
||||
ema_12: Optional[float] = None
|
||||
atr: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class StateComponentConfig:
|
||||
"""Configuration for state component sizes"""
|
||||
eth_ticks: int = 3000 # 300s * 10 features per tick
|
||||
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
|
||||
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
btc_reference: int = 2400 # BTC reference data
|
||||
cnn_features: int = 512 # CNN hidden layer features
|
||||
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
|
||||
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
|
||||
market_regime: int = 20 # Market regime features
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
"""Calculate total state size"""
|
||||
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
|
||||
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
|
||||
self.cnn_features + self.cnn_predictions + self.pivot_points +
|
||||
self.market_regime)
|
||||
|
||||
class EnhancedRLStateBuilder:
|
||||
"""
|
||||
Comprehensive RL state builder implementing specification requirements
|
||||
|
||||
Features:
|
||||
- 300s tick data processing with momentum detection
|
||||
- Multi-timeframe OHLCV integration
|
||||
- CNN hidden layer feature extraction
|
||||
- Pivot point calculation and integration
|
||||
- Market regime analysis
|
||||
- BTC reference data processing
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
# Data windows
|
||||
self.tick_window_seconds = 300 # 5 minutes of tick data
|
||||
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
|
||||
|
||||
# State component sizes
|
||||
self.state_components = {
|
||||
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
|
||||
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'btc_reference': 300 * 8, # 2400 features: BTC reference data
|
||||
'cnn_features': 512, # 512 features: CNN hidden layer
|
||||
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
|
||||
'pivot_points': 250, # 250 features: Williams market structure
|
||||
'market_regime': 20 # 20 features: Market regime indicators
|
||||
}
|
||||
|
||||
self.total_state_size = sum(self.state_components.values())
|
||||
|
||||
# Data buffers for maintaining windows
|
||||
self.tick_buffers = {}
|
||||
self.ohlcv_buffers = {}
|
||||
|
||||
# Normalization parameters
|
||||
self.normalization_params = self._initialize_normalization_params()
|
||||
|
||||
# Feature extractors
|
||||
self.momentum_detector = TickMomentumDetector()
|
||||
self.indicator_calculator = TechnicalIndicatorCalculator()
|
||||
self.regime_analyzer = MarketRegimeAnalyzer()
|
||||
|
||||
logger.info(f"Enhanced RL State Builder initialized")
|
||||
logger.info(f"Total state size: {self.total_state_size} features")
|
||||
logger.info(f"State components: {self.state_components}")
|
||||
|
||||
def build_rl_state(self,
|
||||
eth_ticks: List[TickData],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]],
|
||||
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
|
||||
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||
"""
|
||||
Build comprehensive RL state vector from all data sources
|
||||
|
||||
Args:
|
||||
eth_ticks: List of ETH tick data (last 300s)
|
||||
eth_ohlcv: Dict of ETH OHLCV data by timeframe
|
||||
btc_ohlcv: Dict of BTC OHLCV data by timeframe
|
||||
cnn_hidden_features: CNN hidden layer features by timeframe
|
||||
cnn_predictions: CNN predictions by timeframe
|
||||
pivot_data: Pivot point data from Williams analysis
|
||||
|
||||
Returns:
|
||||
np.ndarray: Comprehensive state vector (~8000+ features)
|
||||
"""
|
||||
try:
|
||||
state_vector = []
|
||||
|
||||
# 1. Process ETH tick data (3000 features)
|
||||
tick_features = self._process_tick_data(eth_ticks)
|
||||
state_vector.extend(tick_features)
|
||||
|
||||
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
if timeframe in eth_ohlcv:
|
||||
ohlcv_features = self._process_ohlcv_data(
|
||||
eth_ohlcv[timeframe], timeframe, symbol='ETH'
|
||||
)
|
||||
else:
|
||||
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
|
||||
state_vector.extend(ohlcv_features)
|
||||
|
||||
# 3. Process BTC reference data (2400 features)
|
||||
btc_features = self._process_btc_reference_data(btc_ohlcv)
|
||||
state_vector.extend(btc_features)
|
||||
|
||||
# 4. Process CNN hidden layer features (512 features)
|
||||
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
||||
state_vector.extend(cnn_hidden)
|
||||
|
||||
# 5. Process CNN predictions (16 features)
|
||||
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
||||
state_vector.extend(cnn_pred)
|
||||
|
||||
# 6. Process pivot points (250 features)
|
||||
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
|
||||
state_vector.extend(pivot_features)
|
||||
|
||||
# 7. Process market regime features (20 features)
|
||||
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
|
||||
state_vector.extend(regime_features)
|
||||
|
||||
# Convert to numpy array and validate size
|
||||
state_array = np.array(state_vector, dtype=np.float32)
|
||||
|
||||
if len(state_array) != self.total_state_size:
|
||||
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
|
||||
# Pad or truncate to expected size
|
||||
if len(state_array) < self.total_state_size:
|
||||
padding = np.zeros(self.total_state_size - len(state_array))
|
||||
state_array = np.concatenate([state_array, padding])
|
||||
else:
|
||||
state_array = state_array[:self.total_state_size]
|
||||
|
||||
# Apply normalization
|
||||
state_array = self._normalize_state(state_array)
|
||||
|
||||
return state_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building RL state: {e}")
|
||||
# Return zero state on error
|
||||
return np.zeros(self.total_state_size, dtype=np.float32)
|
||||
|
||||
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
|
||||
"""Process raw tick data into features for momentum detection"""
|
||||
features = []
|
||||
|
||||
if not ticks or len(ticks) < 10:
|
||||
# Return zeros if insufficient data
|
||||
return [0.0] * self.state_components['eth_ticks']
|
||||
|
||||
# Ensure we have exactly 300 data points (pad or sample)
|
||||
processed_ticks = self._normalize_tick_window(ticks, 300)
|
||||
|
||||
for i, tick in enumerate(processed_ticks):
|
||||
# Basic tick features
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
tick.bid,
|
||||
tick.ask,
|
||||
tick.spread
|
||||
]
|
||||
|
||||
# Derived features
|
||||
if i > 0:
|
||||
prev_tick = processed_ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
|
||||
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_change,
|
||||
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
|
||||
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
|
||||
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
features.extend(tick_features)
|
||||
|
||||
return features[:self.state_components['eth_ticks']]
|
||||
|
||||
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
|
||||
timeframe: str, symbol: str = 'ETH') -> List[float]:
|
||||
"""Process OHLCV data with technical indicators"""
|
||||
features = []
|
||||
|
||||
if not ohlcv_data or len(ohlcv_data) < 20:
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return [0.0] * self.state_components[component_key]
|
||||
|
||||
# Convert to DataFrame for indicator calculation
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': bar.timestamp,
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume
|
||||
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
|
||||
|
||||
# Calculate technical indicators
|
||||
df = self.indicator_calculator.add_all_indicators(df)
|
||||
|
||||
# Ensure we have exactly 300 bars
|
||||
if len(df) < 300:
|
||||
# Pad with last known values
|
||||
last_row = df.iloc[-1:].copy()
|
||||
padding_rows = []
|
||||
for _ in range(300 - len(df)):
|
||||
padding_rows.append(last_row)
|
||||
if padding_rows:
|
||||
df = pd.concat([df] + padding_rows, ignore_index=True)
|
||||
else:
|
||||
df = df.tail(300)
|
||||
|
||||
# Extract features for each bar
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
|
||||
|
||||
for _, row in df.iterrows():
|
||||
bar_features = []
|
||||
for col in feature_columns:
|
||||
if col in row and not pd.isna(row[col]):
|
||||
bar_features.append(float(row[col]))
|
||||
else:
|
||||
bar_features.append(0.0)
|
||||
features.extend(bar_features)
|
||||
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return features[:self.state_components[component_key]]
|
||||
|
||||
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process BTC reference data (using 1h timeframe as primary)"""
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
|
||||
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
|
||||
else:
|
||||
return [0.0] * self.state_components['btc_reference']
|
||||
|
||||
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN hidden layer features"""
|
||||
if not cnn_features:
|
||||
return [0.0] * self.state_components['cnn_features']
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_features and cnn_features[tf] is not None:
|
||||
tf_features = cnn_features[tf].flatten()
|
||||
# Truncate or pad to fit allocation
|
||||
if len(tf_features) >= features_per_timeframe:
|
||||
combined_features.extend(tf_features[:features_per_timeframe])
|
||||
else:
|
||||
combined_features.extend(tf_features)
|
||||
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
|
||||
else:
|
||||
combined_features.extend([0.0] * features_per_timeframe)
|
||||
|
||||
return combined_features[:self.state_components['cnn_features']]
|
||||
|
||||
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN predictions from all timeframes"""
|
||||
if not cnn_predictions:
|
||||
return [0.0] * self.state_components['cnn_predictions']
|
||||
|
||||
predictions = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_predictions and cnn_predictions[tf] is not None:
|
||||
pred = cnn_predictions[tf].flatten()
|
||||
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
|
||||
if len(pred) >= 4:
|
||||
predictions.extend(pred[:4])
|
||||
else:
|
||||
predictions.extend(pred)
|
||||
predictions.extend([0.0] * (4 - len(pred)))
|
||||
else:
|
||||
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
|
||||
|
||||
return predictions[:self.state_components['cnn_predictions']]
|
||||
|
||||
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process pivot points using Williams market structure"""
|
||||
if pivot_data:
|
||||
# Use provided pivot data
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Calculate pivot points from 1m data
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
williams = WilliamsMarketStructure()
|
||||
|
||||
# Convert OHLCV to numpy array
|
||||
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
|
||||
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
else:
|
||||
return [0.0] * self.state_components['pivot_points']
|
||||
|
||||
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process market regime indicators"""
|
||||
regime_features = []
|
||||
|
||||
# ETH regime analysis
|
||||
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
|
||||
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
eth_regime['volatility'],
|
||||
eth_regime['trend_strength'],
|
||||
eth_regime['volume_trend'],
|
||||
eth_regime['momentum'],
|
||||
1.0 if eth_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# BTC regime analysis
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
btc_regime['volatility'],
|
||||
btc_regime['trend_strength'],
|
||||
btc_regime['volume_trend'],
|
||||
btc_regime['momentum'],
|
||||
1.0 if btc_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# Correlation features
|
||||
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
|
||||
regime_features.extend(correlation_features)
|
||||
|
||||
return regime_features[:self.state_components['market_regime']]
|
||||
|
||||
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
|
||||
"""Normalize tick window to target size"""
|
||||
if len(ticks) == target_size:
|
||||
return ticks
|
||||
elif len(ticks) > target_size:
|
||||
# Sample evenly
|
||||
step = len(ticks) / target_size
|
||||
indices = [int(i * step) for i in range(target_size)]
|
||||
return [ticks[i] for i in indices]
|
||||
else:
|
||||
# Pad with last tick
|
||||
result = ticks.copy()
|
||||
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
|
||||
while len(result) < target_size:
|
||||
result.append(last_tick)
|
||||
return result
|
||||
|
||||
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
|
||||
"""Extract features from pivot point data"""
|
||||
features = []
|
||||
|
||||
for level in range(5): # 5 levels of recursion
|
||||
level_key = f'level_{level}'
|
||||
if level_key in pivot_data:
|
||||
level_data = pivot_data[level_key]
|
||||
|
||||
# Swing point features
|
||||
swing_points = level_data.get('swing_points', [])
|
||||
if swing_points:
|
||||
# Last 10 swing points
|
||||
recent_swings = swing_points[-10:]
|
||||
for swing in recent_swings:
|
||||
features.extend([
|
||||
swing['price'],
|
||||
1.0 if swing['type'] == 'swing_high' else 0.0,
|
||||
swing['index']
|
||||
])
|
||||
|
||||
# Pad if fewer than 10 swings
|
||||
while len(recent_swings) < 10:
|
||||
features.extend([0.0, 0.0, 0.0])
|
||||
recent_swings.append({'type': 'none'})
|
||||
else:
|
||||
features.extend([0.0] * 30) # 10 swings * 3 features
|
||||
|
||||
# Trend features
|
||||
features.extend([
|
||||
level_data.get('trend_strength', 0.0),
|
||||
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
|
||||
1.0 if level_data.get('trend_direction') == 'down' else 0.0
|
||||
])
|
||||
else:
|
||||
features.extend([0.0] * 33) # 30 swing + 3 trend features
|
||||
|
||||
return features[:self.state_components['pivot_points']]
|
||||
|
||||
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
|
||||
"""Convert OHLCV data to numpy array"""
|
||||
return np.array([[
|
||||
bar.timestamp.timestamp(),
|
||||
bar.open,
|
||||
bar.high,
|
||||
bar.low,
|
||||
bar.close,
|
||||
bar.volume
|
||||
] for bar in ohlcv_data])
|
||||
|
||||
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Calculate BTC-ETH correlation features"""
|
||||
try:
|
||||
# Use 1h data for correlation
|
||||
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
|
||||
return [0.0] * 6
|
||||
|
||||
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
|
||||
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
|
||||
|
||||
if len(eth_prices) < 10 or len(btc_prices) < 10:
|
||||
return [0.0] * 6
|
||||
|
||||
# Align lengths
|
||||
min_len = min(len(eth_prices), len(btc_prices))
|
||||
eth_prices = eth_prices[-min_len:]
|
||||
btc_prices = btc_prices[-min_len:]
|
||||
|
||||
# Calculate returns
|
||||
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
|
||||
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
|
||||
|
||||
# Correlation
|
||||
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
|
||||
|
||||
# Price ratio
|
||||
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
|
||||
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
|
||||
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
|
||||
|
||||
# Volatility comparison
|
||||
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
|
||||
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
|
||||
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
|
||||
|
||||
return [
|
||||
correlation,
|
||||
current_ratio,
|
||||
ratio_deviation,
|
||||
vol_ratio,
|
||||
eth_vol,
|
||||
btc_vol
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
|
||||
return [0.0] * 6
|
||||
|
||||
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Initialize normalization parameters for different feature types"""
|
||||
return {
|
||||
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
|
||||
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
|
||||
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
|
||||
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
|
||||
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
|
||||
}
|
||||
|
||||
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Apply normalization to state vector"""
|
||||
try:
|
||||
# Simple clipping and scaling for now
|
||||
# More sophisticated normalization can be added based on training data
|
||||
normalized_state = np.clip(state, -10.0, 10.0)
|
||||
|
||||
# Replace any NaN or inf values
|
||||
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return normalized_state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing state: {e}")
|
||||
return state.astype(np.float32)
|
||||
|
||||
class TickMomentumDetector:
|
||||
"""Detect momentum from tick-level data"""
|
||||
|
||||
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
|
||||
"""Calculate micro-momentum from tick sequence"""
|
||||
if len(ticks) < 2:
|
||||
return 0.0
|
||||
|
||||
# Price momentum
|
||||
prices = [tick.price for tick in ticks]
|
||||
price_changes = np.diff(prices)
|
||||
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
|
||||
|
||||
# Volume-weighted momentum
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
if sum(volumes) > 0:
|
||||
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
|
||||
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
|
||||
else:
|
||||
volume_momentum = 0.0
|
||||
|
||||
return (price_momentum + volume_momentum) / 2.0
|
||||
|
||||
class TechnicalIndicatorCalculator:
|
||||
"""Calculate technical indicators for OHLCV data"""
|
||||
|
||||
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add all technical indicators to DataFrame"""
|
||||
df = df.copy()
|
||||
|
||||
# RSI
|
||||
df['rsi'] = self.calculate_rsi(df['close'])
|
||||
|
||||
# MACD
|
||||
df['macd'] = self.calculate_macd(df['close'])
|
||||
|
||||
# Bollinger Bands
|
||||
df['bb_middle'] = df['close'].rolling(20).mean()
|
||||
df['bb_std'] = df['close'].rolling(20).std()
|
||||
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
|
||||
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
|
||||
|
||||
# Fill NaN values
|
||||
df = df.fillna(method='forward').fillna(0)
|
||||
|
||||
return df
|
||||
|
||||
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Calculate RSI"""
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.fillna(50)
|
||||
|
||||
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||
"""Calculate MACD"""
|
||||
ema_fast = prices.ewm(span=fast).mean()
|
||||
ema_slow = prices.ewm(span=slow).mean()
|
||||
macd = ema_fast - ema_slow
|
||||
return macd.fillna(0)
|
||||
|
||||
class MarketRegimeAnalyzer:
|
||||
"""Analyze market regime from OHLCV data"""
|
||||
|
||||
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
|
||||
"""Analyze market regime"""
|
||||
if len(ohlcv_data) < 20:
|
||||
return {
|
||||
'regime': 'unknown',
|
||||
'volatility': 0.0,
|
||||
'trend_strength': 0.0,
|
||||
'volume_trend': 0.0,
|
||||
'momentum': 0.0
|
||||
}
|
||||
|
||||
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
|
||||
volumes = [bar.volume for bar in ohlcv_data[-50:]]
|
||||
|
||||
# Calculate volatility
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * 100 # Percentage volatility
|
||||
|
||||
# Calculate trend strength
|
||||
sma_short = np.mean(prices[-10:])
|
||||
sma_long = np.mean(prices[-30:])
|
||||
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
|
||||
|
||||
# Volume trend
|
||||
volume_ma_short = np.mean(volumes[-10:])
|
||||
volume_ma_long = np.mean(volumes[-30:])
|
||||
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
|
||||
|
||||
# Momentum
|
||||
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
|
||||
|
||||
# Determine regime
|
||||
if volatility > 3.0: # High volatility
|
||||
regime = 'volatile'
|
||||
elif abs(momentum) > 0.02: # Strong momentum
|
||||
regime = 'trending'
|
||||
else:
|
||||
regime = 'ranging'
|
||||
|
||||
return {
|
||||
'regime': regime,
|
||||
'volatility': volatility,
|
||||
'trend_strength': trend_strength,
|
||||
'volume_trend': volume_trend,
|
||||
'momentum': momentum
|
||||
}
|
||||
|
||||
def get_state_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the state structure"""
|
||||
return {
|
||||
'total_size': self.config.total_size,
|
||||
'components': {
|
||||
'eth_ticks': self.config.eth_ticks,
|
||||
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
|
||||
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
|
||||
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
|
||||
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
|
||||
'btc_reference': self.config.btc_reference,
|
||||
'cnn_features': self.config.cnn_features,
|
||||
'cnn_predictions': self.config.cnn_predictions,
|
||||
'pivot_points': self.config.pivot_points,
|
||||
'market_regime': self.config.market_regime,
|
||||
},
|
||||
'data_windows': {
|
||||
'tick_window_seconds': self.tick_window_seconds,
|
||||
'ohlcv_window_bars': self.ohlcv_window_bars,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user