gogo2/training/williams_market_structure.py
2025-05-28 23:42:06 +03:00

640 lines
25 KiB
Python

"""
Williams Market Structure Implementation for RL Training
This module implements Larry Williams market structure analysis methodology for
RL training enhancement with:
- Swing high/low detection with configurable strength
- 5 levels of recursive pivot point calculation
- Trend analysis (higher highs/lows vs lower highs/lows)
- Market bias determination across multiple timeframes
- Feature extraction for RL training (250 features)
Based on Larry Williams' teachings on market structure:
- Markets move in swings between support and resistance
- Higher timeframe structure determines lower timeframe bias
- Recursive analysis reveals fractal patterns
- Trend direction determined by swing point relationships
"""
import numpy as np
import pandas as pd
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class TrendDirection(Enum):
UP = "up"
DOWN = "down"
SIDEWAYS = "sideways"
UNKNOWN = "unknown"
class SwingType(Enum):
SWING_HIGH = "swing_high"
SWING_LOW = "swing_low"
@dataclass
class SwingPoint:
"""Represents a swing high or low point"""
timestamp: datetime
price: float
index: int
swing_type: SwingType
strength: int # Number of bars on each side that confirm the swing
volume: float = 0.0
@dataclass
class TrendAnalysis:
"""Trend analysis results"""
direction: TrendDirection
strength: float # 0.0 to 1.0
confidence: float # 0.0 to 1.0
swing_count: int
last_swing_high: Optional[SwingPoint]
last_swing_low: Optional[SwingPoint]
higher_highs: int
higher_lows: int
lower_highs: int
lower_lows: int
@dataclass
class MarketStructureLevel:
"""Market structure analysis for one recursive level"""
level: int
swing_points: List[SwingPoint]
trend_analysis: TrendAnalysis
support_levels: List[float]
resistance_levels: List[float]
current_bias: TrendDirection
structure_breaks: List[Dict[str, Any]]
class WilliamsMarketStructure:
"""
Implementation of Larry Williams market structure methodology
Features:
- Multi-strength swing detection (2, 3, 5, 8, 13 bar strengths)
- 5 levels of recursive analysis
- Trend direction determination
- Support/resistance level identification
- Market bias calculation
- Structure break detection
"""
def __init__(self, swing_strengths: List[int] = None):
"""
Initialize Williams market structure analyzer
Args:
swing_strengths: List of swing detection strengths (bars on each side)
"""
self.swing_strengths = swing_strengths or [2, 3, 5, 8, 13] # Fibonacci-based strengths
self.max_levels = 5
self.min_swings_for_trend = 3
# Cache for performance
self.swing_cache = {}
self.trend_cache = {}
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
"""
Calculate 5 levels of recursive pivot points
Args:
ohlcv_data: OHLCV data array with columns [timestamp, open, high, low, close, volume]
Returns:
Dict of market structure levels with swing points and trend analysis
"""
if len(ohlcv_data) < 20:
logger.warning("Insufficient data for Williams structure analysis")
return self._create_empty_structure()
levels = {}
current_data = ohlcv_data.copy()
for level in range(self.max_levels):
logger.debug(f"Analyzing level {level} with {len(current_data)} data points")
# Find swing points for this level
swing_points = self._find_swing_points_multi_strength(current_data)
if len(swing_points) < self.min_swings_for_trend:
logger.debug(f"Not enough swings at level {level}: {len(swing_points)}")
# Fill remaining levels with empty data
for remaining_level in range(level, self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
break
# Analyze trend for this level
trend_analysis = self._analyze_trend_from_swings(swing_points)
# Find support/resistance levels
support_levels, resistance_levels = self._find_support_resistance(
swing_points, current_data
)
# Determine current market bias
current_bias = self._determine_market_bias(swing_points, trend_analysis)
# Detect structure breaks
structure_breaks = self._detect_structure_breaks(swing_points, current_data)
# Create level data
levels[f'level_{level}'] = MarketStructureLevel(
level=level,
swing_points=swing_points,
trend_analysis=trend_analysis,
support_levels=support_levels,
resistance_levels=resistance_levels,
current_bias=current_bias,
structure_breaks=structure_breaks
)
# Prepare data for next level (use swing points as input)
if len(swing_points) >= 5:
current_data = self._convert_swings_to_ohlcv(swing_points)
if len(current_data) < 10:
logger.debug(f"Insufficient converted data for level {level + 1}")
break
else:
logger.debug(f"Not enough swings to continue to level {level + 1}")
break
# Fill any remaining empty levels
for remaining_level in range(len(levels), self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
return levels
def _find_swing_points_multi_strength(self, ohlcv_data: np.ndarray) -> List[SwingPoint]:
"""Find swing points using multiple strength criteria"""
all_swings = []
for strength in self.swing_strengths:
swings = self._find_swing_points_single_strength(ohlcv_data, strength)
for swing in swings:
# Avoid duplicates (swings at same index)
if not any(existing.index == swing.index for existing in all_swings):
all_swings.append(swing)
# Sort by timestamp/index
all_swings.sort(key=lambda x: x.index)
# Filter to get the most significant swings
return self._filter_significant_swings(all_swings)
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
"""Find swing points with specific strength requirement"""
swings = []
if len(ohlcv_data) < (strength * 2 + 1):
return swings
for i in range(strength, len(ohlcv_data) - strength):
current_high = ohlcv_data[i, 2] # High price
current_low = ohlcv_data[i, 3] # Low price
current_volume = ohlcv_data[i, 5] if ohlcv_data.shape[1] > 5 else 0.0
# Check for swing high (higher than surrounding bars)
is_swing_high = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 2] >= current_high:
is_swing_high = False
break
if is_swing_high:
swings.append(SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_high,
index=i,
swing_type=SwingType.SWING_HIGH,
strength=strength,
volume=current_volume
))
# Check for swing low (lower than surrounding bars)
is_swing_low = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 3] <= current_low:
is_swing_low = False
break
if is_swing_low:
swings.append(SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_low,
index=i,
swing_type=SwingType.SWING_LOW,
strength=strength,
volume=current_volume
))
return swings
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
"""Filter to keep only the most significant swings"""
if len(swings) <= 20:
return swings
# Sort by strength (higher strength = more significant)
swings_by_strength = sorted(swings, key=lambda x: x.strength, reverse=True)
# Take top swings but ensure we have alternating highs and lows
significant_swings = []
last_type = None
for swing in swings_by_strength:
if len(significant_swings) >= 20:
break
# Prefer alternating swing types for better structure
if last_type is None or swing.swing_type != last_type:
significant_swings.append(swing)
last_type = swing.swing_type
elif len(significant_swings) < 10: # Still add if we need more swings
significant_swings.append(swing)
# Sort by index again
significant_swings.sort(key=lambda x: x.index)
return significant_swings
def _analyze_trend_from_swings(self, swing_points: List[SwingPoint]) -> TrendAnalysis:
"""Analyze trend direction from swing points"""
if len(swing_points) < 2:
return TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
)
# Separate highs and lows
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Count higher highs, higher lows, lower highs, lower lows
higher_highs = self._count_higher_highs(highs)
higher_lows = self._count_higher_lows(lows)
lower_highs = self._count_lower_highs(highs)
lower_lows = self._count_lower_lows(lows)
# Determine trend direction
if higher_highs > 0 and higher_lows > 0:
direction = TrendDirection.UP
elif lower_highs > 0 and lower_lows > 0:
direction = TrendDirection.DOWN
else:
direction = TrendDirection.SIDEWAYS
# Calculate trend strength
total_moves = higher_highs + higher_lows + lower_highs + lower_lows
if direction == TrendDirection.UP:
strength = (higher_highs + higher_lows) / max(total_moves, 1)
elif direction == TrendDirection.DOWN:
strength = (lower_highs + lower_lows) / max(total_moves, 1)
else:
strength = 0.5 # Neutral for sideways
# Calculate confidence based on consistency
if total_moves > 0:
if direction == TrendDirection.UP:
confidence = (higher_highs + higher_lows) / total_moves
elif direction == TrendDirection.DOWN:
confidence = (lower_highs + lower_lows) / total_moves
else:
# For sideways, confidence is based on how balanced it is
up_moves = higher_highs + higher_lows
down_moves = lower_highs + lower_lows
balance = 1.0 - abs(up_moves - down_moves) / total_moves
confidence = balance
else:
confidence = 0.0
return TrendAnalysis(
direction=direction,
strength=min(strength, 1.0),
confidence=min(confidence, 1.0),
swing_count=len(swing_points),
last_swing_high=highs[-1] if highs else None,
last_swing_low=lows[-1] if lows else None,
higher_highs=higher_highs,
higher_lows=higher_lows,
lower_highs=lower_highs,
lower_lows=lower_lows
)
def _count_higher_highs(self, highs: List[SwingPoint]) -> int:
"""Count higher highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price > highs[i-1].price:
count += 1
return count
def _count_higher_lows(self, lows: List[SwingPoint]) -> int:
"""Count higher lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price > lows[i-1].price:
count += 1
return count
def _count_lower_highs(self, highs: List[SwingPoint]) -> int:
"""Count lower highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price < highs[i-1].price:
count += 1
return count
def _count_lower_lows(self, lows: List[SwingPoint]) -> int:
"""Count lower lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price < lows[i-1].price:
count += 1
return count
def _find_support_resistance(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> Tuple[List[float], List[float]]:
"""Find support and resistance levels from swing points"""
highs = [s.price for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s.price for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Cluster similar levels
support_levels = self._cluster_price_levels(lows) if lows else []
resistance_levels = self._cluster_price_levels(highs) if highs else []
return support_levels, resistance_levels
def _cluster_price_levels(self, prices: List[float], tolerance: float = 0.02) -> List[float]:
"""Cluster similar price levels together"""
if not prices:
return []
sorted_prices = sorted(prices)
clusters = []
current_cluster = [sorted_prices[0]]
for price in sorted_prices[1:]:
# If price is within tolerance of cluster average, add to cluster
cluster_avg = np.mean(current_cluster)
if abs(price - cluster_avg) / cluster_avg <= tolerance:
current_cluster.append(price)
else:
# Start new cluster
clusters.append(np.mean(current_cluster))
current_cluster = [price]
# Add last cluster
if current_cluster:
clusters.append(np.mean(current_cluster))
return clusters
def _determine_market_bias(self, swing_points: List[SwingPoint],
trend_analysis: TrendAnalysis) -> TrendDirection:
"""Determine current market bias"""
if not swing_points:
return TrendDirection.UNKNOWN
# Use trend analysis as primary indicator
if trend_analysis.confidence > 0.6:
return trend_analysis.direction
# Look at most recent swings for bias
recent_swings = swing_points[-6:] if len(swing_points) >= 6 else swing_points
if len(recent_swings) >= 2:
first_price = recent_swings[0].price
last_price = recent_swings[-1].price
price_change = (last_price - first_price) / first_price
if price_change > 0.01: # 1% threshold
return TrendDirection.UP
elif price_change < -0.01:
return TrendDirection.DOWN
else:
return TrendDirection.SIDEWAYS
return TrendDirection.UNKNOWN
def _detect_structure_breaks(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> List[Dict[str, Any]]:
"""Detect structure breaks (trend changes)"""
structure_breaks = []
if len(swing_points) < 4:
return structure_breaks
# Look for pattern breaks
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Check for break of structure in highs (lower high after higher highs)
if len(highs) >= 3:
for i in range(2, len(highs)):
if (highs[i-2].price < highs[i-1].price and # Previous was higher high
highs[i-1].price > highs[i].price): # Current is lower high
structure_breaks.append({
'type': 'break_of_structure_high',
'timestamp': highs[i].timestamp,
'price': highs[i].price,
'previous_high': highs[i-1].price,
'significance': abs(highs[i].price - highs[i-1].price) / highs[i-1].price
})
# Check for break of structure in lows (higher low after lower lows)
if len(lows) >= 3:
for i in range(2, len(lows)):
if (lows[i-2].price > lows[i-1].price and # Previous was lower low
lows[i-1].price < lows[i].price): # Current is higher low
structure_breaks.append({
'type': 'break_of_structure_low',
'timestamp': lows[i].timestamp,
'price': lows[i].price,
'previous_low': lows[i-1].price,
'significance': abs(lows[i].price - lows[i-1].price) / lows[i-1].price
})
return structure_breaks
def _convert_swings_to_ohlcv(self, swing_points: List[SwingPoint]) -> np.ndarray:
"""Convert swing points to OHLCV format for next level analysis"""
if len(swing_points) < 2:
return np.array([])
ohlcv_data = []
for i in range(len(swing_points) - 1):
current_swing = swing_points[i]
next_swing = swing_points[i + 1]
# Create synthetic OHLCV bar from swing to swing
if current_swing.swing_type == SwingType.SWING_HIGH:
# From high to next point
open_price = current_swing.price
high_price = current_swing.price
low_price = min(current_swing.price, next_swing.price)
close_price = next_swing.price
else:
# From low to next point
open_price = current_swing.price
high_price = max(current_swing.price, next_swing.price)
low_price = current_swing.price
close_price = next_swing.price
ohlcv_data.append([
current_swing.timestamp.timestamp(),
open_price,
high_price,
low_price,
close_price,
current_swing.volume
])
return np.array(ohlcv_data)
def _create_empty_structure(self) -> Dict[str, MarketStructureLevel]:
"""Create empty structure when insufficient data"""
return {f'level_{i}': self._create_empty_level(i) for i in range(self.max_levels)}
def _create_empty_level(self, level: int) -> MarketStructureLevel:
"""Create empty market structure level"""
return MarketStructureLevel(
level=level,
swing_points=[],
trend_analysis=TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
),
support_levels=[],
resistance_levels=[],
current_bias=TrendDirection.UNKNOWN,
structure_breaks=[]
)
def extract_features_for_rl(self, structure_levels: Dict[str, MarketStructureLevel]) -> List[float]:
"""
Extract features from Williams structure for RL training
Returns ~250 features total:
- 50 features per level (5 levels)
"""
features = []
for level in range(self.max_levels):
level_key = f'level_{level}'
if level_key in structure_levels:
level_data = structure_levels[level_key]
level_features = self._extract_level_features(level_data)
else:
level_features = [0.0] * 50 # 50 features per level
features.extend(level_features)
return features[:250] # Ensure exactly 250 features
def _extract_level_features(self, level: MarketStructureLevel) -> List[float]:
"""Extract features from a single structure level"""
features = []
# Trend features (10 features)
features.extend([
1.0 if level.trend_analysis.direction == TrendDirection.UP else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.DOWN else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.SIDEWAYS else 0.0,
level.trend_analysis.strength,
level.trend_analysis.confidence,
level.trend_analysis.higher_highs,
level.trend_analysis.higher_lows,
level.trend_analysis.lower_highs,
level.trend_analysis.lower_lows,
len(level.swing_points)
])
# Current bias features (4 features)
features.extend([
1.0 if level.current_bias == TrendDirection.UP else 0.0,
1.0 if level.current_bias == TrendDirection.DOWN else 0.0,
1.0 if level.current_bias == TrendDirection.SIDEWAYS else 0.0,
1.0 if level.current_bias == TrendDirection.UNKNOWN else 0.0
])
# Swing point features (20 features - last 10 swings * 2 features each)
recent_swings = level.swing_points[-10:] if len(level.swing_points) >= 10 else level.swing_points
for swing in recent_swings:
features.extend([
swing.price,
1.0 if swing.swing_type == SwingType.SWING_HIGH else 0.0
])
# Pad if fewer than 10 swings
while len(recent_swings) < 10:
features.extend([0.0, 0.0])
recent_swings.append(None) # Just for counting
# Support/resistance levels (10 features - 5 support + 5 resistance)
support_levels = level.support_levels[:5] if len(level.support_levels) >= 5 else level.support_levels
while len(support_levels) < 5:
support_levels.append(0.0)
features.extend(support_levels)
resistance_levels = level.resistance_levels[:5] if len(level.resistance_levels) >= 5 else level.resistance_levels
while len(resistance_levels) < 5:
resistance_levels.append(0.0)
features.extend(resistance_levels)
# Structure break features (6 features)
recent_breaks = level.structure_breaks[-3:] if len(level.structure_breaks) >= 3 else level.structure_breaks
for break_info in recent_breaks:
features.extend([
break_info.get('significance', 0.0),
1.0 if break_info.get('type', '').endswith('_high') else 0.0
])
# Pad if fewer than 3 breaks
while len(recent_breaks) < 3:
features.extend([0.0, 0.0])
recent_breaks.append({})
return features[:50] # Ensure exactly 50 features per level