gogo2/test_enhanced_williams_cnn.py
Dobromir Popov 75dbac1761 tter pivots
2025-05-30 03:03:51 +03:00

346 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Test Enhanced Williams Market Structure with CNN Integration
This script demonstrates the multi-timeframe, multi-symbol CNN-enabled
Williams Market Structure that predicts pivot points using TrainingDataPacket.
Features tested:
- Multi-timeframe data integration (1s, 1m, 1h)
- Multi-symbol support (ETH, BTC)
- Tick data aggregation
- 1h-based normalization strategy
- Multi-level pivot prediction (5 levels, type + price)
"""
import numpy as np
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Mock TrainingDataPacket for testing
@dataclass
class MockTrainingDataPacket:
"""Mock TrainingDataPacket for testing CNN integration"""
timestamp: datetime
symbol: str
tick_cache: List[Dict[str, Any]]
one_second_bars: List[Dict[str, Any]]
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
cnn_features: Optional[Dict[str, np.ndarray]] = None
cnn_predictions: Optional[Dict[str, np.ndarray]] = None
market_state: Optional[Any] = None
universal_stream: Optional[Any] = None
class MockTrainingDataProvider:
"""Mock provider that supplies TrainingDataPacket for testing"""
def __init__(self):
self.training_data_buffer = []
self._generate_mock_data()
def _generate_mock_data(self):
"""Generate comprehensive mock market data"""
current_time = datetime.now()
# Generate ETH data for different timeframes
eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data
eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data
eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data
# Generate BTC data
btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data
# Generate tick data
tick_data = self._generate_tick_data(current_time)
# Create comprehensive TrainingDataPacket
training_packet = MockTrainingDataPacket(
timestamp=current_time,
symbol='ETH/USDT',
tick_cache=tick_data,
one_second_bars=eth_1s_data[-300:], # Last 5 minutes
multi_timeframe_data={
'ETH/USDT': {
'1s': eth_1s_data,
'1m': eth_1m_data,
'1h': eth_1h_data
},
'BTC/USDT': {
'1s': btc_1s_data
}
}
)
self.training_data_buffer.append(training_packet)
logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars")
logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}")
def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]:
"""Generate realistic OHLCV data with indicators"""
data = []
# Calculate time delta based on timeframe
if timeframe == '1s':
delta = timedelta(seconds=1)
elif timeframe == '1m':
delta = timedelta(minutes=1)
elif timeframe == '1h':
delta = timedelta(hours=1)
else:
delta = timedelta(minutes=1)
current_price = base_price
for i in range(count):
timestamp = end_time - delta * (count - i - 1)
# Generate realistic price movement
price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility
current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base
# Generate OHLCV
open_price = current_price
high_price = open_price * (1 + abs(np.random.normal(0, 0.002)))
low_price = open_price * (1 - abs(np.random.normal(0, 0.002)))
close_price = low_price + (high_price - low_price) * np.random.random()
volume = np.random.exponential(1000)
current_price = close_price
# Calculate basic indicators (placeholders)
sma_20 = close_price * (1 + np.random.normal(0, 0.001))
ema_20 = close_price * (1 + np.random.normal(0, 0.0005))
rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70
macd = np.random.normal(0, 0.1)
bb_upper = high_price * 1.02
bar = {
'timestamp': timestamp,
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume,
'sma_20': sma_20,
'ema_20': ema_20,
'rsi_14': rsi_14,
'macd': macd,
'bb_upper': bb_upper
}
data.append(bar)
return data
def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]:
"""Generate realistic tick data for last 5 minutes"""
ticks = []
# Generate ETH ticks
for i in range(300): # 5 minutes * 60 seconds
tick_time = end_time - timedelta(seconds=300 - i)
# 2-5 ticks per second
ticks_per_second = np.random.randint(2, 6)
for j in range(ticks_per_second):
tick = {
'symbol': 'ETH/USDT',
'timestamp': tick_time + timedelta(milliseconds=j * 200),
'price': 2400.0 + np.random.normal(0, 5),
'volume': np.random.exponential(0.5),
'quantity': np.random.exponential(1.0),
'side': 'buy' if np.random.random() > 0.5 else 'sell'
}
ticks.append(tick)
# Generate BTC ticks
for i in range(300): # 5 minutes * 60 seconds
tick_time = end_time - timedelta(seconds=300 - i)
ticks_per_second = np.random.randint(1, 4)
for j in range(ticks_per_second):
tick = {
'symbol': 'BTC/USDT',
'timestamp': tick_time + timedelta(milliseconds=j * 300),
'price': 45000.0 + np.random.normal(0, 100),
'volume': np.random.exponential(0.1),
'quantity': np.random.exponential(0.5),
'side': 'buy' if np.random.random() > 0.5 else 'sell'
}
ticks.append(tick)
return ticks
def get_latest_training_data(self):
"""Return the latest TrainingDataPacket"""
return self.training_data_buffer[-1] if self.training_data_buffer else None
def test_enhanced_williams_cnn():
"""Test the enhanced Williams Market Structure with CNN integration"""
try:
from training.williams_market_structure import WilliamsMarketStructure, SwingType
logger.info("=" * 80)
logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION")
logger.info("=" * 80)
# Create mock data provider
data_provider = MockTrainingDataProvider()
# Initialize Williams Market Structure with CNN
williams = WilliamsMarketStructure(
swing_strengths=[2, 3, 5], # Reduced for testing
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
cnn_output_size=10, # 5 levels * 2 outputs (type + price)
enable_cnn_feature=True, # Enable CNN features
training_data_provider=data_provider
)
logger.info(f"CNN enabled: {williams.enable_cnn_feature}")
logger.info(f"Training data provider available: {williams.training_data_provider is not None}")
# Generate test OHLCV data for Williams calculation
test_ohlcv = generate_test_ohlcv_data()
logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars")
# Test Williams calculation with CNN integration
logger.info("\n" + "=" * 60)
logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION")
logger.info("=" * 60)
structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv)
# Display results
logger.info(f"\nWilliams Structure Analysis Results:")
logger.info(f"Calculated levels: {len(structure_levels)}")
for level_key, level_data in structure_levels.items():
swing_count = len(level_data.swing_points)
logger.info(f"{level_key}: {swing_count} swing points, "
f"trend: {level_data.trend_analysis.direction.value}, "
f"bias: {level_data.current_bias.value}")
if swing_count > 0:
latest_swing = level_data.swing_points[-1]
logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}")
# Test CNN input preparation directly
logger.info("\n" + "=" * 60)
logger.info("TESTING CNN INPUT PREPARATION")
logger.info("=" * 60)
if williams.enable_cnn_feature and structure_levels['level_0'].swing_points:
test_pivot = structure_levels['level_0'].swing_points[-1]
logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}")
# Test input preparation
cnn_input = williams._prepare_cnn_input(
current_pivot=test_pivot,
ohlcv_data_context=test_ohlcv,
previous_pivot_details=None
)
logger.info(f"CNN input shape: {cnn_input.shape}")
logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]")
logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}")
# Test ground truth preparation
if len(structure_levels['level_0'].swing_points) >= 2:
prev_pivot = structure_levels['level_0'].swing_points[-2]
current_pivot = structure_levels['level_0'].swing_points[-1]
prev_details = {'pivot': prev_pivot}
ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot)
logger.info(f"Ground truth shape: {ground_truth.shape}")
logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}")
logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}")
# Test normalization
logger.info("\n" + "=" * 60)
logger.info("TESTING 1H-BASED NORMALIZATION")
logger.info("=" * 60)
training_packet = data_provider.get_latest_training_data()
if training_packet:
# Test normalization with sample data
sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices
normalized = williams._normalize_features_by_1h_range(sample_features, training_packet)
logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]")
logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]")
# Check if 1h data is being used for normalization
eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
if eth_1h:
h1_prices = []
for bar in eth_1h[-24:]:
h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']])
h1_range = max(h1_prices) - min(h1_prices)
logger.info(f"1h price range used for normalization: {h1_range:.2f}")
logger.info("\n" + "=" * 80)
logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY")
logger.info("=" * 80)
return True
except ImportError as e:
logger.error(f"Import error - some dependencies missing: {e}")
logger.info("This is expected if TensorFlow or other dependencies are not installed")
return False
except Exception as e:
logger.error(f"Test failed with error: {e}", exc_info=True)
return False
def generate_test_ohlcv_data(bars=200, base_price=2400.0):
"""Generate test OHLCV data for Williams calculation"""
data = []
current_price = base_price
current_time = datetime.now()
for i in range(bars):
timestamp = current_time - timedelta(seconds=bars - i)
# Generate price movement
price_change = np.random.normal(0, base_price * 0.002)
current_price = max(current_price + price_change, base_price * 0.9)
open_price = current_price
high_price = open_price * (1 + abs(np.random.normal(0, 0.003)))
low_price = open_price * (1 - abs(np.random.normal(0, 0.003)))
close_price = low_price + (high_price - low_price) * np.random.random()
volume = np.random.exponential(1000)
current_price = close_price
bar = [
timestamp.timestamp(),
open_price,
high_price,
low_price,
close_price,
volume
]
data.append(bar)
return np.array(data)
if __name__ == "__main__":
success = test_enhanced_williams_cnn()
if success:
print("\n✅ All tests passed! Enhanced Williams CNN integration is working.")
else:
print("\n❌ Some tests failed. Check logs for details.")