LR module possibly working
This commit is contained in:
219
training/cnn_rl_bridge.py
Normal file
219
training/cnn_rl_bridge.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""
|
||||
CNN-RL Bridge Module
|
||||
|
||||
This module provides the interface between CNN models and RL training,
|
||||
extracting hidden features and predictions from CNN models for use in RL state building.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNRLBridge:
|
||||
"""Bridge between CNN models and RL training for feature extraction"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
"""Initialize CNN-RL bridge"""
|
||||
self.config = config
|
||||
self.cnn_models = {}
|
||||
self.feature_cache = {}
|
||||
self.cache_timeout = 60 # Cache features for 60 seconds
|
||||
|
||||
# Initialize CNN model registry if available
|
||||
self._initialize_cnn_models()
|
||||
|
||||
logger.info("CNN-RL Bridge initialized")
|
||||
|
||||
def _initialize_cnn_models(self):
|
||||
"""Initialize CNN models from config or model registry"""
|
||||
try:
|
||||
# Try to load CNN models from config
|
||||
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
|
||||
for model_name, model_config in self.config.cnn_models.items():
|
||||
try:
|
||||
# Load CNN model (implementation would depend on your CNN architecture)
|
||||
model = self._load_cnn_model(model_name, model_config)
|
||||
if model:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Loaded CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CNN model {model_name}: {e}")
|
||||
|
||||
if not self.cnn_models:
|
||||
logger.info("No CNN models available - RL will train without CNN features")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing CNN models: {e}")
|
||||
|
||||
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
|
||||
"""Load a CNN model from configuration"""
|
||||
try:
|
||||
# This would implement actual CNN model loading
|
||||
# For now, return None to indicate no models available
|
||||
# In your implementation, this would load your specific CNN architecture
|
||||
|
||||
logger.info(f"CNN model loading framework ready for {model_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CNN model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest CNN features and predictions for a symbol"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
||||
if cache_key in self.feature_cache:
|
||||
cached_data = self.feature_cache[cache_key]
|
||||
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
|
||||
return cached_data['features']
|
||||
|
||||
# Generate new features if models available
|
||||
if self.cnn_models:
|
||||
features = self._extract_cnn_features_for_symbol(symbol)
|
||||
|
||||
# Cache the features
|
||||
self.feature_cache[cache_key] = {
|
||||
'timestamp': datetime.now(),
|
||||
'features': features
|
||||
}
|
||||
|
||||
# Clean old cache entries
|
||||
self._cleanup_cache()
|
||||
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting CNN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Extract CNN hidden features and predictions for a symbol"""
|
||||
try:
|
||||
extracted_features = {
|
||||
'hidden_features': {},
|
||||
'predictions': {}
|
||||
}
|
||||
|
||||
for model_name, model in self.cnn_models.items():
|
||||
try:
|
||||
# Extract features from each CNN model
|
||||
hidden_features, predictions = self._extract_model_features(model, symbol)
|
||||
|
||||
if hidden_features is not None:
|
||||
extracted_features['hidden_features'][model_name] = hidden_features
|
||||
|
||||
if predictions is not None:
|
||||
extracted_features['predictions'][model_name] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from {model_name}: {e}")
|
||||
|
||||
return extracted_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN features for {symbol}: {e}")
|
||||
return {'hidden_features': {}, 'predictions': {}}
|
||||
|
||||
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Extract hidden features and predictions from a specific CNN model"""
|
||||
try:
|
||||
# This would implement the actual feature extraction from your CNN models
|
||||
# The implementation depends on your specific CNN architecture
|
||||
|
||||
# For now, return mock data to show the structure
|
||||
# In real implementation, this would:
|
||||
# 1. Get market data for the model
|
||||
# 2. Run forward pass through CNN
|
||||
# 3. Extract hidden layer activations
|
||||
# 4. Get model predictions
|
||||
|
||||
# Mock hidden features (last hidden layer of CNN)
|
||||
hidden_features = np.random.random(512).astype(np.float32)
|
||||
|
||||
# Mock predictions for different timeframes
|
||||
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
|
||||
predictions = np.array([
|
||||
0.45, # 1s prediction (probability of up move)
|
||||
0.52, # 1m prediction
|
||||
0.38, # 1h prediction
|
||||
0.61 # 1d prediction
|
||||
]).astype(np.float32)
|
||||
|
||||
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
|
||||
|
||||
return hidden_features, predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from model: {e}")
|
||||
return None, None
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""Clean up old cache entries"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
expired_keys = []
|
||||
|
||||
for key, data in self.feature_cache.items():
|
||||
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.feature_cache[key]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up feature cache: {e}")
|
||||
|
||||
def register_cnn_model(self, model_name: str, model: nn.Module):
|
||||
"""Register a CNN model for feature extraction"""
|
||||
try:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Registered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering CNN model {model_name}: {e}")
|
||||
|
||||
def unregister_cnn_model(self, model_name: str):
|
||||
"""Unregister a CNN model"""
|
||||
try:
|
||||
if model_name in self.cnn_models:
|
||||
del self.cnn_models[model_name]
|
||||
logger.info(f"Unregistered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error unregistering CNN model {model_name}: {e}")
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Get list of available CNN models"""
|
||||
return list(self.cnn_models.keys())
|
||||
|
||||
def is_model_available(self, model_name: str) -> bool:
|
||||
"""Check if a specific CNN model is available"""
|
||||
return model_name in self.cnn_models
|
||||
|
||||
def get_feature_dimensions(self) -> Dict[str, int]:
|
||||
"""Get the dimensions of features extracted from CNN models"""
|
||||
return {
|
||||
'hidden_features_per_model': 512,
|
||||
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
|
||||
'total_models': len(self.cnn_models)
|
||||
}
|
||||
|
||||
def validate_cnn_integration(self) -> Dict[str, Any]:
|
||||
"""Validate CNN integration status"""
|
||||
status = {
|
||||
'models_available': len(self.cnn_models),
|
||||
'models_list': list(self.cnn_models.keys()),
|
||||
'cache_entries': len(self.feature_cache),
|
||||
'integration_ready': len(self.cnn_models) > 0,
|
||||
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
|
||||
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
|
||||
}
|
||||
|
||||
return status
|
708
training/enhanced_rl_state_builder.py
Normal file
708
training/enhanced_rl_state_builder.py
Normal file
@ -0,0 +1,708 @@
|
||||
"""
|
||||
Enhanced RL State Builder for Comprehensive Market Data Integration
|
||||
|
||||
This module implements the specification requirements for RL training with:
|
||||
- 300s of raw tick data for momentum detection
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
|
||||
- CNN hidden layer features integration
|
||||
- CNN predictions from all timeframes
|
||||
- Pivot point predictions using Williams market structure
|
||||
- Market regime analysis
|
||||
|
||||
State Vector Components:
|
||||
- ETH tick data: ~3000 features (300s * 10 features/tick)
|
||||
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
|
||||
- BTC reference: ~2400 features (300 bars * 8 features)
|
||||
- CNN features: ~512 features (hidden layer)
|
||||
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
|
||||
- Pivot points: ~250 features (Williams structure)
|
||||
- Market regime: ~20 features
|
||||
Total: ~8000+ features
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
try:
|
||||
import ta
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("TA-Lib not available, using pandas for technical indicators")
|
||||
ta = None
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.universal_data_adapter import UniversalDataStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
bid: float = 0.0
|
||||
ask: float = 0.0
|
||||
|
||||
@property
|
||||
def spread(self) -> float:
|
||||
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""OHLCV data structure"""
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
# Technical indicators (optional)
|
||||
rsi: Optional[float] = None
|
||||
macd: Optional[float] = None
|
||||
bb_upper: Optional[float] = None
|
||||
bb_lower: Optional[float] = None
|
||||
sma_20: Optional[float] = None
|
||||
ema_12: Optional[float] = None
|
||||
atr: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class StateComponentConfig:
|
||||
"""Configuration for state component sizes"""
|
||||
eth_ticks: int = 3000 # 300s * 10 features per tick
|
||||
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
|
||||
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
btc_reference: int = 2400 # BTC reference data
|
||||
cnn_features: int = 512 # CNN hidden layer features
|
||||
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
|
||||
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
|
||||
market_regime: int = 20 # Market regime features
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
"""Calculate total state size"""
|
||||
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
|
||||
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
|
||||
self.cnn_features + self.cnn_predictions + self.pivot_points +
|
||||
self.market_regime)
|
||||
|
||||
class EnhancedRLStateBuilder:
|
||||
"""
|
||||
Comprehensive RL state builder implementing specification requirements
|
||||
|
||||
Features:
|
||||
- 300s tick data processing with momentum detection
|
||||
- Multi-timeframe OHLCV integration
|
||||
- CNN hidden layer feature extraction
|
||||
- Pivot point calculation and integration
|
||||
- Market regime analysis
|
||||
- BTC reference data processing
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
# Data windows
|
||||
self.tick_window_seconds = 300 # 5 minutes of tick data
|
||||
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
|
||||
|
||||
# State component sizes
|
||||
self.state_components = {
|
||||
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
|
||||
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'btc_reference': 300 * 8, # 2400 features: BTC reference data
|
||||
'cnn_features': 512, # 512 features: CNN hidden layer
|
||||
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
|
||||
'pivot_points': 250, # 250 features: Williams market structure
|
||||
'market_regime': 20 # 20 features: Market regime indicators
|
||||
}
|
||||
|
||||
self.total_state_size = sum(self.state_components.values())
|
||||
|
||||
# Data buffers for maintaining windows
|
||||
self.tick_buffers = {}
|
||||
self.ohlcv_buffers = {}
|
||||
|
||||
# Normalization parameters
|
||||
self.normalization_params = self._initialize_normalization_params()
|
||||
|
||||
# Feature extractors
|
||||
self.momentum_detector = TickMomentumDetector()
|
||||
self.indicator_calculator = TechnicalIndicatorCalculator()
|
||||
self.regime_analyzer = MarketRegimeAnalyzer()
|
||||
|
||||
logger.info(f"Enhanced RL State Builder initialized")
|
||||
logger.info(f"Total state size: {self.total_state_size} features")
|
||||
logger.info(f"State components: {self.state_components}")
|
||||
|
||||
def build_rl_state(self,
|
||||
eth_ticks: List[TickData],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]],
|
||||
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
|
||||
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||
"""
|
||||
Build comprehensive RL state vector from all data sources
|
||||
|
||||
Args:
|
||||
eth_ticks: List of ETH tick data (last 300s)
|
||||
eth_ohlcv: Dict of ETH OHLCV data by timeframe
|
||||
btc_ohlcv: Dict of BTC OHLCV data by timeframe
|
||||
cnn_hidden_features: CNN hidden layer features by timeframe
|
||||
cnn_predictions: CNN predictions by timeframe
|
||||
pivot_data: Pivot point data from Williams analysis
|
||||
|
||||
Returns:
|
||||
np.ndarray: Comprehensive state vector (~8000+ features)
|
||||
"""
|
||||
try:
|
||||
state_vector = []
|
||||
|
||||
# 1. Process ETH tick data (3000 features)
|
||||
tick_features = self._process_tick_data(eth_ticks)
|
||||
state_vector.extend(tick_features)
|
||||
|
||||
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
if timeframe in eth_ohlcv:
|
||||
ohlcv_features = self._process_ohlcv_data(
|
||||
eth_ohlcv[timeframe], timeframe, symbol='ETH'
|
||||
)
|
||||
else:
|
||||
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
|
||||
state_vector.extend(ohlcv_features)
|
||||
|
||||
# 3. Process BTC reference data (2400 features)
|
||||
btc_features = self._process_btc_reference_data(btc_ohlcv)
|
||||
state_vector.extend(btc_features)
|
||||
|
||||
# 4. Process CNN hidden layer features (512 features)
|
||||
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
||||
state_vector.extend(cnn_hidden)
|
||||
|
||||
# 5. Process CNN predictions (16 features)
|
||||
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
||||
state_vector.extend(cnn_pred)
|
||||
|
||||
# 6. Process pivot points (250 features)
|
||||
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
|
||||
state_vector.extend(pivot_features)
|
||||
|
||||
# 7. Process market regime features (20 features)
|
||||
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
|
||||
state_vector.extend(regime_features)
|
||||
|
||||
# Convert to numpy array and validate size
|
||||
state_array = np.array(state_vector, dtype=np.float32)
|
||||
|
||||
if len(state_array) != self.total_state_size:
|
||||
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
|
||||
# Pad or truncate to expected size
|
||||
if len(state_array) < self.total_state_size:
|
||||
padding = np.zeros(self.total_state_size - len(state_array))
|
||||
state_array = np.concatenate([state_array, padding])
|
||||
else:
|
||||
state_array = state_array[:self.total_state_size]
|
||||
|
||||
# Apply normalization
|
||||
state_array = self._normalize_state(state_array)
|
||||
|
||||
return state_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building RL state: {e}")
|
||||
# Return zero state on error
|
||||
return np.zeros(self.total_state_size, dtype=np.float32)
|
||||
|
||||
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
|
||||
"""Process raw tick data into features for momentum detection"""
|
||||
features = []
|
||||
|
||||
if not ticks or len(ticks) < 10:
|
||||
# Return zeros if insufficient data
|
||||
return [0.0] * self.state_components['eth_ticks']
|
||||
|
||||
# Ensure we have exactly 300 data points (pad or sample)
|
||||
processed_ticks = self._normalize_tick_window(ticks, 300)
|
||||
|
||||
for i, tick in enumerate(processed_ticks):
|
||||
# Basic tick features
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
tick.bid,
|
||||
tick.ask,
|
||||
tick.spread
|
||||
]
|
||||
|
||||
# Derived features
|
||||
if i > 0:
|
||||
prev_tick = processed_ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
|
||||
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_change,
|
||||
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
|
||||
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
|
||||
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
features.extend(tick_features)
|
||||
|
||||
return features[:self.state_components['eth_ticks']]
|
||||
|
||||
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
|
||||
timeframe: str, symbol: str = 'ETH') -> List[float]:
|
||||
"""Process OHLCV data with technical indicators"""
|
||||
features = []
|
||||
|
||||
if not ohlcv_data or len(ohlcv_data) < 20:
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return [0.0] * self.state_components[component_key]
|
||||
|
||||
# Convert to DataFrame for indicator calculation
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': bar.timestamp,
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume
|
||||
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
|
||||
|
||||
# Calculate technical indicators
|
||||
df = self.indicator_calculator.add_all_indicators(df)
|
||||
|
||||
# Ensure we have exactly 300 bars
|
||||
if len(df) < 300:
|
||||
# Pad with last known values
|
||||
last_row = df.iloc[-1:].copy()
|
||||
padding_rows = []
|
||||
for _ in range(300 - len(df)):
|
||||
padding_rows.append(last_row)
|
||||
if padding_rows:
|
||||
df = pd.concat([df] + padding_rows, ignore_index=True)
|
||||
else:
|
||||
df = df.tail(300)
|
||||
|
||||
# Extract features for each bar
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
|
||||
|
||||
for _, row in df.iterrows():
|
||||
bar_features = []
|
||||
for col in feature_columns:
|
||||
if col in row and not pd.isna(row[col]):
|
||||
bar_features.append(float(row[col]))
|
||||
else:
|
||||
bar_features.append(0.0)
|
||||
features.extend(bar_features)
|
||||
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return features[:self.state_components[component_key]]
|
||||
|
||||
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process BTC reference data (using 1h timeframe as primary)"""
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
|
||||
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
|
||||
else:
|
||||
return [0.0] * self.state_components['btc_reference']
|
||||
|
||||
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN hidden layer features"""
|
||||
if not cnn_features:
|
||||
return [0.0] * self.state_components['cnn_features']
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_features and cnn_features[tf] is not None:
|
||||
tf_features = cnn_features[tf].flatten()
|
||||
# Truncate or pad to fit allocation
|
||||
if len(tf_features) >= features_per_timeframe:
|
||||
combined_features.extend(tf_features[:features_per_timeframe])
|
||||
else:
|
||||
combined_features.extend(tf_features)
|
||||
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
|
||||
else:
|
||||
combined_features.extend([0.0] * features_per_timeframe)
|
||||
|
||||
return combined_features[:self.state_components['cnn_features']]
|
||||
|
||||
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN predictions from all timeframes"""
|
||||
if not cnn_predictions:
|
||||
return [0.0] * self.state_components['cnn_predictions']
|
||||
|
||||
predictions = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_predictions and cnn_predictions[tf] is not None:
|
||||
pred = cnn_predictions[tf].flatten()
|
||||
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
|
||||
if len(pred) >= 4:
|
||||
predictions.extend(pred[:4])
|
||||
else:
|
||||
predictions.extend(pred)
|
||||
predictions.extend([0.0] * (4 - len(pred)))
|
||||
else:
|
||||
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
|
||||
|
||||
return predictions[:self.state_components['cnn_predictions']]
|
||||
|
||||
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process pivot points using Williams market structure"""
|
||||
if pivot_data:
|
||||
# Use provided pivot data
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Calculate pivot points from 1m data
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
williams = WilliamsMarketStructure()
|
||||
|
||||
# Convert OHLCV to numpy array
|
||||
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
|
||||
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
else:
|
||||
return [0.0] * self.state_components['pivot_points']
|
||||
|
||||
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process market regime indicators"""
|
||||
regime_features = []
|
||||
|
||||
# ETH regime analysis
|
||||
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
|
||||
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
eth_regime['volatility'],
|
||||
eth_regime['trend_strength'],
|
||||
eth_regime['volume_trend'],
|
||||
eth_regime['momentum'],
|
||||
1.0 if eth_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# BTC regime analysis
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
btc_regime['volatility'],
|
||||
btc_regime['trend_strength'],
|
||||
btc_regime['volume_trend'],
|
||||
btc_regime['momentum'],
|
||||
1.0 if btc_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# Correlation features
|
||||
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
|
||||
regime_features.extend(correlation_features)
|
||||
|
||||
return regime_features[:self.state_components['market_regime']]
|
||||
|
||||
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
|
||||
"""Normalize tick window to target size"""
|
||||
if len(ticks) == target_size:
|
||||
return ticks
|
||||
elif len(ticks) > target_size:
|
||||
# Sample evenly
|
||||
step = len(ticks) / target_size
|
||||
indices = [int(i * step) for i in range(target_size)]
|
||||
return [ticks[i] for i in indices]
|
||||
else:
|
||||
# Pad with last tick
|
||||
result = ticks.copy()
|
||||
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
|
||||
while len(result) < target_size:
|
||||
result.append(last_tick)
|
||||
return result
|
||||
|
||||
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
|
||||
"""Extract features from pivot point data"""
|
||||
features = []
|
||||
|
||||
for level in range(5): # 5 levels of recursion
|
||||
level_key = f'level_{level}'
|
||||
if level_key in pivot_data:
|
||||
level_data = pivot_data[level_key]
|
||||
|
||||
# Swing point features
|
||||
swing_points = level_data.get('swing_points', [])
|
||||
if swing_points:
|
||||
# Last 10 swing points
|
||||
recent_swings = swing_points[-10:]
|
||||
for swing in recent_swings:
|
||||
features.extend([
|
||||
swing['price'],
|
||||
1.0 if swing['type'] == 'swing_high' else 0.0,
|
||||
swing['index']
|
||||
])
|
||||
|
||||
# Pad if fewer than 10 swings
|
||||
while len(recent_swings) < 10:
|
||||
features.extend([0.0, 0.0, 0.0])
|
||||
recent_swings.append({'type': 'none'})
|
||||
else:
|
||||
features.extend([0.0] * 30) # 10 swings * 3 features
|
||||
|
||||
# Trend features
|
||||
features.extend([
|
||||
level_data.get('trend_strength', 0.0),
|
||||
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
|
||||
1.0 if level_data.get('trend_direction') == 'down' else 0.0
|
||||
])
|
||||
else:
|
||||
features.extend([0.0] * 33) # 30 swing + 3 trend features
|
||||
|
||||
return features[:self.state_components['pivot_points']]
|
||||
|
||||
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
|
||||
"""Convert OHLCV data to numpy array"""
|
||||
return np.array([[
|
||||
bar.timestamp.timestamp(),
|
||||
bar.open,
|
||||
bar.high,
|
||||
bar.low,
|
||||
bar.close,
|
||||
bar.volume
|
||||
] for bar in ohlcv_data])
|
||||
|
||||
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Calculate BTC-ETH correlation features"""
|
||||
try:
|
||||
# Use 1h data for correlation
|
||||
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
|
||||
return [0.0] * 6
|
||||
|
||||
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
|
||||
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
|
||||
|
||||
if len(eth_prices) < 10 or len(btc_prices) < 10:
|
||||
return [0.0] * 6
|
||||
|
||||
# Align lengths
|
||||
min_len = min(len(eth_prices), len(btc_prices))
|
||||
eth_prices = eth_prices[-min_len:]
|
||||
btc_prices = btc_prices[-min_len:]
|
||||
|
||||
# Calculate returns
|
||||
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
|
||||
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
|
||||
|
||||
# Correlation
|
||||
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
|
||||
|
||||
# Price ratio
|
||||
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
|
||||
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
|
||||
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
|
||||
|
||||
# Volatility comparison
|
||||
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
|
||||
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
|
||||
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
|
||||
|
||||
return [
|
||||
correlation,
|
||||
current_ratio,
|
||||
ratio_deviation,
|
||||
vol_ratio,
|
||||
eth_vol,
|
||||
btc_vol
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
|
||||
return [0.0] * 6
|
||||
|
||||
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Initialize normalization parameters for different feature types"""
|
||||
return {
|
||||
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
|
||||
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
|
||||
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
|
||||
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
|
||||
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
|
||||
}
|
||||
|
||||
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Apply normalization to state vector"""
|
||||
try:
|
||||
# Simple clipping and scaling for now
|
||||
# More sophisticated normalization can be added based on training data
|
||||
normalized_state = np.clip(state, -10.0, 10.0)
|
||||
|
||||
# Replace any NaN or inf values
|
||||
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return normalized_state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing state: {e}")
|
||||
return state.astype(np.float32)
|
||||
|
||||
class TickMomentumDetector:
|
||||
"""Detect momentum from tick-level data"""
|
||||
|
||||
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
|
||||
"""Calculate micro-momentum from tick sequence"""
|
||||
if len(ticks) < 2:
|
||||
return 0.0
|
||||
|
||||
# Price momentum
|
||||
prices = [tick.price for tick in ticks]
|
||||
price_changes = np.diff(prices)
|
||||
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
|
||||
|
||||
# Volume-weighted momentum
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
if sum(volumes) > 0:
|
||||
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
|
||||
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
|
||||
else:
|
||||
volume_momentum = 0.0
|
||||
|
||||
return (price_momentum + volume_momentum) / 2.0
|
||||
|
||||
class TechnicalIndicatorCalculator:
|
||||
"""Calculate technical indicators for OHLCV data"""
|
||||
|
||||
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add all technical indicators to DataFrame"""
|
||||
df = df.copy()
|
||||
|
||||
# RSI
|
||||
df['rsi'] = self.calculate_rsi(df['close'])
|
||||
|
||||
# MACD
|
||||
df['macd'] = self.calculate_macd(df['close'])
|
||||
|
||||
# Bollinger Bands
|
||||
df['bb_middle'] = df['close'].rolling(20).mean()
|
||||
df['bb_std'] = df['close'].rolling(20).std()
|
||||
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
|
||||
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
|
||||
|
||||
# Fill NaN values
|
||||
df = df.fillna(method='forward').fillna(0)
|
||||
|
||||
return df
|
||||
|
||||
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Calculate RSI"""
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.fillna(50)
|
||||
|
||||
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||
"""Calculate MACD"""
|
||||
ema_fast = prices.ewm(span=fast).mean()
|
||||
ema_slow = prices.ewm(span=slow).mean()
|
||||
macd = ema_fast - ema_slow
|
||||
return macd.fillna(0)
|
||||
|
||||
class MarketRegimeAnalyzer:
|
||||
"""Analyze market regime from OHLCV data"""
|
||||
|
||||
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
|
||||
"""Analyze market regime"""
|
||||
if len(ohlcv_data) < 20:
|
||||
return {
|
||||
'regime': 'unknown',
|
||||
'volatility': 0.0,
|
||||
'trend_strength': 0.0,
|
||||
'volume_trend': 0.0,
|
||||
'momentum': 0.0
|
||||
}
|
||||
|
||||
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
|
||||
volumes = [bar.volume for bar in ohlcv_data[-50:]]
|
||||
|
||||
# Calculate volatility
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * 100 # Percentage volatility
|
||||
|
||||
# Calculate trend strength
|
||||
sma_short = np.mean(prices[-10:])
|
||||
sma_long = np.mean(prices[-30:])
|
||||
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
|
||||
|
||||
# Volume trend
|
||||
volume_ma_short = np.mean(volumes[-10:])
|
||||
volume_ma_long = np.mean(volumes[-30:])
|
||||
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
|
||||
|
||||
# Momentum
|
||||
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
|
||||
|
||||
# Determine regime
|
||||
if volatility > 3.0: # High volatility
|
||||
regime = 'volatile'
|
||||
elif abs(momentum) > 0.02: # Strong momentum
|
||||
regime = 'trending'
|
||||
else:
|
||||
regime = 'ranging'
|
||||
|
||||
return {
|
||||
'regime': regime,
|
||||
'volatility': volatility,
|
||||
'trend_strength': trend_strength,
|
||||
'volume_trend': volume_trend,
|
||||
'momentum': momentum
|
||||
}
|
||||
|
||||
def get_state_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the state structure"""
|
||||
return {
|
||||
'total_size': self.config.total_size,
|
||||
'components': {
|
||||
'eth_ticks': self.config.eth_ticks,
|
||||
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
|
||||
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
|
||||
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
|
||||
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
|
||||
'btc_reference': self.config.btc_reference,
|
||||
'cnn_features': self.config.cnn_features,
|
||||
'cnn_predictions': self.config.cnn_predictions,
|
||||
'pivot_points': self.config.pivot_points,
|
||||
'market_regime': self.config.market_regime,
|
||||
},
|
||||
'data_windows': {
|
||||
'tick_window_seconds': self.tick_window_seconds,
|
||||
'ohlcv_window_bars': self.ohlcv_window_bars,
|
||||
}
|
||||
}
|
@ -10,14 +10,16 @@ This module implements sophisticated RL training with:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque, namedtuple
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
@ -26,6 +28,9 @@ from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from models import RLAgentInterface
|
||||
import models
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -318,42 +323,66 @@ class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with continuous learning from market feedback"""
|
||||
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced RL trainer"""
|
||||
"""Initialize enhanced RL trainer with comprehensive state building"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Create RL agents for each symbol
|
||||
# Initialize comprehensive state builder (replaces mock code)
|
||||
self.state_builder = EnhancedRLStateBuilder(self.config)
|
||||
self.williams_structure = WilliamsMarketStructure()
|
||||
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
|
||||
|
||||
# Enhanced RL agents with much larger state space
|
||||
self.agents = {}
|
||||
for symbol in self.config.symbols:
|
||||
agent_config = self.config.rl.copy()
|
||||
agent_config['name'] = f'RL_{symbol}'
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
self.initialize_agents()
|
||||
|
||||
# Training parameters
|
||||
self.training_interval = 3600 # Train every hour
|
||||
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
|
||||
self.min_experiences = 100 # Minimum experiences before training
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = {symbol: [] for symbol in self.config.symbols}
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.config.symbols},
|
||||
'losses': {symbol: [] for symbol in self.config.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
|
||||
}
|
||||
|
||||
# Create save directory
|
||||
models_path = self.config.rl.get('model_dir', "models/enhanced_rl")
|
||||
self.save_dir = Path(models_path)
|
||||
# Training configuration
|
||||
self.symbols = self.config.symbols
|
||||
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
|
||||
|
||||
# Performance tracking
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.symbols},
|
||||
'losses': {symbol: [] for symbol in self.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.symbols}
|
||||
}
|
||||
|
||||
self.performance_history = {symbol: [] for symbol in self.symbols}
|
||||
|
||||
# Real-time learning parameters
|
||||
self.learning_active = False
|
||||
self.experience_buffer_size = 1000
|
||||
self.min_experiences_for_training = 100
|
||||
|
||||
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
|
||||
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
def initialize_agents(self):
|
||||
"""Initialize RL agents with enhanced state size"""
|
||||
for symbol in self.symbols:
|
||||
agent_config = {
|
||||
'state_size': self.state_builder.total_state_size, # ~13,400 features
|
||||
'action_space': 3, # BUY, SELL, HOLD
|
||||
'hidden_size': 1024, # Larger hidden layers for complex state
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.01,
|
||||
'buffer_size': 50000, # Larger replay buffer
|
||||
'batch_size': 128,
|
||||
'target_update_freq': 1000
|
||||
}
|
||||
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
|
||||
|
||||
async def continuous_learning_loop(self):
|
||||
"""Main continuous learning loop"""
|
||||
logger.info("Starting continuous RL learning loop")
|
||||
@ -378,7 +407,7 @@ class EnhancedRLTrainer:
|
||||
self._save_all_models()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(self.training_interval)
|
||||
await asyncio.sleep(3600) # Train every hour
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous learning loop: {e}")
|
||||
@ -388,7 +417,7 @@ class EnhancedRLTrainer:
|
||||
"""Train all RL agents with their experiences"""
|
||||
for symbol, agent in self.agents.items():
|
||||
try:
|
||||
if len(agent.replay_buffer) >= self.min_experiences:
|
||||
if len(agent.replay_buffer) >= self.min_experiences_for_training:
|
||||
# Train for multiple steps
|
||||
losses = []
|
||||
for _ in range(10): # Train 10 steps per cycle
|
||||
@ -411,7 +440,7 @@ class EnhancedRLTrainer:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
for symbol in self.config.symbols:
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Get recent market states
|
||||
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||
@ -471,11 +500,150 @@ class EnhancedRLTrainer:
|
||||
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to RL state vector"""
|
||||
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
|
||||
return self.orchestrator._market_state_to_rl_state(market_state)
|
||||
"""Convert market state to comprehensive RL state vector using real data"""
|
||||
try:
|
||||
# Extract data from market state and orchestrator
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for comprehensive state building")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
# Get real tick data from orchestrator's data provider
|
||||
symbol = market_state.symbol
|
||||
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
|
||||
|
||||
# Get multi-timeframe OHLCV data
|
||||
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
|
||||
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
|
||||
|
||||
# Get CNN features if available
|
||||
cnn_hidden_features = None
|
||||
cnn_predictions = None
|
||||
if self.cnn_rl_bridge:
|
||||
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
|
||||
if cnn_data:
|
||||
cnn_hidden_features = cnn_data.get('hidden_features', {})
|
||||
cnn_predictions = cnn_data.get('predictions', {})
|
||||
|
||||
# Get pivot point data
|
||||
pivot_data = self._calculate_pivot_points(eth_ohlcv)
|
||||
|
||||
# Build comprehensive state using enhanced state builder
|
||||
comprehensive_state = self.state_builder.build_rl_state(
|
||||
eth_ticks=eth_ticks,
|
||||
eth_ohlcv=eth_ohlcv,
|
||||
btc_ohlcv=btc_ohlcv,
|
||||
cnn_hidden_features=cnn_hidden_features,
|
||||
cnn_predictions=cnn_predictions,
|
||||
pivot_data=pivot_data
|
||||
)
|
||||
|
||||
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
|
||||
return comprehensive_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state: {e}")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
|
||||
"""Get recent tick data from orchestrator's data provider"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
# Get recent ticks from data provider
|
||||
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
|
||||
|
||||
# Convert to required format
|
||||
tick_data = []
|
||||
for tick in recent_ticks[-300:]: # Last 300 ticks max
|
||||
tick_data.append({
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': getattr(tick, 'quantity', tick.volume),
|
||||
'side': getattr(tick, 'side', 'unknown'),
|
||||
'trade_id': getattr(tick, 'trade_id', 'unknown')
|
||||
})
|
||||
|
||||
return tick_data
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
|
||||
"""Get multi-timeframe OHLCV data"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
# Get historical data for timeframe
|
||||
df = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
limit=300,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to list of dictionaries
|
||||
bars = []
|
||||
for _, row in df.tail(300).iterrows():
|
||||
bar = {
|
||||
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
|
||||
'open': float(row.get('open', 0)),
|
||||
'high': float(row.get('high', 0)),
|
||||
'low': float(row.get('low', 0)),
|
||||
'close': float(row.get('close', 0)),
|
||||
'volume': float(row.get('volume', 0))
|
||||
}
|
||||
bars.append(bar)
|
||||
|
||||
ohlcv_data[tf] = bars
|
||||
else:
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
|
||||
"""Calculate Williams pivot points from OHLCV data"""
|
||||
try:
|
||||
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Convert to numpy array for Williams calculation
|
||||
bars = eth_ohlcv['1m']
|
||||
if len(bars) >= 50: # Need minimum data for pivot calculation
|
||||
ohlc_array = np.array([
|
||||
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
|
||||
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
|
||||
for bar in bars[-200:] # Last 200 bars
|
||||
])
|
||||
|
||||
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
|
||||
return pivot_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating pivot points: {e}")
|
||||
return {}
|
||||
|
||||
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Fallback to basic state conversion if comprehensive state building fails"""
|
||||
logger.warning("Using fallback state conversion - limited features")
|
||||
|
||||
# Fallback implementation
|
||||
state_components = [
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
@ -486,8 +654,8 @@ class EnhancedRLTrainer:
|
||||
for timeframe in sorted(market_state.prices.keys()):
|
||||
state_components.append(market_state.prices[timeframe])
|
||||
|
||||
# Pad or truncate to expected state size
|
||||
expected_size = self.config.rl.get('state_size', 100)
|
||||
# Pad to match expected state size
|
||||
expected_size = self.state_builder.total_state_size
|
||||
if len(state_components) < expected_size:
|
||||
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||
else:
|
||||
@ -545,7 +713,7 @@ class EnhancedRLTrainer:
|
||||
timestamp = max(timestamps)
|
||||
|
||||
loaded_count = 0
|
||||
for symbol in self.config.symbols:
|
||||
for symbol in self.symbols:
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
|
640
training/williams_market_structure.py
Normal file
640
training/williams_market_structure.py
Normal file
@ -0,0 +1,640 @@
|
||||
"""
|
||||
Williams Market Structure Implementation for RL Training
|
||||
|
||||
This module implements Larry Williams market structure analysis methodology for
|
||||
RL training enhancement with:
|
||||
- Swing high/low detection with configurable strength
|
||||
- 5 levels of recursive pivot point calculation
|
||||
- Trend analysis (higher highs/lows vs lower highs/lows)
|
||||
- Market bias determination across multiple timeframes
|
||||
- Feature extraction for RL training (250 features)
|
||||
|
||||
Based on Larry Williams' teachings on market structure:
|
||||
- Markets move in swings between support and resistance
|
||||
- Higher timeframe structure determines lower timeframe bias
|
||||
- Recursive analysis reveals fractal patterns
|
||||
- Trend direction determined by swing point relationships
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrendDirection(Enum):
|
||||
UP = "up"
|
||||
DOWN = "down"
|
||||
SIDEWAYS = "sideways"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
class SwingType(Enum):
|
||||
SWING_HIGH = "swing_high"
|
||||
SWING_LOW = "swing_low"
|
||||
|
||||
@dataclass
|
||||
class SwingPoint:
|
||||
"""Represents a swing high or low point"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
index: int
|
||||
swing_type: SwingType
|
||||
strength: int # Number of bars on each side that confirm the swing
|
||||
volume: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class TrendAnalysis:
|
||||
"""Trend analysis results"""
|
||||
direction: TrendDirection
|
||||
strength: float # 0.0 to 1.0
|
||||
confidence: float # 0.0 to 1.0
|
||||
swing_count: int
|
||||
last_swing_high: Optional[SwingPoint]
|
||||
last_swing_low: Optional[SwingPoint]
|
||||
higher_highs: int
|
||||
higher_lows: int
|
||||
lower_highs: int
|
||||
lower_lows: int
|
||||
|
||||
@dataclass
|
||||
class MarketStructureLevel:
|
||||
"""Market structure analysis for one recursive level"""
|
||||
level: int
|
||||
swing_points: List[SwingPoint]
|
||||
trend_analysis: TrendAnalysis
|
||||
support_levels: List[float]
|
||||
resistance_levels: List[float]
|
||||
current_bias: TrendDirection
|
||||
structure_breaks: List[Dict[str, Any]]
|
||||
|
||||
class WilliamsMarketStructure:
|
||||
"""
|
||||
Implementation of Larry Williams market structure methodology
|
||||
|
||||
Features:
|
||||
- Multi-strength swing detection (2, 3, 5, 8, 13 bar strengths)
|
||||
- 5 levels of recursive analysis
|
||||
- Trend direction determination
|
||||
- Support/resistance level identification
|
||||
- Market bias calculation
|
||||
- Structure break detection
|
||||
"""
|
||||
|
||||
def __init__(self, swing_strengths: List[int] = None):
|
||||
"""
|
||||
Initialize Williams market structure analyzer
|
||||
|
||||
Args:
|
||||
swing_strengths: List of swing detection strengths (bars on each side)
|
||||
"""
|
||||
self.swing_strengths = swing_strengths or [2, 3, 5, 8, 13] # Fibonacci-based strengths
|
||||
self.max_levels = 5
|
||||
self.min_swings_for_trend = 3
|
||||
|
||||
# Cache for performance
|
||||
self.swing_cache = {}
|
||||
self.trend_cache = {}
|
||||
|
||||
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
|
||||
|
||||
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
|
||||
"""
|
||||
Calculate 5 levels of recursive pivot points
|
||||
|
||||
Args:
|
||||
ohlcv_data: OHLCV data array with columns [timestamp, open, high, low, close, volume]
|
||||
|
||||
Returns:
|
||||
Dict of market structure levels with swing points and trend analysis
|
||||
"""
|
||||
if len(ohlcv_data) < 20:
|
||||
logger.warning("Insufficient data for Williams structure analysis")
|
||||
return self._create_empty_structure()
|
||||
|
||||
levels = {}
|
||||
current_data = ohlcv_data.copy()
|
||||
|
||||
for level in range(self.max_levels):
|
||||
logger.debug(f"Analyzing level {level} with {len(current_data)} data points")
|
||||
|
||||
# Find swing points for this level
|
||||
swing_points = self._find_swing_points_multi_strength(current_data)
|
||||
|
||||
if len(swing_points) < self.min_swings_for_trend:
|
||||
logger.debug(f"Not enough swings at level {level}: {len(swing_points)}")
|
||||
# Fill remaining levels with empty data
|
||||
for remaining_level in range(level, self.max_levels):
|
||||
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
|
||||
break
|
||||
|
||||
# Analyze trend for this level
|
||||
trend_analysis = self._analyze_trend_from_swings(swing_points)
|
||||
|
||||
# Find support/resistance levels
|
||||
support_levels, resistance_levels = self._find_support_resistance(
|
||||
swing_points, current_data
|
||||
)
|
||||
|
||||
# Determine current market bias
|
||||
current_bias = self._determine_market_bias(swing_points, trend_analysis)
|
||||
|
||||
# Detect structure breaks
|
||||
structure_breaks = self._detect_structure_breaks(swing_points, current_data)
|
||||
|
||||
# Create level data
|
||||
levels[f'level_{level}'] = MarketStructureLevel(
|
||||
level=level,
|
||||
swing_points=swing_points,
|
||||
trend_analysis=trend_analysis,
|
||||
support_levels=support_levels,
|
||||
resistance_levels=resistance_levels,
|
||||
current_bias=current_bias,
|
||||
structure_breaks=structure_breaks
|
||||
)
|
||||
|
||||
# Prepare data for next level (use swing points as input)
|
||||
if len(swing_points) >= 5:
|
||||
current_data = self._convert_swings_to_ohlcv(swing_points)
|
||||
if len(current_data) < 10:
|
||||
logger.debug(f"Insufficient converted data for level {level + 1}")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Not enough swings to continue to level {level + 1}")
|
||||
break
|
||||
|
||||
# Fill any remaining empty levels
|
||||
for remaining_level in range(len(levels), self.max_levels):
|
||||
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
|
||||
|
||||
return levels
|
||||
|
||||
def _find_swing_points_multi_strength(self, ohlcv_data: np.ndarray) -> List[SwingPoint]:
|
||||
"""Find swing points using multiple strength criteria"""
|
||||
all_swings = []
|
||||
|
||||
for strength in self.swing_strengths:
|
||||
swings = self._find_swing_points_single_strength(ohlcv_data, strength)
|
||||
for swing in swings:
|
||||
# Avoid duplicates (swings at same index)
|
||||
if not any(existing.index == swing.index for existing in all_swings):
|
||||
all_swings.append(swing)
|
||||
|
||||
# Sort by timestamp/index
|
||||
all_swings.sort(key=lambda x: x.index)
|
||||
|
||||
# Filter to get the most significant swings
|
||||
return self._filter_significant_swings(all_swings)
|
||||
|
||||
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
|
||||
"""Find swing points with specific strength requirement"""
|
||||
swings = []
|
||||
|
||||
if len(ohlcv_data) < (strength * 2 + 1):
|
||||
return swings
|
||||
|
||||
for i in range(strength, len(ohlcv_data) - strength):
|
||||
current_high = ohlcv_data[i, 2] # High price
|
||||
current_low = ohlcv_data[i, 3] # Low price
|
||||
current_volume = ohlcv_data[i, 5] if ohlcv_data.shape[1] > 5 else 0.0
|
||||
|
||||
# Check for swing high (higher than surrounding bars)
|
||||
is_swing_high = True
|
||||
for j in range(i - strength, i + strength + 1):
|
||||
if j != i and ohlcv_data[j, 2] >= current_high:
|
||||
is_swing_high = False
|
||||
break
|
||||
|
||||
if is_swing_high:
|
||||
swings.append(SwingPoint(
|
||||
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
||||
price=current_high,
|
||||
index=i,
|
||||
swing_type=SwingType.SWING_HIGH,
|
||||
strength=strength,
|
||||
volume=current_volume
|
||||
))
|
||||
|
||||
# Check for swing low (lower than surrounding bars)
|
||||
is_swing_low = True
|
||||
for j in range(i - strength, i + strength + 1):
|
||||
if j != i and ohlcv_data[j, 3] <= current_low:
|
||||
is_swing_low = False
|
||||
break
|
||||
|
||||
if is_swing_low:
|
||||
swings.append(SwingPoint(
|
||||
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
||||
price=current_low,
|
||||
index=i,
|
||||
swing_type=SwingType.SWING_LOW,
|
||||
strength=strength,
|
||||
volume=current_volume
|
||||
))
|
||||
|
||||
return swings
|
||||
|
||||
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
|
||||
"""Filter to keep only the most significant swings"""
|
||||
if len(swings) <= 20:
|
||||
return swings
|
||||
|
||||
# Sort by strength (higher strength = more significant)
|
||||
swings_by_strength = sorted(swings, key=lambda x: x.strength, reverse=True)
|
||||
|
||||
# Take top swings but ensure we have alternating highs and lows
|
||||
significant_swings = []
|
||||
last_type = None
|
||||
|
||||
for swing in swings_by_strength:
|
||||
if len(significant_swings) >= 20:
|
||||
break
|
||||
|
||||
# Prefer alternating swing types for better structure
|
||||
if last_type is None or swing.swing_type != last_type:
|
||||
significant_swings.append(swing)
|
||||
last_type = swing.swing_type
|
||||
elif len(significant_swings) < 10: # Still add if we need more swings
|
||||
significant_swings.append(swing)
|
||||
|
||||
# Sort by index again
|
||||
significant_swings.sort(key=lambda x: x.index)
|
||||
return significant_swings
|
||||
|
||||
def _analyze_trend_from_swings(self, swing_points: List[SwingPoint]) -> TrendAnalysis:
|
||||
"""Analyze trend direction from swing points"""
|
||||
if len(swing_points) < 2:
|
||||
return TrendAnalysis(
|
||||
direction=TrendDirection.UNKNOWN,
|
||||
strength=0.0,
|
||||
confidence=0.0,
|
||||
swing_count=0,
|
||||
last_swing_high=None,
|
||||
last_swing_low=None,
|
||||
higher_highs=0,
|
||||
higher_lows=0,
|
||||
lower_highs=0,
|
||||
lower_lows=0
|
||||
)
|
||||
|
||||
# Separate highs and lows
|
||||
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
|
||||
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
|
||||
|
||||
# Count higher highs, higher lows, lower highs, lower lows
|
||||
higher_highs = self._count_higher_highs(highs)
|
||||
higher_lows = self._count_higher_lows(lows)
|
||||
lower_highs = self._count_lower_highs(highs)
|
||||
lower_lows = self._count_lower_lows(lows)
|
||||
|
||||
# Determine trend direction
|
||||
if higher_highs > 0 and higher_lows > 0:
|
||||
direction = TrendDirection.UP
|
||||
elif lower_highs > 0 and lower_lows > 0:
|
||||
direction = TrendDirection.DOWN
|
||||
else:
|
||||
direction = TrendDirection.SIDEWAYS
|
||||
|
||||
# Calculate trend strength
|
||||
total_moves = higher_highs + higher_lows + lower_highs + lower_lows
|
||||
if direction == TrendDirection.UP:
|
||||
strength = (higher_highs + higher_lows) / max(total_moves, 1)
|
||||
elif direction == TrendDirection.DOWN:
|
||||
strength = (lower_highs + lower_lows) / max(total_moves, 1)
|
||||
else:
|
||||
strength = 0.5 # Neutral for sideways
|
||||
|
||||
# Calculate confidence based on consistency
|
||||
if total_moves > 0:
|
||||
if direction == TrendDirection.UP:
|
||||
confidence = (higher_highs + higher_lows) / total_moves
|
||||
elif direction == TrendDirection.DOWN:
|
||||
confidence = (lower_highs + lower_lows) / total_moves
|
||||
else:
|
||||
# For sideways, confidence is based on how balanced it is
|
||||
up_moves = higher_highs + higher_lows
|
||||
down_moves = lower_highs + lower_lows
|
||||
balance = 1.0 - abs(up_moves - down_moves) / total_moves
|
||||
confidence = balance
|
||||
else:
|
||||
confidence = 0.0
|
||||
|
||||
return TrendAnalysis(
|
||||
direction=direction,
|
||||
strength=min(strength, 1.0),
|
||||
confidence=min(confidence, 1.0),
|
||||
swing_count=len(swing_points),
|
||||
last_swing_high=highs[-1] if highs else None,
|
||||
last_swing_low=lows[-1] if lows else None,
|
||||
higher_highs=higher_highs,
|
||||
higher_lows=higher_lows,
|
||||
lower_highs=lower_highs,
|
||||
lower_lows=lower_lows
|
||||
)
|
||||
|
||||
def _count_higher_highs(self, highs: List[SwingPoint]) -> int:
|
||||
"""Count higher highs in sequence"""
|
||||
if len(highs) < 2:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for i in range(1, len(highs)):
|
||||
if highs[i].price > highs[i-1].price:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _count_higher_lows(self, lows: List[SwingPoint]) -> int:
|
||||
"""Count higher lows in sequence"""
|
||||
if len(lows) < 2:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for i in range(1, len(lows)):
|
||||
if lows[i].price > lows[i-1].price:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _count_lower_highs(self, highs: List[SwingPoint]) -> int:
|
||||
"""Count lower highs in sequence"""
|
||||
if len(highs) < 2:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for i in range(1, len(highs)):
|
||||
if highs[i].price < highs[i-1].price:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _count_lower_lows(self, lows: List[SwingPoint]) -> int:
|
||||
"""Count lower lows in sequence"""
|
||||
if len(lows) < 2:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for i in range(1, len(lows)):
|
||||
if lows[i].price < lows[i-1].price:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _find_support_resistance(self, swing_points: List[SwingPoint],
|
||||
ohlcv_data: np.ndarray) -> Tuple[List[float], List[float]]:
|
||||
"""Find support and resistance levels from swing points"""
|
||||
highs = [s.price for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
|
||||
lows = [s.price for s in swing_points if s.swing_type == SwingType.SWING_LOW]
|
||||
|
||||
# Cluster similar levels
|
||||
support_levels = self._cluster_price_levels(lows) if lows else []
|
||||
resistance_levels = self._cluster_price_levels(highs) if highs else []
|
||||
|
||||
return support_levels, resistance_levels
|
||||
|
||||
def _cluster_price_levels(self, prices: List[float], tolerance: float = 0.02) -> List[float]:
|
||||
"""Cluster similar price levels together"""
|
||||
if not prices:
|
||||
return []
|
||||
|
||||
sorted_prices = sorted(prices)
|
||||
clusters = []
|
||||
current_cluster = [sorted_prices[0]]
|
||||
|
||||
for price in sorted_prices[1:]:
|
||||
# If price is within tolerance of cluster average, add to cluster
|
||||
cluster_avg = np.mean(current_cluster)
|
||||
if abs(price - cluster_avg) / cluster_avg <= tolerance:
|
||||
current_cluster.append(price)
|
||||
else:
|
||||
# Start new cluster
|
||||
clusters.append(np.mean(current_cluster))
|
||||
current_cluster = [price]
|
||||
|
||||
# Add last cluster
|
||||
if current_cluster:
|
||||
clusters.append(np.mean(current_cluster))
|
||||
|
||||
return clusters
|
||||
|
||||
def _determine_market_bias(self, swing_points: List[SwingPoint],
|
||||
trend_analysis: TrendAnalysis) -> TrendDirection:
|
||||
"""Determine current market bias"""
|
||||
if not swing_points:
|
||||
return TrendDirection.UNKNOWN
|
||||
|
||||
# Use trend analysis as primary indicator
|
||||
if trend_analysis.confidence > 0.6:
|
||||
return trend_analysis.direction
|
||||
|
||||
# Look at most recent swings for bias
|
||||
recent_swings = swing_points[-6:] if len(swing_points) >= 6 else swing_points
|
||||
|
||||
if len(recent_swings) >= 2:
|
||||
first_price = recent_swings[0].price
|
||||
last_price = recent_swings[-1].price
|
||||
|
||||
price_change = (last_price - first_price) / first_price
|
||||
|
||||
if price_change > 0.01: # 1% threshold
|
||||
return TrendDirection.UP
|
||||
elif price_change < -0.01:
|
||||
return TrendDirection.DOWN
|
||||
else:
|
||||
return TrendDirection.SIDEWAYS
|
||||
|
||||
return TrendDirection.UNKNOWN
|
||||
|
||||
def _detect_structure_breaks(self, swing_points: List[SwingPoint],
|
||||
ohlcv_data: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""Detect structure breaks (trend changes)"""
|
||||
structure_breaks = []
|
||||
|
||||
if len(swing_points) < 4:
|
||||
return structure_breaks
|
||||
|
||||
# Look for pattern breaks
|
||||
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
|
||||
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
|
||||
|
||||
# Check for break of structure in highs (lower high after higher highs)
|
||||
if len(highs) >= 3:
|
||||
for i in range(2, len(highs)):
|
||||
if (highs[i-2].price < highs[i-1].price and # Previous was higher high
|
||||
highs[i-1].price > highs[i].price): # Current is lower high
|
||||
|
||||
structure_breaks.append({
|
||||
'type': 'break_of_structure_high',
|
||||
'timestamp': highs[i].timestamp,
|
||||
'price': highs[i].price,
|
||||
'previous_high': highs[i-1].price,
|
||||
'significance': abs(highs[i].price - highs[i-1].price) / highs[i-1].price
|
||||
})
|
||||
|
||||
# Check for break of structure in lows (higher low after lower lows)
|
||||
if len(lows) >= 3:
|
||||
for i in range(2, len(lows)):
|
||||
if (lows[i-2].price > lows[i-1].price and # Previous was lower low
|
||||
lows[i-1].price < lows[i].price): # Current is higher low
|
||||
|
||||
structure_breaks.append({
|
||||
'type': 'break_of_structure_low',
|
||||
'timestamp': lows[i].timestamp,
|
||||
'price': lows[i].price,
|
||||
'previous_low': lows[i-1].price,
|
||||
'significance': abs(lows[i].price - lows[i-1].price) / lows[i-1].price
|
||||
})
|
||||
|
||||
return structure_breaks
|
||||
|
||||
def _convert_swings_to_ohlcv(self, swing_points: List[SwingPoint]) -> np.ndarray:
|
||||
"""Convert swing points to OHLCV format for next level analysis"""
|
||||
if len(swing_points) < 2:
|
||||
return np.array([])
|
||||
|
||||
ohlcv_data = []
|
||||
|
||||
for i in range(len(swing_points) - 1):
|
||||
current_swing = swing_points[i]
|
||||
next_swing = swing_points[i + 1]
|
||||
|
||||
# Create synthetic OHLCV bar from swing to swing
|
||||
if current_swing.swing_type == SwingType.SWING_HIGH:
|
||||
# From high to next point
|
||||
open_price = current_swing.price
|
||||
high_price = current_swing.price
|
||||
low_price = min(current_swing.price, next_swing.price)
|
||||
close_price = next_swing.price
|
||||
else:
|
||||
# From low to next point
|
||||
open_price = current_swing.price
|
||||
high_price = max(current_swing.price, next_swing.price)
|
||||
low_price = current_swing.price
|
||||
close_price = next_swing.price
|
||||
|
||||
ohlcv_data.append([
|
||||
current_swing.timestamp.timestamp(),
|
||||
open_price,
|
||||
high_price,
|
||||
low_price,
|
||||
close_price,
|
||||
current_swing.volume
|
||||
])
|
||||
|
||||
return np.array(ohlcv_data)
|
||||
|
||||
def _create_empty_structure(self) -> Dict[str, MarketStructureLevel]:
|
||||
"""Create empty structure when insufficient data"""
|
||||
return {f'level_{i}': self._create_empty_level(i) for i in range(self.max_levels)}
|
||||
|
||||
def _create_empty_level(self, level: int) -> MarketStructureLevel:
|
||||
"""Create empty market structure level"""
|
||||
return MarketStructureLevel(
|
||||
level=level,
|
||||
swing_points=[],
|
||||
trend_analysis=TrendAnalysis(
|
||||
direction=TrendDirection.UNKNOWN,
|
||||
strength=0.0,
|
||||
confidence=0.0,
|
||||
swing_count=0,
|
||||
last_swing_high=None,
|
||||
last_swing_low=None,
|
||||
higher_highs=0,
|
||||
higher_lows=0,
|
||||
lower_highs=0,
|
||||
lower_lows=0
|
||||
),
|
||||
support_levels=[],
|
||||
resistance_levels=[],
|
||||
current_bias=TrendDirection.UNKNOWN,
|
||||
structure_breaks=[]
|
||||
)
|
||||
|
||||
def extract_features_for_rl(self, structure_levels: Dict[str, MarketStructureLevel]) -> List[float]:
|
||||
"""
|
||||
Extract features from Williams structure for RL training
|
||||
|
||||
Returns ~250 features total:
|
||||
- 50 features per level (5 levels)
|
||||
"""
|
||||
features = []
|
||||
|
||||
for level in range(self.max_levels):
|
||||
level_key = f'level_{level}'
|
||||
if level_key in structure_levels:
|
||||
level_data = structure_levels[level_key]
|
||||
level_features = self._extract_level_features(level_data)
|
||||
else:
|
||||
level_features = [0.0] * 50 # 50 features per level
|
||||
|
||||
features.extend(level_features)
|
||||
|
||||
return features[:250] # Ensure exactly 250 features
|
||||
|
||||
def _extract_level_features(self, level: MarketStructureLevel) -> List[float]:
|
||||
"""Extract features from a single structure level"""
|
||||
features = []
|
||||
|
||||
# Trend features (10 features)
|
||||
features.extend([
|
||||
1.0 if level.trend_analysis.direction == TrendDirection.UP else 0.0,
|
||||
1.0 if level.trend_analysis.direction == TrendDirection.DOWN else 0.0,
|
||||
1.0 if level.trend_analysis.direction == TrendDirection.SIDEWAYS else 0.0,
|
||||
level.trend_analysis.strength,
|
||||
level.trend_analysis.confidence,
|
||||
level.trend_analysis.higher_highs,
|
||||
level.trend_analysis.higher_lows,
|
||||
level.trend_analysis.lower_highs,
|
||||
level.trend_analysis.lower_lows,
|
||||
len(level.swing_points)
|
||||
])
|
||||
|
||||
# Current bias features (4 features)
|
||||
features.extend([
|
||||
1.0 if level.current_bias == TrendDirection.UP else 0.0,
|
||||
1.0 if level.current_bias == TrendDirection.DOWN else 0.0,
|
||||
1.0 if level.current_bias == TrendDirection.SIDEWAYS else 0.0,
|
||||
1.0 if level.current_bias == TrendDirection.UNKNOWN else 0.0
|
||||
])
|
||||
|
||||
# Swing point features (20 features - last 10 swings * 2 features each)
|
||||
recent_swings = level.swing_points[-10:] if len(level.swing_points) >= 10 else level.swing_points
|
||||
for swing in recent_swings:
|
||||
features.extend([
|
||||
swing.price,
|
||||
1.0 if swing.swing_type == SwingType.SWING_HIGH else 0.0
|
||||
])
|
||||
|
||||
# Pad if fewer than 10 swings
|
||||
while len(recent_swings) < 10:
|
||||
features.extend([0.0, 0.0])
|
||||
recent_swings.append(None) # Just for counting
|
||||
|
||||
# Support/resistance levels (10 features - 5 support + 5 resistance)
|
||||
support_levels = level.support_levels[:5] if len(level.support_levels) >= 5 else level.support_levels
|
||||
while len(support_levels) < 5:
|
||||
support_levels.append(0.0)
|
||||
features.extend(support_levels)
|
||||
|
||||
resistance_levels = level.resistance_levels[:5] if len(level.resistance_levels) >= 5 else level.resistance_levels
|
||||
while len(resistance_levels) < 5:
|
||||
resistance_levels.append(0.0)
|
||||
features.extend(resistance_levels)
|
||||
|
||||
# Structure break features (6 features)
|
||||
recent_breaks = level.structure_breaks[-3:] if len(level.structure_breaks) >= 3 else level.structure_breaks
|
||||
for break_info in recent_breaks:
|
||||
features.extend([
|
||||
break_info.get('significance', 0.0),
|
||||
1.0 if break_info.get('type', '').endswith('_high') else 0.0
|
||||
])
|
||||
|
||||
# Pad if fewer than 3 breaks
|
||||
while len(recent_breaks) < 3:
|
||||
features.extend([0.0, 0.0])
|
||||
recent_breaks.append({})
|
||||
|
||||
return features[:50] # Ensure exactly 50 features per level
|
Reference in New Issue
Block a user