320 lines
11 KiB
Python
320 lines
11 KiB
Python
"""
|
|
Test Enhanced Pivot-Based RL System
|
|
|
|
Tests the new system with:
|
|
- Different thresholds for entry vs exit
|
|
- Pivot-based rewards
|
|
- CNN predictions for early pivot detection
|
|
- Uninvested rewards
|
|
"""
|
|
|
|
import logging
|
|
import sys
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Any
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
|
stream=sys.stdout
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Add project root to Python path
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from training.enhanced_pivot_rl_trainer import EnhancedPivotRLTrainer, create_enhanced_pivot_trainer
|
|
|
|
def test_enhanced_pivot_thresholds():
|
|
"""Test the enhanced pivot-based threshold system"""
|
|
logger.info("=== Testing Enhanced Pivot-Based Thresholds ===")
|
|
|
|
try:
|
|
# Create components
|
|
data_provider = DataProvider()
|
|
orchestrator = EnhancedTradingOrchestrator(
|
|
data_provider=data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
|
|
# Test threshold initialization
|
|
thresholds = orchestrator.pivot_rl_trainer.get_current_thresholds()
|
|
logger.info(f"Initial thresholds:")
|
|
logger.info(f" Entry: {thresholds['entry_threshold']:.3f}")
|
|
logger.info(f" Exit: {thresholds['exit_threshold']:.3f}")
|
|
logger.info(f" Uninvested: {thresholds['uninvested_threshold']:.3f}")
|
|
|
|
# Verify entry threshold is higher than exit threshold
|
|
assert thresholds['entry_threshold'] > thresholds['exit_threshold'], "Entry threshold should be higher than exit"
|
|
logger.info("✅ Entry threshold correctly higher than exit threshold")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing thresholds: {e}")
|
|
return False
|
|
|
|
def test_pivot_reward_calculation():
|
|
"""Test the pivot-based reward calculation"""
|
|
logger.info("=== Testing Pivot-Based Reward Calculation ===")
|
|
|
|
try:
|
|
# Create enhanced pivot trainer
|
|
data_provider = DataProvider()
|
|
pivot_trainer = create_enhanced_pivot_trainer(data_provider)
|
|
|
|
# Create mock trade decision and outcome
|
|
trade_decision = {
|
|
'action': 'BUY',
|
|
'confidence': 0.75,
|
|
'price': 2500.0,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
trade_outcome = {
|
|
'net_pnl': 15.50, # Profitable trade
|
|
'exit_price': 2515.0,
|
|
'duration': timedelta(minutes=45)
|
|
}
|
|
|
|
# Create mock market data
|
|
market_data = pd.DataFrame({
|
|
'open': np.random.normal(2500, 10, 100),
|
|
'high': np.random.normal(2510, 10, 100),
|
|
'low': np.random.normal(2490, 10, 100),
|
|
'close': np.random.normal(2500, 10, 100),
|
|
'volume': np.random.normal(1000, 100, 100)
|
|
})
|
|
market_data.index = pd.date_range(start=datetime.now() - timedelta(hours=2), periods=100, freq='1min')
|
|
|
|
# Calculate reward
|
|
reward = pivot_trainer.calculate_pivot_based_reward(
|
|
trade_decision, market_data, trade_outcome
|
|
)
|
|
|
|
logger.info(f"Calculated pivot-based reward: {reward:.3f}")
|
|
|
|
# Test should return a reasonable reward for profitable trade
|
|
assert -15.0 <= reward <= 10.0, f"Reward {reward} outside expected range"
|
|
logger.info("✅ Pivot-based reward calculation working")
|
|
|
|
# Test uninvested reward
|
|
low_conf_decision = {
|
|
'action': 'HOLD',
|
|
'confidence': 0.35, # Below uninvested threshold
|
|
'price': 2500.0,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
uninvested_reward = pivot_trainer._calculate_uninvested_rewards(low_conf_decision, 0.35)
|
|
logger.info(f"Uninvested reward for low confidence: {uninvested_reward:.3f}")
|
|
|
|
assert uninvested_reward > 0, "Should get positive reward for staying uninvested with low confidence"
|
|
logger.info("✅ Uninvested rewards working correctly")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing pivot rewards: {e}")
|
|
return False
|
|
|
|
def test_confidence_adjustment():
|
|
"""Test confidence-based reward adjustments"""
|
|
logger.info("=== Testing Confidence-Based Adjustments ===")
|
|
|
|
try:
|
|
pivot_trainer = create_enhanced_pivot_trainer()
|
|
|
|
# Test overconfidence penalty on loss
|
|
high_conf_loss = {
|
|
'action': 'BUY',
|
|
'confidence': 0.85, # High confidence
|
|
'price': 2500.0,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
loss_outcome = {
|
|
'net_pnl': -25.0, # Loss
|
|
'exit_price': 2475.0,
|
|
'duration': timedelta(hours=3)
|
|
}
|
|
|
|
confidence_adjustment = pivot_trainer._calculate_confidence_adjustment(
|
|
high_conf_loss, loss_outcome
|
|
)
|
|
|
|
logger.info(f"Confidence adjustment for overconfident loss: {confidence_adjustment:.3f}")
|
|
assert confidence_adjustment < 0, "Should penalize overconfidence on losses"
|
|
|
|
# Test underconfidence penalty on win
|
|
low_conf_win = {
|
|
'action': 'BUY',
|
|
'confidence': 0.35, # Low confidence
|
|
'price': 2500.0,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
win_outcome = {
|
|
'net_pnl': 20.0, # Profit
|
|
'exit_price': 2520.0,
|
|
'duration': timedelta(minutes=30)
|
|
}
|
|
|
|
confidence_adjustment_2 = pivot_trainer._calculate_confidence_adjustment(
|
|
low_conf_win, win_outcome
|
|
)
|
|
|
|
logger.info(f"Confidence adjustment for underconfident win: {confidence_adjustment_2:.3f}")
|
|
# Should be small penalty or zero
|
|
|
|
logger.info("✅ Confidence adjustments working correctly")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing confidence adjustments: {e}")
|
|
return False
|
|
|
|
def test_dynamic_threshold_updates():
|
|
"""Test dynamic threshold updating based on performance"""
|
|
logger.info("=== Testing Dynamic Threshold Updates ===")
|
|
|
|
try:
|
|
pivot_trainer = create_enhanced_pivot_trainer()
|
|
|
|
# Get initial thresholds
|
|
initial_thresholds = pivot_trainer.get_current_thresholds()
|
|
logger.info(f"Initial thresholds: {initial_thresholds}")
|
|
|
|
# Simulate some poor performance (low win rate)
|
|
for i in range(25):
|
|
outcome = {
|
|
'timestamp': datetime.now(),
|
|
'action': 'BUY',
|
|
'confidence': 0.6,
|
|
'net_pnl': -5.0 if i < 20 else 10.0, # 20% win rate
|
|
'reward': -1.0 if i < 20 else 2.0,
|
|
'duration': timedelta(hours=2)
|
|
}
|
|
pivot_trainer.trade_outcomes.append(outcome)
|
|
|
|
# Update thresholds
|
|
pivot_trainer.update_thresholds_based_on_performance()
|
|
|
|
# Get updated thresholds
|
|
updated_thresholds = pivot_trainer.get_current_thresholds()
|
|
logger.info(f"Updated thresholds after poor performance: {updated_thresholds}")
|
|
|
|
# Entry threshold should increase (more selective) after poor performance
|
|
assert updated_thresholds['entry_threshold'] >= initial_thresholds['entry_threshold'], \
|
|
"Entry threshold should increase after poor performance"
|
|
|
|
logger.info("✅ Dynamic threshold updates working correctly")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing dynamic thresholds: {e}")
|
|
return False
|
|
|
|
def test_cnn_integration():
|
|
"""Test CNN integration for pivot predictions"""
|
|
logger.info("=== Testing CNN Integration ===")
|
|
|
|
try:
|
|
data_provider = DataProvider()
|
|
orchestrator = EnhancedTradingOrchestrator(
|
|
data_provider=data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
|
|
# Check if Williams structure is initialized with CNN
|
|
williams = orchestrator.pivot_rl_trainer.williams
|
|
logger.info(f"Williams CNN enabled: {williams.enable_cnn_feature}")
|
|
logger.info(f"Williams CNN model available: {williams.cnn_model is not None}")
|
|
|
|
# Test CNN threshold adjustment
|
|
from core.enhanced_orchestrator import MarketState
|
|
from datetime import datetime
|
|
|
|
mock_market_state = MarketState(
|
|
symbol='ETH/USDT',
|
|
timestamp=datetime.now(),
|
|
prices={'1s': 2500.0},
|
|
features={'1s': np.array([])},
|
|
volatility=0.02,
|
|
volume=1000.0,
|
|
trend_strength=0.5,
|
|
market_regime='normal',
|
|
universal_data=None
|
|
)
|
|
|
|
cnn_adjustment = orchestrator._get_cnn_threshold_adjustment(
|
|
'ETH/USDT', 'BUY', mock_market_state
|
|
)
|
|
|
|
logger.info(f"CNN threshold adjustment: {cnn_adjustment:.3f}")
|
|
assert 0.0 <= cnn_adjustment <= 0.1, "CNN adjustment should be reasonable"
|
|
|
|
logger.info("✅ CNN integration working correctly")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing CNN integration: {e}")
|
|
return False
|
|
|
|
def run_all_tests():
|
|
"""Run all enhanced pivot RL system tests"""
|
|
logger.info("🚀 Starting Enhanced Pivot RL System Tests")
|
|
|
|
tests = [
|
|
test_enhanced_pivot_thresholds,
|
|
test_pivot_reward_calculation,
|
|
test_confidence_adjustment,
|
|
test_dynamic_threshold_updates,
|
|
test_cnn_integration
|
|
]
|
|
|
|
passed = 0
|
|
total = len(tests)
|
|
|
|
for test_func in tests:
|
|
try:
|
|
if test_func():
|
|
passed += 1
|
|
logger.info(f"✅ {test_func.__name__} PASSED")
|
|
else:
|
|
logger.error(f"❌ {test_func.__name__} FAILED")
|
|
except Exception as e:
|
|
logger.error(f"❌ {test_func.__name__} ERROR: {e}")
|
|
|
|
logger.info(f"\n📊 Test Results: {passed}/{total} tests passed")
|
|
|
|
if passed == total:
|
|
logger.info("🎉 All Enhanced Pivot RL System tests PASSED!")
|
|
return True
|
|
else:
|
|
logger.error(f"⚠️ {total - passed} tests FAILED")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
success = run_all_tests()
|
|
|
|
if success:
|
|
logger.info("\n🔥 Enhanced Pivot RL System is ready for deployment!")
|
|
logger.info("Key improvements:")
|
|
logger.info(" ✅ Higher entry threshold than exit threshold")
|
|
logger.info(" ✅ Pivot-based reward calculation")
|
|
logger.info(" ✅ CNN predictions for early pivot detection")
|
|
logger.info(" ✅ Rewards for staying uninvested when uncertain")
|
|
logger.info(" ✅ Confidence-based reward adjustments")
|
|
logger.info(" ✅ Dynamic threshold learning from performance")
|
|
else:
|
|
logger.error("\n❌ Enhanced Pivot RL System has issues that need fixing")
|
|
|
|
sys.exit(0 if success else 1) |