folder stricture reorganize
This commit is contained in:
346
tests/test_enhanced_williams_cnn.py
Normal file
346
tests/test_enhanced_williams_cnn.py
Normal file
@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Enhanced Williams Market Structure with CNN Integration
|
||||
|
||||
This script demonstrates the multi-timeframe, multi-symbol CNN-enabled
|
||||
Williams Market Structure that predicts pivot points using TrainingDataPacket.
|
||||
|
||||
Features tested:
|
||||
- Multi-timeframe data integration (1s, 1m, 1h)
|
||||
- Multi-symbol support (ETH, BTC)
|
||||
- Tick data aggregation
|
||||
- 1h-based normalization strategy
|
||||
- Multi-level pivot prediction (5 levels, type + price)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mock TrainingDataPacket for testing
|
||||
@dataclass
|
||||
class MockTrainingDataPacket:
|
||||
"""Mock TrainingDataPacket for testing CNN integration"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
tick_cache: List[Dict[str, Any]]
|
||||
one_second_bars: List[Dict[str, Any]]
|
||||
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
|
||||
cnn_features: Optional[Dict[str, np.ndarray]] = None
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None
|
||||
market_state: Optional[Any] = None
|
||||
universal_stream: Optional[Any] = None
|
||||
|
||||
class MockTrainingDataProvider:
|
||||
"""Mock provider that supplies TrainingDataPacket for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.training_data_buffer = []
|
||||
self._generate_mock_data()
|
||||
|
||||
def _generate_mock_data(self):
|
||||
"""Generate comprehensive mock market data"""
|
||||
current_time = datetime.now()
|
||||
|
||||
# Generate ETH data for different timeframes
|
||||
eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data
|
||||
eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data
|
||||
eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data
|
||||
|
||||
# Generate BTC data
|
||||
btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data
|
||||
|
||||
# Generate tick data
|
||||
tick_data = self._generate_tick_data(current_time)
|
||||
|
||||
# Create comprehensive TrainingDataPacket
|
||||
training_packet = MockTrainingDataPacket(
|
||||
timestamp=current_time,
|
||||
symbol='ETH/USDT',
|
||||
tick_cache=tick_data,
|
||||
one_second_bars=eth_1s_data[-300:], # Last 5 minutes
|
||||
multi_timeframe_data={
|
||||
'ETH/USDT': {
|
||||
'1s': eth_1s_data,
|
||||
'1m': eth_1m_data,
|
||||
'1h': eth_1h_data
|
||||
},
|
||||
'BTC/USDT': {
|
||||
'1s': btc_1s_data
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.training_data_buffer.append(training_packet)
|
||||
logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars")
|
||||
logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}")
|
||||
|
||||
def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic OHLCV data with indicators"""
|
||||
data = []
|
||||
|
||||
# Calculate time delta based on timeframe
|
||||
if timeframe == '1s':
|
||||
delta = timedelta(seconds=1)
|
||||
elif timeframe == '1m':
|
||||
delta = timedelta(minutes=1)
|
||||
elif timeframe == '1h':
|
||||
delta = timedelta(hours=1)
|
||||
else:
|
||||
delta = timedelta(minutes=1)
|
||||
|
||||
current_price = base_price
|
||||
|
||||
for i in range(count):
|
||||
timestamp = end_time - delta * (count - i - 1)
|
||||
|
||||
# Generate realistic price movement
|
||||
price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility
|
||||
current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base
|
||||
|
||||
# Generate OHLCV
|
||||
open_price = current_price
|
||||
high_price = open_price * (1 + abs(np.random.normal(0, 0.002)))
|
||||
low_price = open_price * (1 - abs(np.random.normal(0, 0.002)))
|
||||
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||
volume = np.random.exponential(1000)
|
||||
|
||||
current_price = close_price
|
||||
|
||||
# Calculate basic indicators (placeholders)
|
||||
sma_20 = close_price * (1 + np.random.normal(0, 0.001))
|
||||
ema_20 = close_price * (1 + np.random.normal(0, 0.0005))
|
||||
rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70
|
||||
macd = np.random.normal(0, 0.1)
|
||||
bb_upper = high_price * 1.02
|
||||
|
||||
bar = {
|
||||
'timestamp': timestamp,
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume,
|
||||
'sma_20': sma_20,
|
||||
'ema_20': ema_20,
|
||||
'rsi_14': rsi_14,
|
||||
'macd': macd,
|
||||
'bb_upper': bb_upper
|
||||
}
|
||||
data.append(bar)
|
||||
|
||||
return data
|
||||
|
||||
def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic tick data for last 5 minutes"""
|
||||
ticks = []
|
||||
|
||||
# Generate ETH ticks
|
||||
for i in range(300): # 5 minutes * 60 seconds
|
||||
tick_time = end_time - timedelta(seconds=300 - i)
|
||||
|
||||
# 2-5 ticks per second
|
||||
ticks_per_second = np.random.randint(2, 6)
|
||||
|
||||
for j in range(ticks_per_second):
|
||||
tick = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': tick_time + timedelta(milliseconds=j * 200),
|
||||
'price': 2400.0 + np.random.normal(0, 5),
|
||||
'volume': np.random.exponential(0.5),
|
||||
'quantity': np.random.exponential(1.0),
|
||||
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||
}
|
||||
ticks.append(tick)
|
||||
|
||||
# Generate BTC ticks
|
||||
for i in range(300): # 5 minutes * 60 seconds
|
||||
tick_time = end_time - timedelta(seconds=300 - i)
|
||||
|
||||
ticks_per_second = np.random.randint(1, 4)
|
||||
|
||||
for j in range(ticks_per_second):
|
||||
tick = {
|
||||
'symbol': 'BTC/USDT',
|
||||
'timestamp': tick_time + timedelta(milliseconds=j * 300),
|
||||
'price': 45000.0 + np.random.normal(0, 100),
|
||||
'volume': np.random.exponential(0.1),
|
||||
'quantity': np.random.exponential(0.5),
|
||||
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||
}
|
||||
ticks.append(tick)
|
||||
|
||||
return ticks
|
||||
|
||||
def get_latest_training_data(self):
|
||||
"""Return the latest TrainingDataPacket"""
|
||||
return self.training_data_buffer[-1] if self.training_data_buffer else None
|
||||
|
||||
|
||||
def test_enhanced_williams_cnn():
|
||||
"""Test the enhanced Williams Market Structure with CNN integration"""
|
||||
try:
|
||||
from training.williams_market_structure import WilliamsMarketStructure, SwingType
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Create mock data provider
|
||||
data_provider = MockTrainingDataProvider()
|
||||
|
||||
# Initialize Williams Market Structure with CNN
|
||||
williams = WilliamsMarketStructure(
|
||||
swing_strengths=[2, 3, 5], # Reduced for testing
|
||||
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
|
||||
cnn_output_size=10, # 5 levels * 2 outputs (type + price)
|
||||
enable_cnn_feature=True, # Enable CNN features
|
||||
training_data_provider=data_provider
|
||||
)
|
||||
|
||||
logger.info(f"CNN enabled: {williams.enable_cnn_feature}")
|
||||
logger.info(f"Training data provider available: {williams.training_data_provider is not None}")
|
||||
|
||||
# Generate test OHLCV data for Williams calculation
|
||||
test_ohlcv = generate_test_ohlcv_data()
|
||||
logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars")
|
||||
|
||||
# Test Williams calculation with CNN integration
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv)
|
||||
|
||||
# Display results
|
||||
logger.info(f"\nWilliams Structure Analysis Results:")
|
||||
logger.info(f"Calculated levels: {len(structure_levels)}")
|
||||
|
||||
for level_key, level_data in structure_levels.items():
|
||||
swing_count = len(level_data.swing_points)
|
||||
logger.info(f"{level_key}: {swing_count} swing points, "
|
||||
f"trend: {level_data.trend_analysis.direction.value}, "
|
||||
f"bias: {level_data.current_bias.value}")
|
||||
|
||||
if swing_count > 0:
|
||||
latest_swing = level_data.swing_points[-1]
|
||||
logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}")
|
||||
|
||||
# Test CNN input preparation directly
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TESTING CNN INPUT PREPARATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if williams.enable_cnn_feature and structure_levels['level_0'].swing_points:
|
||||
test_pivot = structure_levels['level_0'].swing_points[-1]
|
||||
|
||||
logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}")
|
||||
|
||||
# Test input preparation
|
||||
cnn_input = williams._prepare_cnn_input(
|
||||
current_pivot=test_pivot,
|
||||
ohlcv_data_context=test_ohlcv,
|
||||
previous_pivot_details=None
|
||||
)
|
||||
|
||||
logger.info(f"CNN input shape: {cnn_input.shape}")
|
||||
logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]")
|
||||
logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}")
|
||||
|
||||
# Test ground truth preparation
|
||||
if len(structure_levels['level_0'].swing_points) >= 2:
|
||||
prev_pivot = structure_levels['level_0'].swing_points[-2]
|
||||
current_pivot = structure_levels['level_0'].swing_points[-1]
|
||||
|
||||
prev_details = {'pivot': prev_pivot}
|
||||
ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot)
|
||||
|
||||
logger.info(f"Ground truth shape: {ground_truth.shape}")
|
||||
logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}")
|
||||
logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}")
|
||||
|
||||
# Test normalization
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TESTING 1H-BASED NORMALIZATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
training_packet = data_provider.get_latest_training_data()
|
||||
if training_packet:
|
||||
# Test normalization with sample data
|
||||
sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices
|
||||
|
||||
normalized = williams._normalize_features_by_1h_range(sample_features, training_packet)
|
||||
|
||||
logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]")
|
||||
logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]")
|
||||
|
||||
# Check if 1h data is being used for normalization
|
||||
eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
|
||||
if eth_1h:
|
||||
h1_prices = []
|
||||
for bar in eth_1h[-24:]:
|
||||
h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']])
|
||||
h1_range = max(h1_prices) - min(h1_prices)
|
||||
logger.info(f"1h price range used for normalization: {h1_range:.2f}")
|
||||
|
||||
logger.info("\n" + "=" * 80)
|
||||
logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Import error - some dependencies missing: {e}")
|
||||
logger.info("This is expected if TensorFlow or other dependencies are not installed")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed with error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def generate_test_ohlcv_data(bars=200, base_price=2400.0):
|
||||
"""Generate test OHLCV data for Williams calculation"""
|
||||
data = []
|
||||
current_price = base_price
|
||||
current_time = datetime.now()
|
||||
|
||||
for i in range(bars):
|
||||
timestamp = current_time - timedelta(seconds=bars - i)
|
||||
|
||||
# Generate price movement
|
||||
price_change = np.random.normal(0, base_price * 0.002)
|
||||
current_price = max(current_price + price_change, base_price * 0.9)
|
||||
|
||||
open_price = current_price
|
||||
high_price = open_price * (1 + abs(np.random.normal(0, 0.003)))
|
||||
low_price = open_price * (1 - abs(np.random.normal(0, 0.003)))
|
||||
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||
volume = np.random.exponential(1000)
|
||||
|
||||
current_price = close_price
|
||||
|
||||
bar = [
|
||||
timestamp.timestamp(),
|
||||
open_price,
|
||||
high_price,
|
||||
low_price,
|
||||
close_price,
|
||||
volume
|
||||
]
|
||||
data.append(bar)
|
||||
|
||||
return np.array(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_enhanced_williams_cnn()
|
||||
if success:
|
||||
print("\n✅ All tests passed! Enhanced Williams CNN integration is working.")
|
||||
else:
|
||||
print("\n❌ Some tests failed. Check logs for details.")
|
Reference in New Issue
Block a user