346 lines
14 KiB
Python
346 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Enhanced Williams Market Structure with CNN Integration
|
|
|
|
This script demonstrates the multi-timeframe, multi-symbol CNN-enabled
|
|
Williams Market Structure that predicts pivot points using TrainingDataPacket.
|
|
|
|
Features tested:
|
|
- Multi-timeframe data integration (1s, 1m, 1h)
|
|
- Multi-symbol support (ETH, BTC)
|
|
- Tick data aggregation
|
|
- 1h-based normalization strategy
|
|
- Multi-level pivot prediction (5 levels, type + price)
|
|
"""
|
|
|
|
import numpy as np
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Mock TrainingDataPacket for testing
|
|
@dataclass
|
|
class MockTrainingDataPacket:
|
|
"""Mock TrainingDataPacket for testing CNN integration"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
tick_cache: List[Dict[str, Any]]
|
|
one_second_bars: List[Dict[str, Any]]
|
|
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
|
|
cnn_features: Optional[Dict[str, np.ndarray]] = None
|
|
cnn_predictions: Optional[Dict[str, np.ndarray]] = None
|
|
market_state: Optional[Any] = None
|
|
universal_stream: Optional[Any] = None
|
|
|
|
class MockTrainingDataProvider:
|
|
"""Mock provider that supplies TrainingDataPacket for testing"""
|
|
|
|
def __init__(self):
|
|
self.training_data_buffer = []
|
|
self._generate_mock_data()
|
|
|
|
def _generate_mock_data(self):
|
|
"""Generate comprehensive mock market data"""
|
|
current_time = datetime.now()
|
|
|
|
# Generate ETH data for different timeframes
|
|
eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data
|
|
eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data
|
|
eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data
|
|
|
|
# Generate BTC data
|
|
btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data
|
|
|
|
# Generate tick data
|
|
tick_data = self._generate_tick_data(current_time)
|
|
|
|
# Create comprehensive TrainingDataPacket
|
|
training_packet = MockTrainingDataPacket(
|
|
timestamp=current_time,
|
|
symbol='ETH/USDT',
|
|
tick_cache=tick_data,
|
|
one_second_bars=eth_1s_data[-300:], # Last 5 minutes
|
|
multi_timeframe_data={
|
|
'ETH/USDT': {
|
|
'1s': eth_1s_data,
|
|
'1m': eth_1m_data,
|
|
'1h': eth_1h_data
|
|
},
|
|
'BTC/USDT': {
|
|
'1s': btc_1s_data
|
|
}
|
|
}
|
|
)
|
|
|
|
self.training_data_buffer.append(training_packet)
|
|
logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars")
|
|
logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}")
|
|
|
|
def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]:
|
|
"""Generate realistic OHLCV data with indicators"""
|
|
data = []
|
|
|
|
# Calculate time delta based on timeframe
|
|
if timeframe == '1s':
|
|
delta = timedelta(seconds=1)
|
|
elif timeframe == '1m':
|
|
delta = timedelta(minutes=1)
|
|
elif timeframe == '1h':
|
|
delta = timedelta(hours=1)
|
|
else:
|
|
delta = timedelta(minutes=1)
|
|
|
|
current_price = base_price
|
|
|
|
for i in range(count):
|
|
timestamp = end_time - delta * (count - i - 1)
|
|
|
|
# Generate realistic price movement
|
|
price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility
|
|
current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base
|
|
|
|
# Generate OHLCV
|
|
open_price = current_price
|
|
high_price = open_price * (1 + abs(np.random.normal(0, 0.002)))
|
|
low_price = open_price * (1 - abs(np.random.normal(0, 0.002)))
|
|
close_price = low_price + (high_price - low_price) * np.random.random()
|
|
volume = np.random.exponential(1000)
|
|
|
|
current_price = close_price
|
|
|
|
# Calculate basic indicators (placeholders)
|
|
sma_20 = close_price * (1 + np.random.normal(0, 0.001))
|
|
ema_20 = close_price * (1 + np.random.normal(0, 0.0005))
|
|
rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70
|
|
macd = np.random.normal(0, 0.1)
|
|
bb_upper = high_price * 1.02
|
|
|
|
bar = {
|
|
'timestamp': timestamp,
|
|
'open': open_price,
|
|
'high': high_price,
|
|
'low': low_price,
|
|
'close': close_price,
|
|
'volume': volume,
|
|
'sma_20': sma_20,
|
|
'ema_20': ema_20,
|
|
'rsi_14': rsi_14,
|
|
'macd': macd,
|
|
'bb_upper': bb_upper
|
|
}
|
|
data.append(bar)
|
|
|
|
return data
|
|
|
|
def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]:
|
|
"""Generate realistic tick data for last 5 minutes"""
|
|
ticks = []
|
|
|
|
# Generate ETH ticks
|
|
for i in range(300): # 5 minutes * 60 seconds
|
|
tick_time = end_time - timedelta(seconds=300 - i)
|
|
|
|
# 2-5 ticks per second
|
|
ticks_per_second = np.random.randint(2, 6)
|
|
|
|
for j in range(ticks_per_second):
|
|
tick = {
|
|
'symbol': 'ETH/USDT',
|
|
'timestamp': tick_time + timedelta(milliseconds=j * 200),
|
|
'price': 2400.0 + np.random.normal(0, 5),
|
|
'volume': np.random.exponential(0.5),
|
|
'quantity': np.random.exponential(1.0),
|
|
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
|
}
|
|
ticks.append(tick)
|
|
|
|
# Generate BTC ticks
|
|
for i in range(300): # 5 minutes * 60 seconds
|
|
tick_time = end_time - timedelta(seconds=300 - i)
|
|
|
|
ticks_per_second = np.random.randint(1, 4)
|
|
|
|
for j in range(ticks_per_second):
|
|
tick = {
|
|
'symbol': 'BTC/USDT',
|
|
'timestamp': tick_time + timedelta(milliseconds=j * 300),
|
|
'price': 45000.0 + np.random.normal(0, 100),
|
|
'volume': np.random.exponential(0.1),
|
|
'quantity': np.random.exponential(0.5),
|
|
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
|
}
|
|
ticks.append(tick)
|
|
|
|
return ticks
|
|
|
|
def get_latest_training_data(self):
|
|
"""Return the latest TrainingDataPacket"""
|
|
return self.training_data_buffer[-1] if self.training_data_buffer else None
|
|
|
|
|
|
def test_enhanced_williams_cnn():
|
|
"""Test the enhanced Williams Market Structure with CNN integration"""
|
|
try:
|
|
from training.williams_market_structure import WilliamsMarketStructure, SwingType
|
|
|
|
logger.info("=" * 80)
|
|
logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION")
|
|
logger.info("=" * 80)
|
|
|
|
# Create mock data provider
|
|
data_provider = MockTrainingDataProvider()
|
|
|
|
# Initialize Williams Market Structure with CNN
|
|
williams = WilliamsMarketStructure(
|
|
swing_strengths=[2, 3, 5], # Reduced for testing
|
|
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
|
|
cnn_output_size=10, # 5 levels * 2 outputs (type + price)
|
|
enable_cnn_feature=True, # Enable CNN features
|
|
training_data_provider=data_provider
|
|
)
|
|
|
|
logger.info(f"CNN enabled: {williams.enable_cnn_feature}")
|
|
logger.info(f"Training data provider available: {williams.training_data_provider is not None}")
|
|
|
|
# Generate test OHLCV data for Williams calculation
|
|
test_ohlcv = generate_test_ohlcv_data()
|
|
logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars")
|
|
|
|
# Test Williams calculation with CNN integration
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION")
|
|
logger.info("=" * 60)
|
|
|
|
structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv)
|
|
|
|
# Display results
|
|
logger.info(f"\nWilliams Structure Analysis Results:")
|
|
logger.info(f"Calculated levels: {len(structure_levels)}")
|
|
|
|
for level_key, level_data in structure_levels.items():
|
|
swing_count = len(level_data.swing_points)
|
|
logger.info(f"{level_key}: {swing_count} swing points, "
|
|
f"trend: {level_data.trend_analysis.direction.value}, "
|
|
f"bias: {level_data.current_bias.value}")
|
|
|
|
if swing_count > 0:
|
|
latest_swing = level_data.swing_points[-1]
|
|
logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}")
|
|
|
|
# Test CNN input preparation directly
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("TESTING CNN INPUT PREPARATION")
|
|
logger.info("=" * 60)
|
|
|
|
if williams.enable_cnn_feature and structure_levels['level_0'].swing_points:
|
|
test_pivot = structure_levels['level_0'].swing_points[-1]
|
|
|
|
logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}")
|
|
|
|
# Test input preparation
|
|
cnn_input = williams._prepare_cnn_input(
|
|
current_pivot=test_pivot,
|
|
ohlcv_data_context=test_ohlcv,
|
|
previous_pivot_details=None
|
|
)
|
|
|
|
logger.info(f"CNN input shape: {cnn_input.shape}")
|
|
logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]")
|
|
logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}")
|
|
|
|
# Test ground truth preparation
|
|
if len(structure_levels['level_0'].swing_points) >= 2:
|
|
prev_pivot = structure_levels['level_0'].swing_points[-2]
|
|
current_pivot = structure_levels['level_0'].swing_points[-1]
|
|
|
|
prev_details = {'pivot': prev_pivot}
|
|
ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot)
|
|
|
|
logger.info(f"Ground truth shape: {ground_truth.shape}")
|
|
logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}")
|
|
logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}")
|
|
|
|
# Test normalization
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("TESTING 1H-BASED NORMALIZATION")
|
|
logger.info("=" * 60)
|
|
|
|
training_packet = data_provider.get_latest_training_data()
|
|
if training_packet:
|
|
# Test normalization with sample data
|
|
sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices
|
|
|
|
normalized = williams._normalize_features_by_1h_range(sample_features, training_packet)
|
|
|
|
logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]")
|
|
logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]")
|
|
|
|
# Check if 1h data is being used for normalization
|
|
eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
|
|
if eth_1h:
|
|
h1_prices = []
|
|
for bar in eth_1h[-24:]:
|
|
h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']])
|
|
h1_range = max(h1_prices) - min(h1_prices)
|
|
logger.info(f"1h price range used for normalization: {h1_range:.2f}")
|
|
|
|
logger.info("\n" + "=" * 80)
|
|
logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY")
|
|
logger.info("=" * 80)
|
|
|
|
return True
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Import error - some dependencies missing: {e}")
|
|
logger.info("This is expected if TensorFlow or other dependencies are not installed")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Test failed with error: {e}", exc_info=True)
|
|
return False
|
|
|
|
|
|
def generate_test_ohlcv_data(bars=200, base_price=2400.0):
|
|
"""Generate test OHLCV data for Williams calculation"""
|
|
data = []
|
|
current_price = base_price
|
|
current_time = datetime.now()
|
|
|
|
for i in range(bars):
|
|
timestamp = current_time - timedelta(seconds=bars - i)
|
|
|
|
# Generate price movement
|
|
price_change = np.random.normal(0, base_price * 0.002)
|
|
current_price = max(current_price + price_change, base_price * 0.9)
|
|
|
|
open_price = current_price
|
|
high_price = open_price * (1 + abs(np.random.normal(0, 0.003)))
|
|
low_price = open_price * (1 - abs(np.random.normal(0, 0.003)))
|
|
close_price = low_price + (high_price - low_price) * np.random.random()
|
|
volume = np.random.exponential(1000)
|
|
|
|
current_price = close_price
|
|
|
|
bar = [
|
|
timestamp.timestamp(),
|
|
open_price,
|
|
high_price,
|
|
low_price,
|
|
close_price,
|
|
volume
|
|
]
|
|
data.append(bar)
|
|
|
|
return np.array(data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = test_enhanced_williams_cnn()
|
|
if success:
|
|
print("\n✅ All tests passed! Enhanced Williams CNN integration is working.")
|
|
else:
|
|
print("\n❌ Some tests failed. Check logs for details.") |