159 lines
6.3 KiB
Python
159 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for position-based reward system
|
|
|
|
This script tests the enhanced reward calculations that incentivize:
|
|
1. Holding profitable positions (let winners run)
|
|
2. Closing losing positions (cut losses)
|
|
3. Taking action when appropriate based on P&L
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from core.orchestrator import TradingOrchestrator
|
|
from NN.models.enhanced_cnn import EnhancedCNN
|
|
import numpy as np
|
|
|
|
def test_position_reward_scenarios():
|
|
"""Test various position-based reward scenarios"""
|
|
|
|
print("🧪 POSITION-BASED REWARD SYSTEM TEST")
|
|
print("=" * 50)
|
|
|
|
# Initialize orchestrator
|
|
orchestrator = TradingOrchestrator()
|
|
|
|
# Test scenarios
|
|
scenarios = [
|
|
# (action, position_pnl, has_position, price_change_pct, description)
|
|
("HOLD", 50.0, True, 0.5, "Hold profitable position with continued gains"),
|
|
("HOLD", 50.0, True, -0.3, "Hold profitable position with small pullback"),
|
|
("HOLD", -30.0, True, 0.8, "Hold losing position that recovers"),
|
|
("HOLD", -30.0, True, -0.5, "Hold losing position that continues down"),
|
|
("SELL", 50.0, True, 0.0, "Close profitable position"),
|
|
("SELL", -30.0, True, 0.0, "Close losing position (good)"),
|
|
("BUY", 0.0, False, 1.0, "New buy position with immediate gain"),
|
|
("HOLD", 0.0, False, 0.1, "Hold with no position (stable market)"),
|
|
]
|
|
|
|
print("\n📊 SOPHISTICATED REWARD CALCULATION TESTS:")
|
|
print("-" * 80)
|
|
|
|
for i, (action, position_pnl, has_position, price_change_pct, description) in enumerate(scenarios, 1):
|
|
# Test sophisticated reward calculation
|
|
reward, was_correct = orchestrator._calculate_sophisticated_reward(
|
|
predicted_action=action,
|
|
prediction_confidence=0.8,
|
|
price_change_pct=price_change_pct,
|
|
time_diff_minutes=5.0,
|
|
has_price_prediction=False,
|
|
symbol="ETH/USDT",
|
|
has_position=has_position,
|
|
current_position_pnl=position_pnl
|
|
)
|
|
|
|
print(f"{i:2d}. {description}")
|
|
print(f" Action: {action}, P&L: ${position_pnl:+.1f}, Price Change: {price_change_pct:+.1f}%")
|
|
print(f" Reward: {reward:+.3f}, Correct: {was_correct}")
|
|
print()
|
|
|
|
print("\n🧠 CNN POSITION-ENHANCED REWARD TESTS:")
|
|
print("-" * 80)
|
|
|
|
# Initialize CNN model
|
|
cnn_model = EnhancedCNN(input_shape=100, n_actions=3)
|
|
|
|
for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1):
|
|
base_reward = 0.5 # Moderate base reward
|
|
enhanced_reward = cnn_model._calculate_position_enhanced_reward(
|
|
base_reward=base_reward,
|
|
action=action,
|
|
position_pnl=position_pnl,
|
|
has_position=has_position
|
|
)
|
|
|
|
enhancement = enhanced_reward - base_reward
|
|
print(f"{i:2d}. {description}")
|
|
print(f" Action: {action}, P&L: ${position_pnl:+.1f}")
|
|
print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})")
|
|
print()
|
|
|
|
print("\n🤖 DQN POSITION-ENHANCED REWARD TESTS:")
|
|
print("-" * 80)
|
|
|
|
for i, (action, position_pnl, has_position, _, description) in enumerate(scenarios, 1):
|
|
base_reward = 0.5 # Moderate base reward
|
|
enhanced_reward = orchestrator._calculate_position_enhanced_reward_for_dqn(
|
|
base_reward=base_reward,
|
|
action=action,
|
|
position_pnl=position_pnl,
|
|
has_position=has_position
|
|
)
|
|
|
|
enhancement = enhanced_reward - base_reward
|
|
print(f"{i:2d}. {description}")
|
|
print(f" Action: {action}, P&L: ${position_pnl:+.1f}")
|
|
print(f" Base Reward: {base_reward:+.3f} → Enhanced: {enhanced_reward:+.3f} (Δ{enhancement:+.3f})")
|
|
print()
|
|
|
|
def test_reward_incentives():
|
|
"""Test that rewards properly incentivize desired behaviors"""
|
|
|
|
print("\n🎯 REWARD INCENTIVE VALIDATION:")
|
|
print("-" * 50)
|
|
|
|
orchestrator = TradingOrchestrator()
|
|
cnn_model = EnhancedCNN(input_shape=100, n_actions=3)
|
|
|
|
# Test 1: Holding winners vs holding losers
|
|
print("1. HOLD action comparison:")
|
|
|
|
hold_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", 100.0, True)
|
|
hold_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True)
|
|
|
|
print(f" Hold profitable position (+$100): {hold_winner_reward:+.3f}")
|
|
print(f" Hold losing position (-$100): {hold_loser_reward:+.3f}")
|
|
print(f" ✅ Incentive correct: {hold_winner_reward > hold_loser_reward}")
|
|
|
|
# Test 2: Closing losers vs closing winners
|
|
print("\n2. SELL action comparison:")
|
|
|
|
sell_winner_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", 100.0, True)
|
|
sell_loser_reward = cnn_model._calculate_position_enhanced_reward(0.5, "SELL", -100.0, True)
|
|
|
|
print(f" Sell profitable position (+$100): {sell_winner_reward:+.3f}")
|
|
print(f" Sell losing position (-$100): {sell_loser_reward:+.3f}")
|
|
print(f" ✅ Incentive correct: {sell_loser_reward > sell_winner_reward}")
|
|
|
|
# Test 3: DQN reward scaling
|
|
print("\n3. DQN vs CNN reward scaling:")
|
|
|
|
dqn_reward = orchestrator._calculate_position_enhanced_reward_for_dqn(0.5, "HOLD", -100.0, True)
|
|
cnn_reward = cnn_model._calculate_position_enhanced_reward(0.5, "HOLD", -100.0, True)
|
|
|
|
print(f" DQN penalty for holding loser: {dqn_reward:+.3f}")
|
|
print(f" CNN penalty for holding loser: {cnn_reward:+.3f}")
|
|
print(f" ✅ DQN more sensitive: {abs(dqn_reward) > abs(cnn_reward)}")
|
|
|
|
def main():
|
|
"""Run all position-based reward tests"""
|
|
try:
|
|
test_position_reward_scenarios()
|
|
test_reward_incentives()
|
|
|
|
print("\n🚀 POSITION-BASED REWARD SYSTEM VALIDATION COMPLETE!")
|
|
print("✅ System properly incentivizes:")
|
|
print(" • Holding profitable positions (let winners run)")
|
|
print(" • Closing losing positions (cut losses)")
|
|
print(" • Taking appropriate action based on P&L")
|
|
print(" • Different reward scaling for CNN vs DQN models")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Test failed with error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
main() |