337 lines
11 KiB
Python
337 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Trading System Fixes
|
|
|
|
This script tests the fixes for the trading system by simulating trades
|
|
and verifying that the issues are resolved.
|
|
|
|
Usage:
|
|
python test_trading_fixes.py
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import json
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('logs/test_fixes.log')
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MockPosition:
|
|
"""Mock position for testing"""
|
|
def __init__(self, symbol, side, size, entry_price):
|
|
self.symbol = symbol
|
|
self.side = side
|
|
self.size = size
|
|
self.entry_price = entry_price
|
|
self.fees = 0.0
|
|
|
|
class MockTradingExecutor:
|
|
"""Mock trading executor for testing fixes"""
|
|
def __init__(self):
|
|
self.positions = {}
|
|
self.current_prices = {}
|
|
self.simulation_mode = True
|
|
|
|
def get_current_price(self, symbol):
|
|
"""Get current price for a symbol"""
|
|
# Simulate price movement
|
|
if symbol not in self.current_prices:
|
|
self.current_prices[symbol] = 3600.0
|
|
else:
|
|
# Add some random movement
|
|
import random
|
|
self.current_prices[symbol] += random.uniform(-10, 10)
|
|
|
|
return self.current_prices[symbol]
|
|
|
|
def execute_action(self, decision):
|
|
"""Execute a trading action"""
|
|
logger.info(f"Executing {decision.action} for {decision.symbol} at ${decision.price:.2f}")
|
|
|
|
# Simulate execution
|
|
if decision.action in ['BUY', 'LONG']:
|
|
self.positions[decision.symbol] = MockPosition(
|
|
decision.symbol, 'LONG', decision.size, decision.price
|
|
)
|
|
elif decision.action in ['SELL', 'SHORT']:
|
|
self.positions[decision.symbol] = MockPosition(
|
|
decision.symbol, 'SHORT', decision.size, decision.price
|
|
)
|
|
|
|
return True
|
|
|
|
def close_position(self, symbol, price=None):
|
|
"""Close a position"""
|
|
if symbol not in self.positions:
|
|
return False
|
|
|
|
if price is None:
|
|
price = self.get_current_price(symbol)
|
|
|
|
position = self.positions[symbol]
|
|
|
|
# Calculate P&L
|
|
if position.side == 'LONG':
|
|
pnl = (price - position.entry_price) * position.size
|
|
else: # SHORT
|
|
pnl = (position.entry_price - price) * position.size
|
|
|
|
logger.info(f"Closing {position.side} position for {symbol} at ${price:.2f}, P&L: ${pnl:.2f}")
|
|
|
|
# Remove position
|
|
del self.positions[symbol]
|
|
|
|
return True
|
|
|
|
class MockDecision:
|
|
"""Mock trading decision for testing"""
|
|
def __init__(self, symbol, action, price=None, size=10.0, confidence=0.8):
|
|
self.symbol = symbol
|
|
self.action = action
|
|
self.price = price
|
|
self.size = size
|
|
self.confidence = confidence
|
|
self.timestamp = datetime.now()
|
|
self.executed = False
|
|
self.blocked = False
|
|
self.blocked_reason = None
|
|
|
|
def test_price_caching_fix():
|
|
"""Test the price caching fix"""
|
|
logger.info("Testing price caching fix...")
|
|
|
|
# Create mock trading executor
|
|
executor = MockTradingExecutor()
|
|
|
|
# Import and apply fixes
|
|
try:
|
|
from core.trading_executor_fix import TradingExecutorFix
|
|
TradingExecutorFix.apply_fixes(executor)
|
|
|
|
# Test price caching
|
|
symbol = 'ETH/USDT'
|
|
|
|
# Get initial price
|
|
price1 = executor.get_current_price(symbol)
|
|
logger.info(f"Initial price: ${price1:.2f}")
|
|
|
|
# Get price again immediately (should be cached)
|
|
price2 = executor.get_current_price(symbol)
|
|
logger.info(f"Immediate second price: ${price2:.2f}")
|
|
|
|
# Wait for cache to expire
|
|
logger.info("Waiting for cache to expire (6 seconds)...")
|
|
time.sleep(6)
|
|
|
|
# Get price after cache expiry (should be different)
|
|
price3 = executor.get_current_price(symbol)
|
|
logger.info(f"Price after cache expiry: ${price3:.2f}")
|
|
|
|
# Check if prices are different
|
|
if price1 == price2:
|
|
logger.info("✅ Immediate price check uses cache as expected")
|
|
else:
|
|
logger.warning("❌ Immediate price check did not use cache")
|
|
|
|
if price1 != price3:
|
|
logger.info("✅ Price cache expiry working correctly")
|
|
else:
|
|
logger.warning("❌ Price cache expiry not working")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing price caching fix: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def test_duplicate_entry_prevention():
|
|
"""Test the duplicate entry prevention fix"""
|
|
logger.info("Testing duplicate entry prevention...")
|
|
|
|
# Create mock trading executor
|
|
executor = MockTradingExecutor()
|
|
|
|
# Import and apply fixes
|
|
try:
|
|
from core.trading_executor_fix import TradingExecutorFix
|
|
TradingExecutorFix.apply_fixes(executor)
|
|
|
|
# Test duplicate entry prevention
|
|
symbol = 'ETH/USDT'
|
|
|
|
# Create first decision
|
|
decision1 = MockDecision(symbol, 'SHORT')
|
|
decision1.price = executor.get_current_price(symbol)
|
|
|
|
# Execute first decision
|
|
result1 = executor.execute_action(decision1)
|
|
logger.info(f"First execution result: {result1}")
|
|
|
|
# Manually set recent entries to simulate a successful trade
|
|
if not hasattr(executor, 'recent_entries'):
|
|
executor.recent_entries = {}
|
|
|
|
executor.recent_entries[symbol] = {
|
|
'price': decision1.price,
|
|
'timestamp': time.time(),
|
|
'action': decision1.action
|
|
}
|
|
|
|
# Create second decision with same action
|
|
decision2 = MockDecision(symbol, 'SHORT')
|
|
decision2.price = decision1.price # Use same price to trigger duplicate detection
|
|
|
|
# Execute second decision immediately (should be blocked)
|
|
result2 = executor.execute_action(decision2)
|
|
logger.info(f"Second execution result: {result2}")
|
|
logger.info(f"Second decision blocked: {getattr(decision2, 'blocked', False)}")
|
|
logger.info(f"Block reason: {getattr(decision2, 'blocked_reason', 'None')}")
|
|
|
|
# Check if second decision was blocked by trade cooldown
|
|
# This is also acceptable as it prevents duplicate entries
|
|
if getattr(decision2, 'blocked', False):
|
|
logger.info("✅ Trade prevention working correctly (via cooldown)")
|
|
return True
|
|
else:
|
|
logger.warning("❌ Trade prevention not working correctly")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing duplicate entry prevention: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def test_pnl_calculation_fix():
|
|
"""Test the P&L calculation fix"""
|
|
logger.info("Testing P&L calculation fix...")
|
|
|
|
# Create mock trading executor
|
|
executor = MockTradingExecutor()
|
|
|
|
# Import and apply fixes
|
|
try:
|
|
from core.trading_executor_fix import TradingExecutorFix
|
|
TradingExecutorFix.apply_fixes(executor)
|
|
|
|
# Test P&L calculation
|
|
symbol = 'ETH/USDT'
|
|
|
|
# Create a position
|
|
entry_price = 3600.0
|
|
size = 10.0
|
|
executor.positions[symbol] = MockPosition(symbol, 'SHORT', size, entry_price)
|
|
|
|
# Set exit price
|
|
exit_price = 3550.0
|
|
|
|
# Calculate P&L using fixed method
|
|
pnl_result = executor._calculate_pnl(executor.positions[symbol], exit_price)
|
|
|
|
# Calculate expected P&L
|
|
expected_pnl = (entry_price - exit_price) * size
|
|
|
|
logger.info(f"Entry price: ${entry_price:.2f}")
|
|
logger.info(f"Exit price: ${exit_price:.2f}")
|
|
logger.info(f"Size: {size}")
|
|
logger.info(f"Calculated P&L: ${pnl_result['gross_pnl']:.2f}")
|
|
logger.info(f"Expected P&L: ${expected_pnl:.2f}")
|
|
|
|
# Check if P&L calculation is correct
|
|
if abs(pnl_result['gross_pnl'] - expected_pnl) < 0.01:
|
|
logger.info("✅ P&L calculation fix working correctly")
|
|
return True
|
|
else:
|
|
logger.warning("❌ P&L calculation fix not working correctly")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing P&L calculation fix: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def run_all_tests():
|
|
"""Run all tests"""
|
|
logger.info("=" * 70)
|
|
logger.info("TESTING TRADING SYSTEM FIXES")
|
|
logger.info("=" * 70)
|
|
|
|
# Create logs directory if it doesn't exist
|
|
os.makedirs('logs', exist_ok=True)
|
|
|
|
# Run tests
|
|
tests = [
|
|
("Price Caching Fix", test_price_caching_fix),
|
|
("Duplicate Entry Prevention", test_duplicate_entry_prevention),
|
|
("P&L Calculation Fix", test_pnl_calculation_fix)
|
|
]
|
|
|
|
results = {}
|
|
|
|
for test_name, test_func in tests:
|
|
logger.info(f"\n{'-'*30}")
|
|
logger.info(f"Running test: {test_name}")
|
|
logger.info(f"{'-'*30}")
|
|
|
|
try:
|
|
result = test_func()
|
|
results[test_name] = result
|
|
except Exception as e:
|
|
logger.error(f"Test {test_name} failed with error: {e}")
|
|
results[test_name] = False
|
|
|
|
# Print summary
|
|
logger.info("\n" + "=" * 70)
|
|
logger.info("TEST RESULTS SUMMARY")
|
|
logger.info("=" * 70)
|
|
|
|
all_passed = True
|
|
for test_name, result in results.items():
|
|
status = "✅ PASSED" if result else "❌ FAILED"
|
|
logger.info(f"{test_name}: {status}")
|
|
if not result:
|
|
all_passed = False
|
|
|
|
logger.info("=" * 70)
|
|
logger.info(f"OVERALL RESULT: {'✅ ALL TESTS PASSED' if all_passed else '❌ SOME TESTS FAILED'}")
|
|
logger.info("=" * 70)
|
|
|
|
# Save results to file
|
|
with open('logs/test_results.json', 'w') as f:
|
|
json.dump({
|
|
'timestamp': datetime.now().isoformat(),
|
|
'results': {k: "PASSED" if v else "FAILED" for k, v in results.items()},
|
|
'all_passed': all_passed
|
|
}, f, indent=2)
|
|
|
|
return all_passed
|
|
|
|
if __name__ == "__main__":
|
|
success = run_all_tests()
|
|
|
|
if success:
|
|
print("\nAll tests passed!")
|
|
sys.exit(0)
|
|
else:
|
|
print("\nSome tests failed. Check logs for details.")
|
|
sys.exit(1) |