LR module possibly working

This commit is contained in:
Dobromir Popov
2025-05-28 23:42:06 +03:00
parent de01d3665c
commit 6b7d7aec81
16 changed files with 5118 additions and 580 deletions

219
training/cnn_rl_bridge.py Normal file
View File

@ -0,0 +1,219 @@
"""
CNN-RL Bridge Module
This module provides the interface between CNN models and RL training,
extracting hidden features and predictions from CNN models for use in RL state building.
"""
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class CNNRLBridge:
"""Bridge between CNN models and RL training for feature extraction"""
def __init__(self, config: Dict):
"""Initialize CNN-RL bridge"""
self.config = config
self.cnn_models = {}
self.feature_cache = {}
self.cache_timeout = 60 # Cache features for 60 seconds
# Initialize CNN model registry if available
self._initialize_cnn_models()
logger.info("CNN-RL Bridge initialized")
def _initialize_cnn_models(self):
"""Initialize CNN models from config or model registry"""
try:
# Try to load CNN models from config
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
for model_name, model_config in self.config.cnn_models.items():
try:
# Load CNN model (implementation would depend on your CNN architecture)
model = self._load_cnn_model(model_name, model_config)
if model:
self.cnn_models[model_name] = model
logger.info(f"Loaded CNN model: {model_name}")
except Exception as e:
logger.warning(f"Failed to load CNN model {model_name}: {e}")
if not self.cnn_models:
logger.info("No CNN models available - RL will train without CNN features")
except Exception as e:
logger.warning(f"Error initializing CNN models: {e}")
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
"""Load a CNN model from configuration"""
try:
# This would implement actual CNN model loading
# For now, return None to indicate no models available
# In your implementation, this would load your specific CNN architecture
logger.info(f"CNN model loading framework ready for {model_name}")
return None
except Exception as e:
logger.error(f"Error loading CNN model {model_name}: {e}")
return None
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get latest CNN features and predictions for a symbol"""
try:
# Check cache first
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
if cache_key in self.feature_cache:
cached_data = self.feature_cache[cache_key]
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
return cached_data['features']
# Generate new features if models available
if self.cnn_models:
features = self._extract_cnn_features_for_symbol(symbol)
# Cache the features
self.feature_cache[cache_key] = {
'timestamp': datetime.now(),
'features': features
}
# Clean old cache entries
self._cleanup_cache()
return features
return None
except Exception as e:
logger.warning(f"Error getting CNN features for {symbol}: {e}")
return None
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
"""Extract CNN hidden features and predictions for a symbol"""
try:
extracted_features = {
'hidden_features': {},
'predictions': {}
}
for model_name, model in self.cnn_models.items():
try:
# Extract features from each CNN model
hidden_features, predictions = self._extract_model_features(model, symbol)
if hidden_features is not None:
extracted_features['hidden_features'][model_name] = hidden_features
if predictions is not None:
extracted_features['predictions'][model_name] = predictions
except Exception as e:
logger.warning(f"Error extracting features from {model_name}: {e}")
return extracted_features
except Exception as e:
logger.error(f"Error extracting CNN features for {symbol}: {e}")
return {'hidden_features': {}, 'predictions': {}}
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Extract hidden features and predictions from a specific CNN model"""
try:
# This would implement the actual feature extraction from your CNN models
# The implementation depends on your specific CNN architecture
# For now, return mock data to show the structure
# In real implementation, this would:
# 1. Get market data for the model
# 2. Run forward pass through CNN
# 3. Extract hidden layer activations
# 4. Get model predictions
# Mock hidden features (last hidden layer of CNN)
hidden_features = np.random.random(512).astype(np.float32)
# Mock predictions for different timeframes
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
predictions = np.array([
0.45, # 1s prediction (probability of up move)
0.52, # 1m prediction
0.38, # 1h prediction
0.61 # 1d prediction
]).astype(np.float32)
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
return hidden_features, predictions
except Exception as e:
logger.warning(f"Error extracting features from model: {e}")
return None, None
def _cleanup_cache(self):
"""Clean up old cache entries"""
try:
current_time = datetime.now()
expired_keys = []
for key, data in self.feature_cache.items():
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
expired_keys.append(key)
for key in expired_keys:
del self.feature_cache[key]
except Exception as e:
logger.warning(f"Error cleaning up feature cache: {e}")
def register_cnn_model(self, model_name: str, model: nn.Module):
"""Register a CNN model for feature extraction"""
try:
self.cnn_models[model_name] = model
logger.info(f"Registered CNN model: {model_name}")
except Exception as e:
logger.error(f"Error registering CNN model {model_name}: {e}")
def unregister_cnn_model(self, model_name: str):
"""Unregister a CNN model"""
try:
if model_name in self.cnn_models:
del self.cnn_models[model_name]
logger.info(f"Unregistered CNN model: {model_name}")
except Exception as e:
logger.error(f"Error unregistering CNN model {model_name}: {e}")
def get_available_models(self) -> List[str]:
"""Get list of available CNN models"""
return list(self.cnn_models.keys())
def is_model_available(self, model_name: str) -> bool:
"""Check if a specific CNN model is available"""
return model_name in self.cnn_models
def get_feature_dimensions(self) -> Dict[str, int]:
"""Get the dimensions of features extracted from CNN models"""
return {
'hidden_features_per_model': 512,
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
'total_models': len(self.cnn_models)
}
def validate_cnn_integration(self) -> Dict[str, Any]:
"""Validate CNN integration status"""
status = {
'models_available': len(self.cnn_models),
'models_list': list(self.cnn_models.keys()),
'cache_entries': len(self.feature_cache),
'integration_ready': len(self.cnn_models) > 0,
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
}
return status

View File

@ -0,0 +1,708 @@
"""
Enhanced RL State Builder for Comprehensive Market Data Integration
This module implements the specification requirements for RL training with:
- 300s of raw tick data for momentum detection
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
- CNN hidden layer features integration
- CNN predictions from all timeframes
- Pivot point predictions using Williams market structure
- Market regime analysis
State Vector Components:
- ETH tick data: ~3000 features (300s * 10 features/tick)
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
- BTC reference: ~2400 features (300 bars * 8 features)
- CNN features: ~512 features (hidden layer)
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
- Pivot points: ~250 features (Williams structure)
- Market regime: ~20 features
Total: ~8000+ features
"""
import logging
import numpy as np
import pandas as pd
try:
import ta
except ImportError:
logger = logging.getLogger(__name__)
logger.warning("TA-Lib not available, using pandas for technical indicators")
ta = None
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
from dataclasses import dataclass
from core.universal_data_adapter import UniversalDataStream
logger = logging.getLogger(__name__)
@dataclass
class TickData:
"""Tick data structure"""
timestamp: datetime
price: float
volume: float
bid: float = 0.0
ask: float = 0.0
@property
def spread(self) -> float:
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
@dataclass
class OHLCVData:
"""OHLCV data structure"""
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
# Technical indicators (optional)
rsi: Optional[float] = None
macd: Optional[float] = None
bb_upper: Optional[float] = None
bb_lower: Optional[float] = None
sma_20: Optional[float] = None
ema_12: Optional[float] = None
atr: Optional[float] = None
@dataclass
class StateComponentConfig:
"""Configuration for state component sizes"""
eth_ticks: int = 3000 # 300s * 10 features per tick
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
btc_reference: int = 2400 # BTC reference data
cnn_features: int = 512 # CNN hidden layer features
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
market_regime: int = 20 # Market regime features
@property
def total_size(self) -> int:
"""Calculate total state size"""
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
self.cnn_features + self.cnn_predictions + self.pivot_points +
self.market_regime)
class EnhancedRLStateBuilder:
"""
Comprehensive RL state builder implementing specification requirements
Features:
- 300s tick data processing with momentum detection
- Multi-timeframe OHLCV integration
- CNN hidden layer feature extraction
- Pivot point calculation and integration
- Market regime analysis
- BTC reference data processing
"""
def __init__(self, config: Dict[str, Any]):
self.config = config
# Data windows
self.tick_window_seconds = 300 # 5 minutes of tick data
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
# State component sizes
self.state_components = {
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
'btc_reference': 300 * 8, # 2400 features: BTC reference data
'cnn_features': 512, # 512 features: CNN hidden layer
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
'pivot_points': 250, # 250 features: Williams market structure
'market_regime': 20 # 20 features: Market regime indicators
}
self.total_state_size = sum(self.state_components.values())
# Data buffers for maintaining windows
self.tick_buffers = {}
self.ohlcv_buffers = {}
# Normalization parameters
self.normalization_params = self._initialize_normalization_params()
# Feature extractors
self.momentum_detector = TickMomentumDetector()
self.indicator_calculator = TechnicalIndicatorCalculator()
self.regime_analyzer = MarketRegimeAnalyzer()
logger.info(f"Enhanced RL State Builder initialized")
logger.info(f"Total state size: {self.total_state_size} features")
logger.info(f"State components: {self.state_components}")
def build_rl_state(self,
eth_ticks: List[TickData],
eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]],
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
"""
Build comprehensive RL state vector from all data sources
Args:
eth_ticks: List of ETH tick data (last 300s)
eth_ohlcv: Dict of ETH OHLCV data by timeframe
btc_ohlcv: Dict of BTC OHLCV data by timeframe
cnn_hidden_features: CNN hidden layer features by timeframe
cnn_predictions: CNN predictions by timeframe
pivot_data: Pivot point data from Williams analysis
Returns:
np.ndarray: Comprehensive state vector (~8000+ features)
"""
try:
state_vector = []
# 1. Process ETH tick data (3000 features)
tick_features = self._process_tick_data(eth_ticks)
state_vector.extend(tick_features)
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
for timeframe in ['1s', '1m', '1h', '1d']:
if timeframe in eth_ohlcv:
ohlcv_features = self._process_ohlcv_data(
eth_ohlcv[timeframe], timeframe, symbol='ETH'
)
else:
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
state_vector.extend(ohlcv_features)
# 3. Process BTC reference data (2400 features)
btc_features = self._process_btc_reference_data(btc_ohlcv)
state_vector.extend(btc_features)
# 4. Process CNN hidden layer features (512 features)
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
state_vector.extend(cnn_hidden)
# 5. Process CNN predictions (16 features)
cnn_pred = self._process_cnn_predictions(cnn_predictions)
state_vector.extend(cnn_pred)
# 6. Process pivot points (250 features)
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
state_vector.extend(pivot_features)
# 7. Process market regime features (20 features)
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
state_vector.extend(regime_features)
# Convert to numpy array and validate size
state_array = np.array(state_vector, dtype=np.float32)
if len(state_array) != self.total_state_size:
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
# Pad or truncate to expected size
if len(state_array) < self.total_state_size:
padding = np.zeros(self.total_state_size - len(state_array))
state_array = np.concatenate([state_array, padding])
else:
state_array = state_array[:self.total_state_size]
# Apply normalization
state_array = self._normalize_state(state_array)
return state_array
except Exception as e:
logger.error(f"Error building RL state: {e}")
# Return zero state on error
return np.zeros(self.total_state_size, dtype=np.float32)
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
"""Process raw tick data into features for momentum detection"""
features = []
if not ticks or len(ticks) < 10:
# Return zeros if insufficient data
return [0.0] * self.state_components['eth_ticks']
# Ensure we have exactly 300 data points (pad or sample)
processed_ticks = self._normalize_tick_window(ticks, 300)
for i, tick in enumerate(processed_ticks):
# Basic tick features
tick_features = [
tick.price,
tick.volume,
tick.bid,
tick.ask,
tick.spread
]
# Derived features
if i > 0:
prev_tick = processed_ticks[i-1]
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
tick_features.extend([
price_change,
volume_change,
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
])
else:
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
features.extend(tick_features)
return features[:self.state_components['eth_ticks']]
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
timeframe: str, symbol: str = 'ETH') -> List[float]:
"""Process OHLCV data with technical indicators"""
features = []
if not ohlcv_data or len(ohlcv_data) < 20:
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
return [0.0] * self.state_components[component_key]
# Convert to DataFrame for indicator calculation
df = pd.DataFrame([{
'timestamp': bar.timestamp,
'open': bar.open,
'high': bar.high,
'low': bar.low,
'close': bar.close,
'volume': bar.volume
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
# Calculate technical indicators
df = self.indicator_calculator.add_all_indicators(df)
# Ensure we have exactly 300 bars
if len(df) < 300:
# Pad with last known values
last_row = df.iloc[-1:].copy()
padding_rows = []
for _ in range(300 - len(df)):
padding_rows.append(last_row)
if padding_rows:
df = pd.concat([df] + padding_rows, ignore_index=True)
else:
df = df.tail(300)
# Extract features for each bar
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
for _, row in df.iterrows():
bar_features = []
for col in feature_columns:
if col in row and not pd.isna(row[col]):
bar_features.append(float(row[col]))
else:
bar_features.append(0.0)
features.extend(bar_features)
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
return features[:self.state_components[component_key]]
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process BTC reference data (using 1h timeframe as primary)"""
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
else:
return [0.0] * self.state_components['btc_reference']
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
"""Process CNN hidden layer features"""
if not cnn_features:
return [0.0] * self.state_components['cnn_features']
# Combine features from all timeframes
combined_features = []
timeframes = ['1s', '1m', '1h', '1d']
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
for tf in timeframes:
if tf in cnn_features and cnn_features[tf] is not None:
tf_features = cnn_features[tf].flatten()
# Truncate or pad to fit allocation
if len(tf_features) >= features_per_timeframe:
combined_features.extend(tf_features[:features_per_timeframe])
else:
combined_features.extend(tf_features)
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
else:
combined_features.extend([0.0] * features_per_timeframe)
return combined_features[:self.state_components['cnn_features']]
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
"""Process CNN predictions from all timeframes"""
if not cnn_predictions:
return [0.0] * self.state_components['cnn_predictions']
predictions = []
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
if tf in cnn_predictions and cnn_predictions[tf] is not None:
pred = cnn_predictions[tf].flatten()
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
if len(pred) >= 4:
predictions.extend(pred[:4])
else:
predictions.extend(pred)
predictions.extend([0.0] * (4 - len(pred)))
else:
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
return predictions[:self.state_components['cnn_predictions']]
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process pivot points using Williams market structure"""
if pivot_data:
# Use provided pivot data
return self._extract_pivot_features(pivot_data)
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
# Calculate pivot points from 1m data
from training.williams_market_structure import WilliamsMarketStructure
williams = WilliamsMarketStructure()
# Convert OHLCV to numpy array
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
return self._extract_pivot_features(pivot_data)
else:
return [0.0] * self.state_components['pivot_points']
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Process market regime indicators"""
regime_features = []
# ETH regime analysis
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
regime_features.extend([
eth_regime['volatility'],
eth_regime['trend_strength'],
eth_regime['volume_trend'],
eth_regime['momentum'],
1.0 if eth_regime['regime'] == 'trending' else 0.0,
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
1.0 if eth_regime['regime'] == 'volatile' else 0.0
])
else:
regime_features.extend([0.0] * 7)
# BTC regime analysis
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
regime_features.extend([
btc_regime['volatility'],
btc_regime['trend_strength'],
btc_regime['volume_trend'],
btc_regime['momentum'],
1.0 if btc_regime['regime'] == 'trending' else 0.0,
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
1.0 if btc_regime['regime'] == 'volatile' else 0.0
])
else:
regime_features.extend([0.0] * 7)
# Correlation features
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
regime_features.extend(correlation_features)
return regime_features[:self.state_components['market_regime']]
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
"""Normalize tick window to target size"""
if len(ticks) == target_size:
return ticks
elif len(ticks) > target_size:
# Sample evenly
step = len(ticks) / target_size
indices = [int(i * step) for i in range(target_size)]
return [ticks[i] for i in indices]
else:
# Pad with last tick
result = ticks.copy()
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
while len(result) < target_size:
result.append(last_tick)
return result
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
"""Extract features from pivot point data"""
features = []
for level in range(5): # 5 levels of recursion
level_key = f'level_{level}'
if level_key in pivot_data:
level_data = pivot_data[level_key]
# Swing point features
swing_points = level_data.get('swing_points', [])
if swing_points:
# Last 10 swing points
recent_swings = swing_points[-10:]
for swing in recent_swings:
features.extend([
swing['price'],
1.0 if swing['type'] == 'swing_high' else 0.0,
swing['index']
])
# Pad if fewer than 10 swings
while len(recent_swings) < 10:
features.extend([0.0, 0.0, 0.0])
recent_swings.append({'type': 'none'})
else:
features.extend([0.0] * 30) # 10 swings * 3 features
# Trend features
features.extend([
level_data.get('trend_strength', 0.0),
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
1.0 if level_data.get('trend_direction') == 'down' else 0.0
])
else:
features.extend([0.0] * 33) # 30 swing + 3 trend features
return features[:self.state_components['pivot_points']]
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
"""Convert OHLCV data to numpy array"""
return np.array([[
bar.timestamp.timestamp(),
bar.open,
bar.high,
bar.low,
bar.close,
bar.volume
] for bar in ohlcv_data])
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
"""Calculate BTC-ETH correlation features"""
try:
# Use 1h data for correlation
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
return [0.0] * 6
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
if len(eth_prices) < 10 or len(btc_prices) < 10:
return [0.0] * 6
# Align lengths
min_len = min(len(eth_prices), len(btc_prices))
eth_prices = eth_prices[-min_len:]
btc_prices = btc_prices[-min_len:]
# Calculate returns
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
# Correlation
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
# Price ratio
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
# Volatility comparison
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
return [
correlation,
current_ratio,
ratio_deviation,
vol_ratio,
eth_vol,
btc_vol
]
except Exception as e:
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
return [0.0] * 6
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
"""Initialize normalization parameters for different feature types"""
return {
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
}
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
"""Apply normalization to state vector"""
try:
# Simple clipping and scaling for now
# More sophisticated normalization can be added based on training data
normalized_state = np.clip(state, -10.0, 10.0)
# Replace any NaN or inf values
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
return normalized_state.astype(np.float32)
except Exception as e:
logger.error(f"Error normalizing state: {e}")
return state.astype(np.float32)
class TickMomentumDetector:
"""Detect momentum from tick-level data"""
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
"""Calculate micro-momentum from tick sequence"""
if len(ticks) < 2:
return 0.0
# Price momentum
prices = [tick.price for tick in ticks]
price_changes = np.diff(prices)
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
# Volume-weighted momentum
volumes = [tick.volume for tick in ticks]
if sum(volumes) > 0:
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
else:
volume_momentum = 0.0
return (price_momentum + volume_momentum) / 2.0
class TechnicalIndicatorCalculator:
"""Calculate technical indicators for OHLCV data"""
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add all technical indicators to DataFrame"""
df = df.copy()
# RSI
df['rsi'] = self.calculate_rsi(df['close'])
# MACD
df['macd'] = self.calculate_macd(df['close'])
# Bollinger Bands
df['bb_middle'] = df['close'].rolling(20).mean()
df['bb_std'] = df['close'].rolling(20).std()
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
# Fill NaN values
df = df.fillna(method='forward').fillna(0)
return df
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
"""Calculate RSI"""
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi.fillna(50)
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
"""Calculate MACD"""
ema_fast = prices.ewm(span=fast).mean()
ema_slow = prices.ewm(span=slow).mean()
macd = ema_fast - ema_slow
return macd.fillna(0)
class MarketRegimeAnalyzer:
"""Analyze market regime from OHLCV data"""
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
"""Analyze market regime"""
if len(ohlcv_data) < 20:
return {
'regime': 'unknown',
'volatility': 0.0,
'trend_strength': 0.0,
'volume_trend': 0.0,
'momentum': 0.0
}
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
volumes = [bar.volume for bar in ohlcv_data[-50:]]
# Calculate volatility
returns = np.diff(prices) / prices[:-1]
volatility = np.std(returns) * 100 # Percentage volatility
# Calculate trend strength
sma_short = np.mean(prices[-10:])
sma_long = np.mean(prices[-30:])
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
# Volume trend
volume_ma_short = np.mean(volumes[-10:])
volume_ma_long = np.mean(volumes[-30:])
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
# Momentum
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
# Determine regime
if volatility > 3.0: # High volatility
regime = 'volatile'
elif abs(momentum) > 0.02: # Strong momentum
regime = 'trending'
else:
regime = 'ranging'
return {
'regime': regime,
'volatility': volatility,
'trend_strength': trend_strength,
'volume_trend': volume_trend,
'momentum': momentum
}
def get_state_info(self) -> Dict[str, Any]:
"""Get information about the state structure"""
return {
'total_size': self.config.total_size,
'components': {
'eth_ticks': self.config.eth_ticks,
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
'btc_reference': self.config.btc_reference,
'cnn_features': self.config.cnn_features,
'cnn_predictions': self.config.cnn_predictions,
'pivot_points': self.config.pivot_points,
'market_regime': self.config.market_regime,
},
'data_windows': {
'tick_window_seconds': self.tick_window_seconds,
'ohlcv_window_bars': self.ohlcv_window_bars,
}
}

View File

@ -10,14 +10,16 @@ This module implements sophisticated RL training with:
import asyncio
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from typing import Dict, List, Optional, Tuple, Any, Union
import matplotlib.pyplot as plt
from pathlib import Path
@ -26,6 +28,9 @@ from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
from models import RLAgentInterface
import models
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
from training.williams_market_structure import WilliamsMarketStructure
from training.cnn_rl_bridge import CNNRLBridge
logger = logging.getLogger(__name__)
@ -318,42 +323,66 @@ class EnhancedDQNAgent(nn.Module, RLAgentInterface):
return (param_count * 4 + buffer_size) // (1024 * 1024)
class EnhancedRLTrainer:
"""Enhanced RL trainer with continuous learning from market feedback"""
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
"""Initialize the enhanced RL trainer"""
"""Initialize enhanced RL trainer with comprehensive state building"""
self.config = config or get_config()
self.orchestrator = orchestrator
self.data_provider = DataProvider(self.config)
# Create RL agents for each symbol
# Initialize comprehensive state builder (replaces mock code)
self.state_builder = EnhancedRLStateBuilder(self.config)
self.williams_structure = WilliamsMarketStructure()
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
# Enhanced RL agents with much larger state space
self.agents = {}
for symbol in self.config.symbols:
agent_config = self.config.rl.copy()
agent_config['name'] = f'RL_{symbol}'
self.agents[symbol] = EnhancedDQNAgent(agent_config)
self.initialize_agents()
# Training parameters
self.training_interval = 3600 # Train every hour
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
self.min_experiences = 100 # Minimum experiences before training
# Performance tracking
self.performance_history = {symbol: [] for symbol in self.config.symbols}
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.config.symbols},
'losses': {symbol: [] for symbol in self.config.symbols},
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
}
# Create save directory
models_path = self.config.rl.get('model_dir', "models/enhanced_rl")
self.save_dir = Path(models_path)
# Training configuration
self.symbols = self.config.symbols
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
self.save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
# Performance tracking
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.symbols},
'losses': {symbol: [] for symbol in self.symbols},
'epsilon_values': {symbol: [] for symbol in self.symbols}
}
self.performance_history = {symbol: [] for symbol in self.symbols}
# Real-time learning parameters
self.learning_active = False
self.experience_buffer_size = 1000
self.min_experiences_for_training = 100
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
logger.info(f"Symbols: {self.symbols}")
def initialize_agents(self):
"""Initialize RL agents with enhanced state size"""
for symbol in self.symbols:
agent_config = {
'state_size': self.state_builder.total_state_size, # ~13,400 features
'action_space': 3, # BUY, SELL, HOLD
'hidden_size': 1024, # Larger hidden layers for complex state
'learning_rate': 0.0001,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.01,
'buffer_size': 50000, # Larger replay buffer
'batch_size': 128,
'target_update_freq': 1000
}
self.agents[symbol] = EnhancedDQNAgent(agent_config)
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
async def continuous_learning_loop(self):
"""Main continuous learning loop"""
logger.info("Starting continuous RL learning loop")
@ -378,7 +407,7 @@ class EnhancedRLTrainer:
self._save_all_models()
# Wait before next training cycle
await asyncio.sleep(self.training_interval)
await asyncio.sleep(3600) # Train every hour
except Exception as e:
logger.error(f"Error in continuous learning loop: {e}")
@ -388,7 +417,7 @@ class EnhancedRLTrainer:
"""Train all RL agents with their experiences"""
for symbol, agent in self.agents.items():
try:
if len(agent.replay_buffer) >= self.min_experiences:
if len(agent.replay_buffer) >= self.min_experiences_for_training:
# Train for multiple steps
losses = []
for _ in range(10): # Train 10 steps per cycle
@ -411,7 +440,7 @@ class EnhancedRLTrainer:
if not self.orchestrator:
return
for symbol in self.config.symbols:
for symbol in self.symbols:
try:
# Get recent market states
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
@ -471,11 +500,150 @@ class EnhancedRLTrainer:
logger.error(f"Error adding experience for {symbol}: {e}")
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
"""Convert market state to RL state vector"""
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
return self.orchestrator._market_state_to_rl_state(market_state)
"""Convert market state to comprehensive RL state vector using real data"""
try:
# Extract data from market state and orchestrator
if not self.orchestrator:
logger.warning("No orchestrator available for comprehensive state building")
return self._fallback_state_conversion(market_state)
# Get real tick data from orchestrator's data provider
symbol = market_state.symbol
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
# Get multi-timeframe OHLCV data
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
# Get CNN features if available
cnn_hidden_features = None
cnn_predictions = None
if self.cnn_rl_bridge:
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
if cnn_data:
cnn_hidden_features = cnn_data.get('hidden_features', {})
cnn_predictions = cnn_data.get('predictions', {})
# Get pivot point data
pivot_data = self._calculate_pivot_points(eth_ohlcv)
# Build comprehensive state using enhanced state builder
comprehensive_state = self.state_builder.build_rl_state(
eth_ticks=eth_ticks,
eth_ohlcv=eth_ohlcv,
btc_ohlcv=btc_ohlcv,
cnn_hidden_features=cnn_hidden_features,
cnn_predictions=cnn_predictions,
pivot_data=pivot_data
)
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
return comprehensive_state
except Exception as e:
logger.error(f"Error building comprehensive RL state: {e}")
return self._fallback_state_conversion(market_state)
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
"""Get recent tick data from orchestrator's data provider"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
# Get recent ticks from data provider
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
# Convert to required format
tick_data = []
for tick in recent_ticks[-300:]: # Last 300 ticks max
tick_data.append({
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'quantity': getattr(tick, 'quantity', tick.volume),
'side': getattr(tick, 'side', 'unknown'),
'trade_id': getattr(tick, 'trade_id', 'unknown')
})
return tick_data
return []
except Exception as e:
logger.warning(f"Error getting tick data for {symbol}: {e}")
return []
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
"""Get multi-timeframe OHLCV data"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
try:
# Get historical data for timeframe
df = self.orchestrator.data_provider.get_historical_data(
symbol=symbol,
timeframe=tf,
limit=300,
refresh=True
)
if df is not None and not df.empty:
# Convert to list of dictionaries
bars = []
for _, row in df.tail(300).iterrows():
bar = {
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
'open': float(row.get('open', 0)),
'high': float(row.get('high', 0)),
'low': float(row.get('low', 0)),
'close': float(row.get('close', 0)),
'volume': float(row.get('volume', 0))
}
bars.append(bar)
ohlcv_data[tf] = bars
else:
ohlcv_data[tf] = []
except Exception as e:
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
ohlcv_data[tf] = []
return ohlcv_data
return {}
except Exception as e:
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
return {}
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
"""Calculate Williams pivot points from OHLCV data"""
try:
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
# Convert to numpy array for Williams calculation
bars = eth_ohlcv['1m']
if len(bars) >= 50: # Need minimum data for pivot calculation
ohlc_array = np.array([
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
for bar in bars[-200:] # Last 200 bars
])
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
return pivot_data
return {}
except Exception as e:
logger.warning(f"Error calculating pivot points: {e}")
return {}
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
"""Fallback to basic state conversion if comprehensive state building fails"""
logger.warning("Using fallback state conversion - limited features")
# Fallback implementation
state_components = [
market_state.volatility,
market_state.volume,
@ -486,8 +654,8 @@ class EnhancedRLTrainer:
for timeframe in sorted(market_state.prices.keys()):
state_components.append(market_state.prices[timeframe])
# Pad or truncate to expected state size
expected_size = self.config.rl.get('state_size', 100)
# Pad to match expected state size
expected_size = self.state_builder.total_state_size
if len(state_components) < expected_size:
state_components.extend([0.0] * (expected_size - len(state_components)))
else:
@ -545,7 +713,7 @@ class EnhancedRLTrainer:
timestamp = max(timestamps)
loaded_count = 0
for symbol in self.config.symbols:
for symbol in self.symbols:
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename

View File

@ -0,0 +1,640 @@
"""
Williams Market Structure Implementation for RL Training
This module implements Larry Williams market structure analysis methodology for
RL training enhancement with:
- Swing high/low detection with configurable strength
- 5 levels of recursive pivot point calculation
- Trend analysis (higher highs/lows vs lower highs/lows)
- Market bias determination across multiple timeframes
- Feature extraction for RL training (250 features)
Based on Larry Williams' teachings on market structure:
- Markets move in swings between support and resistance
- Higher timeframe structure determines lower timeframe bias
- Recursive analysis reveals fractal patterns
- Trend direction determined by swing point relationships
"""
import numpy as np
import pandas as pd
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class TrendDirection(Enum):
UP = "up"
DOWN = "down"
SIDEWAYS = "sideways"
UNKNOWN = "unknown"
class SwingType(Enum):
SWING_HIGH = "swing_high"
SWING_LOW = "swing_low"
@dataclass
class SwingPoint:
"""Represents a swing high or low point"""
timestamp: datetime
price: float
index: int
swing_type: SwingType
strength: int # Number of bars on each side that confirm the swing
volume: float = 0.0
@dataclass
class TrendAnalysis:
"""Trend analysis results"""
direction: TrendDirection
strength: float # 0.0 to 1.0
confidence: float # 0.0 to 1.0
swing_count: int
last_swing_high: Optional[SwingPoint]
last_swing_low: Optional[SwingPoint]
higher_highs: int
higher_lows: int
lower_highs: int
lower_lows: int
@dataclass
class MarketStructureLevel:
"""Market structure analysis for one recursive level"""
level: int
swing_points: List[SwingPoint]
trend_analysis: TrendAnalysis
support_levels: List[float]
resistance_levels: List[float]
current_bias: TrendDirection
structure_breaks: List[Dict[str, Any]]
class WilliamsMarketStructure:
"""
Implementation of Larry Williams market structure methodology
Features:
- Multi-strength swing detection (2, 3, 5, 8, 13 bar strengths)
- 5 levels of recursive analysis
- Trend direction determination
- Support/resistance level identification
- Market bias calculation
- Structure break detection
"""
def __init__(self, swing_strengths: List[int] = None):
"""
Initialize Williams market structure analyzer
Args:
swing_strengths: List of swing detection strengths (bars on each side)
"""
self.swing_strengths = swing_strengths or [2, 3, 5, 8, 13] # Fibonacci-based strengths
self.max_levels = 5
self.min_swings_for_trend = 3
# Cache for performance
self.swing_cache = {}
self.trend_cache = {}
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
"""
Calculate 5 levels of recursive pivot points
Args:
ohlcv_data: OHLCV data array with columns [timestamp, open, high, low, close, volume]
Returns:
Dict of market structure levels with swing points and trend analysis
"""
if len(ohlcv_data) < 20:
logger.warning("Insufficient data for Williams structure analysis")
return self._create_empty_structure()
levels = {}
current_data = ohlcv_data.copy()
for level in range(self.max_levels):
logger.debug(f"Analyzing level {level} with {len(current_data)} data points")
# Find swing points for this level
swing_points = self._find_swing_points_multi_strength(current_data)
if len(swing_points) < self.min_swings_for_trend:
logger.debug(f"Not enough swings at level {level}: {len(swing_points)}")
# Fill remaining levels with empty data
for remaining_level in range(level, self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
break
# Analyze trend for this level
trend_analysis = self._analyze_trend_from_swings(swing_points)
# Find support/resistance levels
support_levels, resistance_levels = self._find_support_resistance(
swing_points, current_data
)
# Determine current market bias
current_bias = self._determine_market_bias(swing_points, trend_analysis)
# Detect structure breaks
structure_breaks = self._detect_structure_breaks(swing_points, current_data)
# Create level data
levels[f'level_{level}'] = MarketStructureLevel(
level=level,
swing_points=swing_points,
trend_analysis=trend_analysis,
support_levels=support_levels,
resistance_levels=resistance_levels,
current_bias=current_bias,
structure_breaks=structure_breaks
)
# Prepare data for next level (use swing points as input)
if len(swing_points) >= 5:
current_data = self._convert_swings_to_ohlcv(swing_points)
if len(current_data) < 10:
logger.debug(f"Insufficient converted data for level {level + 1}")
break
else:
logger.debug(f"Not enough swings to continue to level {level + 1}")
break
# Fill any remaining empty levels
for remaining_level in range(len(levels), self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
return levels
def _find_swing_points_multi_strength(self, ohlcv_data: np.ndarray) -> List[SwingPoint]:
"""Find swing points using multiple strength criteria"""
all_swings = []
for strength in self.swing_strengths:
swings = self._find_swing_points_single_strength(ohlcv_data, strength)
for swing in swings:
# Avoid duplicates (swings at same index)
if not any(existing.index == swing.index for existing in all_swings):
all_swings.append(swing)
# Sort by timestamp/index
all_swings.sort(key=lambda x: x.index)
# Filter to get the most significant swings
return self._filter_significant_swings(all_swings)
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
"""Find swing points with specific strength requirement"""
swings = []
if len(ohlcv_data) < (strength * 2 + 1):
return swings
for i in range(strength, len(ohlcv_data) - strength):
current_high = ohlcv_data[i, 2] # High price
current_low = ohlcv_data[i, 3] # Low price
current_volume = ohlcv_data[i, 5] if ohlcv_data.shape[1] > 5 else 0.0
# Check for swing high (higher than surrounding bars)
is_swing_high = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 2] >= current_high:
is_swing_high = False
break
if is_swing_high:
swings.append(SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_high,
index=i,
swing_type=SwingType.SWING_HIGH,
strength=strength,
volume=current_volume
))
# Check for swing low (lower than surrounding bars)
is_swing_low = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 3] <= current_low:
is_swing_low = False
break
if is_swing_low:
swings.append(SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_low,
index=i,
swing_type=SwingType.SWING_LOW,
strength=strength,
volume=current_volume
))
return swings
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
"""Filter to keep only the most significant swings"""
if len(swings) <= 20:
return swings
# Sort by strength (higher strength = more significant)
swings_by_strength = sorted(swings, key=lambda x: x.strength, reverse=True)
# Take top swings but ensure we have alternating highs and lows
significant_swings = []
last_type = None
for swing in swings_by_strength:
if len(significant_swings) >= 20:
break
# Prefer alternating swing types for better structure
if last_type is None or swing.swing_type != last_type:
significant_swings.append(swing)
last_type = swing.swing_type
elif len(significant_swings) < 10: # Still add if we need more swings
significant_swings.append(swing)
# Sort by index again
significant_swings.sort(key=lambda x: x.index)
return significant_swings
def _analyze_trend_from_swings(self, swing_points: List[SwingPoint]) -> TrendAnalysis:
"""Analyze trend direction from swing points"""
if len(swing_points) < 2:
return TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
)
# Separate highs and lows
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Count higher highs, higher lows, lower highs, lower lows
higher_highs = self._count_higher_highs(highs)
higher_lows = self._count_higher_lows(lows)
lower_highs = self._count_lower_highs(highs)
lower_lows = self._count_lower_lows(lows)
# Determine trend direction
if higher_highs > 0 and higher_lows > 0:
direction = TrendDirection.UP
elif lower_highs > 0 and lower_lows > 0:
direction = TrendDirection.DOWN
else:
direction = TrendDirection.SIDEWAYS
# Calculate trend strength
total_moves = higher_highs + higher_lows + lower_highs + lower_lows
if direction == TrendDirection.UP:
strength = (higher_highs + higher_lows) / max(total_moves, 1)
elif direction == TrendDirection.DOWN:
strength = (lower_highs + lower_lows) / max(total_moves, 1)
else:
strength = 0.5 # Neutral for sideways
# Calculate confidence based on consistency
if total_moves > 0:
if direction == TrendDirection.UP:
confidence = (higher_highs + higher_lows) / total_moves
elif direction == TrendDirection.DOWN:
confidence = (lower_highs + lower_lows) / total_moves
else:
# For sideways, confidence is based on how balanced it is
up_moves = higher_highs + higher_lows
down_moves = lower_highs + lower_lows
balance = 1.0 - abs(up_moves - down_moves) / total_moves
confidence = balance
else:
confidence = 0.0
return TrendAnalysis(
direction=direction,
strength=min(strength, 1.0),
confidence=min(confidence, 1.0),
swing_count=len(swing_points),
last_swing_high=highs[-1] if highs else None,
last_swing_low=lows[-1] if lows else None,
higher_highs=higher_highs,
higher_lows=higher_lows,
lower_highs=lower_highs,
lower_lows=lower_lows
)
def _count_higher_highs(self, highs: List[SwingPoint]) -> int:
"""Count higher highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price > highs[i-1].price:
count += 1
return count
def _count_higher_lows(self, lows: List[SwingPoint]) -> int:
"""Count higher lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price > lows[i-1].price:
count += 1
return count
def _count_lower_highs(self, highs: List[SwingPoint]) -> int:
"""Count lower highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price < highs[i-1].price:
count += 1
return count
def _count_lower_lows(self, lows: List[SwingPoint]) -> int:
"""Count lower lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price < lows[i-1].price:
count += 1
return count
def _find_support_resistance(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> Tuple[List[float], List[float]]:
"""Find support and resistance levels from swing points"""
highs = [s.price for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s.price for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Cluster similar levels
support_levels = self._cluster_price_levels(lows) if lows else []
resistance_levels = self._cluster_price_levels(highs) if highs else []
return support_levels, resistance_levels
def _cluster_price_levels(self, prices: List[float], tolerance: float = 0.02) -> List[float]:
"""Cluster similar price levels together"""
if not prices:
return []
sorted_prices = sorted(prices)
clusters = []
current_cluster = [sorted_prices[0]]
for price in sorted_prices[1:]:
# If price is within tolerance of cluster average, add to cluster
cluster_avg = np.mean(current_cluster)
if abs(price - cluster_avg) / cluster_avg <= tolerance:
current_cluster.append(price)
else:
# Start new cluster
clusters.append(np.mean(current_cluster))
current_cluster = [price]
# Add last cluster
if current_cluster:
clusters.append(np.mean(current_cluster))
return clusters
def _determine_market_bias(self, swing_points: List[SwingPoint],
trend_analysis: TrendAnalysis) -> TrendDirection:
"""Determine current market bias"""
if not swing_points:
return TrendDirection.UNKNOWN
# Use trend analysis as primary indicator
if trend_analysis.confidence > 0.6:
return trend_analysis.direction
# Look at most recent swings for bias
recent_swings = swing_points[-6:] if len(swing_points) >= 6 else swing_points
if len(recent_swings) >= 2:
first_price = recent_swings[0].price
last_price = recent_swings[-1].price
price_change = (last_price - first_price) / first_price
if price_change > 0.01: # 1% threshold
return TrendDirection.UP
elif price_change < -0.01:
return TrendDirection.DOWN
else:
return TrendDirection.SIDEWAYS
return TrendDirection.UNKNOWN
def _detect_structure_breaks(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> List[Dict[str, Any]]:
"""Detect structure breaks (trend changes)"""
structure_breaks = []
if len(swing_points) < 4:
return structure_breaks
# Look for pattern breaks
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Check for break of structure in highs (lower high after higher highs)
if len(highs) >= 3:
for i in range(2, len(highs)):
if (highs[i-2].price < highs[i-1].price and # Previous was higher high
highs[i-1].price > highs[i].price): # Current is lower high
structure_breaks.append({
'type': 'break_of_structure_high',
'timestamp': highs[i].timestamp,
'price': highs[i].price,
'previous_high': highs[i-1].price,
'significance': abs(highs[i].price - highs[i-1].price) / highs[i-1].price
})
# Check for break of structure in lows (higher low after lower lows)
if len(lows) >= 3:
for i in range(2, len(lows)):
if (lows[i-2].price > lows[i-1].price and # Previous was lower low
lows[i-1].price < lows[i].price): # Current is higher low
structure_breaks.append({
'type': 'break_of_structure_low',
'timestamp': lows[i].timestamp,
'price': lows[i].price,
'previous_low': lows[i-1].price,
'significance': abs(lows[i].price - lows[i-1].price) / lows[i-1].price
})
return structure_breaks
def _convert_swings_to_ohlcv(self, swing_points: List[SwingPoint]) -> np.ndarray:
"""Convert swing points to OHLCV format for next level analysis"""
if len(swing_points) < 2:
return np.array([])
ohlcv_data = []
for i in range(len(swing_points) - 1):
current_swing = swing_points[i]
next_swing = swing_points[i + 1]
# Create synthetic OHLCV bar from swing to swing
if current_swing.swing_type == SwingType.SWING_HIGH:
# From high to next point
open_price = current_swing.price
high_price = current_swing.price
low_price = min(current_swing.price, next_swing.price)
close_price = next_swing.price
else:
# From low to next point
open_price = current_swing.price
high_price = max(current_swing.price, next_swing.price)
low_price = current_swing.price
close_price = next_swing.price
ohlcv_data.append([
current_swing.timestamp.timestamp(),
open_price,
high_price,
low_price,
close_price,
current_swing.volume
])
return np.array(ohlcv_data)
def _create_empty_structure(self) -> Dict[str, MarketStructureLevel]:
"""Create empty structure when insufficient data"""
return {f'level_{i}': self._create_empty_level(i) for i in range(self.max_levels)}
def _create_empty_level(self, level: int) -> MarketStructureLevel:
"""Create empty market structure level"""
return MarketStructureLevel(
level=level,
swing_points=[],
trend_analysis=TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
),
support_levels=[],
resistance_levels=[],
current_bias=TrendDirection.UNKNOWN,
structure_breaks=[]
)
def extract_features_for_rl(self, structure_levels: Dict[str, MarketStructureLevel]) -> List[float]:
"""
Extract features from Williams structure for RL training
Returns ~250 features total:
- 50 features per level (5 levels)
"""
features = []
for level in range(self.max_levels):
level_key = f'level_{level}'
if level_key in structure_levels:
level_data = structure_levels[level_key]
level_features = self._extract_level_features(level_data)
else:
level_features = [0.0] * 50 # 50 features per level
features.extend(level_features)
return features[:250] # Ensure exactly 250 features
def _extract_level_features(self, level: MarketStructureLevel) -> List[float]:
"""Extract features from a single structure level"""
features = []
# Trend features (10 features)
features.extend([
1.0 if level.trend_analysis.direction == TrendDirection.UP else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.DOWN else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.SIDEWAYS else 0.0,
level.trend_analysis.strength,
level.trend_analysis.confidence,
level.trend_analysis.higher_highs,
level.trend_analysis.higher_lows,
level.trend_analysis.lower_highs,
level.trend_analysis.lower_lows,
len(level.swing_points)
])
# Current bias features (4 features)
features.extend([
1.0 if level.current_bias == TrendDirection.UP else 0.0,
1.0 if level.current_bias == TrendDirection.DOWN else 0.0,
1.0 if level.current_bias == TrendDirection.SIDEWAYS else 0.0,
1.0 if level.current_bias == TrendDirection.UNKNOWN else 0.0
])
# Swing point features (20 features - last 10 swings * 2 features each)
recent_swings = level.swing_points[-10:] if len(level.swing_points) >= 10 else level.swing_points
for swing in recent_swings:
features.extend([
swing.price,
1.0 if swing.swing_type == SwingType.SWING_HIGH else 0.0
])
# Pad if fewer than 10 swings
while len(recent_swings) < 10:
features.extend([0.0, 0.0])
recent_swings.append(None) # Just for counting
# Support/resistance levels (10 features - 5 support + 5 resistance)
support_levels = level.support_levels[:5] if len(level.support_levels) >= 5 else level.support_levels
while len(support_levels) < 5:
support_levels.append(0.0)
features.extend(support_levels)
resistance_levels = level.resistance_levels[:5] if len(level.resistance_levels) >= 5 else level.resistance_levels
while len(resistance_levels) < 5:
resistance_levels.append(0.0)
features.extend(resistance_levels)
# Structure break features (6 features)
recent_breaks = level.structure_breaks[-3:] if len(level.structure_breaks) >= 3 else level.structure_breaks
for break_info in recent_breaks:
features.extend([
break_info.get('significance', 0.0),
1.0 if break_info.get('type', '').endswith('_high') else 0.0
])
# Pad if fewer than 3 breaks
while len(recent_breaks) < 3:
features.extend([0.0, 0.0])
recent_breaks.append({})
return features[:50] # Ensure exactly 50 features per level