gogo2/training/williams_market_structure.py
2025-05-30 03:20:05 +03:00

1278 lines
59 KiB
Python

"""
Williams Market Structure Implementation for RL Training
This module implements Larry Williams market structure analysis methodology for
RL training enhancement with:
**SINGLE TIMEFRAME RECURSIVE APPROACH:**
- Level 0: 1s OHLCV data → swing points using configurable strength [2, 3, 5]
- Level 1: Level 0 swing points treated as "price bars" → higher-level swing points
- Level 2: Level 1 swing points treated as "price bars" → even higher-level swing points
- Level 3: Level 2 swing points treated as "price bars" → top-level swing points
- Level 4: Level 3 swing points treated as "price bars" → highest-level swing points
**RECURSIVE METHODOLOGY:**
Each level uses the previous level's swing points as input "price data", where:
- Each swing point becomes a "price bar" with OHLC = swing point price
- Swing strength detection applied to find patterns in swing point sequences
- This creates fractal market structure analysis across 5 recursive levels
**NOT MULTI-TIMEFRAME:**
Williams structure uses ONLY 1s data and builds recursively.
Multi-timeframe data (1m, 1h) is used separately for CNN feature enhancement.
Based on Larry Williams' teachings on market structure:
- Markets move in swings between support and resistance
- Higher recursive levels reveal longer-term structure patterns
- Recursive analysis reveals fractal patterns within market movements
- Trend direction determined by swing point relationships across levels
"""
import numpy as np
import pandas as pd
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
try:
from NN.models.cnn_model import CNNModel
except ImportError:
CNNModel = None # Allow running without TF/CNN if not installed or path issue
print("Warning: CNNModel could not be imported. CNN-based pivot prediction/training will be disabled.")
try:
from core.unified_data_stream import TrainingDataPacket
except ImportError:
TrainingDataPacket = None
print("Warning: TrainingDataPacket could not be imported. Using fallback interface.")
logger = logging.getLogger(__name__)
class TrendDirection(Enum):
UP = "up"
DOWN = "down"
SIDEWAYS = "sideways"
UNKNOWN = "unknown"
class SwingType(Enum):
SWING_HIGH = "swing_high"
SWING_LOW = "swing_low"
@dataclass
class SwingPoint:
"""Represents a swing high or low point"""
timestamp: datetime
price: float
index: int
swing_type: SwingType
strength: int # Number of bars on each side that confirm the swing
volume: float = 0.0
@dataclass
class TrendAnalysis:
"""Trend analysis results"""
direction: TrendDirection
strength: float # 0.0 to 1.0
confidence: float # 0.0 to 1.0
swing_count: int
last_swing_high: Optional[SwingPoint]
last_swing_low: Optional[SwingPoint]
higher_highs: int
higher_lows: int
lower_highs: int
lower_lows: int
@dataclass
class MarketStructureLevel:
"""Market structure analysis for one recursive level"""
level: int
swing_points: List[SwingPoint]
trend_analysis: TrendAnalysis
support_levels: List[float]
resistance_levels: List[float]
current_bias: TrendDirection
structure_breaks: List[Dict[str, Any]]
class WilliamsMarketStructure:
"""
Implementation of Larry Williams market structure methodology
Features:
- Multi-strength swing detection (2, 3, 5, 8, 13 bar strengths)
- 5 levels of recursive analysis
- Trend direction determination
- Support/resistance level identification
- Market bias calculation
- Structure break detection
"""
def __init__(self,
swing_strengths: List[int] = None,
cnn_input_shape: Optional[Tuple[int, int]] = (900, 50), # Updated: 900 timesteps (1s), 50 features
cnn_output_size: Optional[int] = 10, # Updated: 5 levels * (type + price) = 10 outputs
cnn_model_config: Optional[Dict[str, Any]] = None, # For build_model params like filters, learning_rate
cnn_model_path: Optional[str] = None,
enable_cnn_feature: bool = True, # Master switch for this feature
training_data_provider: Optional[Any] = None): # Provider for TrainingDataPacket access
"""
Initialize Williams market structure analyzer
Args:
swing_strengths: List of swing detection strengths (bars on each side)
cnn_input_shape: Shape of input data for CNN (sequence_length, features)
cnn_output_size: Number of output classes for CNN (10 for 5 levels * 2 outputs each)
cnn_model_config: Dictionary with parameters for CNNModel.build_model()
cnn_model_path: Path to a pre-trained Keras CNN model (.h5 file)
enable_cnn_feature: If True, enables CNN prediction and training at pivots.
training_data_provider: Provider/stream for accessing TrainingDataPacket
"""
self.swing_strengths = swing_strengths or [2, 3, 5] # Simplified strengths for better performance
self.max_levels = 5
self.min_swings_for_trend = 3
# Cache for performance
self.swing_cache = {}
self.trend_cache = {}
self.enable_cnn_feature = enable_cnn_feature and CNNModel is not None
self.cnn_model: Optional[CNNModel] = None
self.previous_pivot_details_for_cnn: Optional[Dict[str, Any]] = None # Stores {'features': X, 'pivot': SwingPoint}
self.training_data_provider = training_data_provider # Access to TrainingDataPacket
if self.enable_cnn_feature:
try:
logger.info(f"Initializing CNN for multi-timeframe pivot prediction. Input: {cnn_input_shape}, Output: {cnn_output_size}")
logger.info("CNN will predict next pivot (type + price) for all 5 Williams levels")
self.cnn_model = CNNModel(input_shape=cnn_input_shape, output_size=cnn_output_size)
if cnn_model_path:
logger.info(f"Loading pre-trained CNN model from: {cnn_model_path}")
self.cnn_model.load(cnn_model_path)
else:
logger.info("Building new CNN model.")
# Use provided config or defaults for build_model
build_params = cnn_model_config or {}
self.cnn_model.build_model(**build_params)
logger.info("CNN Model initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize or load CNN model: {e}. Disabling CNN feature.", exc_info=True)
self.enable_cnn_feature = False
else:
logger.info("CNN feature for pivot prediction/training is disabled.")
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
"""
Calculate 5 levels of recursive pivot points using SINGLE TIMEFRAME (1s) data
**RECURSIVE STRUCTURE:**
- Level 0: Raw 1s OHLCV data → swing points (strength 2, 3, 5)
- Level 1: Level 0 swing points → treated as "price bars" → higher-level swing points
- Level 2: Level 1 swing points → treated as "price bars" → even higher-level swing points
- Level 3: Level 2 swing points → treated as "price bars" → top-level swing points
- Level 4: Level 3 swing points → treated as "price bars" → highest-level swing points
**HOW RECURSION WORKS:**
1. Start with 1s OHLCV data (timestamp, open, high, low, close, volume)
2. Find Level 0 swing points using configurable strength [2, 3, 5]
3. Convert Level 0 swing points to "price bar" format where OHLC = swing price
4. Apply swing detection to these "price bars" to find Level 1 swing points
5. Repeat process: Level N swing points → "price bars" → Level N+1 swing points
This creates a fractal analysis where each level reveals longer-term structure patterns
within the same 1s timeframe data, NOT across different timeframes.
Args:
ohlcv_data: 1s OHLCV data array [timestamp, open, high, low, close, volume]
Returns:
Dict of 5 market structure levels with recursive swing points and analysis
"""
if len(ohlcv_data) < 20:
logger.warning("Insufficient data for Williams structure analysis")
return self._create_empty_structure()
levels = {}
current_price_points = ohlcv_data.copy() # Start with raw 1s OHLCV data
for level in range(self.max_levels):
logger.debug(f"Analyzing level {level} with {len(current_price_points)} data points")
if level == 0:
# Level 0: Calculate swing points from raw 1s OHLCV data
swing_points = self._find_swing_points_multi_strength(current_price_points)
else:
# Level 1-4: Calculate swing points from previous level's swing points
# Previous level's swing points are treated as "price bars"
swing_points = self._find_pivot_points_from_pivot_points(current_price_points, level)
if len(swing_points) < self.min_swings_for_trend:
logger.debug(f"Not enough swings at level {level}: {len(swing_points)}")
# Fill remaining levels with empty data
for remaining_level in range(level, self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
break
# Analyze trend for this level
trend_analysis = self._analyze_trend_from_swings(swing_points)
# Find support/resistance levels
support_levels, resistance_levels = self._find_support_resistance(
swing_points, current_price_points if level == 0 else None
)
# Determine current market bias
current_bias = self._determine_market_bias(swing_points, trend_analysis)
# Detect structure breaks
structure_breaks = self._detect_structure_breaks(swing_points, current_price_points if level == 0 else None)
# Create level data
levels[f'level_{level}'] = MarketStructureLevel(
level=level,
swing_points=swing_points,
trend_analysis=trend_analysis,
support_levels=support_levels,
resistance_levels=resistance_levels,
current_bias=current_bias,
structure_breaks=structure_breaks
)
# Prepare data for next level: convert swing points to "price points"
if len(swing_points) >= 5:
current_price_points = self._convert_pivots_to_price_points(swing_points)
if len(current_price_points) < 10:
logger.debug(f"Insufficient pivot data for level {level + 1}")
break
else:
logger.debug(f"Not enough swings to continue to level {level + 1}")
break
# Fill any remaining empty levels
for remaining_level in range(len(levels), self.max_levels):
levels[f'level_{remaining_level}'] = self._create_empty_level(remaining_level)
return levels
def _find_swing_points_multi_strength(self, ohlcv_data: np.ndarray) -> List[SwingPoint]:
"""Find swing points using multiple strength criteria"""
all_swings = []
for strength in self.swing_strengths:
swings_at_strength = self._find_swing_points_single_strength(ohlcv_data, strength)
for swing in swings_at_strength:
# Avoid duplicates (swings at same index)
if not any(existing.index == swing.index for existing in all_swings):
all_swings.append(swing)
# Sort by timestamp/index
all_swings.sort(key=lambda x: x.index)
# Filter to get the most significant swings
return self._filter_significant_swings(all_swings)
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
"""Find swing points with specific strength requirement"""
identified_swings_in_this_call = [] # Temporary list for swings found in this specific call
if len(ohlcv_data) < (strength * 2 + 1):
return identified_swings_in_this_call
for i in range(strength, len(ohlcv_data) - strength):
current_high = ohlcv_data[i, 2] # High price
current_low = ohlcv_data[i, 3] # Low price
current_volume = ohlcv_data[i, 5] if ohlcv_data.shape[1] > 5 else 0.0
# Check for swing high (higher than surrounding bars)
is_swing_high = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 2] >= current_high:
is_swing_high = False
break
if is_swing_high:
new_pivot = SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_high,
index=i,
swing_type=SwingType.SWING_HIGH,
strength=strength,
volume=current_volume
)
identified_swings_in_this_call.append(new_pivot)
self._handle_cnn_at_pivot(new_pivot, ohlcv_data) # CNN logic call
# Check for swing low (lower than surrounding bars)
is_swing_low = True
for j in range(i - strength, i + strength + 1):
if j != i and ohlcv_data[j, 3] <= current_low:
is_swing_low = False
break
if is_swing_low:
new_pivot = SwingPoint(
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
price=current_low,
index=i,
swing_type=SwingType.SWING_LOW,
strength=strength,
volume=current_volume
)
identified_swings_in_this_call.append(new_pivot)
self._handle_cnn_at_pivot(new_pivot, ohlcv_data) # CNN logic call
return identified_swings_in_this_call # Return swings found in this call
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
"""Filter to keep only the most significant swings"""
if len(swings) <= 20:
return swings
# Sort by strength (higher strength = more significant)
swings_by_strength = sorted(swings, key=lambda x: x.strength, reverse=True)
# Take top swings but ensure we have alternating highs and lows
significant_swings = []
last_type = None
for swing in swings_by_strength:
if len(significant_swings) >= 20:
break
# Prefer alternating swing types for better structure
if last_type is None or swing.swing_type != last_type:
significant_swings.append(swing)
last_type = swing.swing_type
elif len(significant_swings) < 10: # Still add if we need more swings
significant_swings.append(swing)
# Sort by index again
significant_swings.sort(key=lambda x: x.index)
return significant_swings
def _analyze_trend_from_swings(self, swing_points: List[SwingPoint]) -> TrendAnalysis:
"""Analyze trend direction from swing points"""
if len(swing_points) < 2:
return TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
)
# Separate highs and lows
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Count higher highs, higher lows, lower highs, lower lows
higher_highs = self._count_higher_highs(highs)
higher_lows = self._count_higher_lows(lows)
lower_highs = self._count_lower_highs(highs)
lower_lows = self._count_lower_lows(lows)
# Determine trend direction
if higher_highs > 0 and higher_lows > 0:
direction = TrendDirection.UP
elif lower_highs > 0 and lower_lows > 0:
direction = TrendDirection.DOWN
else:
direction = TrendDirection.SIDEWAYS
# Calculate trend strength
total_moves = higher_highs + higher_lows + lower_highs + lower_lows
if direction == TrendDirection.UP:
strength = (higher_highs + higher_lows) / max(total_moves, 1)
elif direction == TrendDirection.DOWN:
strength = (lower_highs + lower_lows) / max(total_moves, 1)
else:
strength = 0.5 # Neutral for sideways
# Calculate confidence based on consistency
if total_moves > 0:
if direction == TrendDirection.UP:
confidence = (higher_highs + higher_lows) / total_moves
elif direction == TrendDirection.DOWN:
confidence = (lower_highs + lower_lows) / total_moves
else:
# For sideways, confidence is based on how balanced it is
up_moves = higher_highs + higher_lows
down_moves = lower_highs + lower_lows
balance = 1.0 - abs(up_moves - down_moves) / total_moves
confidence = balance
else:
confidence = 0.0
return TrendAnalysis(
direction=direction,
strength=min(strength, 1.0),
confidence=min(confidence, 1.0),
swing_count=len(swing_points),
last_swing_high=highs[-1] if highs else None,
last_swing_low=lows[-1] if lows else None,
higher_highs=higher_highs,
higher_lows=higher_lows,
lower_highs=lower_highs,
lower_lows=lower_lows
)
def _count_higher_highs(self, highs: List[SwingPoint]) -> int:
"""Count higher highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price > highs[i-1].price:
count += 1
return count
def _count_higher_lows(self, lows: List[SwingPoint]) -> int:
"""Count higher lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price > lows[i-1].price:
count += 1
return count
def _count_lower_highs(self, highs: List[SwingPoint]) -> int:
"""Count lower highs in sequence"""
if len(highs) < 2:
return 0
count = 0
for i in range(1, len(highs)):
if highs[i].price < highs[i-1].price:
count += 1
return count
def _count_lower_lows(self, lows: List[SwingPoint]) -> int:
"""Count lower lows in sequence"""
if len(lows) < 2:
return 0
count = 0
for i in range(1, len(lows)):
if lows[i].price < lows[i-1].price:
count += 1
return count
def _find_support_resistance(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> Tuple[List[float], List[float]]:
"""Find support and resistance levels from swing points"""
highs = [s.price for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s.price for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Cluster similar levels
support_levels = self._cluster_price_levels(lows) if lows else []
resistance_levels = self._cluster_price_levels(highs) if highs else []
return support_levels, resistance_levels
def _cluster_price_levels(self, prices: List[float], tolerance: float = 0.02) -> List[float]:
"""Cluster similar price levels together"""
if not prices:
return []
sorted_prices = sorted(prices)
clusters = []
current_cluster = [sorted_prices[0]]
for price in sorted_prices[1:]:
# If price is within tolerance of cluster average, add to cluster
cluster_avg = np.mean(current_cluster)
if abs(price - cluster_avg) / cluster_avg <= tolerance:
current_cluster.append(price)
else:
# Start new cluster
clusters.append(np.mean(current_cluster))
current_cluster = [price]
# Add last cluster
if current_cluster:
clusters.append(np.mean(current_cluster))
return clusters
def _determine_market_bias(self, swing_points: List[SwingPoint],
trend_analysis: TrendAnalysis) -> TrendDirection:
"""Determine current market bias"""
if not swing_points:
return TrendDirection.UNKNOWN
# Use trend analysis as primary indicator
if trend_analysis.confidence > 0.6:
return trend_analysis.direction
# Look at most recent swings for bias
recent_swings = swing_points[-6:] if len(swing_points) >= 6 else swing_points
if len(recent_swings) >= 2:
first_price = recent_swings[0].price
last_price = recent_swings[-1].price
price_change = (last_price - first_price) / first_price
if price_change > 0.01: # 1% threshold
return TrendDirection.UP
elif price_change < -0.01:
return TrendDirection.DOWN
else:
return TrendDirection.SIDEWAYS
return TrendDirection.UNKNOWN
def _detect_structure_breaks(self, swing_points: List[SwingPoint],
ohlcv_data: np.ndarray) -> List[Dict[str, Any]]:
"""Detect structure breaks (trend changes)"""
structure_breaks = []
if len(swing_points) < 4:
return structure_breaks
# Look for pattern breaks
highs = [s for s in swing_points if s.swing_type == SwingType.SWING_HIGH]
lows = [s for s in swing_points if s.swing_type == SwingType.SWING_LOW]
# Check for break of structure in highs (lower high after higher highs)
if len(highs) >= 3:
for i in range(2, len(highs)):
if (highs[i-2].price < highs[i-1].price and # Previous was higher high
highs[i-1].price > highs[i].price): # Current is lower high
structure_breaks.append({
'type': 'break_of_structure_high',
'timestamp': highs[i].timestamp,
'price': highs[i].price,
'previous_high': highs[i-1].price,
'significance': abs(highs[i].price - highs[i-1].price) / highs[i-1].price
})
# Check for break of structure in lows (higher low after lower lows)
if len(lows) >= 3:
for i in range(2, len(lows)):
if (lows[i-2].price > lows[i-1].price and # Previous was lower low
lows[i-1].price < lows[i].price): # Current is higher low
structure_breaks.append({
'type': 'break_of_structure_low',
'timestamp': lows[i].timestamp,
'price': lows[i].price,
'previous_low': lows[i-1].price,
'significance': abs(lows[i].price - lows[i-1].price) / lows[i-1].price
})
return structure_breaks
def _find_pivot_points_from_pivot_points(self, pivot_array: np.ndarray, level: int) -> List[SwingPoint]:
"""
Find swing points from previous level's swing points (RECURSIVE APPROACH)
**RECURSIVE SWING DETECTION:**
For Level N (where N > 0): A Level N swing high occurs when a Level N-1 swing point
is higher than surrounding Level N-1 swing points (and vice versa for lows).
This is NOT multi-timeframe analysis - it's recursive fractal analysis where:
- Level 1 finds patterns in Level 0 swing sequences (from 1s data)
- Level 2 finds patterns in Level 1 swing sequences
- Level 3 finds patterns in Level 2 swing sequences
- Level 4 finds patterns in Level 3 swing sequences
All based on the original 1s timeframe data, recursively analyzed.
Args:
pivot_array: Array of Level N-1 swing points formatted as "price bars"
[timestamp, price, price, price, price, 0] format
level: Current recursive level being calculated (1, 2, 3, or 4)
"""
identified_swings_in_this_call = [] # Temporary list
if len(pivot_array) < 5: # Min bars for even smallest strength (e.g. strength 2 needs 2+1+2=5)
return identified_swings_in_this_call
# Use configurable strength for higher levels (more conservative)
strength = min(2 + level, 5) # Level 1: 3 bars, Level 2: 4 bars, Level 3+: 5 bars
for i in range(strength, len(pivot_array) - strength):
current_price = pivot_array[i, 1] # Use the price from pivot point
current_timestamp = pivot_array[i, 0]
# Check for swing high (pivot high surrounded by lower pivot highs)
is_swing_high = True
for j in range(i - strength, i + strength + 1):
if j != i and pivot_array[j, 1] >= current_price: # Compare with price of other pivots
is_swing_high = False
break
if is_swing_high:
new_pivot = SwingPoint(
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
price=current_price,
index=i,
swing_type=SwingType.SWING_HIGH,
strength=strength, # Strength here is derived from level, e.g., min(2 + level, 5)
volume=0.0 # Pivot points don't have volume
)
identified_swings_in_this_call.append(new_pivot)
self._handle_cnn_at_pivot(new_pivot, pivot_array) # CNN logic call
# Check for swing low (pivot low surrounded by higher pivot lows)
is_swing_low = True
for j in range(i - strength, i + strength + 1):
if j != i and pivot_array[j, 1] <= current_price: # Compare with price of other pivots
is_swing_low = False
break
if is_swing_low:
new_pivot = SwingPoint(
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
price=current_price,
index=i,
swing_type=SwingType.SWING_LOW,
strength=strength, # Strength here is derived from level
volume=0.0 # Pivot points don't have volume
)
identified_swings_in_this_call.append(new_pivot)
self._handle_cnn_at_pivot(new_pivot, pivot_array) # CNN logic call
return identified_swings_in_this_call # Return swings found in this call
def _convert_pivots_to_price_points(self, swing_points: List[SwingPoint]) -> np.ndarray:
"""
Convert swing points to "price bar" format for next recursive level calculation
**RECURSIVE CONVERSION PROCESS:**
Each swing point from Level N becomes a "price bar" for Level N+1 calculation:
- Timestamp = swing point timestamp
- Open = High = Low = Close = swing point price (since it's a single point)
- Volume = 0 (not applicable for swing points)
This allows Level N+1 to treat Level N swing points as if they were regular
OHLCV price bars, enabling the same swing detection algorithm to find
higher-level patterns in the swing point sequences.
Example:
- Level 0: 1000 x 1s bars → 50 swing points
- Level 1: 50 "price bars" (from Level 0 swings) → 10 swing points
- Level 2: 10 "price bars" (from Level 1 swings) → 3 swing points
"""
if len(swing_points) < 2:
return np.array([])
price_points = []
for swing in swing_points:
# Each pivot point becomes a price point where OHLC = pivot price
price_points.append([
swing.timestamp.timestamp(),
swing.price, # Open = pivot price
swing.price, # High = pivot price
swing.price, # Low = pivot price
swing.price, # Close = pivot price
0.0 # Volume = 0 (not applicable for pivot points)
])
return np.array(price_points)
def _create_empty_structure(self) -> Dict[str, MarketStructureLevel]:
"""Create empty structure when insufficient data"""
return {f'level_{i}': self._create_empty_level(i) for i in range(self.max_levels)}
def _create_empty_level(self, level: int) -> MarketStructureLevel:
"""Create empty market structure level"""
return MarketStructureLevel(
level=level,
swing_points=[],
trend_analysis=TrendAnalysis(
direction=TrendDirection.UNKNOWN,
strength=0.0,
confidence=0.0,
swing_count=0,
last_swing_high=None,
last_swing_low=None,
higher_highs=0,
higher_lows=0,
lower_highs=0,
lower_lows=0
),
support_levels=[],
resistance_levels=[],
current_bias=TrendDirection.UNKNOWN,
structure_breaks=[]
)
def extract_features_for_rl(self, structure_levels: Dict[str, MarketStructureLevel]) -> List[float]:
"""
Extract features from Williams structure for RL training
Returns ~250 features total:
- 50 features per level (5 levels)
"""
features = []
for level in range(self.max_levels):
level_key = f'level_{level}'
if level_key in structure_levels:
level_data = structure_levels[level_key]
level_features = self._extract_level_features(level_data)
else:
level_features = [0.0] * 50 # 50 features per level
features.extend(level_features)
return features[:250] # Ensure exactly 250 features
def _extract_level_features(self, level: MarketStructureLevel) -> List[float]:
"""Extract features from a single structure level"""
features = []
# Trend features (10 features)
features.extend([
1.0 if level.trend_analysis.direction == TrendDirection.UP else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.DOWN else 0.0,
1.0 if level.trend_analysis.direction == TrendDirection.SIDEWAYS else 0.0,
level.trend_analysis.strength,
level.trend_analysis.confidence,
level.trend_analysis.higher_highs,
level.trend_analysis.higher_lows,
level.trend_analysis.lower_highs,
level.trend_analysis.lower_lows,
len(level.swing_points)
])
# Current bias features (4 features)
features.extend([
1.0 if level.current_bias == TrendDirection.UP else 0.0,
1.0 if level.current_bias == TrendDirection.DOWN else 0.0,
1.0 if level.current_bias == TrendDirection.SIDEWAYS else 0.0,
1.0 if level.current_bias == TrendDirection.UNKNOWN else 0.0
])
# Swing point features (20 features - last 10 swings * 2 features each)
recent_swings = level.swing_points[-10:] if len(level.swing_points) >= 10 else level.swing_points
for swing in recent_swings:
features.extend([
swing.price,
1.0 if swing.swing_type == SwingType.SWING_HIGH else 0.0
])
# Pad if fewer than 10 swings
while len(recent_swings) < 10:
features.extend([0.0, 0.0])
recent_swings.append(None) # Just for counting
# Support/resistance levels (10 features - 5 support + 5 resistance)
support_levels = level.support_levels[:5] if len(level.support_levels) >= 5 else level.support_levels
while len(support_levels) < 5:
support_levels.append(0.0)
features.extend(support_levels)
resistance_levels = level.resistance_levels[:5] if len(level.resistance_levels) >= 5 else level.resistance_levels
while len(resistance_levels) < 5:
resistance_levels.append(0.0)
features.extend(resistance_levels)
# Structure break features (6 features)
recent_breaks = level.structure_breaks[-3:] if len(level.structure_breaks) >= 3 else level.structure_breaks
for break_info in recent_breaks:
features.extend([
break_info.get('significance', 0.0),
1.0 if break_info.get('type', '').endswith('_high') else 0.0
])
# Pad if fewer than 3 breaks
while len(recent_breaks) < 3:
features.extend([0.0, 0.0])
recent_breaks.append({})
return features[:50] # Ensure exactly 50 features per level
def _handle_cnn_at_pivot(self,
newly_identified_pivot: SwingPoint,
ohlcv_data_context: np.ndarray):
"""
Handles CNN training for the previous pivot and prediction for the next pivot.
Called when a new pivot point is identified.
Args:
newly_identified_pivot: The SwingPoint object for the just-formed pivot.
ohlcv_data_context: The OHLCV data (or pivot array for higher levels)
relevant to this pivot's formation.
"""
if not self.enable_cnn_feature or self.cnn_model is None:
return
# 1. Train model based on the *previous* pivot's prediction and the *current* actual outcome
if self.previous_pivot_details_for_cnn:
try:
logger.debug(f"CNN Training: Previous pivot at idx {self.previous_pivot_details_for_cnn['pivot'].index}, "
f"Current pivot (ground truth) at idx {newly_identified_pivot.index}")
X_train = self.previous_pivot_details_for_cnn['features']
# previous_pivot_info contains 'pivot' which is the SwingPoint object of N-1
y_train = self._get_cnn_ground_truth(self.previous_pivot_details_for_cnn, newly_identified_pivot)
if X_train is not None and X_train.size > 0 and y_train is not None and y_train.size > 0:
# Reshape X_train if it's a single sample and model expects batch
if len(X_train.shape) == len(self.cnn_model.input_shape) and X_train.shape == self.cnn_model.input_shape :
X_train_batch = np.expand_dims(X_train, axis=0)
else: # Should already be correctly shaped by _prepare_cnn_input
X_train_batch = X_train # Or handle error
# Reshape y_train if needed
if self.cnn_model.output_size > 1 and len(y_train.shape) ==1: # e.g. [0.,1.] but model needs [[0.,1.]]
y_train_batch = np.expand_dims(y_train, axis=0)
elif self.cnn_model.output_size == 1 and not isinstance(y_train, (list, np.ndarray)): # e.g. plain 0 or 1
y_train_batch = np.array([[y_train]], dtype=np.float32)
elif self.cnn_model.output_size == 1 and isinstance(y_train, np.ndarray) and y_train.ndim == 1:
y_train_batch = y_train.reshape(-1,1) # ensure [[0.]] for single binary output
else:
y_train_batch = y_train
logger.info(f"CNN Training with X_shape: {X_train_batch.shape}, y_shape: {y_train_batch.shape}")
# Perform a single step of training (online learning)
# Use minimal callbacks for online learning, or allow configuration
self.cnn_model.model.fit(X_train_batch, y_train_batch, batch_size=1, epochs=1, verbose=0, callbacks=[])
logger.info(f"CNN online training step completed for pivot at index {self.previous_pivot_details_for_cnn['pivot'].index}.")
else:
logger.warning("CNN Training: Skipping due to invalid X_train or y_train.")
except Exception as e:
logger.error(f"Error during CNN online training: {e}", exc_info=True)
# 2. Predict for the *next* pivot based on the *current* newly_identified_pivot
try:
logger.debug(f"CNN Prediction: Preparing input for current pivot at idx {newly_identified_pivot.index}")
# The 'previous_pivot_details' for _prepare_cnn_input here is the one active *before* this current call
# which means it refers to the pivot that just got its ground truth trained on.
# If this is the first pivot ever, self.previous_pivot_details_for_cnn would be None.
# Correct context for _prepare_cnn_input:
# current_pivot = newly_identified_pivot
# previous_pivot_details = self.previous_pivot_details_for_cnn (this is N-1, which was used for training above)
X_predict = self._prepare_cnn_input(newly_identified_pivot,
ohlcv_data_context,
self.previous_pivot_details_for_cnn) # Pass the N-1 pivot details
if X_predict is not None and X_predict.size > 0:
# Reshape X_predict if it's a single sample and model expects batch
if len(X_predict.shape) == len(self.cnn_model.input_shape) and X_predict.shape == self.cnn_model.input_shape :
X_predict_batch = np.expand_dims(X_predict, axis=0)
else:
X_predict_batch = X_predict # Or handle error
logger.info(f"CNN Predicting with X_shape: {X_predict_batch.shape}")
pred_class, pred_proba = self.cnn_model.predict(X_predict_batch) # predict expects batch
# pred_class/pred_proba might be arrays if batch_size > 1, or if output is multi-dim
# For batch_size=1, take the first element
final_pred_class = pred_class[0] if isinstance(pred_class, np.ndarray) and pred_class.ndim > 0 else pred_class
final_pred_proba = pred_proba[0] if isinstance(pred_proba, np.ndarray) and pred_proba.ndim > 0 else pred_proba
logger.info(f"CNN Prediction for pivot after index {newly_identified_pivot.index}: Class={final_pred_class}, Proba/Val={final_pred_proba}")
# Store the features (X_predict) and the pivot (newly_identified_pivot) itself for the next training cycle
self.previous_pivot_details_for_cnn = {'features': X_predict, 'pivot': newly_identified_pivot}
else:
logger.warning("CNN Prediction: Skipping due to invalid X_predict.")
# If prediction can't be made, ensure we don't carry over stale 'previous_pivot_details_for_cnn'
# Or, decide if we should clear it or keep the N-2 details.
# For now, if X_predict is None, we clear it so no training happens next round unless a new pred is made.
self.previous_pivot_details_for_cnn = None
except Exception as e:
logger.error(f"Error during CNN prediction: {e}", exc_info=True)
self.previous_pivot_details_for_cnn = None # Clear on error to prevent bad training
def _prepare_cnn_input(self,
current_pivot: SwingPoint,
ohlcv_data_context: np.ndarray,
previous_pivot_details: Optional[Dict[str, Any]]) -> np.ndarray:
"""
Prepare multi-timeframe, multi-symbol input features for CNN using TrainingDataPacket.
Features include:
- ETH: 5 min ticks → 300 x 1s bars with ticks features (4 features)
- ETH: 900 x 1s OHLCV + indicators (10 features)
- ETH: 900 x 1m OHLCV + indicators (10 features)
- ETH: 900 x 1h OHLCV + indicators (10 features)
- ETH: All pivot points from all levels (15 features)
- BTC: 5 min ticks → 300 x 1s reference (4 features)
- Chart labels for data identification (7 features)
Total: ~50 features per timestep over 900 timesteps
Data normalized using 1h min/max to preserve cross-timeframe relationships.
Args:
current_pivot: The newly identified SwingPoint
ohlcv_data_context: The OHLCV data from Williams calculation (may not be used directly)
previous_pivot_details: Previous pivot info for context
Returns:
A numpy array of shape (900, 50) with normalized features
"""
if self.cnn_model is None or not self.training_data_provider:
logger.warning("CNN model or training data provider not available")
return np.zeros(self.cnn_model.input_shape if self.cnn_model else (900, 50), dtype=np.float32)
sequence_length, num_features = self.cnn_model.input_shape
try:
# Get latest TrainingDataPacket from provider
training_packet = self._get_latest_training_data()
if not training_packet:
logger.warning("No TrainingDataPacket available for CNN input")
return np.zeros((sequence_length, num_features), dtype=np.float32)
logger.debug(f"CNN Input: Preparing features for pivot at {current_pivot.timestamp}")
# Prepare feature components (in actual values)
eth_features = self._prepare_eth_features(training_packet, sequence_length)
btc_features = self._prepare_btc_reference_features(training_packet, sequence_length)
pivot_features = self._prepare_pivot_features(training_packet, current_pivot, sequence_length)
chart_labels = self._prepare_chart_labels(sequence_length)
# Combine all features (still in actual values)
combined_features = np.concatenate([
eth_features, # ~40 features
btc_features, # ~4 features
pivot_features, # ~3 features
chart_labels # ~3 features
], axis=1)
# Ensure we match expected feature count
if combined_features.shape[1] > num_features:
combined_features = combined_features[:, :num_features]
elif combined_features.shape[1] < num_features:
padding = np.zeros((sequence_length, num_features - combined_features.shape[1]))
combined_features = np.concatenate([combined_features, padding], axis=1)
# NORMALIZATION: Apply 1h timeframe min/max to preserve relationships
normalized_features = self._normalize_features_by_1h_range(combined_features, training_packet)
logger.debug(f"CNN Input prepared: shape {normalized_features.shape}, "
f"min: {normalized_features.min():.4f}, max: {normalized_features.max():.4f}")
return normalized_features.astype(np.float32)
except Exception as e:
logger.error(f"Error preparing CNN input: {e}", exc_info=True)
return np.zeros((sequence_length, num_features), dtype=np.float32)
def _get_latest_training_data(self):
"""Get latest TrainingDataPacket from provider"""
try:
if hasattr(self.training_data_provider, 'get_latest_training_data'):
return self.training_data_provider.get_latest_training_data()
elif hasattr(self.training_data_provider, 'training_data_buffer'):
return self.training_data_provider.training_data_buffer[-1] if self.training_data_provider.training_data_buffer else None
else:
logger.warning("Training data provider does not have expected interface")
return None
except Exception as e:
logger.error(f"Error getting training data: {e}")
return None
def _prepare_eth_features(self, training_packet, sequence_length: int) -> np.ndarray:
"""
Prepare ETH multi-timeframe features (keep in actual values):
- 1s bars with indicators (10 features)
- 1m bars with indicators (10 features)
- 1h bars with indicators (10 features)
- Tick-derived 1s features (10 features)
Total: 40 features per timestep
"""
features = []
# ETH 1s data with indicators
eth_1s_features = self._extract_timeframe_features(
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1s', []),
sequence_length, 'ETH_1s'
)
features.append(eth_1s_features)
# ETH 1m data with indicators
eth_1m_features = self._extract_timeframe_features(
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1m', []),
sequence_length, 'ETH_1m'
)
features.append(eth_1m_features)
# ETH 1h data with indicators
eth_1h_features = self._extract_timeframe_features(
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', []),
sequence_length, 'ETH_1h'
)
features.append(eth_1h_features)
# ETH tick-derived features (5 min of ticks → 300 x 1s aggregated to match sequence_length)
eth_tick_features = self._extract_tick_features(
training_packet.tick_cache, 'ETH/USDT', sequence_length
)
features.append(eth_tick_features)
return np.concatenate(features, axis=1)
def _prepare_btc_reference_features(self, training_packet, sequence_length: int) -> np.ndarray:
"""
Prepare BTC reference features (keep in actual values):
- Tick-derived features for correlation analysis
Total: 4 features per timestep
"""
return self._extract_tick_features(
training_packet.tick_cache, 'BTC/USDT', sequence_length
)
def _prepare_pivot_features(self, training_packet, current_pivot: SwingPoint, sequence_length: int) -> np.ndarray:
"""
Prepare pivot point features from all Williams levels:
- Recent pivot characteristics
- Level-specific trend information
Total: 3 features per timestep (repeated for sequence)
"""
# Extract Williams pivot features using existing method if available
if hasattr(training_packet, 'universal_stream') and training_packet.universal_stream:
# Use existing pivot extraction logic
pivot_feature_vector = [
current_pivot.price,
1.0 if current_pivot.swing_type == SwingType.SWING_HIGH else 0.0,
float(current_pivot.strength)
]
else:
pivot_feature_vector = [0.0, 0.0, 0.0]
# Repeat pivot features for all timesteps in sequence
return np.tile(pivot_feature_vector, (sequence_length, 1))
def _prepare_chart_labels(self, sequence_length: int) -> np.ndarray:
"""
Prepare chart identification labels:
- Symbol identifiers
- Timeframe identifiers
Total: 3 features per timestep
"""
# Simple encoding: [is_eth, is_btc, timeframe_mix]
chart_labels = [1.0, 1.0, 1.0] # Mixed multi-timeframe ETH+BTC data
return np.tile(chart_labels, (sequence_length, 1))
def _extract_timeframe_features(self, ohlcv_data: List[Dict], sequence_length: int, timeframe_label: str) -> np.ndarray:
"""
Extract OHLCV + indicator features from timeframe data (keep actual values).
Returns 10 features: OHLCV + volume + 5 indicators
"""
if not ohlcv_data:
return np.zeros((sequence_length, 10))
# Take last sequence_length bars or pad if insufficient
data_to_use = ohlcv_data[-sequence_length:] if len(ohlcv_data) >= sequence_length else ohlcv_data
features = []
for bar in data_to_use:
bar_features = [
bar.get('open', 0.0),
bar.get('high', 0.0),
bar.get('low', 0.0),
bar.get('close', 0.0),
bar.get('volume', 0.0),
# TODO: Add 5 calculated indicators (SMA, EMA, RSI, MACD, etc.)
bar.get('sma_20', bar.get('close', 0.0)), # Placeholder
bar.get('ema_20', bar.get('close', 0.0)), # Placeholder
bar.get('rsi_14', 50.0), # Placeholder
bar.get('macd', 0.0), # Placeholder
bar.get('bb_upper', bar.get('high', 0.0)) # Placeholder
]
features.append(bar_features)
# Pad if insufficient data
while len(features) < sequence_length:
features.insert(0, features[0] if features else [0.0] * 10)
return np.array(features, dtype=np.float32)
def _extract_tick_features(self, tick_cache: List[Dict], symbol: str, sequence_length: int) -> np.ndarray:
"""
Extract tick-derived features aggregated to 1s intervals (keep actual values).
Returns 4 features: tick_count, total_volume, vwap, price_volatility per second
"""
# Filter ticks for symbol and last 5 minutes
symbol_ticks = [t for t in tick_cache[-1500:] if t.get('symbol') == symbol] # Assume ~5 ticks/sec
if not symbol_ticks:
return np.zeros((sequence_length, 4))
# Group ticks by second and calculate features
tick_features = []
current_time = datetime.now()
for i in range(sequence_length):
second_start = current_time - timedelta(seconds=sequence_length - i)
second_end = second_start + timedelta(seconds=1)
second_ticks = [
t for t in symbol_ticks
if second_start <= t.get('timestamp', datetime.min) < second_end
]
if second_ticks:
prices = [t.get('price', 0.0) for t in second_ticks]
volumes = [t.get('volume', 0.0) for t in second_ticks]
total_volume = sum(volumes)
tick_count = len(second_ticks)
vwap = sum(p * v for p, v in zip(prices, volumes)) / total_volume if total_volume > 0 else 0.0
price_volatility = np.std(prices) if len(prices) > 1 else 0.0
second_features = [tick_count, total_volume, vwap, price_volatility]
else:
second_features = [0.0, 0.0, 0.0, 0.0]
tick_features.append(second_features)
return np.array(tick_features, dtype=np.float32)
def _normalize_features_by_1h_range(self, features: np.ndarray, training_packet) -> np.ndarray:
"""
Normalize all features using 1h timeframe min/max to preserve cross-timeframe relationships.
This is the final normalization step before feeding to CNN.
"""
try:
# Get 1h ETH data for normalization reference
eth_1h_data = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
if not eth_1h_data:
logger.warning("No 1h data available for normalization, using feature-wise normalization")
# Fallback: normalize each feature independently
feature_min = np.min(features, axis=0, keepdims=True)
feature_max = np.max(features, axis=0, keepdims=True)
feature_range = feature_max - feature_min
feature_range[feature_range == 0] = 1.0 # Avoid division by zero
return (features - feature_min) / feature_range
# Extract 1h price range for primary normalization
h1_prices = []
for bar in eth_1h_data[-24:]: # Last 24 hours for robust range
h1_prices.extend([
bar.get('open', 0.0),
bar.get('high', 0.0),
bar.get('low', 0.0),
bar.get('close', 0.0)
])
if h1_prices:
h1_min = min(h1_prices)
h1_max = max(h1_prices)
h1_range = h1_max - h1_min
if h1_range > 0:
logger.debug(f"Normalizing features using 1h range: {h1_min:.2f} - {h1_max:.2f}")
# Apply 1h-based normalization to price-related features (first ~30 features)
normalized_features = features.copy()
price_feature_count = min(30, features.shape[1])
# Normalize price-related features with 1h range
normalized_features[:, :price_feature_count] = (
(features[:, :price_feature_count] - h1_min) / h1_range
)
# For non-price features (indicators, counts, etc.), use feature-wise normalization
if features.shape[1] > price_feature_count:
remaining_features = features[:, price_feature_count:]
feature_min = np.min(remaining_features, axis=0, keepdims=True)
feature_max = np.max(remaining_features, axis=0, keepdims=True)
feature_range = feature_max - feature_min
feature_range[feature_range == 0] = 1.0
normalized_features[:, price_feature_count:] = (
(remaining_features - feature_min) / feature_range
)
return normalized_features
# Fallback normalization if 1h range calculation fails
logger.warning("1h range calculation failed, using min-max normalization")
feature_min = np.min(features, axis=0, keepdims=True)
feature_max = np.max(features, axis=0, keepdims=True)
feature_range = feature_max - feature_min
feature_range[feature_range == 0] = 1.0
return (features - feature_min) / feature_range
except Exception as e:
logger.error(f"Error in normalization: {e}", exc_info=True)
# Emergency fallback: return features as-is but scaled to [0,1] roughly
return np.clip(features / (np.max(np.abs(features)) + 1e-8), -1.0, 1.0)
def _get_cnn_ground_truth(self,
previous_pivot_info: Dict[str, Any], # Contains 'pivot': SwingPoint obj of N-1
actual_current_pivot: SwingPoint # This is pivot N
) -> np.ndarray:
"""
Determine the ground truth for CNN prediction made at previous_pivot.
Updated to return prediction for next pivot in ALL 5 LEVELS:
- For each level: [type (0=LOW, 1=HIGH), normalized_price_target]
- Total output: 10 values (5 levels * 2 outputs each)
Args:
previous_pivot_info: Dict with 'pivot' = SwingPoint of N-1
actual_current_pivot: SwingPoint of pivot N (actual outcome)
Returns:
A numpy array of shape (10,) with ground truth for all levels
"""
if self.cnn_model is None:
return np.array([])
# Initialize ground truth array for all 5 levels
ground_truth = np.zeros(10, dtype=np.float32) # 5 levels * 2 outputs
try:
# For Level 0 (current pivot level), we have actual data
level_0_type = 1.0 if actual_current_pivot.swing_type == SwingType.SWING_HIGH else 0.0
level_0_price = actual_current_pivot.price
# Normalize price (this is a placeholder - proper normalization should use market context)
# In real implementation, use the same 1h range normalization as input features
normalized_price = level_0_price / 10000.0 # Rough normalization for ETH prices
ground_truth[0] = level_0_type # Level 0 type
ground_truth[1] = normalized_price # Level 0 price
# For higher levels (1-4), we would need to calculate what the next pivot would be
# This requires access to higher-level Williams calculations
# For now, use placeholder logic based on current pivot characteristics
for level in range(1, 5):
# Placeholder: higher levels follow similar pattern but with reduced confidence
confidence_factor = 1.0 / (level + 1)
ground_truth[level * 2] = level_0_type * confidence_factor # Level N type
ground_truth[level * 2 + 1] = normalized_price * confidence_factor # Level N price
logger.debug(f"CNN Ground Truth: Level 0 = [{level_0_type}, {normalized_price:.4f}], "
f"Current pivot = {actual_current_pivot.swing_type.name} @ {actual_current_pivot.price}")
return ground_truth
except Exception as e:
logger.error(f"Error calculating CNN ground truth: {e}", exc_info=True)
return np.zeros(10, dtype=np.float32)