Files
gogo2/core/trading_executor_fix.py
2025-07-23 17:33:43 +03:00

401 lines
18 KiB
Python

"""
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
This module provides fixes for:
1. Identical entry prices issue
2. Price caching problems
3. Position tracking reset logic
4. Trade cooldown implementation
5. P&L calculation verification
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
"""
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
logger = logging.getLogger(__name__)
class TradingExecutorFix:
"""
Fixes for the TradingExecutor class to address entry/exit price issues
and improve P&L calculation accuracy.
"""
def __init__(self, trading_executor):
"""
Initialize the fix with a reference to the trading executor
Args:
trading_executor: The TradingExecutor instance to fix
"""
self.trading_executor = trading_executor
# Add cooldown tracking
self.last_trade_time = {} # {symbol: timestamp}
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
# Add price history for validation
self.recent_entry_prices = {} # {symbol: [recent_prices]}
self.max_price_history = 10 # Keep last 10 entry prices
# Add position reset tracking
self.position_reset_flags = {} # {symbol: bool}
# Add price update tracking
self.last_price_update = {} # {symbol: timestamp}
self.price_update_threshold = 5 # 5 seconds max since last price update
# Add P&L verification
self.trade_history = {} # {symbol: [trade_records]}
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
def apply_fixes(self):
"""Apply all fixes to the trading executor"""
self._patch_execute_action()
self._patch_close_position()
self._patch_calculate_pnl()
self._patch_update_prices()
logger.info("All trading executor fixes applied successfully")
def _patch_execute_action(self):
"""Patch the execute_action method to add price validation and cooldown"""
original_execute_action = self.trading_executor.execute_action
def execute_action_with_fixes(decision):
"""Enhanced execute_action with price validation and cooldown"""
try:
symbol = decision.symbol
action = decision.action
current_time = datetime.now()
# 1. Check cooldown period
if symbol in self.last_trade_time:
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
if time_since_last_trade < self.min_trade_cooldown:
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
return False
# 2. Validate price freshness
if symbol in self.last_price_update:
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
if time_since_update > self.price_update_threshold:
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
# Force price refresh
self._refresh_price(symbol)
return False
# 3. Validate entry price against recent history
current_price = self._get_current_price(symbol)
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
# Check if price is identical to any recent entry
if current_price in self.recent_entry_prices[symbol]:
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
return False
# 4. Ensure position is properly reset before new entry
if not self._ensure_position_reset(symbol):
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
return False
# Execute the original action
result = original_execute_action(decision)
# If successful, update tracking
if result:
# Update cooldown timestamp
self.last_trade_time[symbol] = current_time
# Update price history
if symbol not in self.recent_entry_prices:
self.recent_entry_prices[symbol] = []
self.recent_entry_prices[symbol].append(current_price)
# Keep only the most recent prices
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
# Mark position as active
self.position_reset_flags[symbol] = False
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
return result
except Exception as e:
logger.error(f"Error in execute_action_with_fixes: {e}")
return original_execute_action(decision)
# Replace the original method
self.trading_executor.execute_action = execute_action_with_fixes
logger.info("Patched execute_action with price validation and cooldown")
def _patch_close_position(self):
"""Patch the close_position method to ensure proper position reset"""
original_close_position = self.trading_executor.close_position
def close_position_with_fixes(symbol, **kwargs):
"""Enhanced close_position with proper reset logic"""
try:
# Get current price for P&L verification
exit_price = self._get_current_price(symbol)
# Call original close position
result = original_close_position(symbol, **kwargs)
if result:
# Mark position as reset
self.position_reset_flags[symbol] = True
# Record trade for verification
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
position = self.trading_executor.positions[symbol]
# Create trade record
trade_record = {
'symbol': symbol,
'entry_time': getattr(position, 'entry_time', datetime.now()),
'exit_time': datetime.now(),
'entry_price': getattr(position, 'entry_price', 0),
'exit_price': exit_price,
'size': getattr(position, 'size', 0),
'side': getattr(position, 'side', 'UNKNOWN'),
'pnl': self._calculate_verified_pnl(position, exit_price),
'fees': getattr(position, 'fees', 0),
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
}
# Store trade record
if symbol not in self.trade_history:
self.trade_history[symbol] = []
self.trade_history[symbol].append(trade_record)
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
return result
except Exception as e:
logger.error(f"Error in close_position_with_fixes: {e}")
return original_close_position(symbol, **kwargs)
# Replace the original method
self.trading_executor.close_position = close_position_with_fixes
logger.info("Patched close_position with proper reset logic")
def _patch_calculate_pnl(self):
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
def calculate_pnl_with_fixes(position, current_price=None):
"""Enhanced calculate_pnl with verification"""
try:
# If no original method, implement our own
if original_calculate_pnl is None:
return self._calculate_verified_pnl(position, current_price)
# Call original method
original_pnl = original_calculate_pnl(position, current_price)
# Calculate our verified P&L
verified_pnl = self._calculate_verified_pnl(position, current_price)
# If there's a significant difference, log it
if abs(original_pnl - verified_pnl) > 0.01:
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
# Use the verified P&L
return verified_pnl
return original_pnl
except Exception as e:
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
if original_calculate_pnl:
return original_calculate_pnl(position, current_price)
return 0.0
# Replace the original method if it exists
if original_calculate_pnl:
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Patched calculate_pnl with verification")
else:
# Add the method if it doesn't exist
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Added calculate_pnl method with verification")
def _patch_update_prices(self):
"""Patch the update_prices method to track price updates"""
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
def update_prices_with_tracking(prices):
"""Enhanced update_prices with timestamp tracking"""
try:
# Call original method if it exists
if original_update_prices:
result = original_update_prices(prices)
else:
# If no original method, update prices directly
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices.update(prices)
result = True
# Track update timestamps
current_time = datetime.now()
for symbol in prices:
self.last_price_update[symbol] = current_time
return result
except Exception as e:
logger.error(f"Error in update_prices_with_tracking: {e}")
if original_update_prices:
return original_update_prices(prices)
return False
# Replace the original method if it exists
if original_update_prices:
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Patched update_prices with timestamp tracking")
else:
# Add the method if it doesn't exist
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Added update_prices method with timestamp tracking")
def _calculate_verified_pnl(self, position, current_price=None):
"""Calculate verified P&L for a position"""
try:
# Get position details
entry_price = getattr(position, 'entry_price', 0)
size = getattr(position, 'size', 0)
side = getattr(position, 'side', 'UNKNOWN')
leverage = getattr(position, 'leverage', 1.0)
fees = getattr(position, 'fees', 0.0)
# If current_price is not provided, try to get it
if current_price is None:
symbol = getattr(position, 'symbol', None)
if symbol:
current_price = self._get_current_price(symbol)
else:
return 0.0
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size * leverage
elif side == 'SHORT':
pnl = (entry_price - current_price) * size * leverage
else:
pnl = 0.0
# Subtract fees for net P&L
net_pnl = pnl - fees
return net_pnl
except Exception as e:
logger.error(f"Error calculating verified P&L: {e}")
return 0.0
def _get_current_price(self, symbol):
"""Get current price for a symbol with fallbacks"""
try:
# Try to get from trading executor
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
return self.trading_executor.current_prices[symbol]
# Try to get from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'get_current_price'):
price = data_provider.get_current_price(symbol)
if price and price > 0:
return price
# Try to get from COB data
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
cob_data = self.trading_executor.latest_cob_data[symbol]
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
return cob_data.stats['mid_price']
# Default fallback
return 0.0
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return 0.0
def _refresh_price(self, symbol):
"""Force a price refresh for a symbol"""
try:
# Try to refresh from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'fetch_current_price'):
price = data_provider.fetch_current_price(symbol)
if price and price > 0:
# Update trading executor price
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices[symbol] = price
# Update timestamp
self.last_price_update[symbol] = datetime.now()
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
return True
logger.warning(f"Failed to refresh price for {symbol}")
return False
except Exception as e:
logger.error(f"Error refreshing price for {symbol}: {e}")
return False
def _ensure_position_reset(self, symbol):
"""Ensure position is properly reset before new entry"""
try:
# Check if we have an active position
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
# Position exists, check if it's valid
position = self.trading_executor.positions[symbol]
if position and getattr(position, 'active', False):
logger.warning(f"Position already active for {symbol}, cannot enter new position")
return False
# Check reset flag
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
# Force position cleanup
if hasattr(self.trading_executor, 'positions'):
self.trading_executor.positions.pop(symbol, None)
logger.info(f"Forced position reset for {symbol}")
self.position_reset_flags[symbol] = True
return True
except Exception as e:
logger.error(f"Error ensuring position reset for {symbol}: {e}")
return False
def get_trade_history(self, symbol=None):
"""Get verified trade history"""
if symbol:
return self.trade_history.get(symbol, [])
return self.trade_history
def get_price_update_status(self):
"""Get price update status for all symbols"""
status = {}
current_time = datetime.now()
for symbol, timestamp in self.last_price_update.items():
time_since_update = (current_time - timestamp).total_seconds()
status[symbol] = {
'last_update': timestamp,
'seconds_ago': time_since_update,
'is_fresh': time_since_update <= self.price_update_threshold
}
return status