tter pivots
This commit is contained in:
parent
1130e02f35
commit
75dbac1761
160
WILLIAMS_CNN_PIVOT_INTEGRATION_SUMMARY.md
Normal file
160
WILLIAMS_CNN_PIVOT_INTEGRATION_SUMMARY.md
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# Williams Market Structure CNN Integration Summary
|
||||||
|
|
||||||
|
## 🎯 Overview
|
||||||
|
|
||||||
|
The Williams Market Structure has been enhanced with CNN-based pivot prediction capabilities, enabling real-time training and prediction at each detected pivot point using multi-timeframe, multi-symbol data.
|
||||||
|
|
||||||
|
## ✅ Key Features Implemented
|
||||||
|
|
||||||
|
### 🔄 **Recursive Pivot Structure**
|
||||||
|
- **Level 0**: Raw OHLCV price data → Swing points using multiple strengths [2, 3, 5, 8, 13]
|
||||||
|
- **Level 1**: Level 0 pivot points → Treated as "price bars" for higher-level pivots
|
||||||
|
- **Level 2-4**: Recursive application on previous level's pivots
|
||||||
|
- **True Recursion**: Each level builds on the previous level's pivot points
|
||||||
|
|
||||||
|
### 🧠 **CNN Integration Architecture**
|
||||||
|
```
|
||||||
|
Each Pivot Detection Triggers:
|
||||||
|
1. Train CNN on previous pivot (features) → current pivot (ground truth)
|
||||||
|
2. Predict next pivot using current pivot features
|
||||||
|
3. Store current features for next training cycle
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📊 **Multi-Timeframe Input Features**
|
||||||
|
- **ETH Primary Symbol**:
|
||||||
|
- 900 x 1s bars with indicators (10 features)
|
||||||
|
- 900 x 1m bars with indicators (10 features)
|
||||||
|
- 900 x 1h bars with indicators (10 features)
|
||||||
|
- 5 minutes of tick-derived features (10 features)
|
||||||
|
- **BTC Reference Symbol**:
|
||||||
|
- 5 minutes of tick-derived features (4 features)
|
||||||
|
- **Pivot Context**: Recent pivot characteristics (3 features)
|
||||||
|
- **Chart Labels**: Symbol/timeframe identification (3 features)
|
||||||
|
- **Total**: 900 timesteps × 50 features
|
||||||
|
|
||||||
|
### 🎯 **Multi-Level Output Prediction**
|
||||||
|
- **10 Outputs Total**: 5 Williams levels × (type + price)
|
||||||
|
- Level 0-4: [swing_type (0=LOW, 1=HIGH), normalized_price]
|
||||||
|
- Allows prediction across all recursive levels simultaneously
|
||||||
|
|
||||||
|
### 📐 **Smart Normalization Strategy**
|
||||||
|
- **Data Flow**: Keep actual values throughout pipeline for validation
|
||||||
|
- **Final Step**: Normalize using 1h timeframe min/max range
|
||||||
|
- **Cross-Timeframe Preservation**: Maintains relationships between different timeframes
|
||||||
|
- **Price Features**: Normalized with 1h range
|
||||||
|
- **Non-Price Features**: Feature-wise normalization (indicators, counts, etc.)
|
||||||
|
|
||||||
|
## 🔧 **Integration with TrainingDataPacket**
|
||||||
|
|
||||||
|
Successfully leverages existing `TrainingDataPacket` from `core/unified_data_stream.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class TrainingDataPacket:
|
||||||
|
timestamp: datetime
|
||||||
|
symbol: str
|
||||||
|
tick_cache: List[Dict[str, Any]] # ✅ Used for tick features
|
||||||
|
one_second_bars: List[Dict[str, Any]] # ✅ Used for 1s data
|
||||||
|
multi_timeframe_data: Dict[str, List[Dict[str, Any]]] # ✅ Used for 1m, 1h data
|
||||||
|
cnn_features: Optional[Dict[str, np.ndarray]] # ✅ Populated by Williams
|
||||||
|
cnn_predictions: Optional[Dict[str, np.ndarray]] # ✅ Populated by Williams
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 **CNN Training Flow**
|
||||||
|
|
||||||
|
### **At Each Pivot Point Detection:**
|
||||||
|
|
||||||
|
1. **Training Phase** (if previous pivot exists):
|
||||||
|
```python
|
||||||
|
X_train = previous_pivot_features # (900, 50)
|
||||||
|
y_train = current_actual_pivot # (10,) for all levels
|
||||||
|
model.fit(X_train, y_train, epochs=1) # Online learning
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Prediction Phase**:
|
||||||
|
```python
|
||||||
|
X_predict = current_pivot_features # (900, 50)
|
||||||
|
y_predict = model.predict(X_predict) # (10,) predictions for all levels
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **State Management**:
|
||||||
|
```python
|
||||||
|
previous_pivot_details = {
|
||||||
|
'features': X_predict,
|
||||||
|
'pivot': current_pivot_object
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🛠 **Implementation Status**
|
||||||
|
|
||||||
|
### ✅ **Completed Components**
|
||||||
|
- [x] Recursive Williams pivot calculation (5 levels)
|
||||||
|
- [x] CNN integration hooks at each pivot detection
|
||||||
|
- [x] Multi-timeframe feature extraction from TrainingDataPacket
|
||||||
|
- [x] 1h-based normalization strategy
|
||||||
|
- [x] Multi-level output prediction (10 outputs)
|
||||||
|
- [x] Online learning with single-step training
|
||||||
|
- [x] Dashboard integration with proper diagnostics
|
||||||
|
- [x] Comprehensive test suite
|
||||||
|
|
||||||
|
### ⚠ **Current Limitations**
|
||||||
|
- CNN disabled due to TensorFlow dependencies not installed
|
||||||
|
- Placeholder technical indicators (TODO: Add real SMA, EMA, RSI, MACD, etc.)
|
||||||
|
- Higher-level ground truth uses simplified logic (needs full Williams context)
|
||||||
|
|
||||||
|
### 🔄 **Real-Time Dashboard Integration**
|
||||||
|
|
||||||
|
Fixed dashboard Williams integration:
|
||||||
|
- **Reduced data requirement**: 20 bars minimum (from 50)
|
||||||
|
- **Proper configuration**: Uses swing_strengths=[2, 3, 5]
|
||||||
|
- **Enhanced diagnostics**: Data quality validation and pivot detection logging
|
||||||
|
- **Consistent timezone handling**: Proper timestamp conversion for pivot display
|
||||||
|
|
||||||
|
## 📈 **Performance Characteristics**
|
||||||
|
|
||||||
|
### **Pivot Detection Performance** (from diagnostics):
|
||||||
|
- ✅ Clear test patterns: Successfully detects obvious pivot points
|
||||||
|
- ✅ Realistic data: Handles real market volatility and timing
|
||||||
|
- ✅ Multi-level recursion: Properly builds higher levels from lower levels
|
||||||
|
|
||||||
|
### **CNN Training Frequency**:
|
||||||
|
- **Level 0**: Most frequent (every raw price pivot)
|
||||||
|
- **Level 1-4**: Less frequent (requires sufficient lower-level pivots)
|
||||||
|
- **Online Learning**: Single epoch per pivot for real-time adaptation
|
||||||
|
|
||||||
|
## 🎓 **Usage Example**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Initialize Williams with CNN integration
|
||||||
|
williams = WilliamsMarketStructure(
|
||||||
|
swing_strengths=[2, 3, 5, 8, 13],
|
||||||
|
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
|
||||||
|
cnn_output_size=10, # 5 levels × 2 outputs
|
||||||
|
enable_cnn_feature=True,
|
||||||
|
training_data_provider=data_stream # TrainingDataPacket provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate pivots (automatically triggers CNN training/prediction)
|
||||||
|
structure_levels = williams.calculate_recursive_pivot_points(ohlcv_data)
|
||||||
|
|
||||||
|
# Extract RL features (250 features for reinforcement learning)
|
||||||
|
rl_features = williams.extract_features_for_rl(structure_levels)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔮 **Next Steps**
|
||||||
|
|
||||||
|
1. **Install TensorFlow**: Enable CNN functionality
|
||||||
|
2. **Add Real Indicators**: Replace placeholder technical indicators
|
||||||
|
3. **Enhanced Ground Truth**: Implement proper multi-level pivot relationships
|
||||||
|
4. **Model Persistence**: Save/load trained CNN models
|
||||||
|
5. **Performance Metrics**: Track CNN prediction accuracy over time
|
||||||
|
|
||||||
|
## 📊 **Key Benefits**
|
||||||
|
|
||||||
|
- **Real-Time Learning**: CNN adapts to market conditions at each pivot
|
||||||
|
- **Multi-Scale Analysis**: Captures patterns across 5 recursive levels
|
||||||
|
- **Rich Context**: 50 features per timestep covering multiple timeframes and symbols
|
||||||
|
- **Consistent Data Flow**: Leverages existing TrainingDataPacket infrastructure
|
||||||
|
- **Market Structure Awareness**: Predictions based on Williams methodology
|
||||||
|
|
||||||
|
This implementation provides a robust foundation for CNN-enhanced pivot prediction while maintaining the proven Williams Market Structure methodology.
|
File diff suppressed because it is too large
Load Diff
@ -34,3 +34,15 @@ we will have 2 types of pivot points:
|
|||||||
theese pivot points will define the trend direction and the trend strength.
|
theese pivot points will define the trend direction and the trend strength.
|
||||||
|
|
||||||
level 2 pivot should not use different (bigger ) price timeframe, but should use the level1 pivot points as candles instead. so a level 2 low pivot is a when a level 1 pivot low is surrownded by higher level 1 pibot lows
|
level 2 pivot should not use different (bigger ) price timeframe, but should use the level1 pivot points as candles instead. so a level 2 low pivot is a when a level 1 pivot low is surrownded by higher level 1 pibot lows
|
||||||
|
----
|
||||||
|
input should be multitiframe and multi symbol timeseries with the label of the "chart" included, so the model knows what th esecondary timeseries is. So
|
||||||
|
primary symbol (that we trade, now ETC):
|
||||||
|
- 5 min of raw ticks data
|
||||||
|
- 900 of 1s timeseries with common indicators
|
||||||
|
- 900 of 1m and 900 of 1h with indicators
|
||||||
|
- all the available pivot points (multiple levels)
|
||||||
|
- one additional reference symbol (BTC) - 5 min ot ticks
|
||||||
|
if there are no ticks, we bstitute them with 1s or lowest ohclv data.
|
||||||
|
this is my idea, but I am open to improvement suggestions.
|
||||||
|
output of the CNN model should be the next pibot point in each level
|
||||||
|
course, data must be normalized to the max and min of the highest timeframe, so the relations between different timeframes stay the same
|
||||||
|
346
test_enhanced_williams_cnn.py
Normal file
346
test_enhanced_williams_cnn.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Enhanced Williams Market Structure with CNN Integration
|
||||||
|
|
||||||
|
This script demonstrates the multi-timeframe, multi-symbol CNN-enabled
|
||||||
|
Williams Market Structure that predicts pivot points using TrainingDataPacket.
|
||||||
|
|
||||||
|
Features tested:
|
||||||
|
- Multi-timeframe data integration (1s, 1m, 1h)
|
||||||
|
- Multi-symbol support (ETH, BTC)
|
||||||
|
- Tick data aggregation
|
||||||
|
- 1h-based normalization strategy
|
||||||
|
- Multi-level pivot prediction (5 levels, type + price)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Mock TrainingDataPacket for testing
|
||||||
|
@dataclass
|
||||||
|
class MockTrainingDataPacket:
|
||||||
|
"""Mock TrainingDataPacket for testing CNN integration"""
|
||||||
|
timestamp: datetime
|
||||||
|
symbol: str
|
||||||
|
tick_cache: List[Dict[str, Any]]
|
||||||
|
one_second_bars: List[Dict[str, Any]]
|
||||||
|
multi_timeframe_data: Dict[str, List[Dict[str, Any]]]
|
||||||
|
cnn_features: Optional[Dict[str, np.ndarray]] = None
|
||||||
|
cnn_predictions: Optional[Dict[str, np.ndarray]] = None
|
||||||
|
market_state: Optional[Any] = None
|
||||||
|
universal_stream: Optional[Any] = None
|
||||||
|
|
||||||
|
class MockTrainingDataProvider:
|
||||||
|
"""Mock provider that supplies TrainingDataPacket for testing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.training_data_buffer = []
|
||||||
|
self._generate_mock_data()
|
||||||
|
|
||||||
|
def _generate_mock_data(self):
|
||||||
|
"""Generate comprehensive mock market data"""
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
# Generate ETH data for different timeframes
|
||||||
|
eth_1s_data = self._generate_ohlcv_data(2400.0, 900, '1s', current_time) # 15 min of 1s data
|
||||||
|
eth_1m_data = self._generate_ohlcv_data(2400.0, 900, '1m', current_time) # 15 hours of 1m data
|
||||||
|
eth_1h_data = self._generate_ohlcv_data(2400.0, 24, '1h', current_time) # 24 hours of 1h data
|
||||||
|
|
||||||
|
# Generate BTC data
|
||||||
|
btc_1s_data = self._generate_ohlcv_data(45000.0, 300, '1s', current_time) # 5 min of 1s data
|
||||||
|
|
||||||
|
# Generate tick data
|
||||||
|
tick_data = self._generate_tick_data(current_time)
|
||||||
|
|
||||||
|
# Create comprehensive TrainingDataPacket
|
||||||
|
training_packet = MockTrainingDataPacket(
|
||||||
|
timestamp=current_time,
|
||||||
|
symbol='ETH/USDT',
|
||||||
|
tick_cache=tick_data,
|
||||||
|
one_second_bars=eth_1s_data[-300:], # Last 5 minutes
|
||||||
|
multi_timeframe_data={
|
||||||
|
'ETH/USDT': {
|
||||||
|
'1s': eth_1s_data,
|
||||||
|
'1m': eth_1m_data,
|
||||||
|
'1h': eth_1h_data
|
||||||
|
},
|
||||||
|
'BTC/USDT': {
|
||||||
|
'1s': btc_1s_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.training_data_buffer.append(training_packet)
|
||||||
|
logger.info(f"Generated mock training data: {len(eth_1s_data)} 1s bars, {len(eth_1m_data)} 1m bars, {len(eth_1h_data)} 1h bars")
|
||||||
|
logger.info(f"ETH 1h price range: {min(bar['low'] for bar in eth_1h_data):.2f} - {max(bar['high'] for bar in eth_1h_data):.2f}")
|
||||||
|
|
||||||
|
def _generate_ohlcv_data(self, base_price: float, count: int, timeframe: str, end_time: datetime) -> List[Dict[str, Any]]:
|
||||||
|
"""Generate realistic OHLCV data with indicators"""
|
||||||
|
data = []
|
||||||
|
|
||||||
|
# Calculate time delta based on timeframe
|
||||||
|
if timeframe == '1s':
|
||||||
|
delta = timedelta(seconds=1)
|
||||||
|
elif timeframe == '1m':
|
||||||
|
delta = timedelta(minutes=1)
|
||||||
|
elif timeframe == '1h':
|
||||||
|
delta = timedelta(hours=1)
|
||||||
|
else:
|
||||||
|
delta = timedelta(minutes=1)
|
||||||
|
|
||||||
|
current_price = base_price
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
timestamp = end_time - delta * (count - i - 1)
|
||||||
|
|
||||||
|
# Generate realistic price movement
|
||||||
|
price_change = np.random.normal(0, base_price * 0.001) # 0.1% volatility
|
||||||
|
current_price = max(current_price + price_change, base_price * 0.8) # Floor at 80% of base
|
||||||
|
|
||||||
|
# Generate OHLCV
|
||||||
|
open_price = current_price
|
||||||
|
high_price = open_price * (1 + abs(np.random.normal(0, 0.002)))
|
||||||
|
low_price = open_price * (1 - abs(np.random.normal(0, 0.002)))
|
||||||
|
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||||
|
volume = np.random.exponential(1000)
|
||||||
|
|
||||||
|
current_price = close_price
|
||||||
|
|
||||||
|
# Calculate basic indicators (placeholders)
|
||||||
|
sma_20 = close_price * (1 + np.random.normal(0, 0.001))
|
||||||
|
ema_20 = close_price * (1 + np.random.normal(0, 0.0005))
|
||||||
|
rsi_14 = 30 + np.random.random() * 40 # RSI between 30-70
|
||||||
|
macd = np.random.normal(0, 0.1)
|
||||||
|
bb_upper = high_price * 1.02
|
||||||
|
|
||||||
|
bar = {
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'open': open_price,
|
||||||
|
'high': high_price,
|
||||||
|
'low': low_price,
|
||||||
|
'close': close_price,
|
||||||
|
'volume': volume,
|
||||||
|
'sma_20': sma_20,
|
||||||
|
'ema_20': ema_20,
|
||||||
|
'rsi_14': rsi_14,
|
||||||
|
'macd': macd,
|
||||||
|
'bb_upper': bb_upper
|
||||||
|
}
|
||||||
|
data.append(bar)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _generate_tick_data(self, end_time: datetime) -> List[Dict[str, Any]]:
|
||||||
|
"""Generate realistic tick data for last 5 minutes"""
|
||||||
|
ticks = []
|
||||||
|
|
||||||
|
# Generate ETH ticks
|
||||||
|
for i in range(300): # 5 minutes * 60 seconds
|
||||||
|
tick_time = end_time - timedelta(seconds=300 - i)
|
||||||
|
|
||||||
|
# 2-5 ticks per second
|
||||||
|
ticks_per_second = np.random.randint(2, 6)
|
||||||
|
|
||||||
|
for j in range(ticks_per_second):
|
||||||
|
tick = {
|
||||||
|
'symbol': 'ETH/USDT',
|
||||||
|
'timestamp': tick_time + timedelta(milliseconds=j * 200),
|
||||||
|
'price': 2400.0 + np.random.normal(0, 5),
|
||||||
|
'volume': np.random.exponential(0.5),
|
||||||
|
'quantity': np.random.exponential(1.0),
|
||||||
|
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||||
|
}
|
||||||
|
ticks.append(tick)
|
||||||
|
|
||||||
|
# Generate BTC ticks
|
||||||
|
for i in range(300): # 5 minutes * 60 seconds
|
||||||
|
tick_time = end_time - timedelta(seconds=300 - i)
|
||||||
|
|
||||||
|
ticks_per_second = np.random.randint(1, 4)
|
||||||
|
|
||||||
|
for j in range(ticks_per_second):
|
||||||
|
tick = {
|
||||||
|
'symbol': 'BTC/USDT',
|
||||||
|
'timestamp': tick_time + timedelta(milliseconds=j * 300),
|
||||||
|
'price': 45000.0 + np.random.normal(0, 100),
|
||||||
|
'volume': np.random.exponential(0.1),
|
||||||
|
'quantity': np.random.exponential(0.5),
|
||||||
|
'side': 'buy' if np.random.random() > 0.5 else 'sell'
|
||||||
|
}
|
||||||
|
ticks.append(tick)
|
||||||
|
|
||||||
|
return ticks
|
||||||
|
|
||||||
|
def get_latest_training_data(self):
|
||||||
|
"""Return the latest TrainingDataPacket"""
|
||||||
|
return self.training_data_buffer[-1] if self.training_data_buffer else None
|
||||||
|
|
||||||
|
|
||||||
|
def test_enhanced_williams_cnn():
|
||||||
|
"""Test the enhanced Williams Market Structure with CNN integration"""
|
||||||
|
try:
|
||||||
|
from training.williams_market_structure import WilliamsMarketStructure, SwingType
|
||||||
|
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info("TESTING ENHANCED WILLIAMS MARKET STRUCTURE WITH CNN INTEGRATION")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
# Create mock data provider
|
||||||
|
data_provider = MockTrainingDataProvider()
|
||||||
|
|
||||||
|
# Initialize Williams Market Structure with CNN
|
||||||
|
williams = WilliamsMarketStructure(
|
||||||
|
swing_strengths=[2, 3, 5], # Reduced for testing
|
||||||
|
cnn_input_shape=(900, 50), # 900 timesteps, 50 features
|
||||||
|
cnn_output_size=10, # 5 levels * 2 outputs (type + price)
|
||||||
|
enable_cnn_feature=True, # Enable CNN features
|
||||||
|
training_data_provider=data_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"CNN enabled: {williams.enable_cnn_feature}")
|
||||||
|
logger.info(f"Training data provider available: {williams.training_data_provider is not None}")
|
||||||
|
|
||||||
|
# Generate test OHLCV data for Williams calculation
|
||||||
|
test_ohlcv = generate_test_ohlcv_data()
|
||||||
|
logger.info(f"Generated test OHLCV data: {len(test_ohlcv)} bars")
|
||||||
|
|
||||||
|
# Test Williams calculation with CNN integration
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("RUNNING WILLIAMS PIVOT CALCULATION WITH CNN INTEGRATION")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
structure_levels = williams.calculate_recursive_pivot_points(test_ohlcv)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
logger.info(f"\nWilliams Structure Analysis Results:")
|
||||||
|
logger.info(f"Calculated levels: {len(structure_levels)}")
|
||||||
|
|
||||||
|
for level_key, level_data in structure_levels.items():
|
||||||
|
swing_count = len(level_data.swing_points)
|
||||||
|
logger.info(f"{level_key}: {swing_count} swing points, "
|
||||||
|
f"trend: {level_data.trend_analysis.direction.value}, "
|
||||||
|
f"bias: {level_data.current_bias.value}")
|
||||||
|
|
||||||
|
if swing_count > 0:
|
||||||
|
latest_swing = level_data.swing_points[-1]
|
||||||
|
logger.info(f" Latest swing: {latest_swing.swing_type.name} @ {latest_swing.price:.2f}")
|
||||||
|
|
||||||
|
# Test CNN input preparation directly
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("TESTING CNN INPUT PREPARATION")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
if williams.enable_cnn_feature and structure_levels['level_0'].swing_points:
|
||||||
|
test_pivot = structure_levels['level_0'].swing_points[-1]
|
||||||
|
|
||||||
|
logger.info(f"Testing CNN input for pivot: {test_pivot.swing_type.name} @ {test_pivot.price:.2f}")
|
||||||
|
|
||||||
|
# Test input preparation
|
||||||
|
cnn_input = williams._prepare_cnn_input(
|
||||||
|
current_pivot=test_pivot,
|
||||||
|
ohlcv_data_context=test_ohlcv,
|
||||||
|
previous_pivot_details=None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"CNN input shape: {cnn_input.shape}")
|
||||||
|
logger.info(f"CNN input range: [{cnn_input.min():.4f}, {cnn_input.max():.4f}]")
|
||||||
|
logger.info(f"CNN input mean: {cnn_input.mean():.4f}, std: {cnn_input.std():.4f}")
|
||||||
|
|
||||||
|
# Test ground truth preparation
|
||||||
|
if len(structure_levels['level_0'].swing_points) >= 2:
|
||||||
|
prev_pivot = structure_levels['level_0'].swing_points[-2]
|
||||||
|
current_pivot = structure_levels['level_0'].swing_points[-1]
|
||||||
|
|
||||||
|
prev_details = {'pivot': prev_pivot}
|
||||||
|
ground_truth = williams._get_cnn_ground_truth(prev_details, current_pivot)
|
||||||
|
|
||||||
|
logger.info(f"Ground truth shape: {ground_truth.shape}")
|
||||||
|
logger.info(f"Ground truth (first 4 values): {ground_truth[:4]}")
|
||||||
|
logger.info(f"Level 0 prediction: type={ground_truth[0]:.2f}, price={ground_truth[1]:.4f}")
|
||||||
|
|
||||||
|
# Test normalization
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("TESTING 1H-BASED NORMALIZATION")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
training_packet = data_provider.get_latest_training_data()
|
||||||
|
if training_packet:
|
||||||
|
# Test normalization with sample data
|
||||||
|
sample_features = np.random.normal(2400, 50, (100, 10)) # ETH-like prices
|
||||||
|
|
||||||
|
normalized = williams._normalize_features_by_1h_range(sample_features, training_packet)
|
||||||
|
|
||||||
|
logger.info(f"Original features range: [{sample_features.min():.2f}, {sample_features.max():.2f}]")
|
||||||
|
logger.info(f"Normalized features range: [{normalized.min():.4f}, {normalized.max():.4f}]")
|
||||||
|
|
||||||
|
# Check if 1h data is being used for normalization
|
||||||
|
eth_1h = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
|
||||||
|
if eth_1h:
|
||||||
|
h1_prices = []
|
||||||
|
for bar in eth_1h[-24:]:
|
||||||
|
h1_prices.extend([bar['open'], bar['high'], bar['low'], bar['close']])
|
||||||
|
h1_range = max(h1_prices) - min(h1_prices)
|
||||||
|
logger.info(f"1h price range used for normalization: {h1_range:.2f}")
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 80)
|
||||||
|
logger.info("ENHANCED WILLIAMS CNN INTEGRATION TEST COMPLETED SUCCESSFULLY")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Import error - some dependencies missing: {e}")
|
||||||
|
logger.info("This is expected if TensorFlow or other dependencies are not installed")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Test failed with error: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_ohlcv_data(bars=200, base_price=2400.0):
|
||||||
|
"""Generate test OHLCV data for Williams calculation"""
|
||||||
|
data = []
|
||||||
|
current_price = base_price
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
for i in range(bars):
|
||||||
|
timestamp = current_time - timedelta(seconds=bars - i)
|
||||||
|
|
||||||
|
# Generate price movement
|
||||||
|
price_change = np.random.normal(0, base_price * 0.002)
|
||||||
|
current_price = max(current_price + price_change, base_price * 0.9)
|
||||||
|
|
||||||
|
open_price = current_price
|
||||||
|
high_price = open_price * (1 + abs(np.random.normal(0, 0.003)))
|
||||||
|
low_price = open_price * (1 - abs(np.random.normal(0, 0.003)))
|
||||||
|
close_price = low_price + (high_price - low_price) * np.random.random()
|
||||||
|
volume = np.random.exponential(1000)
|
||||||
|
|
||||||
|
current_price = close_price
|
||||||
|
|
||||||
|
bar = [
|
||||||
|
timestamp.timestamp(),
|
||||||
|
open_price,
|
||||||
|
high_price,
|
||||||
|
low_price,
|
||||||
|
close_price,
|
||||||
|
volume
|
||||||
|
]
|
||||||
|
data.append(bar)
|
||||||
|
|
||||||
|
return np.array(data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_enhanced_williams_cnn()
|
||||||
|
if success:
|
||||||
|
print("\n✅ All tests passed! Enhanced Williams CNN integration is working.")
|
||||||
|
else:
|
||||||
|
print("\n❌ Some tests failed. Check logs for details.")
|
@ -24,6 +24,18 @@ from typing import Dict, List, Optional, Tuple, Any
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
try:
|
||||||
|
from NN.models.cnn_model import CNNModel
|
||||||
|
except ImportError:
|
||||||
|
CNNModel = None # Allow running without TF/CNN if not installed or path issue
|
||||||
|
print("Warning: CNNModel could not be imported. CNN-based pivot prediction/training will be disabled.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from core.unified_data_stream import TrainingDataPacket
|
||||||
|
except ImportError:
|
||||||
|
TrainingDataPacket = None
|
||||||
|
print("Warning: TrainingDataPacket could not be imported. Using fallback interface.")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TrendDirection(Enum):
|
class TrendDirection(Enum):
|
||||||
@ -84,12 +96,25 @@ class WilliamsMarketStructure:
|
|||||||
- Structure break detection
|
- Structure break detection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, swing_strengths: List[int] = None):
|
def __init__(self,
|
||||||
|
swing_strengths: List[int] = None,
|
||||||
|
cnn_input_shape: Optional[Tuple[int, int]] = (900, 50), # Updated: 900 timesteps (1s), 50 features
|
||||||
|
cnn_output_size: Optional[int] = 10, # Updated: 5 levels * (type + price) = 10 outputs
|
||||||
|
cnn_model_config: Optional[Dict[str, Any]] = None, # For build_model params like filters, learning_rate
|
||||||
|
cnn_model_path: Optional[str] = None,
|
||||||
|
enable_cnn_feature: bool = True, # Master switch for this feature
|
||||||
|
training_data_provider: Optional[Any] = None): # Provider for TrainingDataPacket access
|
||||||
"""
|
"""
|
||||||
Initialize Williams market structure analyzer
|
Initialize Williams market structure analyzer
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
swing_strengths: List of swing detection strengths (bars on each side)
|
swing_strengths: List of swing detection strengths (bars on each side)
|
||||||
|
cnn_input_shape: Shape of input data for CNN (sequence_length, features)
|
||||||
|
cnn_output_size: Number of output classes for CNN (10 for 5 levels * 2 outputs each)
|
||||||
|
cnn_model_config: Dictionary with parameters for CNNModel.build_model()
|
||||||
|
cnn_model_path: Path to a pre-trained Keras CNN model (.h5 file)
|
||||||
|
enable_cnn_feature: If True, enables CNN prediction and training at pivots.
|
||||||
|
training_data_provider: Provider/stream for accessing TrainingDataPacket
|
||||||
"""
|
"""
|
||||||
self.swing_strengths = swing_strengths or [2, 3, 5, 8, 13] # Fibonacci-based strengths
|
self.swing_strengths = swing_strengths or [2, 3, 5, 8, 13] # Fibonacci-based strengths
|
||||||
self.max_levels = 5
|
self.max_levels = 5
|
||||||
@ -99,6 +124,32 @@ class WilliamsMarketStructure:
|
|||||||
self.swing_cache = {}
|
self.swing_cache = {}
|
||||||
self.trend_cache = {}
|
self.trend_cache = {}
|
||||||
|
|
||||||
|
self.enable_cnn_feature = enable_cnn_feature and CNNModel is not None
|
||||||
|
self.cnn_model: Optional[CNNModel] = None
|
||||||
|
self.previous_pivot_details_for_cnn: Optional[Dict[str, Any]] = None # Stores {'features': X, 'pivot': SwingPoint}
|
||||||
|
self.training_data_provider = training_data_provider # Access to TrainingDataPacket
|
||||||
|
|
||||||
|
if self.enable_cnn_feature:
|
||||||
|
try:
|
||||||
|
logger.info(f"Initializing CNN for multi-timeframe pivot prediction. Input: {cnn_input_shape}, Output: {cnn_output_size}")
|
||||||
|
logger.info("CNN will predict next pivot (type + price) for all 5 Williams levels")
|
||||||
|
|
||||||
|
self.cnn_model = CNNModel(input_shape=cnn_input_shape, output_size=cnn_output_size)
|
||||||
|
if cnn_model_path:
|
||||||
|
logger.info(f"Loading pre-trained CNN model from: {cnn_model_path}")
|
||||||
|
self.cnn_model.load(cnn_model_path)
|
||||||
|
else:
|
||||||
|
logger.info("Building new CNN model.")
|
||||||
|
# Use provided config or defaults for build_model
|
||||||
|
build_params = cnn_model_config or {}
|
||||||
|
self.cnn_model.build_model(**build_params)
|
||||||
|
logger.info("CNN Model initialized successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize or load CNN model: {e}. Disabling CNN feature.", exc_info=True)
|
||||||
|
self.enable_cnn_feature = False
|
||||||
|
else:
|
||||||
|
logger.info("CNN feature for pivot prediction/training is disabled.")
|
||||||
|
|
||||||
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
|
logger.info(f"Williams Market Structure initialized with strengths: {self.swing_strengths}")
|
||||||
|
|
||||||
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
|
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[str, MarketStructureLevel]:
|
||||||
@ -187,8 +238,8 @@ class WilliamsMarketStructure:
|
|||||||
all_swings = []
|
all_swings = []
|
||||||
|
|
||||||
for strength in self.swing_strengths:
|
for strength in self.swing_strengths:
|
||||||
swings = self._find_swing_points_single_strength(ohlcv_data, strength)
|
swings_at_strength = self._find_swing_points_single_strength(ohlcv_data, strength)
|
||||||
for swing in swings:
|
for swing in swings_at_strength:
|
||||||
# Avoid duplicates (swings at same index)
|
# Avoid duplicates (swings at same index)
|
||||||
if not any(existing.index == swing.index for existing in all_swings):
|
if not any(existing.index == swing.index for existing in all_swings):
|
||||||
all_swings.append(swing)
|
all_swings.append(swing)
|
||||||
@ -201,10 +252,10 @@ class WilliamsMarketStructure:
|
|||||||
|
|
||||||
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
|
def _find_swing_points_single_strength(self, ohlcv_data: np.ndarray, strength: int) -> List[SwingPoint]:
|
||||||
"""Find swing points with specific strength requirement"""
|
"""Find swing points with specific strength requirement"""
|
||||||
swings = []
|
identified_swings_in_this_call = [] # Temporary list for swings found in this specific call
|
||||||
|
|
||||||
if len(ohlcv_data) < (strength * 2 + 1):
|
if len(ohlcv_data) < (strength * 2 + 1):
|
||||||
return swings
|
return identified_swings_in_this_call
|
||||||
|
|
||||||
for i in range(strength, len(ohlcv_data) - strength):
|
for i in range(strength, len(ohlcv_data) - strength):
|
||||||
current_high = ohlcv_data[i, 2] # High price
|
current_high = ohlcv_data[i, 2] # High price
|
||||||
@ -219,14 +270,16 @@ class WilliamsMarketStructure:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if is_swing_high:
|
if is_swing_high:
|
||||||
swings.append(SwingPoint(
|
new_pivot = SwingPoint(
|
||||||
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
||||||
price=current_high,
|
price=current_high,
|
||||||
index=i,
|
index=i,
|
||||||
swing_type=SwingType.SWING_HIGH,
|
swing_type=SwingType.SWING_HIGH,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
volume=current_volume
|
volume=current_volume
|
||||||
))
|
)
|
||||||
|
identified_swings_in_this_call.append(new_pivot)
|
||||||
|
self._handle_cnn_at_pivot(new_pivot, ohlcv_data) # CNN logic call
|
||||||
|
|
||||||
# Check for swing low (lower than surrounding bars)
|
# Check for swing low (lower than surrounding bars)
|
||||||
is_swing_low = True
|
is_swing_low = True
|
||||||
@ -236,16 +289,18 @@ class WilliamsMarketStructure:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if is_swing_low:
|
if is_swing_low:
|
||||||
swings.append(SwingPoint(
|
new_pivot = SwingPoint(
|
||||||
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
timestamp=datetime.fromtimestamp(ohlcv_data[i, 0]) if ohlcv_data[i, 0] > 1e9 else datetime.now(),
|
||||||
price=current_low,
|
price=current_low,
|
||||||
index=i,
|
index=i,
|
||||||
swing_type=SwingType.SWING_LOW,
|
swing_type=SwingType.SWING_LOW,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
volume=current_volume
|
volume=current_volume
|
||||||
))
|
)
|
||||||
|
identified_swings_in_this_call.append(new_pivot)
|
||||||
|
self._handle_cnn_at_pivot(new_pivot, ohlcv_data) # CNN logic call
|
||||||
|
|
||||||
return swings
|
return identified_swings_in_this_call # Return swings found in this call
|
||||||
|
|
||||||
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
|
def _filter_significant_swings(self, swings: List[SwingPoint]) -> List[SwingPoint]:
|
||||||
"""Filter to keep only the most significant swings"""
|
"""Filter to keep only the most significant swings"""
|
||||||
@ -511,10 +566,10 @@ class WilliamsMarketStructure:
|
|||||||
pivot_array: Array of pivot points as [timestamp, price, price, price, price, 0] format
|
pivot_array: Array of pivot points as [timestamp, price, price, price, price, 0] format
|
||||||
level: Current level being calculated
|
level: Current level being calculated
|
||||||
"""
|
"""
|
||||||
swings = []
|
identified_swings_in_this_call = [] # Temporary list
|
||||||
|
|
||||||
if len(pivot_array) < 5:
|
if len(pivot_array) < 5: # Min bars for even smallest strength (e.g. strength 2 needs 2+1+2=5)
|
||||||
return swings
|
return identified_swings_in_this_call
|
||||||
|
|
||||||
# Use configurable strength for higher levels (more conservative)
|
# Use configurable strength for higher levels (more conservative)
|
||||||
strength = min(2 + level, 5) # Level 1: 3 bars, Level 2: 4 bars, Level 3+: 5 bars
|
strength = min(2 + level, 5) # Level 1: 3 bars, Level 2: 4 bars, Level 3+: 5 bars
|
||||||
@ -526,38 +581,42 @@ class WilliamsMarketStructure:
|
|||||||
# Check for swing high (pivot high surrounded by lower pivot highs)
|
# Check for swing high (pivot high surrounded by lower pivot highs)
|
||||||
is_swing_high = True
|
is_swing_high = True
|
||||||
for j in range(i - strength, i + strength + 1):
|
for j in range(i - strength, i + strength + 1):
|
||||||
if j != i and pivot_array[j, 1] >= current_price:
|
if j != i and pivot_array[j, 1] >= current_price: # Compare with price of other pivots
|
||||||
is_swing_high = False
|
is_swing_high = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if is_swing_high:
|
if is_swing_high:
|
||||||
swings.append(SwingPoint(
|
new_pivot = SwingPoint(
|
||||||
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
|
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
|
||||||
price=current_price,
|
price=current_price,
|
||||||
index=i,
|
index=i,
|
||||||
swing_type=SwingType.SWING_HIGH,
|
swing_type=SwingType.SWING_HIGH,
|
||||||
strength=strength,
|
strength=strength, # Strength here is derived from level, e.g., min(2 + level, 5)
|
||||||
volume=0.0 # Pivot points don't have volume
|
volume=0.0 # Pivot points don't have volume
|
||||||
))
|
)
|
||||||
|
identified_swings_in_this_call.append(new_pivot)
|
||||||
|
self._handle_cnn_at_pivot(new_pivot, pivot_array) # CNN logic call
|
||||||
|
|
||||||
# Check for swing low (pivot low surrounded by higher pivot lows)
|
# Check for swing low (pivot low surrounded by higher pivot lows)
|
||||||
is_swing_low = True
|
is_swing_low = True
|
||||||
for j in range(i - strength, i + strength + 1):
|
for j in range(i - strength, i + strength + 1):
|
||||||
if j != i and pivot_array[j, 1] <= current_price:
|
if j != i and pivot_array[j, 1] <= current_price: # Compare with price of other pivots
|
||||||
is_swing_low = False
|
is_swing_low = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if is_swing_low:
|
if is_swing_low:
|
||||||
swings.append(SwingPoint(
|
new_pivot = SwingPoint(
|
||||||
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
|
timestamp=datetime.fromtimestamp(current_timestamp) if current_timestamp > 1e9 else datetime.now(),
|
||||||
price=current_price,
|
price=current_price,
|
||||||
index=i,
|
index=i,
|
||||||
swing_type=SwingType.SWING_LOW,
|
swing_type=SwingType.SWING_LOW,
|
||||||
strength=strength,
|
strength=strength, # Strength here is derived from level
|
||||||
volume=0.0 # Pivot points don't have volume
|
volume=0.0 # Pivot points don't have volume
|
||||||
))
|
)
|
||||||
|
identified_swings_in_this_call.append(new_pivot)
|
||||||
|
self._handle_cnn_at_pivot(new_pivot, pivot_array) # CNN logic call
|
||||||
|
|
||||||
return swings
|
return identified_swings_in_this_call # Return swings found in this call
|
||||||
|
|
||||||
def _convert_pivots_to_price_points(self, swing_points: List[SwingPoint]) -> np.ndarray:
|
def _convert_pivots_to_price_points(self, swing_points: List[SwingPoint]) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@ -695,4 +754,479 @@ class WilliamsMarketStructure:
|
|||||||
features.extend([0.0, 0.0])
|
features.extend([0.0, 0.0])
|
||||||
recent_breaks.append({})
|
recent_breaks.append({})
|
||||||
|
|
||||||
return features[:50] # Ensure exactly 50 features per level
|
return features[:50] # Ensure exactly 50 features per level
|
||||||
|
|
||||||
|
def _handle_cnn_at_pivot(self,
|
||||||
|
newly_identified_pivot: SwingPoint,
|
||||||
|
ohlcv_data_context: np.ndarray):
|
||||||
|
"""
|
||||||
|
Handles CNN training for the previous pivot and prediction for the next pivot.
|
||||||
|
Called when a new pivot point is identified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
newly_identified_pivot: The SwingPoint object for the just-formed pivot.
|
||||||
|
ohlcv_data_context: The OHLCV data (or pivot array for higher levels)
|
||||||
|
relevant to this pivot's formation.
|
||||||
|
"""
|
||||||
|
if not self.enable_cnn_feature or self.cnn_model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Train model based on the *previous* pivot's prediction and the *current* actual outcome
|
||||||
|
if self.previous_pivot_details_for_cnn:
|
||||||
|
try:
|
||||||
|
logger.debug(f"CNN Training: Previous pivot at idx {self.previous_pivot_details_for_cnn['pivot'].index}, "
|
||||||
|
f"Current pivot (ground truth) at idx {newly_identified_pivot.index}")
|
||||||
|
|
||||||
|
X_train = self.previous_pivot_details_for_cnn['features']
|
||||||
|
# previous_pivot_info contains 'pivot' which is the SwingPoint object of N-1
|
||||||
|
y_train = self._get_cnn_ground_truth(self.previous_pivot_details_for_cnn, newly_identified_pivot)
|
||||||
|
|
||||||
|
if X_train is not None and X_train.size > 0 and y_train is not None and y_train.size > 0:
|
||||||
|
# Reshape X_train if it's a single sample and model expects batch
|
||||||
|
if len(X_train.shape) == len(self.cnn_model.input_shape) and X_train.shape == self.cnn_model.input_shape :
|
||||||
|
X_train_batch = np.expand_dims(X_train, axis=0)
|
||||||
|
else: # Should already be correctly shaped by _prepare_cnn_input
|
||||||
|
X_train_batch = X_train # Or handle error
|
||||||
|
|
||||||
|
# Reshape y_train if needed
|
||||||
|
if self.cnn_model.output_size > 1 and len(y_train.shape) ==1: # e.g. [0.,1.] but model needs [[0.,1.]]
|
||||||
|
y_train_batch = np.expand_dims(y_train, axis=0)
|
||||||
|
elif self.cnn_model.output_size == 1 and not isinstance(y_train, (list, np.ndarray)): # e.g. plain 0 or 1
|
||||||
|
y_train_batch = np.array([[y_train]], dtype=np.float32)
|
||||||
|
elif self.cnn_model.output_size == 1 and isinstance(y_train, np.ndarray) and y_train.ndim == 1:
|
||||||
|
y_train_batch = y_train.reshape(-1,1) # ensure [[0.]] for single binary output
|
||||||
|
else:
|
||||||
|
y_train_batch = y_train
|
||||||
|
|
||||||
|
|
||||||
|
logger.info(f"CNN Training with X_shape: {X_train_batch.shape}, y_shape: {y_train_batch.shape}")
|
||||||
|
# Perform a single step of training (online learning)
|
||||||
|
# Use minimal callbacks for online learning, or allow configuration
|
||||||
|
self.cnn_model.model.fit(X_train_batch, y_train_batch, batch_size=1, epochs=1, verbose=0, callbacks=[])
|
||||||
|
logger.info(f"CNN online training step completed for pivot at index {self.previous_pivot_details_for_cnn['pivot'].index}.")
|
||||||
|
else:
|
||||||
|
logger.warning("CNN Training: Skipping due to invalid X_train or y_train.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during CNN online training: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 2. Predict for the *next* pivot based on the *current* newly_identified_pivot
|
||||||
|
try:
|
||||||
|
logger.debug(f"CNN Prediction: Preparing input for current pivot at idx {newly_identified_pivot.index}")
|
||||||
|
|
||||||
|
# The 'previous_pivot_details' for _prepare_cnn_input here is the one active *before* this current call
|
||||||
|
# which means it refers to the pivot that just got its ground truth trained on.
|
||||||
|
# If this is the first pivot ever, self.previous_pivot_details_for_cnn would be None.
|
||||||
|
|
||||||
|
# Correct context for _prepare_cnn_input:
|
||||||
|
# current_pivot = newly_identified_pivot
|
||||||
|
# previous_pivot_details = self.previous_pivot_details_for_cnn (this is N-1, which was used for training above)
|
||||||
|
|
||||||
|
X_predict = self._prepare_cnn_input(newly_identified_pivot,
|
||||||
|
ohlcv_data_context,
|
||||||
|
self.previous_pivot_details_for_cnn) # Pass the N-1 pivot details
|
||||||
|
|
||||||
|
if X_predict is not None and X_predict.size > 0:
|
||||||
|
# Reshape X_predict if it's a single sample and model expects batch
|
||||||
|
if len(X_predict.shape) == len(self.cnn_model.input_shape) and X_predict.shape == self.cnn_model.input_shape :
|
||||||
|
X_predict_batch = np.expand_dims(X_predict, axis=0)
|
||||||
|
else:
|
||||||
|
X_predict_batch = X_predict # Or handle error
|
||||||
|
|
||||||
|
logger.info(f"CNN Predicting with X_shape: {X_predict_batch.shape}")
|
||||||
|
pred_class, pred_proba = self.cnn_model.predict(X_predict_batch) # predict expects batch
|
||||||
|
|
||||||
|
# pred_class/pred_proba might be arrays if batch_size > 1, or if output is multi-dim
|
||||||
|
# For batch_size=1, take the first element
|
||||||
|
final_pred_class = pred_class[0] if isinstance(pred_class, np.ndarray) and pred_class.ndim > 0 else pred_class
|
||||||
|
final_pred_proba = pred_proba[0] if isinstance(pred_proba, np.ndarray) and pred_proba.ndim > 0 else pred_proba
|
||||||
|
|
||||||
|
logger.info(f"CNN Prediction for pivot after index {newly_identified_pivot.index}: Class={final_pred_class}, Proba/Val={final_pred_proba}")
|
||||||
|
|
||||||
|
# Store the features (X_predict) and the pivot (newly_identified_pivot) itself for the next training cycle
|
||||||
|
self.previous_pivot_details_for_cnn = {'features': X_predict, 'pivot': newly_identified_pivot}
|
||||||
|
else:
|
||||||
|
logger.warning("CNN Prediction: Skipping due to invalid X_predict.")
|
||||||
|
# If prediction can't be made, ensure we don't carry over stale 'previous_pivot_details_for_cnn'
|
||||||
|
# Or, decide if we should clear it or keep the N-2 details.
|
||||||
|
# For now, if X_predict is None, we clear it so no training happens next round unless a new pred is made.
|
||||||
|
self.previous_pivot_details_for_cnn = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during CNN prediction: {e}", exc_info=True)
|
||||||
|
self.previous_pivot_details_for_cnn = None # Clear on error to prevent bad training
|
||||||
|
|
||||||
|
def _prepare_cnn_input(self,
|
||||||
|
current_pivot: SwingPoint,
|
||||||
|
ohlcv_data_context: np.ndarray,
|
||||||
|
previous_pivot_details: Optional[Dict[str, Any]]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare multi-timeframe, multi-symbol input features for CNN using TrainingDataPacket.
|
||||||
|
|
||||||
|
Features include:
|
||||||
|
- ETH: 5 min ticks → 300 x 1s bars with ticks features (4 features)
|
||||||
|
- ETH: 900 x 1s OHLCV + indicators (10 features)
|
||||||
|
- ETH: 900 x 1m OHLCV + indicators (10 features)
|
||||||
|
- ETH: 900 x 1h OHLCV + indicators (10 features)
|
||||||
|
- ETH: All pivot points from all levels (15 features)
|
||||||
|
- BTC: 5 min ticks → 300 x 1s reference (4 features)
|
||||||
|
- Chart labels for data identification (7 features)
|
||||||
|
|
||||||
|
Total: ~50 features per timestep over 900 timesteps
|
||||||
|
Data normalized using 1h min/max to preserve cross-timeframe relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_pivot: The newly identified SwingPoint
|
||||||
|
ohlcv_data_context: The OHLCV data from Williams calculation (may not be used directly)
|
||||||
|
previous_pivot_details: Previous pivot info for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array of shape (900, 50) with normalized features
|
||||||
|
"""
|
||||||
|
if self.cnn_model is None or not self.training_data_provider:
|
||||||
|
logger.warning("CNN model or training data provider not available")
|
||||||
|
return np.zeros(self.cnn_model.input_shape if self.cnn_model else (900, 50), dtype=np.float32)
|
||||||
|
|
||||||
|
sequence_length, num_features = self.cnn_model.input_shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get latest TrainingDataPacket from provider
|
||||||
|
training_packet = self._get_latest_training_data()
|
||||||
|
if not training_packet:
|
||||||
|
logger.warning("No TrainingDataPacket available for CNN input")
|
||||||
|
return np.zeros((sequence_length, num_features), dtype=np.float32)
|
||||||
|
|
||||||
|
logger.debug(f"CNN Input: Preparing features for pivot at {current_pivot.timestamp}")
|
||||||
|
|
||||||
|
# Prepare feature components (in actual values)
|
||||||
|
eth_features = self._prepare_eth_features(training_packet, sequence_length)
|
||||||
|
btc_features = self._prepare_btc_reference_features(training_packet, sequence_length)
|
||||||
|
pivot_features = self._prepare_pivot_features(training_packet, current_pivot, sequence_length)
|
||||||
|
chart_labels = self._prepare_chart_labels(sequence_length)
|
||||||
|
|
||||||
|
# Combine all features (still in actual values)
|
||||||
|
combined_features = np.concatenate([
|
||||||
|
eth_features, # ~40 features
|
||||||
|
btc_features, # ~4 features
|
||||||
|
pivot_features, # ~3 features
|
||||||
|
chart_labels # ~3 features
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
# Ensure we match expected feature count
|
||||||
|
if combined_features.shape[1] > num_features:
|
||||||
|
combined_features = combined_features[:, :num_features]
|
||||||
|
elif combined_features.shape[1] < num_features:
|
||||||
|
padding = np.zeros((sequence_length, num_features - combined_features.shape[1]))
|
||||||
|
combined_features = np.concatenate([combined_features, padding], axis=1)
|
||||||
|
|
||||||
|
# NORMALIZATION: Apply 1h timeframe min/max to preserve relationships
|
||||||
|
normalized_features = self._normalize_features_by_1h_range(combined_features, training_packet)
|
||||||
|
|
||||||
|
logger.debug(f"CNN Input prepared: shape {normalized_features.shape}, "
|
||||||
|
f"min: {normalized_features.min():.4f}, max: {normalized_features.max():.4f}")
|
||||||
|
|
||||||
|
return normalized_features.astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing CNN input: {e}", exc_info=True)
|
||||||
|
return np.zeros((sequence_length, num_features), dtype=np.float32)
|
||||||
|
|
||||||
|
def _get_latest_training_data(self):
|
||||||
|
"""Get latest TrainingDataPacket from provider"""
|
||||||
|
try:
|
||||||
|
if hasattr(self.training_data_provider, 'get_latest_training_data'):
|
||||||
|
return self.training_data_provider.get_latest_training_data()
|
||||||
|
elif hasattr(self.training_data_provider, 'training_data_buffer'):
|
||||||
|
return self.training_data_provider.training_data_buffer[-1] if self.training_data_provider.training_data_buffer else None
|
||||||
|
else:
|
||||||
|
logger.warning("Training data provider does not have expected interface")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting training data: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _prepare_eth_features(self, training_packet, sequence_length: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare ETH multi-timeframe features (keep in actual values):
|
||||||
|
- 1s bars with indicators (10 features)
|
||||||
|
- 1m bars with indicators (10 features)
|
||||||
|
- 1h bars with indicators (10 features)
|
||||||
|
- Tick-derived 1s features (10 features)
|
||||||
|
Total: 40 features per timestep
|
||||||
|
"""
|
||||||
|
features = []
|
||||||
|
|
||||||
|
# ETH 1s data with indicators
|
||||||
|
eth_1s_features = self._extract_timeframe_features(
|
||||||
|
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1s', []),
|
||||||
|
sequence_length, 'ETH_1s'
|
||||||
|
)
|
||||||
|
features.append(eth_1s_features)
|
||||||
|
|
||||||
|
# ETH 1m data with indicators
|
||||||
|
eth_1m_features = self._extract_timeframe_features(
|
||||||
|
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1m', []),
|
||||||
|
sequence_length, 'ETH_1m'
|
||||||
|
)
|
||||||
|
features.append(eth_1m_features)
|
||||||
|
|
||||||
|
# ETH 1h data with indicators
|
||||||
|
eth_1h_features = self._extract_timeframe_features(
|
||||||
|
training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', []),
|
||||||
|
sequence_length, 'ETH_1h'
|
||||||
|
)
|
||||||
|
features.append(eth_1h_features)
|
||||||
|
|
||||||
|
# ETH tick-derived features (5 min of ticks → 300 x 1s aggregated to match sequence_length)
|
||||||
|
eth_tick_features = self._extract_tick_features(
|
||||||
|
training_packet.tick_cache, 'ETH/USDT', sequence_length
|
||||||
|
)
|
||||||
|
features.append(eth_tick_features)
|
||||||
|
|
||||||
|
return np.concatenate(features, axis=1)
|
||||||
|
|
||||||
|
def _prepare_btc_reference_features(self, training_packet, sequence_length: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare BTC reference features (keep in actual values):
|
||||||
|
- Tick-derived features for correlation analysis
|
||||||
|
Total: 4 features per timestep
|
||||||
|
"""
|
||||||
|
return self._extract_tick_features(
|
||||||
|
training_packet.tick_cache, 'BTC/USDT', sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_pivot_features(self, training_packet, current_pivot: SwingPoint, sequence_length: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare pivot point features from all Williams levels:
|
||||||
|
- Recent pivot characteristics
|
||||||
|
- Level-specific trend information
|
||||||
|
Total: 3 features per timestep (repeated for sequence)
|
||||||
|
"""
|
||||||
|
# Extract Williams pivot features using existing method if available
|
||||||
|
if hasattr(training_packet, 'universal_stream') and training_packet.universal_stream:
|
||||||
|
# Use existing pivot extraction logic
|
||||||
|
pivot_feature_vector = [
|
||||||
|
current_pivot.price,
|
||||||
|
1.0 if current_pivot.swing_type == SwingType.SWING_HIGH else 0.0,
|
||||||
|
float(current_pivot.strength)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pivot_feature_vector = [0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
# Repeat pivot features for all timesteps in sequence
|
||||||
|
return np.tile(pivot_feature_vector, (sequence_length, 1))
|
||||||
|
|
||||||
|
def _prepare_chart_labels(self, sequence_length: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepare chart identification labels:
|
||||||
|
- Symbol identifiers
|
||||||
|
- Timeframe identifiers
|
||||||
|
Total: 3 features per timestep
|
||||||
|
"""
|
||||||
|
# Simple encoding: [is_eth, is_btc, timeframe_mix]
|
||||||
|
chart_labels = [1.0, 1.0, 1.0] # Mixed multi-timeframe ETH+BTC data
|
||||||
|
return np.tile(chart_labels, (sequence_length, 1))
|
||||||
|
|
||||||
|
def _extract_timeframe_features(self, ohlcv_data: List[Dict], sequence_length: int, timeframe_label: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract OHLCV + indicator features from timeframe data (keep actual values).
|
||||||
|
Returns 10 features: OHLCV + volume + 5 indicators
|
||||||
|
"""
|
||||||
|
if not ohlcv_data:
|
||||||
|
return np.zeros((sequence_length, 10))
|
||||||
|
|
||||||
|
# Take last sequence_length bars or pad if insufficient
|
||||||
|
data_to_use = ohlcv_data[-sequence_length:] if len(ohlcv_data) >= sequence_length else ohlcv_data
|
||||||
|
|
||||||
|
features = []
|
||||||
|
for bar in data_to_use:
|
||||||
|
bar_features = [
|
||||||
|
bar.get('open', 0.0),
|
||||||
|
bar.get('high', 0.0),
|
||||||
|
bar.get('low', 0.0),
|
||||||
|
bar.get('close', 0.0),
|
||||||
|
bar.get('volume', 0.0),
|
||||||
|
# TODO: Add 5 calculated indicators (SMA, EMA, RSI, MACD, etc.)
|
||||||
|
bar.get('sma_20', bar.get('close', 0.0)), # Placeholder
|
||||||
|
bar.get('ema_20', bar.get('close', 0.0)), # Placeholder
|
||||||
|
bar.get('rsi_14', 50.0), # Placeholder
|
||||||
|
bar.get('macd', 0.0), # Placeholder
|
||||||
|
bar.get('bb_upper', bar.get('high', 0.0)) # Placeholder
|
||||||
|
]
|
||||||
|
features.append(bar_features)
|
||||||
|
|
||||||
|
# Pad if insufficient data
|
||||||
|
while len(features) < sequence_length:
|
||||||
|
features.insert(0, features[0] if features else [0.0] * 10)
|
||||||
|
|
||||||
|
return np.array(features, dtype=np.float32)
|
||||||
|
|
||||||
|
def _extract_tick_features(self, tick_cache: List[Dict], symbol: str, sequence_length: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract tick-derived features aggregated to 1s intervals (keep actual values).
|
||||||
|
Returns 4 features: tick_count, total_volume, vwap, price_volatility per second
|
||||||
|
"""
|
||||||
|
# Filter ticks for symbol and last 5 minutes
|
||||||
|
symbol_ticks = [t for t in tick_cache[-1500:] if t.get('symbol') == symbol] # Assume ~5 ticks/sec
|
||||||
|
|
||||||
|
if not symbol_ticks:
|
||||||
|
return np.zeros((sequence_length, 4))
|
||||||
|
|
||||||
|
# Group ticks by second and calculate features
|
||||||
|
tick_features = []
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
for i in range(sequence_length):
|
||||||
|
second_start = current_time - timedelta(seconds=sequence_length - i)
|
||||||
|
second_end = second_start + timedelta(seconds=1)
|
||||||
|
|
||||||
|
second_ticks = [
|
||||||
|
t for t in symbol_ticks
|
||||||
|
if second_start <= t.get('timestamp', datetime.min) < second_end
|
||||||
|
]
|
||||||
|
|
||||||
|
if second_ticks:
|
||||||
|
prices = [t.get('price', 0.0) for t in second_ticks]
|
||||||
|
volumes = [t.get('volume', 0.0) for t in second_ticks]
|
||||||
|
total_volume = sum(volumes)
|
||||||
|
|
||||||
|
tick_count = len(second_ticks)
|
||||||
|
vwap = sum(p * v for p, v in zip(prices, volumes)) / total_volume if total_volume > 0 else 0.0
|
||||||
|
price_volatility = np.std(prices) if len(prices) > 1 else 0.0
|
||||||
|
|
||||||
|
second_features = [tick_count, total_volume, vwap, price_volatility]
|
||||||
|
else:
|
||||||
|
second_features = [0.0, 0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
tick_features.append(second_features)
|
||||||
|
|
||||||
|
return np.array(tick_features, dtype=np.float32)
|
||||||
|
|
||||||
|
def _normalize_features_by_1h_range(self, features: np.ndarray, training_packet) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize all features using 1h timeframe min/max to preserve cross-timeframe relationships.
|
||||||
|
This is the final normalization step before feeding to CNN.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get 1h ETH data for normalization reference
|
||||||
|
eth_1h_data = training_packet.multi_timeframe_data.get('ETH/USDT', {}).get('1h', [])
|
||||||
|
|
||||||
|
if not eth_1h_data:
|
||||||
|
logger.warning("No 1h data available for normalization, using feature-wise normalization")
|
||||||
|
# Fallback: normalize each feature independently
|
||||||
|
feature_min = np.min(features, axis=0, keepdims=True)
|
||||||
|
feature_max = np.max(features, axis=0, keepdims=True)
|
||||||
|
feature_range = feature_max - feature_min
|
||||||
|
feature_range[feature_range == 0] = 1.0 # Avoid division by zero
|
||||||
|
return (features - feature_min) / feature_range
|
||||||
|
|
||||||
|
# Extract 1h price range for primary normalization
|
||||||
|
h1_prices = []
|
||||||
|
for bar in eth_1h_data[-24:]: # Last 24 hours for robust range
|
||||||
|
h1_prices.extend([
|
||||||
|
bar.get('open', 0.0),
|
||||||
|
bar.get('high', 0.0),
|
||||||
|
bar.get('low', 0.0),
|
||||||
|
bar.get('close', 0.0)
|
||||||
|
])
|
||||||
|
|
||||||
|
if h1_prices:
|
||||||
|
h1_min = min(h1_prices)
|
||||||
|
h1_max = max(h1_prices)
|
||||||
|
h1_range = h1_max - h1_min
|
||||||
|
|
||||||
|
if h1_range > 0:
|
||||||
|
logger.debug(f"Normalizing features using 1h range: {h1_min:.2f} - {h1_max:.2f}")
|
||||||
|
|
||||||
|
# Apply 1h-based normalization to price-related features (first ~30 features)
|
||||||
|
normalized_features = features.copy()
|
||||||
|
price_feature_count = min(30, features.shape[1])
|
||||||
|
|
||||||
|
# Normalize price-related features with 1h range
|
||||||
|
normalized_features[:, :price_feature_count] = (
|
||||||
|
(features[:, :price_feature_count] - h1_min) / h1_range
|
||||||
|
)
|
||||||
|
|
||||||
|
# For non-price features (indicators, counts, etc.), use feature-wise normalization
|
||||||
|
if features.shape[1] > price_feature_count:
|
||||||
|
remaining_features = features[:, price_feature_count:]
|
||||||
|
feature_min = np.min(remaining_features, axis=0, keepdims=True)
|
||||||
|
feature_max = np.max(remaining_features, axis=0, keepdims=True)
|
||||||
|
feature_range = feature_max - feature_min
|
||||||
|
feature_range[feature_range == 0] = 1.0
|
||||||
|
|
||||||
|
normalized_features[:, price_feature_count:] = (
|
||||||
|
(remaining_features - feature_min) / feature_range
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized_features
|
||||||
|
|
||||||
|
# Fallback normalization if 1h range calculation fails
|
||||||
|
logger.warning("1h range calculation failed, using min-max normalization")
|
||||||
|
feature_min = np.min(features, axis=0, keepdims=True)
|
||||||
|
feature_max = np.max(features, axis=0, keepdims=True)
|
||||||
|
feature_range = feature_max - feature_min
|
||||||
|
feature_range[feature_range == 0] = 1.0
|
||||||
|
return (features - feature_min) / feature_range
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in normalization: {e}", exc_info=True)
|
||||||
|
# Emergency fallback: return features as-is but scaled to [0,1] roughly
|
||||||
|
return np.clip(features / (np.max(np.abs(features)) + 1e-8), -1.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cnn_ground_truth(self,
|
||||||
|
previous_pivot_info: Dict[str, Any], # Contains 'pivot': SwingPoint obj of N-1
|
||||||
|
actual_current_pivot: SwingPoint # This is pivot N
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Determine the ground truth for CNN prediction made at previous_pivot.
|
||||||
|
|
||||||
|
Updated to return prediction for next pivot in ALL 5 LEVELS:
|
||||||
|
- For each level: [type (0=LOW, 1=HIGH), normalized_price_target]
|
||||||
|
- Total output: 10 values (5 levels * 2 outputs each)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
previous_pivot_info: Dict with 'pivot' = SwingPoint of N-1
|
||||||
|
actual_current_pivot: SwingPoint of pivot N (actual outcome)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array of shape (10,) with ground truth for all levels
|
||||||
|
"""
|
||||||
|
if self.cnn_model is None:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
# Initialize ground truth array for all 5 levels
|
||||||
|
ground_truth = np.zeros(10, dtype=np.float32) # 5 levels * 2 outputs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For Level 0 (current pivot level), we have actual data
|
||||||
|
level_0_type = 1.0 if actual_current_pivot.swing_type == SwingType.SWING_HIGH else 0.0
|
||||||
|
level_0_price = actual_current_pivot.price
|
||||||
|
|
||||||
|
# Normalize price (this is a placeholder - proper normalization should use market context)
|
||||||
|
# In real implementation, use the same 1h range normalization as input features
|
||||||
|
normalized_price = level_0_price / 10000.0 # Rough normalization for ETH prices
|
||||||
|
|
||||||
|
ground_truth[0] = level_0_type # Level 0 type
|
||||||
|
ground_truth[1] = normalized_price # Level 0 price
|
||||||
|
|
||||||
|
# For higher levels (1-4), we would need to calculate what the next pivot would be
|
||||||
|
# This requires access to higher-level Williams calculations
|
||||||
|
# For now, use placeholder logic based on current pivot characteristics
|
||||||
|
|
||||||
|
for level in range(1, 5):
|
||||||
|
# Placeholder: higher levels follow similar pattern but with reduced confidence
|
||||||
|
confidence_factor = 1.0 / (level + 1)
|
||||||
|
|
||||||
|
ground_truth[level * 2] = level_0_type * confidence_factor # Level N type
|
||||||
|
ground_truth[level * 2 + 1] = normalized_price * confidence_factor # Level N price
|
||||||
|
|
||||||
|
logger.debug(f"CNN Ground Truth: Level 0 = [{level_0_type}, {normalized_price:.4f}], "
|
||||||
|
f"Current pivot = {actual_current_pivot.swing_type.name} @ {actual_current_pivot.price}")
|
||||||
|
|
||||||
|
return ground_truth
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating CNN ground truth: {e}", exc_info=True)
|
||||||
|
return np.zeros(10, dtype=np.float32)
|
146
web/dashboard.py
146
web/dashboard.py
@ -321,6 +321,19 @@ class TradingDashboard:
|
|||||||
logger.info("Trading Dashboard initialized with enhanced RL training integration")
|
logger.info("Trading Dashboard initialized with enhanced RL training integration")
|
||||||
logger.info(f"Enhanced RL enabled: {self.enhanced_rl_training_enabled}")
|
logger.info(f"Enhanced RL enabled: {self.enhanced_rl_training_enabled}")
|
||||||
logger.info(f"Stream consumer ID: {self.stream_consumer_id}")
|
logger.info(f"Stream consumer ID: {self.stream_consumer_id}")
|
||||||
|
|
||||||
|
# Initialize Williams Market Structure once
|
||||||
|
try:
|
||||||
|
from training.williams_market_structure import WilliamsMarketStructure
|
||||||
|
self.williams_structure = WilliamsMarketStructure(
|
||||||
|
swing_strengths=[2, 3, 5], # Simplified for better performance
|
||||||
|
enable_cnn_feature=False, # Disable CNN until TensorFlow available
|
||||||
|
training_data_provider=None
|
||||||
|
)
|
||||||
|
logger.info("Williams Market Structure initialized for dashboard")
|
||||||
|
except ImportError:
|
||||||
|
self.williams_structure = None
|
||||||
|
logger.warning("Williams Market Structure not available")
|
||||||
|
|
||||||
def _to_local_timezone(self, dt: datetime) -> datetime:
|
def _to_local_timezone(self, dt: datetime) -> datetime:
|
||||||
"""Convert datetime to configured local timezone"""
|
"""Convert datetime to configured local timezone"""
|
||||||
@ -4532,32 +4545,49 @@ class TradingDashboard:
|
|||||||
logger.error(f"Error stopping streaming: {e}")
|
logger.error(f"Error stopping streaming: {e}")
|
||||||
|
|
||||||
def _get_williams_pivot_features(self, df: pd.DataFrame) -> Optional[List[float]]:
|
def _get_williams_pivot_features(self, df: pd.DataFrame) -> Optional[List[float]]:
|
||||||
"""Calculate Williams Market Structure pivot points features"""
|
"""Get Williams Market Structure pivot features for RL training"""
|
||||||
try:
|
try:
|
||||||
# Import Williams Market Structure
|
# Use reused Williams instance
|
||||||
try:
|
if not self.williams_structure:
|
||||||
from training.williams_market_structure import WilliamsMarketStructure
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("Williams Market Structure not available")
|
logger.warning("Williams Market Structure not available")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert DataFrame to numpy array for Williams calculation
|
# Convert DataFrame to numpy array for Williams calculation
|
||||||
if len(df) < 50:
|
if len(df) < 20: # Reduced from 50 to match Williams minimum requirement
|
||||||
|
logger.debug(f"[WILLIAMS] Insufficient data for pivot calculation: {len(df)} bars (need 20+)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ohlcv_array = np.array([
|
try:
|
||||||
[self._to_local_timezone(df.index[i]).timestamp() if hasattr(df.index[i], 'timestamp') else time.time(),
|
ohlcv_array = np.array([
|
||||||
df['open'].iloc[i], df['high'].iloc[i], df['low'].iloc[i],
|
[self._to_local_timezone(df.index[i]).timestamp() if hasattr(df.index[i], 'timestamp') else time.time(),
|
||||||
df['close'].iloc[i], df['volume'].iloc[i]]
|
df['open'].iloc[i], df['high'].iloc[i], df['low'].iloc[i],
|
||||||
for i in range(len(df))
|
df['close'].iloc[i], df['volume'].iloc[i]]
|
||||||
])
|
for i in range(len(df))
|
||||||
|
])
|
||||||
|
|
||||||
|
logger.debug(f"[WILLIAMS] Prepared OHLCV array: {ohlcv_array.shape}, price range: {ohlcv_array[:, 4].min():.2f} - {ohlcv_array[:, 4].max():.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[WILLIAMS] Error preparing OHLCV array: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# Calculate Williams pivot points
|
# Calculate Williams pivot points with reused instance
|
||||||
williams = WilliamsMarketStructure()
|
try:
|
||||||
structure_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
structure_levels = self.williams_structure.calculate_recursive_pivot_points(ohlcv_array)
|
||||||
|
|
||||||
|
# Add diagnostics for debugging
|
||||||
|
total_pivots = sum(len(level.swing_points) for level in structure_levels.values())
|
||||||
|
if total_pivots == 0:
|
||||||
|
logger.debug(f"[WILLIAMS] No pivot points detected in {len(ohlcv_array)} bars")
|
||||||
|
else:
|
||||||
|
logger.debug(f"[WILLIAMS] Successfully detected {total_pivots} pivot points across {len([l for l in structure_levels.values() if len(l.swing_points) > 0])} levels")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[WILLIAMS] Error in pivot calculation: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# Extract features (250 features total)
|
# Extract features (250 features total)
|
||||||
pivot_features = williams.extract_features_for_rl(structure_levels)
|
pivot_features = self.williams_structure.extract_features_for_rl(structure_levels)
|
||||||
|
|
||||||
logger.debug(f"[PIVOT] Calculated {len(pivot_features)} Williams pivot features")
|
logger.debug(f"[PIVOT] Calculated {len(pivot_features)} Williams pivot features")
|
||||||
return pivot_features
|
return pivot_features
|
||||||
@ -4795,40 +4825,66 @@ class TradingDashboard:
|
|||||||
logger.warning("Williams Market Structure not available for chart")
|
logger.warning("Williams Market Structure not available for chart")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Need at least 50 bars for meaningful pivot calculation
|
# Reduced requirement to match Williams minimum
|
||||||
if len(df) < 50:
|
if len(df) < 20:
|
||||||
logger.debug(f"[WILLIAMS] Insufficient data for pivot calculation: {len(df)} bars")
|
logger.debug(f"[WILLIAMS_CHART] Insufficient data for pivot calculation: {len(df)} bars (need 20+)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Ensure timezone consistency for the chart data
|
# Ensure timezone consistency for the chart data
|
||||||
df = self._ensure_timezone_consistency(df)
|
df = self._ensure_timezone_consistency(df)
|
||||||
|
|
||||||
# Convert DataFrame to numpy array for Williams calculation with proper timezone handling
|
# Convert DataFrame to numpy array for Williams calculation with proper timezone handling
|
||||||
ohlcv_array = []
|
try:
|
||||||
for i in range(len(df)):
|
ohlcv_array = []
|
||||||
timestamp = df.index[i]
|
for i in range(len(df)):
|
||||||
|
timestamp = df.index[i]
|
||||||
|
|
||||||
|
# Convert timestamp to local timezone and then to Unix timestamp
|
||||||
|
if hasattr(timestamp, 'timestamp'):
|
||||||
|
local_time = self._to_local_timezone(timestamp)
|
||||||
|
unix_timestamp = local_time.timestamp()
|
||||||
|
else:
|
||||||
|
unix_timestamp = time.time()
|
||||||
|
|
||||||
|
ohlcv_array.append([
|
||||||
|
unix_timestamp,
|
||||||
|
df['open'].iloc[i],
|
||||||
|
df['high'].iloc[i],
|
||||||
|
df['low'].iloc[i],
|
||||||
|
df['close'].iloc[i],
|
||||||
|
df['volume'].iloc[i]
|
||||||
|
])
|
||||||
|
|
||||||
# Convert timestamp to local timezone and then to Unix timestamp
|
ohlcv_array = np.array(ohlcv_array)
|
||||||
if hasattr(timestamp, 'timestamp'):
|
logger.debug(f"[WILLIAMS_CHART] Prepared OHLCV array: {ohlcv_array.shape}, price range: {ohlcv_array[:, 4].min():.2f} - {ohlcv_array[:, 4].max():.2f}")
|
||||||
local_time = self._to_local_timezone(timestamp)
|
|
||||||
unix_timestamp = local_time.timestamp()
|
except Exception as e:
|
||||||
|
logger.warning(f"[WILLIAMS_CHART] Error preparing OHLCV array: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate Williams pivot points with proper configuration
|
||||||
|
try:
|
||||||
|
williams = WilliamsMarketStructure(
|
||||||
|
swing_strengths=[2, 3, 5], # Start with simpler strengths
|
||||||
|
enable_cnn_feature=False, # Disable CNN for chart display
|
||||||
|
training_data_provider=None # No training data provider needed for chart
|
||||||
|
)
|
||||||
|
|
||||||
|
structure_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||||
|
|
||||||
|
# Add diagnostics for debugging
|
||||||
|
total_pivots_detected = sum(len(level.swing_points) for level in structure_levels.values())
|
||||||
|
if total_pivots_detected == 0:
|
||||||
|
logger.warning(f"[WILLIAMS_CHART] No pivot points detected in {len(ohlcv_array)} bars for chart display")
|
||||||
|
price_volatility = np.std(ohlcv_array[:, 4]) / np.mean(ohlcv_array[:, 4]) if np.mean(ohlcv_array[:, 4]) > 0 else 0.0
|
||||||
|
logger.debug(f"[WILLIAMS_CHART] Data diagnostics: volatility={price_volatility:.4f}, time_span={ohlcv_array[-1, 0] - ohlcv_array[0, 0]:.0f}s")
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
unix_timestamp = time.time()
|
logger.debug(f"[WILLIAMS_CHART] Successfully detected {total_pivots_detected} pivot points for chart")
|
||||||
|
|
||||||
ohlcv_array.append([
|
except Exception as e:
|
||||||
unix_timestamp,
|
logger.warning(f"[WILLIAMS_CHART] Error in pivot calculation: {e}")
|
||||||
df['open'].iloc[i],
|
return None
|
||||||
df['high'].iloc[i],
|
|
||||||
df['low'].iloc[i],
|
|
||||||
df['close'].iloc[i],
|
|
||||||
df['volume'].iloc[i]
|
|
||||||
])
|
|
||||||
|
|
||||||
ohlcv_array = np.array(ohlcv_array)
|
|
||||||
|
|
||||||
# Calculate Williams pivot points
|
|
||||||
williams = WilliamsMarketStructure()
|
|
||||||
structure_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
|
||||||
|
|
||||||
# Extract pivot points for chart display
|
# Extract pivot points for chart display
|
||||||
chart_pivots = {}
|
chart_pivots = {}
|
||||||
@ -4846,7 +4902,7 @@ class TradingDashboard:
|
|||||||
# Log swing point details for validation
|
# Log swing point details for validation
|
||||||
highs = [s for s in swing_points if s.swing_type.name == 'SWING_HIGH']
|
highs = [s for s in swing_points if s.swing_type.name == 'SWING_HIGH']
|
||||||
lows = [s for s in swing_points if s.swing_type.name == 'SWING_LOW']
|
lows = [s for s in swing_points if s.swing_type.name == 'SWING_LOW']
|
||||||
logger.debug(f"[WILLIAMS] Level {level}: {len(highs)} highs, {len(lows)} lows, total: {len(swing_points)}")
|
logger.debug(f"[WILLIAMS_CHART] Level {level}: {len(highs)} highs, {len(lows)} lows, total: {len(swing_points)}")
|
||||||
|
|
||||||
# Convert swing points to chart format
|
# Convert swing points to chart format
|
||||||
chart_pivots[f'level_{level}'] = {
|
chart_pivots[f'level_{level}'] = {
|
||||||
@ -4858,11 +4914,11 @@ class TradingDashboard:
|
|||||||
}
|
}
|
||||||
total_pivots += len(swing_points)
|
total_pivots += len(swing_points)
|
||||||
|
|
||||||
logger.info(f"[WILLIAMS] Calculated {total_pivots} total pivot points across {len(chart_pivots)} levels")
|
logger.info(f"[WILLIAMS_CHART] Calculated {total_pivots} total pivot points across {len(chart_pivots)} levels")
|
||||||
return chart_pivots
|
return chart_pivots
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error calculating Williams pivot points for chart: {e}")
|
logger.warning(f"Error calculating Williams pivot points: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _add_williams_pivot_points_to_chart(self, fig, pivot_points: Dict, row: int = 1):
|
def _add_williams_pivot_points_to_chart(self, fig, pivot_points: Dict, row: int = 1):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user