more cleanup
This commit is contained in:
@@ -1,349 +0,0 @@
|
||||
# Enhanced Reward System for Reinforcement Learning Training
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of an enhanced reward system for your reinforcement learning trading models. The system uses **mean squared error (MSE) between predictions and empirical outcomes** as the primary reward mechanism, with support for multiple timeframes and comprehensive accuracy tracking.
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ MSE-Based Reward Calculation
|
||||
- Uses mean squared difference between predicted and actual prices
|
||||
- Exponential decay function heavily penalizes large prediction errors
|
||||
- Direction accuracy bonus/penalty system
|
||||
- Confidence-weighted final rewards
|
||||
|
||||
### ✅ Multi-Timeframe Support
|
||||
- Separate tracking for **1s, 1m, 1h, 1d** timeframes
|
||||
- Independent accuracy metrics for each timeframe
|
||||
- Timeframe-specific evaluation timeouts
|
||||
- Models know which timeframe they're predicting on
|
||||
|
||||
### ✅ Prediction History Tracking
|
||||
- Maintains last **6 predictions per timeframe** per symbol
|
||||
- Comprehensive prediction records with outcomes
|
||||
- Historical accuracy analysis
|
||||
- Memory-efficient with automatic cleanup
|
||||
|
||||
### ✅ Real-Time Training
|
||||
- Training triggered at each inference when outcomes are available
|
||||
- Separate training batches for each model and timeframe
|
||||
- Automatic evaluation of predictions after appropriate timeouts
|
||||
- Integration with existing RL training infrastructure
|
||||
|
||||
### ✅ Enhanced Inference Scheduling
|
||||
- **Continuous inference** every 1-5 seconds on primary timeframe
|
||||
- **Hourly multi-timeframe inference** (4 predictions per hour - one for each timeframe)
|
||||
- Timeframe-aware inference context
|
||||
- Proper scheduling and coordination
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Market Data] --> B[Timeframe Inference Coordinator]
|
||||
B --> C[Model Inference]
|
||||
C --> D[Enhanced Reward Calculator]
|
||||
D --> E[Prediction Tracking]
|
||||
E --> F[Outcome Evaluation]
|
||||
F --> G[MSE Reward Calculation]
|
||||
G --> H[Enhanced RL Training Adapter]
|
||||
H --> I[Model Training]
|
||||
I --> J[Performance Monitoring]
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. EnhancedRewardCalculator (`core/enhanced_reward_calculator.py`)
|
||||
|
||||
**Purpose**: Central reward calculation engine using MSE methodology
|
||||
|
||||
**Key Methods**:
|
||||
- `add_prediction()` - Track new predictions
|
||||
- `evaluate_predictions()` - Calculate rewards when outcomes available
|
||||
- `get_accuracy_summary()` - Comprehensive accuracy metrics
|
||||
- `get_training_data()` - Extract training samples for models
|
||||
|
||||
**Reward Formula**:
|
||||
```python
|
||||
# MSE calculation
|
||||
price_error = actual_price - predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize to reasonable scale
|
||||
max_mse = (current_price * 0.1) ** 2 # 10% as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
|
||||
# Exponential decay (heavily penalize large errors)
|
||||
mse_reward = exp(-5 * normalized_mse) # Range: [exp(-5), 1]
|
||||
|
||||
# Direction bonus/penalty
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Final reward (confidence weighted)
|
||||
final_reward = (mse_reward + direction_bonus) * confidence
|
||||
```
|
||||
|
||||
### 2. TimeframeInferenceCoordinator (`core/timeframe_inference_coordinator.py`)
|
||||
|
||||
**Purpose**: Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
**Key Features**:
|
||||
- **Continuous inference loop** for each symbol (every 5 seconds)
|
||||
- **Hourly multi-timeframe scheduler** (4 predictions per hour)
|
||||
- **Inference context management** (models know target timeframe)
|
||||
- **Automatic reward evaluation** and training triggers
|
||||
|
||||
**Scheduling**:
|
||||
- **Every 5 seconds**: Inference on primary timeframe (1s)
|
||||
- **Every hour**: One inference for each timeframe (1s, 1m, 1h, 1d)
|
||||
- **Evaluation timeouts**: 5s for 1s predictions, 60s for 1m, 300s for 1h, 900s for 1d
|
||||
|
||||
### 3. EnhancedRLTrainingAdapter (`core/enhanced_rl_training_adapter.py`)
|
||||
|
||||
**Purpose**: Bridge between new reward system and existing RL training infrastructure
|
||||
|
||||
**Key Features**:
|
||||
- **Model inference wrappers** for DQN, COB RL, and CNN models
|
||||
- **Training batch creation** from prediction records and rewards
|
||||
- **Real-time training triggers** based on evaluation results
|
||||
- **Backward compatibility** with existing training systems
|
||||
|
||||
### 4. EnhancedRewardSystemIntegration (`core/enhanced_reward_system_integration.py`)
|
||||
|
||||
**Purpose**: Simple integration point for existing systems
|
||||
|
||||
**Key Features**:
|
||||
- **One-line integration** with existing TradingOrchestrator
|
||||
- **Helper functions** for easy prediction tracking
|
||||
- **Comprehensive monitoring** and statistics
|
||||
- **Minimal code changes** required
|
||||
|
||||
## Integration Guide
|
||||
|
||||
### Step 1: Import Required Components
|
||||
|
||||
Add to your `orchestrator.py`:
|
||||
|
||||
```python
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
```
|
||||
|
||||
### Step 2: Initialize in TradingOrchestrator
|
||||
|
||||
In your `TradingOrchestrator.__init__()`:
|
||||
|
||||
```python
|
||||
# Add this line after existing initialization
|
||||
integrate_enhanced_rewards(self, symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
```
|
||||
|
||||
### Step 3: Start the System
|
||||
|
||||
In your `TradingOrchestrator.run()` method:
|
||||
|
||||
```python
|
||||
# Add this line after initialization
|
||||
await self.enhanced_reward_system.start_integration()
|
||||
```
|
||||
|
||||
### Step 4: Track Predictions
|
||||
|
||||
In your model inference methods (CNN, DQN, COB RL):
|
||||
|
||||
```python
|
||||
# Example in CNN inference
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
self, # orchestrator instance
|
||||
symbol, # 'ETH/USDT'
|
||||
timeframe, # '1s', '1m', '1h', '1d'
|
||||
predicted_price, # model's price prediction
|
||||
direction, # -1 (down), 0 (neutral), 1 (up)
|
||||
confidence, # 0.0 to 1.0
|
||||
current_price, # current market price
|
||||
'enhanced_cnn' # model name
|
||||
)
|
||||
```
|
||||
|
||||
### Step 5: Monitor Performance
|
||||
|
||||
```python
|
||||
# Get comprehensive statistics
|
||||
stats = self.enhanced_reward_system.get_integration_statistics()
|
||||
accuracy = self.enhanced_reward_system.get_model_accuracy()
|
||||
|
||||
# Force evaluation for testing
|
||||
self.enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
See `examples/enhanced_reward_system_example.py` for a complete demonstration.
|
||||
|
||||
```bash
|
||||
python examples/enhanced_reward_system_example.py
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
### 🎯 Better Accuracy Measurement
|
||||
- **MSE rewards** provide nuanced feedback vs. simple directional accuracy
|
||||
- **Price prediction accuracy** measured alongside direction accuracy
|
||||
- **Confidence-weighted rewards** encourage well-calibrated predictions
|
||||
|
||||
### 📊 Multi-Timeframe Intelligence
|
||||
- **Separate tracking** prevents timeframe confusion
|
||||
- **Timeframe-specific evaluation** accounts for different market dynamics
|
||||
- **Comprehensive accuracy picture** across all prediction horizons
|
||||
|
||||
### ⚡ Real-Time Learning
|
||||
- **Immediate training** when prediction outcomes available
|
||||
- **No batch delays** - models learn from every prediction
|
||||
- **Adaptive training frequency** based on prediction evaluation
|
||||
|
||||
### 🔄 Enhanced Inference Scheduling
|
||||
- **Optimal prediction frequency** balances real-time response with computational efficiency
|
||||
- **Hourly multi-timeframe predictions** provide comprehensive market coverage
|
||||
- **Context-aware models** make better predictions knowing their target timeframe
|
||||
|
||||
## Configuration
|
||||
|
||||
### Evaluation Timeouts (Configurable in EnhancedRewardCalculator)
|
||||
|
||||
```python
|
||||
evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
```
|
||||
|
||||
### Inference Scheduling (Configurable in TimeframeInferenceCoordinator)
|
||||
|
||||
```python
|
||||
schedule = InferenceSchedule(
|
||||
continuous_interval_seconds=5.0, # Continuous inference every 5 seconds
|
||||
hourly_timeframes=[TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
)
|
||||
```
|
||||
|
||||
### Training Configuration (Configurable in EnhancedRLTrainingAdapter)
|
||||
|
||||
```python
|
||||
min_batch_size = 8 # Minimum samples for training
|
||||
max_batch_size = 64 # Maximum samples per training batch
|
||||
training_interval_seconds = 5.0 # Training check frequency
|
||||
```
|
||||
|
||||
## Monitoring and Statistics
|
||||
|
||||
### Integration Statistics
|
||||
|
||||
```python
|
||||
stats = enhanced_reward_system.get_integration_statistics()
|
||||
```
|
||||
|
||||
Returns:
|
||||
- System running status
|
||||
- Total predictions tracked
|
||||
- Component status
|
||||
- Inference and training statistics
|
||||
- Performance metrics
|
||||
|
||||
### Model Accuracy
|
||||
|
||||
```python
|
||||
accuracy = enhanced_reward_system.get_model_accuracy()
|
||||
```
|
||||
|
||||
Returns for each symbol and timeframe:
|
||||
- Total predictions made
|
||||
- Direction accuracy percentage
|
||||
- Average MSE
|
||||
- Recent prediction count
|
||||
|
||||
### Real-Time Monitoring
|
||||
|
||||
The system provides comprehensive logging at different levels:
|
||||
- `INFO`: Major system events, training results
|
||||
- `DEBUG`: Detailed prediction tracking, reward calculations
|
||||
- `ERROR`: System errors and recovery actions
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The enhanced reward system is designed to be **fully backward compatible**:
|
||||
|
||||
✅ **Existing models continue to work** without modification
|
||||
✅ **Existing training systems** remain functional
|
||||
✅ **Existing reward calculations** can run in parallel
|
||||
✅ **Gradual migration** - enable for specific models incrementally
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Force Evaluation for Testing
|
||||
|
||||
```python
|
||||
# Force immediate evaluation of all predictions
|
||||
enhanced_reward_system.force_evaluation_and_training()
|
||||
|
||||
# Force evaluation for specific symbol/timeframe
|
||||
enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
### Manual Prediction Addition
|
||||
|
||||
```python
|
||||
# Add predictions manually for testing
|
||||
prediction_id = enhanced_reward_system.add_prediction_manually(
|
||||
symbol='ETH/USDT',
|
||||
timeframe_str='1s',
|
||||
predicted_price=3150.50,
|
||||
predicted_direction=1,
|
||||
confidence=0.85,
|
||||
current_price=3150.00,
|
||||
model_name='test_model'
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
The system includes automatic memory management:
|
||||
|
||||
- **Automatic prediction cleanup** (configurable retention period)
|
||||
- **Circular buffers** for prediction history (max 100 per timeframe)
|
||||
- **Price cache management** (max 1000 price points per symbol)
|
||||
- **Efficient storage** using deques and compressed data structures
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
The architecture supports easy extension for:
|
||||
|
||||
1. **Additional timeframes** (30s, 5m, 15m, etc.)
|
||||
2. **Custom reward functions** (Sharpe ratio, maximum drawdown, etc.)
|
||||
3. **Multi-symbol correlation** rewards
|
||||
4. **Advanced statistical metrics** (Sortino ratio, Calmar ratio)
|
||||
5. **Model ensemble** reward aggregation
|
||||
6. **A/B testing** framework for reward functions
|
||||
|
||||
## Conclusion
|
||||
|
||||
The Enhanced Reward System provides a comprehensive foundation for improving RL model training through:
|
||||
|
||||
- **Precise MSE-based rewards** that accurately measure prediction quality
|
||||
- **Multi-timeframe intelligence** that prevents confusion between different prediction horizons
|
||||
- **Real-time learning** that maximizes training opportunities
|
||||
- **Easy integration** that requires minimal changes to existing code
|
||||
- **Comprehensive monitoring** that provides insights into model performance
|
||||
|
||||
This system addresses the specific requirements you outlined:
|
||||
✅ MSE-based accuracy calculation
|
||||
✅ Training at each inference using last prediction vs. current outcome
|
||||
✅ Separate accuracy tracking for up to 6 last predictions per timeframe
|
||||
✅ Models know which timeframe they're predicting on
|
||||
✅ Hourly multi-timeframe inference (4 predictions per hour)
|
||||
✅ Integration with existing 1-5 second inference frequency
|
||||
|
||||
@@ -1,494 +0,0 @@
|
||||
# RL Training Pipeline Audit and Improvements
|
||||
|
||||
## Current State Analysis
|
||||
|
||||
### 1. Existing RL Training Components
|
||||
|
||||
**Current Architecture:**
|
||||
- **EnhancedDQNAgent**: Main RL agent with dueling DQN architecture
|
||||
- **EnhancedRLTrainer**: Training coordinator with prioritized experience replay
|
||||
- **PrioritizedReplayBuffer**: Experience replay with priority sampling
|
||||
- **RLTrainer**: Basic training pipeline for scalping scenarios
|
||||
|
||||
**Current Data Input Structure:**
|
||||
```python
|
||||
# Current MarketState in enhanced_orchestrator.py
|
||||
@dataclass
|
||||
class MarketState:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
prices: Dict[str, float] # {timeframe: current_price}
|
||||
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
|
||||
volatility: float
|
||||
volume: float
|
||||
trend_strength: float
|
||||
market_regime: str # 'trending', 'ranging', 'volatile'
|
||||
universal_data: UniversalDataStream
|
||||
```
|
||||
|
||||
**Current State Conversion:**
|
||||
- Limited to basic market metrics (volatility, volume, trend)
|
||||
- Missing tick-level features
|
||||
- No multi-symbol correlation data
|
||||
- No CNN hidden layer integration
|
||||
- Incomplete implementation of required data format
|
||||
|
||||
## Critical Issues Identified
|
||||
|
||||
### 1. **Insufficient Data Input (CRITICAL)**
|
||||
**Current Problem:** RL model only receives basic market metrics, missing required data:
|
||||
- ❌ 300s of raw tick data for momentum detection
|
||||
- ❌ Multi-timeframe OHLCV (1s, 1m, 1h, 1d) for both ETH and BTC
|
||||
- ❌ CNN hidden layer features
|
||||
- ❌ CNN predictions from all timeframes
|
||||
- ❌ Pivot point predictions
|
||||
|
||||
**Required Input per Specification:**
|
||||
```
|
||||
ETH:
|
||||
- 300s max of raw ticks data (detecting single big moves and momentum)
|
||||
- 300s of 1s OHLCV data (5 min)
|
||||
- 300 OHLCV + indicators bars of each 1m 1h 1d and 1s BTC
|
||||
|
||||
RL model should have access to:
|
||||
- Last hidden layers of the CNN model where patterns are learned
|
||||
- CNN output (predictions) for each timeframe (1s 1m 1h 1d)
|
||||
- Next expected pivot point predictions
|
||||
```
|
||||
|
||||
### 2. **Inadequate State Representation**
|
||||
**Current Issues:**
|
||||
- State size fixed at 100 features (too small)
|
||||
- No standardization/normalization
|
||||
- Missing temporal sequence information
|
||||
- No multi-symbol context
|
||||
|
||||
### 3. **Training Pipeline Limitations**
|
||||
- No real-time tick processing integration
|
||||
- Missing CNN feature integration
|
||||
- Limited reward engineering
|
||||
- No market regime-specific training
|
||||
|
||||
### 4. **Missing Pivot Point Integration**
|
||||
- No pivot point calculation system
|
||||
- No recursive trend analysis
|
||||
- Missing Williams market structure implementation
|
||||
|
||||
## Comprehensive Improvement Plan
|
||||
|
||||
### Phase 1: Enhanced State Representation
|
||||
|
||||
#### 1.1 Create Comprehensive State Builder
|
||||
```python
|
||||
class EnhancedRLStateBuilder:
|
||||
"""Build comprehensive RL state from all available data sources"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.tick_window = 300 # 300s of ticks
|
||||
self.ohlcv_window = 300 # 300 1s bars
|
||||
self.state_components = {
|
||||
'eth_ticks': 300 * 10, # ~10 features per tick
|
||||
'eth_1s_ohlcv': 300 * 8, # OHLCV + indicators
|
||||
'eth_1m_ohlcv': 300 * 8, # 300 1m bars
|
||||
'eth_1h_ohlcv': 300 * 8, # 300 1h bars
|
||||
'eth_1d_ohlcv': 300 * 8, # 300 1d bars
|
||||
'btc_reference': 300 * 8, # BTC reference data
|
||||
'cnn_features': 512, # CNN hidden layer features
|
||||
'cnn_predictions': 16, # CNN predictions (4 timeframes * 4 outputs)
|
||||
'pivot_points': 50, # Recursive pivot points
|
||||
'market_regime': 10 # Market regime features
|
||||
}
|
||||
self.total_state_size = sum(self.state_components.values()) # ~8000+ features
|
||||
```
|
||||
|
||||
#### 1.2 Multi-Symbol Data Integration
|
||||
```python
|
||||
def build_rl_state(self, universal_stream: UniversalDataStream,
|
||||
cnn_hidden_features: Dict = None,
|
||||
cnn_predictions: Dict = None) -> np.ndarray:
|
||||
"""Build comprehensive RL state vector"""
|
||||
|
||||
state_vector = []
|
||||
|
||||
# 1. ETH Tick Data (300s window)
|
||||
eth_tick_features = self._process_tick_data(
|
||||
universal_stream.eth_ticks, window_size=300
|
||||
)
|
||||
state_vector.extend(eth_tick_features)
|
||||
|
||||
# 2. ETH Multi-timeframe OHLCV
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
ohlcv_features = self._process_ohlcv_data(
|
||||
getattr(universal_stream, f'eth_{timeframe}'),
|
||||
timeframe=timeframe,
|
||||
window_size=300
|
||||
)
|
||||
state_vector.extend(ohlcv_features)
|
||||
|
||||
# 3. BTC Reference Data
|
||||
btc_features = self._process_btc_reference(universal_stream.btc_ticks)
|
||||
state_vector.extend(btc_features)
|
||||
|
||||
# 4. CNN Hidden Layer Features
|
||||
if cnn_hidden_features:
|
||||
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
||||
state_vector.extend(cnn_hidden)
|
||||
else:
|
||||
state_vector.extend([0.0] * self.state_components['cnn_features'])
|
||||
|
||||
# 5. CNN Predictions
|
||||
if cnn_predictions:
|
||||
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
||||
state_vector.extend(cnn_pred)
|
||||
else:
|
||||
state_vector.extend([0.0] * self.state_components['cnn_predictions'])
|
||||
|
||||
# 6. Pivot Points
|
||||
pivot_features = self._calculate_recursive_pivot_points(universal_stream)
|
||||
state_vector.extend(pivot_features)
|
||||
|
||||
# 7. Market Regime Features
|
||||
regime_features = self._extract_market_regime_features(universal_stream)
|
||||
state_vector.extend(regime_features)
|
||||
|
||||
return np.array(state_vector, dtype=np.float32)
|
||||
```
|
||||
|
||||
### Phase 2: Pivot Point System Implementation
|
||||
|
||||
#### 2.1 Williams Market Structure Pivot Points
|
||||
```python
|
||||
class WilliamsMarketStructure:
|
||||
"""Implementation of Larry Williams market structure analysis"""
|
||||
|
||||
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict:
|
||||
"""Calculate 5 levels of recursive pivot points"""
|
||||
|
||||
levels = {}
|
||||
current_data = ohlcv_data
|
||||
|
||||
for level in range(5):
|
||||
# Find swing highs and lows
|
||||
swing_points = self._find_swing_points(current_data)
|
||||
|
||||
# Determine trend direction
|
||||
trend_direction = self._determine_trend_direction(swing_points)
|
||||
|
||||
levels[f'level_{level}'] = {
|
||||
'swing_points': swing_points,
|
||||
'trend_direction': trend_direction,
|
||||
'trend_strength': self._calculate_trend_strength(swing_points)
|
||||
}
|
||||
|
||||
# Use swing points as input for next level
|
||||
if len(swing_points) >= 5:
|
||||
current_data = self._convert_swings_to_ohlcv(swing_points)
|
||||
else:
|
||||
break
|
||||
|
||||
return levels
|
||||
|
||||
def _find_swing_points(self, ohlcv_data: np.ndarray) -> List[Dict]:
|
||||
"""Find swing highs and lows (higher lows/lower highs on both sides)"""
|
||||
swing_points = []
|
||||
|
||||
for i in range(2, len(ohlcv_data) - 2):
|
||||
current_high = ohlcv_data[i, 2] # High price
|
||||
current_low = ohlcv_data[i, 3] # Low price
|
||||
|
||||
# Check for swing high (lower highs on both sides)
|
||||
if (current_high > ohlcv_data[i-1, 2] and
|
||||
current_high > ohlcv_data[i-2, 2] and
|
||||
current_high > ohlcv_data[i+1, 2] and
|
||||
current_high > ohlcv_data[i+2, 2]):
|
||||
|
||||
swing_points.append({
|
||||
'type': 'swing_high',
|
||||
'timestamp': ohlcv_data[i, 0],
|
||||
'price': current_high,
|
||||
'index': i
|
||||
})
|
||||
|
||||
# Check for swing low (higher lows on both sides)
|
||||
if (current_low < ohlcv_data[i-1, 3] and
|
||||
current_low < ohlcv_data[i-2, 3] and
|
||||
current_low < ohlcv_data[i+1, 3] and
|
||||
current_low < ohlcv_data[i+2, 3]):
|
||||
|
||||
swing_points.append({
|
||||
'type': 'swing_low',
|
||||
'timestamp': ohlcv_data[i, 0],
|
||||
'price': current_low,
|
||||
'index': i
|
||||
})
|
||||
|
||||
return swing_points
|
||||
```
|
||||
|
||||
### Phase 3: CNN Integration Layer
|
||||
|
||||
#### 3.1 CNN-RL Bridge
|
||||
```python
|
||||
class CNNRLBridge:
|
||||
"""Bridge between CNN and RL models for feature sharing"""
|
||||
|
||||
def __init__(self, cnn_models: Dict, rl_agents: Dict):
|
||||
self.cnn_models = cnn_models
|
||||
self.rl_agents = rl_agents
|
||||
self.feature_cache = {}
|
||||
|
||||
async def extract_cnn_features_for_rl(self, universal_stream: UniversalDataStream) -> Dict:
|
||||
"""Extract CNN hidden layer features and predictions for RL"""
|
||||
|
||||
cnn_features = {
|
||||
'hidden_features': {},
|
||||
'predictions': {},
|
||||
'confidences': {}
|
||||
}
|
||||
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
if timeframe in self.cnn_models:
|
||||
model = self.cnn_models[timeframe]
|
||||
|
||||
# Get input data for this timeframe
|
||||
timeframe_data = getattr(universal_stream, f'eth_{timeframe}')
|
||||
|
||||
if len(timeframe_data) > 0:
|
||||
# Extract hidden layer features
|
||||
hidden_features = await self._extract_hidden_features(
|
||||
model, timeframe_data
|
||||
)
|
||||
cnn_features['hidden_features'][timeframe] = hidden_features
|
||||
|
||||
# Get predictions
|
||||
predictions, confidence = await model.predict(timeframe_data)
|
||||
cnn_features['predictions'][timeframe] = predictions
|
||||
cnn_features['confidences'][timeframe] = confidence
|
||||
|
||||
return cnn_features
|
||||
|
||||
async def _extract_hidden_features(self, model, data: np.ndarray) -> np.ndarray:
|
||||
"""Extract hidden layer features from CNN model"""
|
||||
try:
|
||||
# Hook into the model's hidden layers
|
||||
activation = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activation[name] = output.detach()
|
||||
return hook
|
||||
|
||||
# Register hook on the last hidden layer before output
|
||||
handle = model.fc_hidden.register_forward_hook(get_activation('hidden'))
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
_ = model(torch.FloatTensor(data).unsqueeze(0))
|
||||
|
||||
# Remove hook
|
||||
handle.remove()
|
||||
|
||||
# Return flattened hidden features
|
||||
if 'hidden' in activation:
|
||||
return activation['hidden'].cpu().numpy().flatten()
|
||||
else:
|
||||
return np.zeros(512) # Default size
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN hidden features: {e}")
|
||||
return np.zeros(512)
|
||||
```
|
||||
|
||||
### Phase 4: Enhanced Training Pipeline
|
||||
|
||||
#### 4.1 Multi-Modal Training Loop
|
||||
```python
|
||||
class EnhancedRLTrainingPipeline:
|
||||
"""Comprehensive RL training with all required data inputs"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.state_builder = EnhancedRLStateBuilder(config)
|
||||
self.pivot_calculator = WilliamsMarketStructure()
|
||||
self.cnn_rl_bridge = CNNRLBridge(config.cnn_models, config.rl_agents)
|
||||
|
||||
# Enhanced DQN with larger state space
|
||||
self.agent = EnhancedDQNAgent({
|
||||
'state_size': self.state_builder.total_state_size, # ~8000+ features
|
||||
'action_space': 3,
|
||||
'hidden_size': 1024, # Larger hidden layers
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.99,
|
||||
'buffer_size': 50000, # Larger replay buffer
|
||||
'batch_size': 128
|
||||
})
|
||||
|
||||
async def training_step(self, universal_stream: UniversalDataStream):
|
||||
"""Single training step with comprehensive data"""
|
||||
|
||||
# 1. Extract CNN features and predictions
|
||||
cnn_data = await self.cnn_rl_bridge.extract_cnn_features_for_rl(universal_stream)
|
||||
|
||||
# 2. Build comprehensive RL state
|
||||
current_state = self.state_builder.build_rl_state(
|
||||
universal_stream=universal_stream,
|
||||
cnn_hidden_features=cnn_data['hidden_features'],
|
||||
cnn_predictions=cnn_data['predictions']
|
||||
)
|
||||
|
||||
# 3. Agent action selection
|
||||
action = self.agent.act(current_state)
|
||||
|
||||
# 4. Execute action and get reward
|
||||
reward, next_universal_stream = await self._execute_action_and_get_reward(
|
||||
action, universal_stream
|
||||
)
|
||||
|
||||
# 5. Build next state
|
||||
next_cnn_data = await self.cnn_rl_bridge.extract_cnn_features_for_rl(
|
||||
next_universal_stream
|
||||
)
|
||||
next_state = self.state_builder.build_rl_state(
|
||||
universal_stream=next_universal_stream,
|
||||
cnn_hidden_features=next_cnn_data['hidden_features'],
|
||||
cnn_predictions=next_cnn_data['predictions']
|
||||
)
|
||||
|
||||
# 6. Store experience
|
||||
self.agent.remember(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# 7. Train if enough experiences
|
||||
if len(self.agent.replay_buffer) > self.agent.batch_size:
|
||||
loss = self.agent.replay()
|
||||
return {'loss': loss, 'reward': reward, 'action': action}
|
||||
|
||||
return {'reward': reward, 'action': action}
|
||||
```
|
||||
|
||||
#### 4.2 Enhanced Reward Engineering
|
||||
```python
|
||||
class EnhancedRewardCalculator:
|
||||
"""Sophisticated reward calculation considering multiple factors"""
|
||||
|
||||
def calculate_reward(self, action: int, market_data_before: Dict,
|
||||
market_data_after: Dict, trade_outcome: float = None) -> float:
|
||||
"""Calculate multi-factor reward"""
|
||||
|
||||
base_reward = 0.0
|
||||
|
||||
# 1. Price Movement Reward
|
||||
if trade_outcome is not None:
|
||||
# Direct trading outcome
|
||||
base_reward += trade_outcome * 10 # Scale P&L
|
||||
else:
|
||||
# Prediction accuracy reward
|
||||
price_change = self._calculate_price_change(market_data_before, market_data_after)
|
||||
action_correctness = self._evaluate_action_correctness(action, price_change)
|
||||
base_reward += action_correctness * 5
|
||||
|
||||
# 2. Market Regime Bonus
|
||||
regime_bonus = self._calculate_regime_bonus(action, market_data_after)
|
||||
base_reward += regime_bonus
|
||||
|
||||
# 3. Volatility Penalty/Bonus
|
||||
volatility_factor = self._calculate_volatility_factor(market_data_after)
|
||||
base_reward *= volatility_factor
|
||||
|
||||
# 4. CNN Confidence Alignment
|
||||
cnn_alignment = self._calculate_cnn_alignment_bonus(action, market_data_after)
|
||||
base_reward += cnn_alignment
|
||||
|
||||
# 5. Pivot Point Accuracy
|
||||
pivot_accuracy = self._calculate_pivot_accuracy_bonus(action, market_data_after)
|
||||
base_reward += pivot_accuracy
|
||||
|
||||
return base_reward
|
||||
```
|
||||
|
||||
### Phase 5: Implementation Timeline
|
||||
|
||||
#### Week 1: State Representation Enhancement
|
||||
- [ ] Implement EnhancedRLStateBuilder
|
||||
- [ ] Add tick data processing
|
||||
- [ ] Implement multi-timeframe OHLCV integration
|
||||
- [ ] Add BTC reference data processing
|
||||
|
||||
#### Week 2: Pivot Point System
|
||||
- [ ] Implement WilliamsMarketStructure class
|
||||
- [ ] Add recursive pivot point calculation
|
||||
- [ ] Integrate with state builder
|
||||
- [ ] Test pivot point accuracy
|
||||
|
||||
#### Week 3: CNN-RL Integration
|
||||
- [ ] Implement CNNRLBridge
|
||||
- [ ] Add hidden feature extraction
|
||||
- [ ] Integrate CNN predictions into RL state
|
||||
- [ ] Test feature consistency
|
||||
|
||||
#### Week 4: Enhanced Training Pipeline
|
||||
- [ ] Implement EnhancedRLTrainingPipeline
|
||||
- [ ] Add enhanced reward calculator
|
||||
- [ ] Integrate all components
|
||||
- [ ] Performance testing and optimization
|
||||
|
||||
#### Week 5: Testing and Validation
|
||||
- [ ] Comprehensive integration testing
|
||||
- [ ] Performance validation
|
||||
- [ ] Memory usage optimization
|
||||
- [ ] Documentation and monitoring
|
||||
|
||||
## Expected Improvements
|
||||
|
||||
### 1. **State Representation Quality**
|
||||
- **Current**: ~100 basic features
|
||||
- **Enhanced**: ~8000+ comprehensive features
|
||||
- **Improvement**: 80x more information density
|
||||
|
||||
### 2. **Decision Making Accuracy**
|
||||
- **Current**: Limited to basic market metrics
|
||||
- **Enhanced**: Multi-modal with CNN features + pivot points
|
||||
- **Expected**: 40-60% improvement in prediction accuracy
|
||||
|
||||
### 3. **Market Adaptability**
|
||||
- **Current**: Basic market regime detection
|
||||
- **Enhanced**: Multi-timeframe analysis with recursive trends
|
||||
- **Expected**: Better performance across different market conditions
|
||||
|
||||
### 4. **Learning Efficiency**
|
||||
- **Current**: Simple experience replay
|
||||
- **Enhanced**: Prioritized replay with sophisticated rewards
|
||||
- **Expected**: 2-3x faster convergence
|
||||
|
||||
## Risk Mitigation
|
||||
|
||||
### 1. **Memory Usage**
|
||||
- **Risk**: Large state vectors (~8000 features) may cause memory issues
|
||||
- **Mitigation**: Implement state compression and efficient batching
|
||||
|
||||
### 2. **Training Stability**
|
||||
- **Risk**: Complex state space may cause training instability
|
||||
- **Mitigation**: Gradual state expansion, careful hyperparameter tuning
|
||||
|
||||
### 3. **Integration Complexity**
|
||||
- **Risk**: CNN-RL integration may introduce bugs
|
||||
- **Mitigation**: Extensive testing, fallback mechanisms
|
||||
|
||||
### 4. **Performance Impact**
|
||||
- **Risk**: Real-time performance degradation
|
||||
- **Mitigation**: Asynchronous processing, optimized data structures
|
||||
|
||||
## Success Metrics
|
||||
|
||||
1. **State Quality**: Feature coverage > 95% of required specification
|
||||
2. **Training Performance**: Convergence time < 50% of current
|
||||
3. **Decision Accuracy**: Prediction accuracy > 65% (vs current ~45%)
|
||||
4. **Market Adaptability**: Consistent performance across 3+ market regimes
|
||||
5. **Integration Stability**: Uptime > 99.5% with CNN integration
|
||||
|
||||
This comprehensive upgrade will transform the RL training pipeline from a basic implementation to a sophisticated multi-modal system that fully meets the specification requirements.
|
||||
@@ -1,280 +0,0 @@
|
||||
# Trading System Logging Upgrade
|
||||
|
||||
## Overview
|
||||
|
||||
This upgrade implements a comprehensive logging and metadata management system that addresses the key issues:
|
||||
|
||||
1. **Eliminates scattered "No checkpoints found" logs** during runtime
|
||||
2. **Fast checkpoint metadata access** without loading full models
|
||||
3. **Centralized inference logging** with database and text file storage
|
||||
4. **Structured tracking** of model performance and checkpoints
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Database Manager (`utils/database_manager.py`)
|
||||
|
||||
**Purpose**: SQLite-based storage for structured data
|
||||
|
||||
**Features**:
|
||||
- Inference records logging with deduplication
|
||||
- Checkpoint metadata storage (separate from model weights)
|
||||
- Model performance tracking
|
||||
- Fast queries without loading model files
|
||||
|
||||
**Tables**:
|
||||
- `inference_records`: All model predictions with metadata
|
||||
- `checkpoint_metadata`: Checkpoint info without model weights
|
||||
- `model_performance`: Daily aggregated statistics
|
||||
|
||||
### 2. Inference Logger (`utils/inference_logger.py`)
|
||||
|
||||
**Purpose**: Centralized logging for all model inferences
|
||||
|
||||
**Features**:
|
||||
- Single function call replaces scattered `logger.info()` calls
|
||||
- Automatic feature hashing for deduplication
|
||||
- Memory usage tracking
|
||||
- Processing time measurement
|
||||
- Dual storage (database + text files)
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from utils.inference_logger import log_model_inference
|
||||
|
||||
log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol="ETH/USDT",
|
||||
action="BUY",
|
||||
confidence=0.85,
|
||||
probabilities={"BUY": 0.85, "SELL": 0.10, "HOLD": 0.05},
|
||||
input_features=features_array,
|
||||
processing_time_ms=12.5,
|
||||
checkpoint_id="dqn_agent_20250725_143500"
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Text Logger (`utils/text_logger.py`)
|
||||
|
||||
**Purpose**: Human-readable log files for tracking
|
||||
|
||||
**Features**:
|
||||
- Separate files for different event types
|
||||
- Clean, tabular format
|
||||
- Automatic cleanup of old entries
|
||||
- Easy to read and grep
|
||||
|
||||
**Files**:
|
||||
- `logs/inference_records.txt`: All model predictions
|
||||
- `logs/checkpoint_events.txt`: Save/load events
|
||||
- `logs/system_events.txt`: General system events
|
||||
|
||||
### 4. Enhanced Checkpoint Manager (`utils/checkpoint_manager.py`)
|
||||
|
||||
**Purpose**: Improved checkpoint handling with metadata separation
|
||||
|
||||
**Features**:
|
||||
- Database-backed metadata storage
|
||||
- Fast metadata queries without loading models
|
||||
- Eliminates "No checkpoints found" spam
|
||||
- Backward compatibility with existing code
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Performance Improvements
|
||||
|
||||
**Before**: Loading full checkpoint just to get metadata
|
||||
```python
|
||||
# Old way - loads entire model!
|
||||
checkpoint_path, metadata = load_best_checkpoint("dqn_agent")
|
||||
loss = metadata.loss # Expensive operation
|
||||
```
|
||||
|
||||
**After**: Fast metadata access from database
|
||||
```python
|
||||
# New way - database query only
|
||||
metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
loss = metadata.performance_metrics['loss'] # Fast!
|
||||
```
|
||||
|
||||
### 2. Cleaner Runtime Logs
|
||||
|
||||
**Before**: Scattered logs everywhere
|
||||
```
|
||||
2025-07-25 14:34:39,749 - utils.checkpoint_manager - INFO - No checkpoints found for dqn_agent
|
||||
2025-07-25 14:34:39,754 - utils.checkpoint_manager - INFO - No checkpoints found for enhanced_cnn
|
||||
2025-07-25 14:34:39,756 - utils.checkpoint_manager - INFO - No checkpoints found for extrema_trainer
|
||||
```
|
||||
|
||||
**After**: Clean, structured logging
|
||||
```
|
||||
2025-07-25 14:34:39 | dqn_agent | ETH/USDT | BUY | conf=0.850 | time= 12.5ms [checkpoint: dqn_agent_20250725_143500]
|
||||
2025-07-25 14:34:40 | enhanced_cnn | ETH/USDT | HOLD | conf=0.720 | time= 8.2ms [checkpoint: enhanced_cnn_20250725_143501]
|
||||
```
|
||||
|
||||
### 3. Structured Data Storage
|
||||
|
||||
**Database Schema**:
|
||||
```sql
|
||||
-- Fast metadata queries
|
||||
SELECT * FROM checkpoint_metadata WHERE model_name = 'dqn_agent' AND is_active = TRUE;
|
||||
|
||||
-- Performance analysis
|
||||
SELECT model_name, AVG(confidence), COUNT(*)
|
||||
FROM inference_records
|
||||
WHERE timestamp > datetime('now', '-24 hours')
|
||||
GROUP BY model_name;
|
||||
```
|
||||
|
||||
### 4. Easy Integration
|
||||
|
||||
**In Model Code**:
|
||||
```python
|
||||
# Replace scattered logging
|
||||
# OLD: logger.info(f"DQN prediction: {action} confidence={conf}")
|
||||
|
||||
# NEW: Centralized logging
|
||||
self.orchestrator.log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probs,
|
||||
input_features=features,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
```
|
||||
|
||||
## Implementation Guide
|
||||
|
||||
### 1. Update Model Classes
|
||||
|
||||
Add inference logging to prediction methods:
|
||||
|
||||
```python
|
||||
class DQNAgent:
|
||||
def predict(self, state):
|
||||
start_time = time.time()
|
||||
|
||||
# Your prediction logic here
|
||||
action = self._predict_action(state)
|
||||
confidence = self._calculate_confidence()
|
||||
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Log the inference
|
||||
self.orchestrator.log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol=self.symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=self.action_probabilities,
|
||||
input_features=state,
|
||||
processing_time_ms=processing_time,
|
||||
checkpoint_id=self.current_checkpoint_id
|
||||
)
|
||||
|
||||
return action
|
||||
```
|
||||
|
||||
### 2. Update Checkpoint Saving
|
||||
|
||||
Use the enhanced checkpoint manager:
|
||||
|
||||
```python
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
# Save with metadata
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=self.model,
|
||||
model_name="dqn_agent",
|
||||
model_type="rl",
|
||||
performance_metrics={"loss": 0.0234, "accuracy": 0.87},
|
||||
training_metadata={"epochs": 100, "lr": 0.001}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Fast Metadata Access
|
||||
|
||||
Get checkpoint info without loading models:
|
||||
|
||||
```python
|
||||
# Fast metadata access
|
||||
metadata = orchestrator.get_checkpoint_metadata_fast("dqn_agent")
|
||||
if metadata:
|
||||
current_loss = metadata.performance_metrics['loss']
|
||||
checkpoint_id = metadata.checkpoint_id
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
1. **Install new dependencies** (if any)
|
||||
2. **Update model classes** to use centralized logging
|
||||
3. **Replace checkpoint loading** with database queries where possible
|
||||
4. **Remove scattered logger.info()** calls for inferences
|
||||
5. **Test with demo script**: `python demo_logging_system.py`
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
utils/
|
||||
├── database_manager.py # SQLite database management
|
||||
├── inference_logger.py # Centralized inference logging
|
||||
├── text_logger.py # Human-readable text logs
|
||||
└── checkpoint_manager.py # Enhanced checkpoint handling
|
||||
|
||||
logs/ # Text log files
|
||||
├── inference_records.txt
|
||||
├── checkpoint_events.txt
|
||||
└── system_events.txt
|
||||
|
||||
data/
|
||||
└── trading_system.db # SQLite database
|
||||
|
||||
demo_logging_system.py # Demonstration script
|
||||
```
|
||||
|
||||
## Monitoring and Maintenance
|
||||
|
||||
### Daily Tasks
|
||||
- Check `logs/inference_records.txt` for recent activity
|
||||
- Monitor database size: `ls -lh data/trading_system.db`
|
||||
|
||||
### Weekly Tasks
|
||||
- Run cleanup: `inference_logger.cleanup_old_logs(days_to_keep=30)`
|
||||
- Check model performance trends in database
|
||||
|
||||
### Monthly Tasks
|
||||
- Archive old log files
|
||||
- Analyze model performance statistics
|
||||
- Review checkpoint storage usage
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Database locked**: Multiple processes accessing SQLite
|
||||
- Solution: Use connection timeout and proper context managers
|
||||
|
||||
2. **Log files growing too large**:
|
||||
- Solution: Run `text_logger.cleanup_old_logs(max_lines=10000)`
|
||||
|
||||
3. **Missing checkpoint metadata**:
|
||||
- Solution: System falls back to file-based approach automatically
|
||||
|
||||
### Debug Commands
|
||||
|
||||
```python
|
||||
# Check database status
|
||||
db_manager = get_database_manager()
|
||||
checkpoints = db_manager.list_checkpoints("dqn_agent")
|
||||
|
||||
# Check recent inferences
|
||||
inference_logger = get_inference_logger()
|
||||
stats = inference_logger.get_model_stats("dqn_agent", hours=24)
|
||||
|
||||
# View text logs
|
||||
text_logger = get_text_logger()
|
||||
recent = text_logger.get_recent_inferences(lines=50)
|
||||
```
|
||||
|
||||
This upgrade provides a solid foundation for tracking model performance, eliminating log spam, and enabling fast metadata access without the overhead of loading full model checkpoints.
|
||||
Reference in New Issue
Block a user