261 lines
11 KiB
Python
261 lines
11 KiB
Python
"""
|
|
Trading Executor Fix
|
|
|
|
This module provides fixes for the trading executor to address:
|
|
1. Duplicate entry prices
|
|
2. P&L calculation issues
|
|
3. Position tracking problems
|
|
|
|
Apply these fixes by importing and applying the patch in main.py
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TradingExecutorFix:
|
|
"""Fixes for the TradingExecutor class"""
|
|
|
|
@staticmethod
|
|
def apply_fixes(trading_executor):
|
|
"""Apply all fixes to the trading executor"""
|
|
logger.info("Applying TradingExecutor fixes...")
|
|
|
|
# Store original methods for patching
|
|
original_execute_action = trading_executor.execute_action
|
|
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
|
|
|
|
# Apply fixes
|
|
TradingExecutorFix._fix_price_caching(trading_executor)
|
|
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
|
|
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
|
|
TradingExecutorFix._add_trade_cooldown(trading_executor)
|
|
TradingExecutorFix._fix_position_tracking(trading_executor)
|
|
|
|
logger.info("TradingExecutor fixes applied successfully")
|
|
return trading_executor
|
|
|
|
@staticmethod
|
|
def _fix_price_caching(trading_executor):
|
|
"""Fix price caching to prevent duplicate entry prices"""
|
|
# Add a price cache timestamp to track when prices were last updated
|
|
trading_executor.price_cache_timestamp = {}
|
|
|
|
# Store original get_current_price method
|
|
original_get_current_price = trading_executor.get_current_price
|
|
|
|
def get_current_price_fixed(self, symbol):
|
|
"""Fixed get_current_price method with cache invalidation"""
|
|
now = time.time()
|
|
|
|
# Force price refresh if cache is older than 5 seconds
|
|
if symbol in self.price_cache_timestamp:
|
|
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
|
|
if cache_age > 5: # 5 seconds max cache age
|
|
# Clear price cache for this symbol
|
|
if hasattr(self, 'current_prices') and symbol in self.current_prices:
|
|
del self.current_prices[symbol]
|
|
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
|
|
|
|
# Call original method to get fresh price
|
|
price = original_get_current_price(symbol)
|
|
|
|
# Update cache timestamp
|
|
self.price_cache_timestamp[symbol] = now
|
|
|
|
return price
|
|
|
|
# Apply the patch
|
|
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
|
|
logger.info("Price caching fix applied")
|
|
|
|
@staticmethod
|
|
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
|
|
"""Fix P&L calculation to ensure accuracy"""
|
|
def calculate_pnl_fixed(self, position, current_price=None):
|
|
"""Fixed P&L calculation with proper handling of position side"""
|
|
try:
|
|
# Get position details
|
|
entry_price = position.entry_price
|
|
size = position.size
|
|
side = position.side
|
|
|
|
# Use provided price or get current price
|
|
if current_price is None:
|
|
current_price = self.get_current_price(position.symbol)
|
|
|
|
# Calculate P&L based on position side
|
|
if side == 'LONG':
|
|
pnl = (current_price - entry_price) * size
|
|
else: # SHORT
|
|
pnl = (entry_price - current_price) * size
|
|
|
|
# Calculate fees (if available)
|
|
fees = getattr(position, 'fees', 0.0)
|
|
|
|
# Return both gross and net P&L
|
|
return {
|
|
'gross_pnl': pnl,
|
|
'fees': fees,
|
|
'net_pnl': pnl - fees
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating P&L: {e}")
|
|
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
|
|
|
|
# Apply the patch if original method exists
|
|
if original_calculate_pnl:
|
|
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
|
logger.info("P&L calculation fix applied")
|
|
else:
|
|
# Add the method if it doesn't exist
|
|
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
|
logger.info("P&L calculation method added")
|
|
|
|
@staticmethod
|
|
def _fix_execute_action(trading_executor, original_execute_action):
|
|
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
|
|
def execute_action_fixed(self, decision):
|
|
"""Fixed execute_action with duplicate entry prevention"""
|
|
try:
|
|
symbol = decision.symbol
|
|
action = decision.action
|
|
|
|
# Check for duplicate entry (same price as recent entry)
|
|
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
|
recent_entry = self.recent_entries[symbol]
|
|
current_price = self.get_current_price(symbol)
|
|
|
|
# If price is within 0.1% of recent entry, consider it a duplicate
|
|
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
|
|
time_diff = time.time() - recent_entry['timestamp']
|
|
|
|
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
|
|
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
|
|
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
|
|
|
|
# Mark decision as blocked
|
|
decision.blocked = True
|
|
decision.blocked_reason = "Duplicate entry prevention"
|
|
return False
|
|
|
|
# Check trade cooldown
|
|
if hasattr(self, '_check_trade_cooldown'):
|
|
if not self._check_trade_cooldown(symbol, action):
|
|
# Mark decision as blocked
|
|
decision.blocked = True
|
|
decision.blocked_reason = "Trade cooldown active"
|
|
return False
|
|
|
|
# Force price refresh before execution
|
|
fresh_price = self.get_current_price(symbol)
|
|
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
|
|
|
|
# Update decision price with fresh price
|
|
decision.price = fresh_price
|
|
|
|
# Call original execute_action
|
|
result = original_execute_action(decision)
|
|
|
|
# If execution was successful, record the entry
|
|
if result and not getattr(decision, 'blocked', False):
|
|
if not hasattr(self, 'recent_entries'):
|
|
self.recent_entries = {}
|
|
|
|
self.recent_entries[symbol] = {
|
|
'price': fresh_price,
|
|
'timestamp': time.time(),
|
|
'action': action
|
|
}
|
|
|
|
# Record last trade time for cooldown
|
|
if not hasattr(self, 'last_trade_time'):
|
|
self.last_trade_time = {}
|
|
|
|
self.last_trade_time[symbol] = time.time()
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in execute_action_fixed: {e}")
|
|
return False
|
|
|
|
# Apply the patch
|
|
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
|
|
|
|
# Initialize recent entries dict if it doesn't exist
|
|
if not hasattr(trading_executor, 'recent_entries'):
|
|
trading_executor.recent_entries = {}
|
|
|
|
logger.info("Execute action fix applied")
|
|
|
|
@staticmethod
|
|
def _add_trade_cooldown(trading_executor):
|
|
"""Add trade cooldown to prevent rapid consecutive trades"""
|
|
# Add cooldown settings
|
|
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
|
|
|
|
if not hasattr(trading_executor, 'last_trade_time'):
|
|
trading_executor.last_trade_time = {}
|
|
|
|
def check_trade_cooldown(self, symbol, action):
|
|
"""Check if trade cooldown is active for a symbol"""
|
|
if not hasattr(self, 'last_trade_time'):
|
|
self.last_trade_time = {}
|
|
return True
|
|
|
|
if symbol not in self.last_trade_time:
|
|
return True
|
|
|
|
# Get time since last trade
|
|
time_since_last = time.time() - self.last_trade_time[symbol]
|
|
|
|
# Check if cooldown is still active
|
|
if time_since_last < self.trade_cooldown_seconds:
|
|
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
|
|
f"need {self.trade_cooldown_seconds}s")
|
|
return False
|
|
|
|
return True
|
|
|
|
# Add the method
|
|
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
|
|
logger.info("Trade cooldown feature added")
|
|
|
|
@staticmethod
|
|
def _fix_position_tracking(trading_executor):
|
|
"""Fix position tracking to ensure proper reset between trades"""
|
|
# Store original close_position method
|
|
original_close_position = getattr(trading_executor, 'close_position', None)
|
|
|
|
if original_close_position:
|
|
def close_position_fixed(self, symbol, price=None):
|
|
"""Fixed close_position with proper position cleanup"""
|
|
try:
|
|
# Call original close_position
|
|
result = original_close_position(symbol, price)
|
|
|
|
# Ensure position is fully cleaned up
|
|
if symbol in self.positions:
|
|
del self.positions[symbol]
|
|
|
|
# Clear recent entry for this symbol
|
|
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
|
del self.recent_entries[symbol]
|
|
|
|
logger.info(f"Position for {symbol} fully cleaned up after close")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in close_position_fixed: {e}")
|
|
return False
|
|
|
|
# Apply the patch
|
|
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
|
|
logger.info("Position tracking fix applied")
|
|
else:
|
|
logger.warning("close_position method not found, skipping position tracking fix") |