gogo2/training/enhanced_rl_state_builder.py
2025-05-28 23:42:06 +03:00

708 lines
29 KiB
Python

"""
Enhanced RL State Builder for Comprehensive Market Data Integration
This module implements the specification requirements for RL training with:
- 300s of raw tick data for momentum detection
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
- CNN hidden layer features integration
- CNN predictions from all timeframes
- Pivot point predictions using Williams market structure
- Market regime analysis
State Vector Components:
- ETH tick data: ~3000 features (300s * 10 features/tick)
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
- BTC reference: ~2400 features (300 bars * 8 features)
- CNN features: ~512 features (hidden layer)
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
- Pivot points: ~250 features (Williams structure)
- Market regime: ~20 features
Total: ~8000+ features
"""
import logging
import numpy as np
import pandas as pd
try:
import ta
except ImportError:
logger = logging.getLogger(__name__)
logger.warning("TA-Lib not available, using pandas for technical indicators")
ta = None
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
from dataclasses import dataclass
from core.universal_data_adapter import UniversalDataStream
logger = logging.getLogger(__name__)
@dataclass
class TickData:
"""Tick data structure"""
timestamp: datetime
price: float
volume: float
bid: float = 0.0
ask: float = 0.0
@property
def spread(self) -> float:
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
@dataclass
class OHLCVData:
"""OHLCV data structure"""
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
# Technical indicators (optional)
rsi: Optional[float] = None
macd: Optional[float] = None
bb_upper: Optional[float] = None
bb_lower: Optional[float] = None
sma_20: Optional[float] = None
ema_12: Optional[float] = None
atr: Optional[float] = None
@dataclass
class StateComponentConfig:
"""Configuration for state component sizes"""
eth_ticks: int = 3000 # 300s * 10 features per tick
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
btc_reference: int = 2400 # BTC reference data
cnn_features: int = 512 # CNN hidden layer features
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
market_regime: int = 20 # Market regime features
@property
def total_size(self) -> int:
"""Calculate total state size"""
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
self.cnn_features + self.cnn_predictions + self.pivot_points +
self.market_regime)
class EnhancedRLStateBuilder:
"""
Comprehensive RL state builder implementing specification requirements
Features:
- 300s tick data processing with momentum detection
- Multi-timeframe OHLCV integration
- CNN hidden layer feature extraction
- Pivot point calculation and integration
- Market regime analysis
- BTC reference data processing
"""
def __init__(self, config: Dict[str, Any]):
self.config = config
# Data windows
self.tick_window_seconds = 300 # 5 minutes of tick data
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
# State component sizes
self.state_components = {
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'btc_reference': 300 * 8, # 2400 features: BTC reference data
'cnn_features': 512, # 512 features: CNN hidden layer
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
'pivot_points': 250, # 250 features: Williams market structure
'market_regime': 20 # 20 features: Market regime indicators
}
self.total_state_size = sum(self.state_components.values())
# Data buffers for maintaining windows
self.tick_buffers = {}
self.ohlcv_buffers = {}
# Normalization parameters
self.normalization_params = self._initialize_normalization_params()
# Feature extractors
self.momentum_detector = TickMomentumDetector()
self.indicator_calculator = TechnicalIndicatorCalculator()
self.regime_analyzer = MarketRegimeAnalyzer()
logger.info(f"Enhanced RL State Builder initialized")
logger.info(f"Total state size: {self.total_state_size} features")
logger.info(f"State components: {self.state_components}")
def build_rl_state(self,
eth_ticks: List[TickData],
eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]],
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
"""
Build comprehensive RL state vector from all data sources
Args:
eth_ticks: List of ETH tick data (last 300s)
eth_ohlcv: Dict of ETH OHLCV data by timeframe
btc_ohlcv: Dict of BTC OHLCV data by timeframe
cnn_hidden_features: CNN hidden layer features by timeframe
cnn_predictions: CNN predictions by timeframe
pivot_data: Pivot point data from Williams analysis
Returns:
np.ndarray: Comprehensive state vector (~8000+ features)
"""
try:
state_vector = []
# 1. Process ETH tick data (3000 features)
tick_features = self._process_tick_data(eth_ticks)
state_vector.extend(tick_features)
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
for timeframe in ['1s', '1m', '1h', '1d']:
if timeframe in eth_ohlcv:
ohlcv_features = self._process_ohlcv_data(
eth_ohlcv[timeframe], timeframe, symbol='ETH'
)
else:
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
state_vector.extend(ohlcv_features)
# 3. Process BTC reference data (2400 features)
btc_features = self._process_btc_reference_data(btc_ohlcv)
state_vector.extend(btc_features)
# 4. Process CNN hidden layer features (512 features)
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
state_vector.extend(cnn_hidden)
# 5. Process CNN predictions (16 features)
cnn_pred = self._process_cnn_predictions(cnn_predictions)
state_vector.extend(cnn_pred)
# 6. Process pivot points (250 features)
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
state_vector.extend(pivot_features)
# 7. Process market regime features (20 features)
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
state_vector.extend(regime_features)
# Convert to numpy array and validate size
state_array = np.array(state_vector, dtype=np.float32)
if len(state_array) != self.total_state_size:
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
# Pad or truncate to expected size
if len(state_array) < self.total_state_size:
padding = np.zeros(self.total_state_size - len(state_array))
state_array = np.concatenate([state_array, padding])
else:
state_array = state_array[:self.total_state_size]
# Apply normalization
state_array = self._normalize_state(state_array)
return state_array
except Exception as e:
logger.error(f"Error building RL state: {e}")
# Return zero state on error
return np.zeros(self.total_state_size, dtype=np.float32)
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
"""Process raw tick data into features for momentum detection"""
features = []
if not ticks or len(ticks) < 10:
# Return zeros if insufficient data
return [0.0] * self.state_components['eth_ticks']
# Ensure we have exactly 300 data points (pad or sample)
processed_ticks = self._normalize_tick_window(ticks, 300)
for i, tick in enumerate(processed_ticks):
# Basic tick features
tick_features = [
tick.price,
tick.volume,
tick.bid,
tick.ask,
tick.spread
]
# Derived features
if i > 0:
prev_tick = processed_ticks[i-1]
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
tick_features.extend([
price_change,
volume_change,
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
])
else:
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
features.extend(tick_features)
return features[:self.state_components['eth_ticks']]
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
timeframe: str, symbol: str = 'ETH') -> List[float]:
"""Process OHLCV data with technical indicators"""
features = []
if not ohlcv_data or len(ohlcv_data) < 20:
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
return [0.0] * self.state_components[component_key]
# Convert to DataFrame for indicator calculation
df = pd.DataFrame([{
'timestamp': bar.timestamp,
'open': bar.open,
'high': bar.high,
'low': bar.low,
'close': bar.close,
'volume': bar.volume
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
# Calculate technical indicators
df = self.indicator_calculator.add_all_indicators(df)
# Ensure we have exactly 300 bars
if len(df) < 300:
# Pad with last known values
last_row = df.iloc[-1:].copy()
padding_rows = []
for _ in range(300 - len(df)):
padding_rows.append(last_row)
if padding_rows:
df = pd.concat([df] + padding_rows, ignore_index=True)
else:
df = df.tail(300)
# Extract features for each bar
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
for _, row in df.iterrows():
bar_features = []
for col in feature_columns:
if col in row and not pd.isna(row[col]):
bar_features.append(float(row[col]))
else:
bar_features.append(0.0)
features.extend(bar_features)
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
return features[:self.state_components[component_key]]
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process BTC reference data (using 1h timeframe as primary)"""
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
else:
return [0.0] * self.state_components['btc_reference']
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
"""Process CNN hidden layer features"""
if not cnn_features:
return [0.0] * self.state_components['cnn_features']
# Combine features from all timeframes
combined_features = []
timeframes = ['1s', '1m', '1h', '1d']
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
for tf in timeframes:
if tf in cnn_features and cnn_features[tf] is not None:
tf_features = cnn_features[tf].flatten()
# Truncate or pad to fit allocation
if len(tf_features) >= features_per_timeframe:
combined_features.extend(tf_features[:features_per_timeframe])
else:
combined_features.extend(tf_features)
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
else:
combined_features.extend([0.0] * features_per_timeframe)
return combined_features[:self.state_components['cnn_features']]
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
"""Process CNN predictions from all timeframes"""
if not cnn_predictions:
return [0.0] * self.state_components['cnn_predictions']
predictions = []
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
if tf in cnn_predictions and cnn_predictions[tf] is not None:
pred = cnn_predictions[tf].flatten()
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
if len(pred) >= 4:
predictions.extend(pred[:4])
else:
predictions.extend(pred)
predictions.extend([0.0] * (4 - len(pred)))
else:
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
return predictions[:self.state_components['cnn_predictions']]
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process pivot points using Williams market structure"""
if pivot_data:
# Use provided pivot data
return self._extract_pivot_features(pivot_data)
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
# Calculate pivot points from 1m data
from training.williams_market_structure import WilliamsMarketStructure
williams = WilliamsMarketStructure()
# Convert OHLCV to numpy array
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
return self._extract_pivot_features(pivot_data)
else:
return [0.0] * self.state_components['pivot_points']
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process market regime indicators"""
regime_features = []
# ETH regime analysis
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
regime_features.extend([
eth_regime['volatility'],
eth_regime['trend_strength'],
eth_regime['volume_trend'],
eth_regime['momentum'],
1.0 if eth_regime['regime'] == 'trending' else 0.0,
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
1.0 if eth_regime['regime'] == 'volatile' else 0.0
])
else:
regime_features.extend([0.0] * 7)
# BTC regime analysis
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
regime_features.extend([
btc_regime['volatility'],
btc_regime['trend_strength'],
btc_regime['volume_trend'],
btc_regime['momentum'],
1.0 if btc_regime['regime'] == 'trending' else 0.0,
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
1.0 if btc_regime['regime'] == 'volatile' else 0.0
])
else:
regime_features.extend([0.0] * 7)
# Correlation features
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
regime_features.extend(correlation_features)
return regime_features[:self.state_components['market_regime']]
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
"""Normalize tick window to target size"""
if len(ticks) == target_size:
return ticks
elif len(ticks) > target_size:
# Sample evenly
step = len(ticks) / target_size
indices = [int(i * step) for i in range(target_size)]
return [ticks[i] for i in indices]
else:
# Pad with last tick
result = ticks.copy()
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
while len(result) < target_size:
result.append(last_tick)
return result
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
"""Extract features from pivot point data"""
features = []
for level in range(5): # 5 levels of recursion
level_key = f'level_{level}'
if level_key in pivot_data:
level_data = pivot_data[level_key]
# Swing point features
swing_points = level_data.get('swing_points', [])
if swing_points:
# Last 10 swing points
recent_swings = swing_points[-10:]
for swing in recent_swings:
features.extend([
swing['price'],
1.0 if swing['type'] == 'swing_high' else 0.0,
swing['index']
])
# Pad if fewer than 10 swings
while len(recent_swings) < 10:
features.extend([0.0, 0.0, 0.0])
recent_swings.append({'type': 'none'})
else:
features.extend([0.0] * 30) # 10 swings * 3 features
# Trend features
features.extend([
level_data.get('trend_strength', 0.0),
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
1.0 if level_data.get('trend_direction') == 'down' else 0.0
])
else:
features.extend([0.0] * 33) # 30 swing + 3 trend features
return features[:self.state_components['pivot_points']]
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
"""Convert OHLCV data to numpy array"""
return np.array([[
bar.timestamp.timestamp(),
bar.open,
bar.high,
bar.low,
bar.close,
bar.volume
] for bar in ohlcv_data])
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Calculate BTC-ETH correlation features"""
try:
# Use 1h data for correlation
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
return [0.0] * 6
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
if len(eth_prices) < 10 or len(btc_prices) < 10:
return [0.0] * 6
# Align lengths
min_len = min(len(eth_prices), len(btc_prices))
eth_prices = eth_prices[-min_len:]
btc_prices = btc_prices[-min_len:]
# Calculate returns
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
# Correlation
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
# Price ratio
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
# Volatility comparison
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
return [
correlation,
current_ratio,
ratio_deviation,
vol_ratio,
eth_vol,
btc_vol
]
except Exception as e:
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
return [0.0] * 6
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
"""Initialize normalization parameters for different feature types"""
return {
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
}
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
"""Apply normalization to state vector"""
try:
# Simple clipping and scaling for now
# More sophisticated normalization can be added based on training data
normalized_state = np.clip(state, -10.0, 10.0)
# Replace any NaN or inf values
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
return normalized_state.astype(np.float32)
except Exception as e:
logger.error(f"Error normalizing state: {e}")
return state.astype(np.float32)
class TickMomentumDetector:
"""Detect momentum from tick-level data"""
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
"""Calculate micro-momentum from tick sequence"""
if len(ticks) < 2:
return 0.0
# Price momentum
prices = [tick.price for tick in ticks]
price_changes = np.diff(prices)
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
# Volume-weighted momentum
volumes = [tick.volume for tick in ticks]
if sum(volumes) > 0:
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
else:
volume_momentum = 0.0
return (price_momentum + volume_momentum) / 2.0
class TechnicalIndicatorCalculator:
"""Calculate technical indicators for OHLCV data"""
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add all technical indicators to DataFrame"""
df = df.copy()
# RSI
df['rsi'] = self.calculate_rsi(df['close'])
# MACD
df['macd'] = self.calculate_macd(df['close'])
# Bollinger Bands
df['bb_middle'] = df['close'].rolling(20).mean()
df['bb_std'] = df['close'].rolling(20).std()
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
# Fill NaN values
df = df.fillna(method='forward').fillna(0)
return df
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
"""Calculate RSI"""
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi.fillna(50)
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
"""Calculate MACD"""
ema_fast = prices.ewm(span=fast).mean()
ema_slow = prices.ewm(span=slow).mean()
macd = ema_fast - ema_slow
return macd.fillna(0)
class MarketRegimeAnalyzer:
"""Analyze market regime from OHLCV data"""
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
"""Analyze market regime"""
if len(ohlcv_data) < 20:
return {
'regime': 'unknown',
'volatility': 0.0,
'trend_strength': 0.0,
'volume_trend': 0.0,
'momentum': 0.0
}
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
volumes = [bar.volume for bar in ohlcv_data[-50:]]
# Calculate volatility
returns = np.diff(prices) / prices[:-1]
volatility = np.std(returns) * 100 # Percentage volatility
# Calculate trend strength
sma_short = np.mean(prices[-10:])
sma_long = np.mean(prices[-30:])
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
# Volume trend
volume_ma_short = np.mean(volumes[-10:])
volume_ma_long = np.mean(volumes[-30:])
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
# Momentum
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
# Determine regime
if volatility > 3.0: # High volatility
regime = 'volatile'
elif abs(momentum) > 0.02: # Strong momentum
regime = 'trending'
else:
regime = 'ranging'
return {
'regime': regime,
'volatility': volatility,
'trend_strength': trend_strength,
'volume_trend': volume_trend,
'momentum': momentum
}
def get_state_info(self) -> Dict[str, Any]:
"""Get information about the state structure"""
return {
'total_size': self.config.total_size,
'components': {
'eth_ticks': self.config.eth_ticks,
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
'btc_reference': self.config.btc_reference,
'cnn_features': self.config.cnn_features,
'cnn_predictions': self.config.cnn_predictions,
'pivot_points': self.config.pivot_points,
'market_regime': self.config.market_regime,
},
'data_windows': {
'tick_window_seconds': self.tick_window_seconds,
'ohlcv_window_bars': self.ohlcv_window_bars,
}
}