708 lines
29 KiB
Python
708 lines
29 KiB
Python
"""
|
|
Enhanced RL State Builder for Comprehensive Market Data Integration
|
|
|
|
This module implements the specification requirements for RL training with:
|
|
- 300s of raw tick data for momentum detection
|
|
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
|
|
- CNN hidden layer features integration
|
|
- CNN predictions from all timeframes
|
|
- Pivot point predictions using Williams market structure
|
|
- Market regime analysis
|
|
|
|
State Vector Components:
|
|
- ETH tick data: ~3000 features (300s * 10 features/tick)
|
|
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
|
|
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
|
|
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
|
|
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
|
|
- BTC reference: ~2400 features (300 bars * 8 features)
|
|
- CNN features: ~512 features (hidden layer)
|
|
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
|
|
- Pivot points: ~250 features (Williams structure)
|
|
- Market regime: ~20 features
|
|
Total: ~8000+ features
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
try:
|
|
import ta
|
|
except ImportError:
|
|
logger = logging.getLogger(__name__)
|
|
logger.warning("TA-Lib not available, using pandas for technical indicators")
|
|
ta = None
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass
|
|
|
|
from core.universal_data_adapter import UniversalDataStream
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TickData:
|
|
"""Tick data structure"""
|
|
timestamp: datetime
|
|
price: float
|
|
volume: float
|
|
bid: float = 0.0
|
|
ask: float = 0.0
|
|
|
|
@property
|
|
def spread(self) -> float:
|
|
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
|
|
|
|
@dataclass
|
|
class OHLCVData:
|
|
"""OHLCV data structure"""
|
|
timestamp: datetime
|
|
open: float
|
|
high: float
|
|
low: float
|
|
close: float
|
|
volume: float
|
|
|
|
# Technical indicators (optional)
|
|
rsi: Optional[float] = None
|
|
macd: Optional[float] = None
|
|
bb_upper: Optional[float] = None
|
|
bb_lower: Optional[float] = None
|
|
sma_20: Optional[float] = None
|
|
ema_12: Optional[float] = None
|
|
atr: Optional[float] = None
|
|
|
|
@dataclass
|
|
class StateComponentConfig:
|
|
"""Configuration for state component sizes"""
|
|
eth_ticks: int = 3000 # 300s * 10 features per tick
|
|
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
|
|
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
|
|
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
|
|
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
|
|
btc_reference: int = 2400 # BTC reference data
|
|
cnn_features: int = 512 # CNN hidden layer features
|
|
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
|
|
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
|
|
market_regime: int = 20 # Market regime features
|
|
|
|
@property
|
|
def total_size(self) -> int:
|
|
"""Calculate total state size"""
|
|
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
|
|
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
|
|
self.cnn_features + self.cnn_predictions + self.pivot_points +
|
|
self.market_regime)
|
|
|
|
class EnhancedRLStateBuilder:
|
|
"""
|
|
Comprehensive RL state builder implementing specification requirements
|
|
|
|
Features:
|
|
- 300s tick data processing with momentum detection
|
|
- Multi-timeframe OHLCV integration
|
|
- CNN hidden layer feature extraction
|
|
- Pivot point calculation and integration
|
|
- Market regime analysis
|
|
- BTC reference data processing
|
|
"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
|
|
# Data windows
|
|
self.tick_window_seconds = 300 # 5 minutes of tick data
|
|
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
|
|
|
|
# State component sizes
|
|
self.state_components = {
|
|
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
|
|
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
|
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
|
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
|
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
|
'btc_reference': 300 * 8, # 2400 features: BTC reference data
|
|
'cnn_features': 512, # 512 features: CNN hidden layer
|
|
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
|
|
'pivot_points': 250, # 250 features: Williams market structure
|
|
'market_regime': 20 # 20 features: Market regime indicators
|
|
}
|
|
|
|
self.total_state_size = sum(self.state_components.values())
|
|
|
|
# Data buffers for maintaining windows
|
|
self.tick_buffers = {}
|
|
self.ohlcv_buffers = {}
|
|
|
|
# Normalization parameters
|
|
self.normalization_params = self._initialize_normalization_params()
|
|
|
|
# Feature extractors
|
|
self.momentum_detector = TickMomentumDetector()
|
|
self.indicator_calculator = TechnicalIndicatorCalculator()
|
|
self.regime_analyzer = MarketRegimeAnalyzer()
|
|
|
|
logger.info(f"Enhanced RL State Builder initialized")
|
|
logger.info(f"Total state size: {self.total_state_size} features")
|
|
logger.info(f"State components: {self.state_components}")
|
|
|
|
def build_rl_state(self,
|
|
eth_ticks: List[TickData],
|
|
eth_ohlcv: Dict[str, List[OHLCVData]],
|
|
btc_ohlcv: Dict[str, List[OHLCVData]],
|
|
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
|
|
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
|
|
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
|
"""
|
|
Build comprehensive RL state vector from all data sources
|
|
|
|
Args:
|
|
eth_ticks: List of ETH tick data (last 300s)
|
|
eth_ohlcv: Dict of ETH OHLCV data by timeframe
|
|
btc_ohlcv: Dict of BTC OHLCV data by timeframe
|
|
cnn_hidden_features: CNN hidden layer features by timeframe
|
|
cnn_predictions: CNN predictions by timeframe
|
|
pivot_data: Pivot point data from Williams analysis
|
|
|
|
Returns:
|
|
np.ndarray: Comprehensive state vector (~8000+ features)
|
|
"""
|
|
try:
|
|
state_vector = []
|
|
|
|
# 1. Process ETH tick data (3000 features)
|
|
tick_features = self._process_tick_data(eth_ticks)
|
|
state_vector.extend(tick_features)
|
|
|
|
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
|
|
for timeframe in ['1s', '1m', '1h', '1d']:
|
|
if timeframe in eth_ohlcv:
|
|
ohlcv_features = self._process_ohlcv_data(
|
|
eth_ohlcv[timeframe], timeframe, symbol='ETH'
|
|
)
|
|
else:
|
|
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
|
|
state_vector.extend(ohlcv_features)
|
|
|
|
# 3. Process BTC reference data (2400 features)
|
|
btc_features = self._process_btc_reference_data(btc_ohlcv)
|
|
state_vector.extend(btc_features)
|
|
|
|
# 4. Process CNN hidden layer features (512 features)
|
|
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
|
state_vector.extend(cnn_hidden)
|
|
|
|
# 5. Process CNN predictions (16 features)
|
|
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
|
state_vector.extend(cnn_pred)
|
|
|
|
# 6. Process pivot points (250 features)
|
|
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
|
|
state_vector.extend(pivot_features)
|
|
|
|
# 7. Process market regime features (20 features)
|
|
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
|
|
state_vector.extend(regime_features)
|
|
|
|
# Convert to numpy array and validate size
|
|
state_array = np.array(state_vector, dtype=np.float32)
|
|
|
|
if len(state_array) != self.total_state_size:
|
|
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
|
|
# Pad or truncate to expected size
|
|
if len(state_array) < self.total_state_size:
|
|
padding = np.zeros(self.total_state_size - len(state_array))
|
|
state_array = np.concatenate([state_array, padding])
|
|
else:
|
|
state_array = state_array[:self.total_state_size]
|
|
|
|
# Apply normalization
|
|
state_array = self._normalize_state(state_array)
|
|
|
|
return state_array
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building RL state: {e}")
|
|
# Return zero state on error
|
|
return np.zeros(self.total_state_size, dtype=np.float32)
|
|
|
|
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
|
|
"""Process raw tick data into features for momentum detection"""
|
|
features = []
|
|
|
|
if not ticks or len(ticks) < 10:
|
|
# Return zeros if insufficient data
|
|
return [0.0] * self.state_components['eth_ticks']
|
|
|
|
# Ensure we have exactly 300 data points (pad or sample)
|
|
processed_ticks = self._normalize_tick_window(ticks, 300)
|
|
|
|
for i, tick in enumerate(processed_ticks):
|
|
# Basic tick features
|
|
tick_features = [
|
|
tick.price,
|
|
tick.volume,
|
|
tick.bid,
|
|
tick.ask,
|
|
tick.spread
|
|
]
|
|
|
|
# Derived features
|
|
if i > 0:
|
|
prev_tick = processed_ticks[i-1]
|
|
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
|
|
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
|
|
|
|
tick_features.extend([
|
|
price_change,
|
|
volume_change,
|
|
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
|
|
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
|
|
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
|
|
])
|
|
else:
|
|
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
|
|
|
features.extend(tick_features)
|
|
|
|
return features[:self.state_components['eth_ticks']]
|
|
|
|
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
|
|
timeframe: str, symbol: str = 'ETH') -> List[float]:
|
|
"""Process OHLCV data with technical indicators"""
|
|
features = []
|
|
|
|
if not ohlcv_data or len(ohlcv_data) < 20:
|
|
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
|
return [0.0] * self.state_components[component_key]
|
|
|
|
# Convert to DataFrame for indicator calculation
|
|
df = pd.DataFrame([{
|
|
'timestamp': bar.timestamp,
|
|
'open': bar.open,
|
|
'high': bar.high,
|
|
'low': bar.low,
|
|
'close': bar.close,
|
|
'volume': bar.volume
|
|
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
|
|
|
|
# Calculate technical indicators
|
|
df = self.indicator_calculator.add_all_indicators(df)
|
|
|
|
# Ensure we have exactly 300 bars
|
|
if len(df) < 300:
|
|
# Pad with last known values
|
|
last_row = df.iloc[-1:].copy()
|
|
padding_rows = []
|
|
for _ in range(300 - len(df)):
|
|
padding_rows.append(last_row)
|
|
if padding_rows:
|
|
df = pd.concat([df] + padding_rows, ignore_index=True)
|
|
else:
|
|
df = df.tail(300)
|
|
|
|
# Extract features for each bar
|
|
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
|
|
|
|
for _, row in df.iterrows():
|
|
bar_features = []
|
|
for col in feature_columns:
|
|
if col in row and not pd.isna(row[col]):
|
|
bar_features.append(float(row[col]))
|
|
else:
|
|
bar_features.append(0.0)
|
|
features.extend(bar_features)
|
|
|
|
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
|
return features[:self.state_components[component_key]]
|
|
|
|
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
|
"""Process BTC reference data (using 1h timeframe as primary)"""
|
|
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
|
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
|
|
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
|
|
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
|
|
else:
|
|
return [0.0] * self.state_components['btc_reference']
|
|
|
|
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
|
"""Process CNN hidden layer features"""
|
|
if not cnn_features:
|
|
return [0.0] * self.state_components['cnn_features']
|
|
|
|
# Combine features from all timeframes
|
|
combined_features = []
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
|
|
|
|
for tf in timeframes:
|
|
if tf in cnn_features and cnn_features[tf] is not None:
|
|
tf_features = cnn_features[tf].flatten()
|
|
# Truncate or pad to fit allocation
|
|
if len(tf_features) >= features_per_timeframe:
|
|
combined_features.extend(tf_features[:features_per_timeframe])
|
|
else:
|
|
combined_features.extend(tf_features)
|
|
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
|
|
else:
|
|
combined_features.extend([0.0] * features_per_timeframe)
|
|
|
|
return combined_features[:self.state_components['cnn_features']]
|
|
|
|
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
|
"""Process CNN predictions from all timeframes"""
|
|
if not cnn_predictions:
|
|
return [0.0] * self.state_components['cnn_predictions']
|
|
|
|
predictions = []
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
|
|
for tf in timeframes:
|
|
if tf in cnn_predictions and cnn_predictions[tf] is not None:
|
|
pred = cnn_predictions[tf].flatten()
|
|
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
|
|
if len(pred) >= 4:
|
|
predictions.extend(pred[:4])
|
|
else:
|
|
predictions.extend(pred)
|
|
predictions.extend([0.0] * (4 - len(pred)))
|
|
else:
|
|
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
|
|
|
|
return predictions[:self.state_components['cnn_predictions']]
|
|
|
|
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
|
|
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
|
"""Process pivot points using Williams market structure"""
|
|
if pivot_data:
|
|
# Use provided pivot data
|
|
return self._extract_pivot_features(pivot_data)
|
|
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
|
# Calculate pivot points from 1m data
|
|
from training.williams_market_structure import WilliamsMarketStructure
|
|
williams = WilliamsMarketStructure()
|
|
|
|
# Convert OHLCV to numpy array
|
|
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
|
|
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
|
|
return self._extract_pivot_features(pivot_data)
|
|
else:
|
|
return [0.0] * self.state_components['pivot_points']
|
|
|
|
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
|
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
|
"""Process market regime indicators"""
|
|
regime_features = []
|
|
|
|
# ETH regime analysis
|
|
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
|
|
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
|
|
regime_features.extend([
|
|
eth_regime['volatility'],
|
|
eth_regime['trend_strength'],
|
|
eth_regime['volume_trend'],
|
|
eth_regime['momentum'],
|
|
1.0 if eth_regime['regime'] == 'trending' else 0.0,
|
|
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
|
|
1.0 if eth_regime['regime'] == 'volatile' else 0.0
|
|
])
|
|
else:
|
|
regime_features.extend([0.0] * 7)
|
|
|
|
# BTC regime analysis
|
|
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
|
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
|
|
regime_features.extend([
|
|
btc_regime['volatility'],
|
|
btc_regime['trend_strength'],
|
|
btc_regime['volume_trend'],
|
|
btc_regime['momentum'],
|
|
1.0 if btc_regime['regime'] == 'trending' else 0.0,
|
|
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
|
|
1.0 if btc_regime['regime'] == 'volatile' else 0.0
|
|
])
|
|
else:
|
|
regime_features.extend([0.0] * 7)
|
|
|
|
# Correlation features
|
|
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
|
|
regime_features.extend(correlation_features)
|
|
|
|
return regime_features[:self.state_components['market_regime']]
|
|
|
|
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
|
|
"""Normalize tick window to target size"""
|
|
if len(ticks) == target_size:
|
|
return ticks
|
|
elif len(ticks) > target_size:
|
|
# Sample evenly
|
|
step = len(ticks) / target_size
|
|
indices = [int(i * step) for i in range(target_size)]
|
|
return [ticks[i] for i in indices]
|
|
else:
|
|
# Pad with last tick
|
|
result = ticks.copy()
|
|
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
|
|
while len(result) < target_size:
|
|
result.append(last_tick)
|
|
return result
|
|
|
|
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
|
|
"""Extract features from pivot point data"""
|
|
features = []
|
|
|
|
for level in range(5): # 5 levels of recursion
|
|
level_key = f'level_{level}'
|
|
if level_key in pivot_data:
|
|
level_data = pivot_data[level_key]
|
|
|
|
# Swing point features
|
|
swing_points = level_data.get('swing_points', [])
|
|
if swing_points:
|
|
# Last 10 swing points
|
|
recent_swings = swing_points[-10:]
|
|
for swing in recent_swings:
|
|
features.extend([
|
|
swing['price'],
|
|
1.0 if swing['type'] == 'swing_high' else 0.0,
|
|
swing['index']
|
|
])
|
|
|
|
# Pad if fewer than 10 swings
|
|
while len(recent_swings) < 10:
|
|
features.extend([0.0, 0.0, 0.0])
|
|
recent_swings.append({'type': 'none'})
|
|
else:
|
|
features.extend([0.0] * 30) # 10 swings * 3 features
|
|
|
|
# Trend features
|
|
features.extend([
|
|
level_data.get('trend_strength', 0.0),
|
|
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
|
|
1.0 if level_data.get('trend_direction') == 'down' else 0.0
|
|
])
|
|
else:
|
|
features.extend([0.0] * 33) # 30 swing + 3 trend features
|
|
|
|
return features[:self.state_components['pivot_points']]
|
|
|
|
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
|
|
"""Convert OHLCV data to numpy array"""
|
|
return np.array([[
|
|
bar.timestamp.timestamp(),
|
|
bar.open,
|
|
bar.high,
|
|
bar.low,
|
|
bar.close,
|
|
bar.volume
|
|
] for bar in ohlcv_data])
|
|
|
|
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
|
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
|
"""Calculate BTC-ETH correlation features"""
|
|
try:
|
|
# Use 1h data for correlation
|
|
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
|
|
return [0.0] * 6
|
|
|
|
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
|
|
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
|
|
|
|
if len(eth_prices) < 10 or len(btc_prices) < 10:
|
|
return [0.0] * 6
|
|
|
|
# Align lengths
|
|
min_len = min(len(eth_prices), len(btc_prices))
|
|
eth_prices = eth_prices[-min_len:]
|
|
btc_prices = btc_prices[-min_len:]
|
|
|
|
# Calculate returns
|
|
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
|
|
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
|
|
|
|
# Correlation
|
|
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
|
|
|
|
# Price ratio
|
|
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
|
|
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
|
|
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
|
|
|
|
# Volatility comparison
|
|
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
|
|
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
|
|
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
|
|
|
|
return [
|
|
correlation,
|
|
current_ratio,
|
|
ratio_deviation,
|
|
vol_ratio,
|
|
eth_vol,
|
|
btc_vol
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
|
|
return [0.0] * 6
|
|
|
|
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
|
|
"""Initialize normalization parameters for different feature types"""
|
|
return {
|
|
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
|
|
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
|
|
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
|
|
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
|
|
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
|
|
}
|
|
|
|
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
|
"""Apply normalization to state vector"""
|
|
try:
|
|
# Simple clipping and scaling for now
|
|
# More sophisticated normalization can be added based on training data
|
|
normalized_state = np.clip(state, -10.0, 10.0)
|
|
|
|
# Replace any NaN or inf values
|
|
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
|
|
|
|
return normalized_state.astype(np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error normalizing state: {e}")
|
|
return state.astype(np.float32)
|
|
|
|
class TickMomentumDetector:
|
|
"""Detect momentum from tick-level data"""
|
|
|
|
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
|
|
"""Calculate micro-momentum from tick sequence"""
|
|
if len(ticks) < 2:
|
|
return 0.0
|
|
|
|
# Price momentum
|
|
prices = [tick.price for tick in ticks]
|
|
price_changes = np.diff(prices)
|
|
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
|
|
|
|
# Volume-weighted momentum
|
|
volumes = [tick.volume for tick in ticks]
|
|
if sum(volumes) > 0:
|
|
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
|
|
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
|
|
else:
|
|
volume_momentum = 0.0
|
|
|
|
return (price_momentum + volume_momentum) / 2.0
|
|
|
|
class TechnicalIndicatorCalculator:
|
|
"""Calculate technical indicators for OHLCV data"""
|
|
|
|
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Add all technical indicators to DataFrame"""
|
|
df = df.copy()
|
|
|
|
# RSI
|
|
df['rsi'] = self.calculate_rsi(df['close'])
|
|
|
|
# MACD
|
|
df['macd'] = self.calculate_macd(df['close'])
|
|
|
|
# Bollinger Bands
|
|
df['bb_middle'] = df['close'].rolling(20).mean()
|
|
df['bb_std'] = df['close'].rolling(20).std()
|
|
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
|
|
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
|
|
|
|
# Fill NaN values
|
|
df = df.fillna(method='forward').fillna(0)
|
|
|
|
return df
|
|
|
|
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
|
"""Calculate RSI"""
|
|
delta = prices.diff()
|
|
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
|
rs = gain / loss
|
|
rsi = 100 - (100 / (1 + rs))
|
|
return rsi.fillna(50)
|
|
|
|
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
|
|
"""Calculate MACD"""
|
|
ema_fast = prices.ewm(span=fast).mean()
|
|
ema_slow = prices.ewm(span=slow).mean()
|
|
macd = ema_fast - ema_slow
|
|
return macd.fillna(0)
|
|
|
|
class MarketRegimeAnalyzer:
|
|
"""Analyze market regime from OHLCV data"""
|
|
|
|
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
|
|
"""Analyze market regime"""
|
|
if len(ohlcv_data) < 20:
|
|
return {
|
|
'regime': 'unknown',
|
|
'volatility': 0.0,
|
|
'trend_strength': 0.0,
|
|
'volume_trend': 0.0,
|
|
'momentum': 0.0
|
|
}
|
|
|
|
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
|
|
volumes = [bar.volume for bar in ohlcv_data[-50:]]
|
|
|
|
# Calculate volatility
|
|
returns = np.diff(prices) / prices[:-1]
|
|
volatility = np.std(returns) * 100 # Percentage volatility
|
|
|
|
# Calculate trend strength
|
|
sma_short = np.mean(prices[-10:])
|
|
sma_long = np.mean(prices[-30:])
|
|
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
|
|
|
|
# Volume trend
|
|
volume_ma_short = np.mean(volumes[-10:])
|
|
volume_ma_long = np.mean(volumes[-30:])
|
|
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
|
|
|
|
# Momentum
|
|
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
|
|
|
|
# Determine regime
|
|
if volatility > 3.0: # High volatility
|
|
regime = 'volatile'
|
|
elif abs(momentum) > 0.02: # Strong momentum
|
|
regime = 'trending'
|
|
else:
|
|
regime = 'ranging'
|
|
|
|
return {
|
|
'regime': regime,
|
|
'volatility': volatility,
|
|
'trend_strength': trend_strength,
|
|
'volume_trend': volume_trend,
|
|
'momentum': momentum
|
|
}
|
|
|
|
def get_state_info(self) -> Dict[str, Any]:
|
|
"""Get information about the state structure"""
|
|
return {
|
|
'total_size': self.config.total_size,
|
|
'components': {
|
|
'eth_ticks': self.config.eth_ticks,
|
|
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
|
|
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
|
|
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
|
|
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
|
|
'btc_reference': self.config.btc_reference,
|
|
'cnn_features': self.config.cnn_features,
|
|
'cnn_predictions': self.config.cnn_predictions,
|
|
'pivot_points': self.config.pivot_points,
|
|
'market_regime': self.config.market_regime,
|
|
},
|
|
'data_windows': {
|
|
'tick_window_seconds': self.tick_window_seconds,
|
|
'ohlcv_window_bars': self.ohlcv_window_bars,
|
|
}
|
|
} |