401 lines
18 KiB
Python
401 lines
18 KiB
Python
"""
|
|
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
|
|
|
|
This module provides fixes for:
|
|
1. Identical entry prices issue
|
|
2. Price caching problems
|
|
3. Position tracking reset logic
|
|
4. Trade cooldown implementation
|
|
5. P&L calculation verification
|
|
|
|
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TradingExecutorFix:
|
|
"""
|
|
Fixes for the TradingExecutor class to address entry/exit price issues
|
|
and improve P&L calculation accuracy.
|
|
"""
|
|
|
|
def __init__(self, trading_executor):
|
|
"""
|
|
Initialize the fix with a reference to the trading executor
|
|
|
|
Args:
|
|
trading_executor: The TradingExecutor instance to fix
|
|
"""
|
|
self.trading_executor = trading_executor
|
|
|
|
# Add cooldown tracking
|
|
self.last_trade_time = {} # {symbol: timestamp}
|
|
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
|
|
|
|
# Add price history for validation
|
|
self.recent_entry_prices = {} # {symbol: [recent_prices]}
|
|
self.max_price_history = 10 # Keep last 10 entry prices
|
|
|
|
# Add position reset tracking
|
|
self.position_reset_flags = {} # {symbol: bool}
|
|
|
|
# Add price update tracking
|
|
self.last_price_update = {} # {symbol: timestamp}
|
|
self.price_update_threshold = 5 # 5 seconds max since last price update
|
|
|
|
# Add P&L verification
|
|
self.trade_history = {} # {symbol: [trade_records]}
|
|
|
|
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
|
|
|
|
def apply_fixes(self):
|
|
"""Apply all fixes to the trading executor"""
|
|
self._patch_execute_action()
|
|
self._patch_close_position()
|
|
self._patch_calculate_pnl()
|
|
self._patch_update_prices()
|
|
|
|
logger.info("All trading executor fixes applied successfully")
|
|
|
|
def _patch_execute_action(self):
|
|
"""Patch the execute_action method to add price validation and cooldown"""
|
|
original_execute_action = self.trading_executor.execute_action
|
|
|
|
def execute_action_with_fixes(decision):
|
|
"""Enhanced execute_action with price validation and cooldown"""
|
|
try:
|
|
symbol = decision.symbol
|
|
action = decision.action
|
|
current_time = datetime.now()
|
|
|
|
# 1. Check cooldown period
|
|
if symbol in self.last_trade_time:
|
|
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
|
|
if time_since_last_trade < self.min_trade_cooldown:
|
|
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
|
|
return False
|
|
|
|
# 2. Validate price freshness
|
|
if symbol in self.last_price_update:
|
|
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
|
|
if time_since_update > self.price_update_threshold:
|
|
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
|
|
# Force price refresh
|
|
self._refresh_price(symbol)
|
|
return False
|
|
|
|
# 3. Validate entry price against recent history
|
|
current_price = self._get_current_price(symbol)
|
|
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
|
|
# Check if price is identical to any recent entry
|
|
if current_price in self.recent_entry_prices[symbol]:
|
|
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
|
|
return False
|
|
|
|
# 4. Ensure position is properly reset before new entry
|
|
if not self._ensure_position_reset(symbol):
|
|
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
|
|
return False
|
|
|
|
# Execute the original action
|
|
result = original_execute_action(decision)
|
|
|
|
# If successful, update tracking
|
|
if result:
|
|
# Update cooldown timestamp
|
|
self.last_trade_time[symbol] = current_time
|
|
|
|
# Update price history
|
|
if symbol not in self.recent_entry_prices:
|
|
self.recent_entry_prices[symbol] = []
|
|
|
|
self.recent_entry_prices[symbol].append(current_price)
|
|
# Keep only the most recent prices
|
|
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
|
|
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
|
|
|
|
# Mark position as active
|
|
self.position_reset_flags[symbol] = False
|
|
|
|
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in execute_action_with_fixes: {e}")
|
|
return original_execute_action(decision)
|
|
|
|
# Replace the original method
|
|
self.trading_executor.execute_action = execute_action_with_fixes
|
|
logger.info("Patched execute_action with price validation and cooldown")
|
|
|
|
def _patch_close_position(self):
|
|
"""Patch the close_position method to ensure proper position reset"""
|
|
original_close_position = self.trading_executor.close_position
|
|
|
|
def close_position_with_fixes(symbol, **kwargs):
|
|
"""Enhanced close_position with proper reset logic"""
|
|
try:
|
|
# Get current price for P&L verification
|
|
exit_price = self._get_current_price(symbol)
|
|
|
|
# Call original close position
|
|
result = original_close_position(symbol, **kwargs)
|
|
|
|
if result:
|
|
# Mark position as reset
|
|
self.position_reset_flags[symbol] = True
|
|
|
|
# Record trade for verification
|
|
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
|
position = self.trading_executor.positions[symbol]
|
|
|
|
# Create trade record
|
|
trade_record = {
|
|
'symbol': symbol,
|
|
'entry_time': getattr(position, 'entry_time', datetime.now()),
|
|
'exit_time': datetime.now(),
|
|
'entry_price': getattr(position, 'entry_price', 0),
|
|
'exit_price': exit_price,
|
|
'size': getattr(position, 'size', 0),
|
|
'side': getattr(position, 'side', 'UNKNOWN'),
|
|
'pnl': self._calculate_verified_pnl(position, exit_price),
|
|
'fees': getattr(position, 'fees', 0),
|
|
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
|
|
}
|
|
|
|
# Store trade record
|
|
if symbol not in self.trade_history:
|
|
self.trade_history[symbol] = []
|
|
self.trade_history[symbol].append(trade_record)
|
|
|
|
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in close_position_with_fixes: {e}")
|
|
return original_close_position(symbol, **kwargs)
|
|
|
|
# Replace the original method
|
|
self.trading_executor.close_position = close_position_with_fixes
|
|
logger.info("Patched close_position with proper reset logic")
|
|
|
|
def _patch_calculate_pnl(self):
|
|
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
|
|
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
|
|
|
|
def calculate_pnl_with_fixes(position, current_price=None):
|
|
"""Enhanced calculate_pnl with verification"""
|
|
try:
|
|
# If no original method, implement our own
|
|
if original_calculate_pnl is None:
|
|
return self._calculate_verified_pnl(position, current_price)
|
|
|
|
# Call original method
|
|
original_pnl = original_calculate_pnl(position, current_price)
|
|
|
|
# Calculate our verified P&L
|
|
verified_pnl = self._calculate_verified_pnl(position, current_price)
|
|
|
|
# If there's a significant difference, log it
|
|
if abs(original_pnl - verified_pnl) > 0.01:
|
|
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
|
|
# Use the verified P&L
|
|
return verified_pnl
|
|
|
|
return original_pnl
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
|
|
if original_calculate_pnl:
|
|
return original_calculate_pnl(position, current_price)
|
|
return 0.0
|
|
|
|
# Replace the original method if it exists
|
|
if original_calculate_pnl:
|
|
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
|
logger.info("Patched calculate_pnl with verification")
|
|
else:
|
|
# Add the method if it doesn't exist
|
|
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
|
logger.info("Added calculate_pnl method with verification")
|
|
|
|
def _patch_update_prices(self):
|
|
"""Patch the update_prices method to track price updates"""
|
|
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
|
|
|
|
def update_prices_with_tracking(prices):
|
|
"""Enhanced update_prices with timestamp tracking"""
|
|
try:
|
|
# Call original method if it exists
|
|
if original_update_prices:
|
|
result = original_update_prices(prices)
|
|
else:
|
|
# If no original method, update prices directly
|
|
if hasattr(self.trading_executor, 'current_prices'):
|
|
self.trading_executor.current_prices.update(prices)
|
|
result = True
|
|
|
|
# Track update timestamps
|
|
current_time = datetime.now()
|
|
for symbol in prices:
|
|
self.last_price_update[symbol] = current_time
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in update_prices_with_tracking: {e}")
|
|
if original_update_prices:
|
|
return original_update_prices(prices)
|
|
return False
|
|
|
|
# Replace the original method if it exists
|
|
if original_update_prices:
|
|
self.trading_executor.update_prices = update_prices_with_tracking
|
|
logger.info("Patched update_prices with timestamp tracking")
|
|
else:
|
|
# Add the method if it doesn't exist
|
|
self.trading_executor.update_prices = update_prices_with_tracking
|
|
logger.info("Added update_prices method with timestamp tracking")
|
|
|
|
def _calculate_verified_pnl(self, position, current_price=None):
|
|
"""Calculate verified P&L for a position"""
|
|
try:
|
|
# Get position details
|
|
entry_price = getattr(position, 'entry_price', 0)
|
|
size = getattr(position, 'size', 0)
|
|
side = getattr(position, 'side', 'UNKNOWN')
|
|
leverage = getattr(position, 'leverage', 1.0)
|
|
fees = getattr(position, 'fees', 0.0)
|
|
|
|
# If current_price is not provided, try to get it
|
|
if current_price is None:
|
|
symbol = getattr(position, 'symbol', None)
|
|
if symbol:
|
|
current_price = self._get_current_price(symbol)
|
|
else:
|
|
return 0.0
|
|
|
|
# Calculate P&L based on position side
|
|
if side == 'LONG':
|
|
pnl = (current_price - entry_price) * size * leverage
|
|
elif side == 'SHORT':
|
|
pnl = (entry_price - current_price) * size * leverage
|
|
else:
|
|
pnl = 0.0
|
|
|
|
# Subtract fees for net P&L
|
|
net_pnl = pnl - fees
|
|
|
|
return net_pnl
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating verified P&L: {e}")
|
|
return 0.0
|
|
|
|
def _get_current_price(self, symbol):
|
|
"""Get current price for a symbol with fallbacks"""
|
|
try:
|
|
# Try to get from trading executor
|
|
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
|
|
return self.trading_executor.current_prices[symbol]
|
|
|
|
# Try to get from data provider
|
|
if hasattr(self.trading_executor, 'data_provider'):
|
|
data_provider = self.trading_executor.data_provider
|
|
if hasattr(data_provider, 'get_current_price'):
|
|
price = data_provider.get_current_price(symbol)
|
|
if price and price > 0:
|
|
return price
|
|
|
|
# Try to get from COB data
|
|
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
|
|
cob_data = self.trading_executor.latest_cob_data[symbol]
|
|
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
|
|
return cob_data.stats['mid_price']
|
|
|
|
# Default fallback
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
|
return 0.0
|
|
|
|
def _refresh_price(self, symbol):
|
|
"""Force a price refresh for a symbol"""
|
|
try:
|
|
# Try to refresh from data provider
|
|
if hasattr(self.trading_executor, 'data_provider'):
|
|
data_provider = self.trading_executor.data_provider
|
|
if hasattr(data_provider, 'fetch_current_price'):
|
|
price = data_provider.fetch_current_price(symbol)
|
|
if price and price > 0:
|
|
# Update trading executor price
|
|
if hasattr(self.trading_executor, 'current_prices'):
|
|
self.trading_executor.current_prices[symbol] = price
|
|
|
|
# Update timestamp
|
|
self.last_price_update[symbol] = datetime.now()
|
|
|
|
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
|
|
return True
|
|
|
|
logger.warning(f"Failed to refresh price for {symbol}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error refreshing price for {symbol}: {e}")
|
|
return False
|
|
|
|
def _ensure_position_reset(self, symbol):
|
|
"""Ensure position is properly reset before new entry"""
|
|
try:
|
|
# Check if we have an active position
|
|
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
|
# Position exists, check if it's valid
|
|
position = self.trading_executor.positions[symbol]
|
|
if position and getattr(position, 'active', False):
|
|
logger.warning(f"Position already active for {symbol}, cannot enter new position")
|
|
return False
|
|
|
|
# Check reset flag
|
|
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
|
|
# Force position cleanup
|
|
if hasattr(self.trading_executor, 'positions'):
|
|
self.trading_executor.positions.pop(symbol, None)
|
|
|
|
logger.info(f"Forced position reset for {symbol}")
|
|
self.position_reset_flags[symbol] = True
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring position reset for {symbol}: {e}")
|
|
return False
|
|
|
|
def get_trade_history(self, symbol=None):
|
|
"""Get verified trade history"""
|
|
if symbol:
|
|
return self.trade_history.get(symbol, [])
|
|
return self.trade_history
|
|
|
|
def get_price_update_status(self):
|
|
"""Get price update status for all symbols"""
|
|
status = {}
|
|
current_time = datetime.now()
|
|
|
|
for symbol, timestamp in self.last_price_update.items():
|
|
time_since_update = (current_time - timestamp).total_seconds()
|
|
status[symbol] = {
|
|
'last_update': timestamp,
|
|
'seconds_ago': time_since_update,
|
|
'is_fresh': time_since_update <= self.price_update_threshold
|
|
}
|
|
|
|
return status |