integrationg COB
This commit is contained in:
@ -1387,4 +1387,246 @@ class WilliamsMarketStructure:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating CNN ground truth: {e}", exc_info=True)
|
||||
return np.zeros(10, dtype=np.float32)
|
||||
return np.zeros(10, dtype=np.float32)
|
||||
|
||||
def extract_pivot_features(df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract pivot-based features for RL state building
|
||||
|
||||
Args:
|
||||
df: Market data DataFrame with OHLCV columns
|
||||
|
||||
Returns:
|
||||
numpy array with pivot features (1000 features)
|
||||
"""
|
||||
try:
|
||||
if df is None or df.empty or len(df) < 50:
|
||||
return None
|
||||
|
||||
features = []
|
||||
|
||||
# === PIVOT DETECTION FEATURES (200) ===
|
||||
highs = df['high'].values
|
||||
lows = df['low'].values
|
||||
closes = df['close'].values
|
||||
|
||||
# Find pivot highs and lows
|
||||
pivot_high_indices = []
|
||||
pivot_low_indices = []
|
||||
window = 5
|
||||
|
||||
for i in range(window, len(highs) - window):
|
||||
# Pivot high: current high is higher than surrounding highs
|
||||
if all(highs[i] > highs[j] for j in range(i-window, i)) and \
|
||||
all(highs[i] > highs[j] for j in range(i+1, i+window+1)):
|
||||
pivot_high_indices.append(i)
|
||||
|
||||
# Pivot low: current low is lower than surrounding lows
|
||||
if all(lows[i] < lows[j] for j in range(i-window, i)) and \
|
||||
all(lows[i] < lows[j] for j in range(i+1, i+window+1)):
|
||||
pivot_low_indices.append(i)
|
||||
|
||||
# Pivot high features (100 features)
|
||||
if pivot_high_indices:
|
||||
recent_pivot_highs = [highs[i] for i in pivot_high_indices[-100:]]
|
||||
features.extend(recent_pivot_highs)
|
||||
features.extend([0.0] * max(0, 100 - len(recent_pivot_highs)))
|
||||
else:
|
||||
features.extend([0.0] * 100)
|
||||
|
||||
# Pivot low features (100 features)
|
||||
if pivot_low_indices:
|
||||
recent_pivot_lows = [lows[i] for i in pivot_low_indices[-100:]]
|
||||
features.extend(recent_pivot_lows)
|
||||
features.extend([0.0] * max(0, 100 - len(recent_pivot_lows)))
|
||||
else:
|
||||
features.extend([0.0] * 100)
|
||||
|
||||
# === PIVOT DISTANCE FEATURES (200) ===
|
||||
current_price = closes[-1]
|
||||
|
||||
# Distance to nearest pivot highs (100 features)
|
||||
if pivot_high_indices:
|
||||
distances_to_highs = [(current_price - highs[i]) / current_price for i in pivot_high_indices[-100:]]
|
||||
features.extend(distances_to_highs)
|
||||
features.extend([0.0] * max(0, 100 - len(distances_to_highs)))
|
||||
else:
|
||||
features.extend([0.0] * 100)
|
||||
|
||||
# Distance to nearest pivot lows (100 features)
|
||||
if pivot_low_indices:
|
||||
distances_to_lows = [(current_price - lows[i]) / current_price for i in pivot_low_indices[-100:]]
|
||||
features.extend(distances_to_lows)
|
||||
features.extend([0.0] * max(0, 100 - len(distances_to_lows)))
|
||||
else:
|
||||
features.extend([0.0] * 100)
|
||||
|
||||
# === MARKET STRUCTURE FEATURES (200) ===
|
||||
# Higher highs and higher lows detection
|
||||
structure_features = []
|
||||
|
||||
if len(pivot_high_indices) >= 2:
|
||||
# Recent pivot high trend
|
||||
recent_highs = [highs[i] for i in pivot_high_indices[-5:]]
|
||||
high_trend = 1.0 if len(recent_highs) >= 2 and recent_highs[-1] > recent_highs[-2] else -1.0
|
||||
structure_features.append(high_trend)
|
||||
else:
|
||||
structure_features.append(0.0)
|
||||
|
||||
if len(pivot_low_indices) >= 2:
|
||||
# Recent pivot low trend
|
||||
recent_lows = [lows[i] for i in pivot_low_indices[-5:]]
|
||||
low_trend = 1.0 if len(recent_lows) >= 2 and recent_lows[-1] > recent_lows[-2] else -1.0
|
||||
structure_features.append(low_trend)
|
||||
else:
|
||||
structure_features.append(0.0)
|
||||
|
||||
# Swing strength
|
||||
if pivot_high_indices and pivot_low_indices:
|
||||
last_high = highs[pivot_high_indices[-1]] if pivot_high_indices else current_price
|
||||
last_low = lows[pivot_low_indices[-1]] if pivot_low_indices else current_price
|
||||
swing_range = (last_high - last_low) / current_price if current_price > 0 else 0
|
||||
structure_features.append(swing_range)
|
||||
else:
|
||||
structure_features.append(0.0)
|
||||
|
||||
# Pad structure features to 200
|
||||
features.extend(structure_features)
|
||||
features.extend([0.0] * (200 - len(structure_features)))
|
||||
|
||||
# === TREND AND MOMENTUM FEATURES (400) ===
|
||||
# Moving averages
|
||||
if len(closes) >= 50:
|
||||
sma_20 = np.mean(closes[-20:])
|
||||
sma_50 = np.mean(closes[-50:])
|
||||
features.extend([sma_20, sma_50, current_price - sma_20, current_price - sma_50])
|
||||
else:
|
||||
features.extend([0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Price momentum over different periods
|
||||
momentum_periods = [5, 10, 20, 30, 50]
|
||||
for period in momentum_periods:
|
||||
if len(closes) > period:
|
||||
momentum = (closes[-1] - closes[-period-1]) / closes[-period-1]
|
||||
features.append(momentum)
|
||||
else:
|
||||
features.append(0.0)
|
||||
|
||||
# Volume analysis
|
||||
if 'volume' in df.columns and len(df['volume']) > 20:
|
||||
volume_sma = np.mean(df['volume'].values[-20:])
|
||||
current_volume = df['volume'].values[-1]
|
||||
volume_ratio = current_volume / volume_sma if volume_sma > 0 else 1.0
|
||||
features.append(volume_ratio)
|
||||
else:
|
||||
features.append(1.0)
|
||||
|
||||
# Volatility features
|
||||
if len(closes) > 20:
|
||||
returns = np.diff(np.log(closes[-20:]))
|
||||
volatility = np.std(returns) * np.sqrt(1440) # Daily volatility
|
||||
features.append(volatility)
|
||||
else:
|
||||
features.append(0.02) # Default volatility
|
||||
|
||||
# Pad to 400 features
|
||||
while len(features) < 800:
|
||||
features.append(0.0)
|
||||
|
||||
# Ensure exactly 1000 features
|
||||
features = features[:1000]
|
||||
while len(features) < 1000:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting pivot features: {e}")
|
||||
return None
|
||||
|
||||
def analyze_pivot_context(market_data: Dict, trade_timestamp: datetime, trade_action: str) -> Optional[Dict]:
|
||||
"""
|
||||
Analyze pivot context around a specific trade for reward calculation
|
||||
|
||||
Args:
|
||||
market_data: Market data context
|
||||
trade_timestamp: When the trade was made
|
||||
trade_action: BUY/SELL action
|
||||
|
||||
Returns:
|
||||
Dictionary with pivot analysis results
|
||||
"""
|
||||
try:
|
||||
# Extract price data if available
|
||||
if 'ohlcv_data' not in market_data:
|
||||
return None
|
||||
|
||||
df = market_data['ohlcv_data']
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Find recent pivot points
|
||||
highs = df['high'].values
|
||||
lows = df['low'].values
|
||||
closes = df['close'].values
|
||||
|
||||
if len(closes) < 20:
|
||||
return None
|
||||
|
||||
current_price = closes[-1]
|
||||
|
||||
# Find pivot points
|
||||
pivot_highs = []
|
||||
pivot_lows = []
|
||||
window = 3
|
||||
|
||||
for i in range(window, len(highs) - window):
|
||||
# Pivot high
|
||||
if all(highs[i] >= highs[j] for j in range(i-window, i)) and \
|
||||
all(highs[i] >= highs[j] for j in range(i+1, i+window+1)):
|
||||
pivot_highs.append((i, highs[i]))
|
||||
|
||||
# Pivot low
|
||||
if all(lows[i] <= lows[j] for j in range(i-window, i)) and \
|
||||
all(lows[i] <= lows[j] for j in range(i+1, i+window+1)):
|
||||
pivot_lows.append((i, lows[i]))
|
||||
|
||||
analysis = {
|
||||
'near_pivot': False,
|
||||
'pivot_strength': 0.0,
|
||||
'pivot_break_direction': None,
|
||||
'against_pivot_structure': False
|
||||
}
|
||||
|
||||
# Check if near significant pivot
|
||||
pivot_threshold = current_price * 0.005 # 0.5% threshold
|
||||
|
||||
for idx, price in pivot_highs[-5:]: # Check last 5 pivot highs
|
||||
if abs(current_price - price) < pivot_threshold:
|
||||
analysis['near_pivot'] = True
|
||||
analysis['pivot_strength'] = min(1.0, (current_price - price) / pivot_threshold)
|
||||
|
||||
# Check for breakout
|
||||
if current_price > price * 1.001: # 0.1% breakout
|
||||
analysis['pivot_break_direction'] = 'up'
|
||||
elif trade_action == 'SELL' and current_price < price:
|
||||
analysis['against_pivot_structure'] = True
|
||||
break
|
||||
|
||||
for idx, price in pivot_lows[-5:]: # Check last 5 pivot lows
|
||||
if abs(current_price - price) < pivot_threshold:
|
||||
analysis['near_pivot'] = True
|
||||
analysis['pivot_strength'] = min(1.0, (price - current_price) / pivot_threshold)
|
||||
|
||||
# Check for breakout
|
||||
if current_price < price * 0.999: # 0.1% breakdown
|
||||
analysis['pivot_break_direction'] = 'down'
|
||||
elif trade_action == 'BUY' and current_price > price:
|
||||
analysis['against_pivot_structure'] = True
|
||||
break
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing pivot context: {e}")
|
||||
return None
|
Reference in New Issue
Block a user