""" Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations This module provides fixes for: 1. Identical entry prices issue 2. Price caching problems 3. Position tracking reset logic 4. Trade cooldown implementation 5. P&L calculation verification Apply these fixes to the TradingExecutor class to improve trade execution reliability. """ import logging import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Union logger = logging.getLogger(__name__) class TradingExecutorFix: """ Fixes for the TradingExecutor class to address entry/exit price issues and improve P&L calculation accuracy. """ def __init__(self, trading_executor): """ Initialize the fix with a reference to the trading executor Args: trading_executor: The TradingExecutor instance to fix """ self.trading_executor = trading_executor # Add cooldown tracking self.last_trade_time = {} # {symbol: timestamp} self.min_trade_cooldown = 30 # 30 seconds minimum between trades # Add price history for validation self.recent_entry_prices = {} # {symbol: [recent_prices]} self.max_price_history = 10 # Keep last 10 entry prices # Add position reset tracking self.position_reset_flags = {} # {symbol: bool} # Add price update tracking self.last_price_update = {} # {symbol: timestamp} self.price_update_threshold = 5 # 5 seconds max since last price update # Add P&L verification self.trade_history = {} # {symbol: [trade_records]} logger.info("TradingExecutorFix initialized - addressing entry/exit price issues") def apply_fixes(self): """Apply all fixes to the trading executor""" self._patch_execute_action() self._patch_close_position() self._patch_calculate_pnl() self._patch_update_prices() logger.info("All trading executor fixes applied successfully") def _patch_execute_action(self): """Patch the execute_action method to add price validation and cooldown""" original_execute_action = self.trading_executor.execute_action def execute_action_with_fixes(decision): """Enhanced execute_action with price validation and cooldown""" try: symbol = decision.symbol action = decision.action current_time = datetime.now() # 1. Check cooldown period if symbol in self.last_trade_time: time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds() if time_since_last_trade < self.min_trade_cooldown: logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}") return False # 2. Validate price freshness if symbol in self.last_price_update: time_since_update = (current_time - self.last_price_update[symbol]).total_seconds() if time_since_update > self.price_update_threshold: logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}") # Force price refresh self._refresh_price(symbol) return False # 3. Validate entry price against recent history current_price = self._get_current_price(symbol) if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0: # Check if price is identical to any recent entry if current_price in self.recent_entry_prices[symbol]: logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}") return False # 4. Ensure position is properly reset before new entry if not self._ensure_position_reset(symbol): logger.warning(f"Trade rejected: Position not properly reset for {symbol}") return False # Execute the original action result = original_execute_action(decision) # If successful, update tracking if result: # Update cooldown timestamp self.last_trade_time[symbol] = current_time # Update price history if symbol not in self.recent_entry_prices: self.recent_entry_prices[symbol] = [] self.recent_entry_prices[symbol].append(current_price) # Keep only the most recent prices if len(self.recent_entry_prices[symbol]) > self.max_price_history: self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:] # Mark position as active self.position_reset_flags[symbol] = False logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation") return result except Exception as e: logger.error(f"Error in execute_action_with_fixes: {e}") return original_execute_action(decision) # Replace the original method self.trading_executor.execute_action = execute_action_with_fixes logger.info("Patched execute_action with price validation and cooldown") def _patch_close_position(self): """Patch the close_position method to ensure proper position reset""" original_close_position = self.trading_executor.close_position def close_position_with_fixes(symbol, **kwargs): """Enhanced close_position with proper reset logic""" try: # Get current price for P&L verification exit_price = self._get_current_price(symbol) # Call original close position result = original_close_position(symbol, **kwargs) if result: # Mark position as reset self.position_reset_flags[symbol] = True # Record trade for verification if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions: position = self.trading_executor.positions[symbol] # Create trade record trade_record = { 'symbol': symbol, 'entry_time': getattr(position, 'entry_time', datetime.now()), 'exit_time': datetime.now(), 'entry_price': getattr(position, 'entry_price', 0), 'exit_price': exit_price, 'size': getattr(position, 'size', 0), 'side': getattr(position, 'side', 'UNKNOWN'), 'pnl': self._calculate_verified_pnl(position, exit_price), 'fees': getattr(position, 'fees', 0), 'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds() } # Store trade record if symbol not in self.trade_history: self.trade_history[symbol] = [] self.trade_history[symbol].append(trade_record) logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}") return result except Exception as e: logger.error(f"Error in close_position_with_fixes: {e}") return original_close_position(symbol, **kwargs) # Replace the original method self.trading_executor.close_position = close_position_with_fixes logger.info("Patched close_position with proper reset logic") def _patch_calculate_pnl(self): """Patch the calculate_pnl method to ensure accurate P&L calculation""" original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None) def calculate_pnl_with_fixes(position, current_price=None): """Enhanced calculate_pnl with verification""" try: # If no original method, implement our own if original_calculate_pnl is None: return self._calculate_verified_pnl(position, current_price) # Call original method original_pnl = original_calculate_pnl(position, current_price) # Calculate our verified P&L verified_pnl = self._calculate_verified_pnl(position, current_price) # If there's a significant difference, log it if abs(original_pnl - verified_pnl) > 0.01: logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}") # Use the verified P&L return verified_pnl return original_pnl except Exception as e: logger.error(f"Error in calculate_pnl_with_fixes: {e}") if original_calculate_pnl: return original_calculate_pnl(position, current_price) return 0.0 # Replace the original method if it exists if original_calculate_pnl: self.trading_executor.calculate_pnl = calculate_pnl_with_fixes logger.info("Patched calculate_pnl with verification") else: # Add the method if it doesn't exist self.trading_executor.calculate_pnl = calculate_pnl_with_fixes logger.info("Added calculate_pnl method with verification") def _patch_update_prices(self): """Patch the update_prices method to track price updates""" original_update_prices = getattr(self.trading_executor, 'update_prices', None) def update_prices_with_tracking(prices): """Enhanced update_prices with timestamp tracking""" try: # Call original method if it exists if original_update_prices: result = original_update_prices(prices) else: # If no original method, update prices directly if hasattr(self.trading_executor, 'current_prices'): self.trading_executor.current_prices.update(prices) result = True # Track update timestamps current_time = datetime.now() for symbol in prices: self.last_price_update[symbol] = current_time return result except Exception as e: logger.error(f"Error in update_prices_with_tracking: {e}") if original_update_prices: return original_update_prices(prices) return False # Replace the original method if it exists if original_update_prices: self.trading_executor.update_prices = update_prices_with_tracking logger.info("Patched update_prices with timestamp tracking") else: # Add the method if it doesn't exist self.trading_executor.update_prices = update_prices_with_tracking logger.info("Added update_prices method with timestamp tracking") def _calculate_verified_pnl(self, position, current_price=None): """Calculate verified P&L for a position""" try: # Get position details entry_price = getattr(position, 'entry_price', 0) size = getattr(position, 'size', 0) side = getattr(position, 'side', 'UNKNOWN') leverage = getattr(position, 'leverage', 1.0) fees = getattr(position, 'fees', 0.0) # If current_price is not provided, try to get it if current_price is None: symbol = getattr(position, 'symbol', None) if symbol: current_price = self._get_current_price(symbol) else: return 0.0 # Calculate P&L based on position side if side == 'LONG': pnl = (current_price - entry_price) * size * leverage elif side == 'SHORT': pnl = (entry_price - current_price) * size * leverage else: pnl = 0.0 # Subtract fees for net P&L net_pnl = pnl - fees return net_pnl except Exception as e: logger.error(f"Error calculating verified P&L: {e}") return 0.0 def _get_current_price(self, symbol): """Get current price for a symbol with fallbacks""" try: # Try to get from trading executor if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices: return self.trading_executor.current_prices[symbol] # Try to get from data provider if hasattr(self.trading_executor, 'data_provider'): data_provider = self.trading_executor.data_provider if hasattr(data_provider, 'get_current_price'): price = data_provider.get_current_price(symbol) if price and price > 0: return price # Try to get from COB data if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data: cob_data = self.trading_executor.latest_cob_data[symbol] if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats: return cob_data.stats['mid_price'] # Default fallback return 0.0 except Exception as e: logger.error(f"Error getting current price for {symbol}: {e}") return 0.0 def _refresh_price(self, symbol): """Force a price refresh for a symbol""" try: # Try to refresh from data provider if hasattr(self.trading_executor, 'data_provider'): data_provider = self.trading_executor.data_provider if hasattr(data_provider, 'fetch_current_price'): price = data_provider.fetch_current_price(symbol) if price and price > 0: # Update trading executor price if hasattr(self.trading_executor, 'current_prices'): self.trading_executor.current_prices[symbol] = price # Update timestamp self.last_price_update[symbol] = datetime.now() logger.info(f"Refreshed price for {symbol}: ${price:.2f}") return True logger.warning(f"Failed to refresh price for {symbol}") return False except Exception as e: logger.error(f"Error refreshing price for {symbol}: {e}") return False def _ensure_position_reset(self, symbol): """Ensure position is properly reset before new entry""" try: # Check if we have an active position if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions: # Position exists, check if it's valid position = self.trading_executor.positions[symbol] if position and getattr(position, 'active', False): logger.warning(f"Position already active for {symbol}, cannot enter new position") return False # Check reset flag if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]: # Force position cleanup if hasattr(self.trading_executor, 'positions'): self.trading_executor.positions.pop(symbol, None) logger.info(f"Forced position reset for {symbol}") self.position_reset_flags[symbol] = True return True except Exception as e: logger.error(f"Error ensuring position reset for {symbol}: {e}") return False def get_trade_history(self, symbol=None): """Get verified trade history""" if symbol: return self.trade_history.get(symbol, []) return self.trade_history def get_price_update_status(self): """Get price update status for all symbols""" status = {} current_time = datetime.now() for symbol, timestamp in self.last_price_update.items(): time_since_update = (current_time - timestamp).total_seconds() status[symbol] = { 'last_update': timestamp, 'seconds_ago': time_since_update, 'is_fresh': time_since_update <= self.price_update_threshold } return status