integratoin fixes - COB and CNN

This commit is contained in:
Dobromir Popov
2025-07-23 17:33:43 +03:00
parent f1d63f9da6
commit 2a0f8f5199
8 changed files with 883 additions and 230 deletions

View File

@ -1,261 +1,401 @@
"""
Trading Executor Fix
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
This module provides fixes for the trading executor to address:
1. Duplicate entry prices
2. P&L calculation issues
3. Position tracking problems
This module provides fixes for:
1. Identical entry prices issue
2. Price caching problems
3. Position tracking reset logic
4. Trade cooldown implementation
5. P&L calculation verification
Apply these fixes by importing and applying the patch in main.py
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
"""
import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
logger = logging.getLogger(__name__)
class TradingExecutorFix:
"""Fixes for the TradingExecutor class"""
"""
Fixes for the TradingExecutor class to address entry/exit price issues
and improve P&L calculation accuracy.
"""
@staticmethod
def apply_fixes(trading_executor):
def __init__(self, trading_executor):
"""
Initialize the fix with a reference to the trading executor
Args:
trading_executor: The TradingExecutor instance to fix
"""
self.trading_executor = trading_executor
# Add cooldown tracking
self.last_trade_time = {} # {symbol: timestamp}
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
# Add price history for validation
self.recent_entry_prices = {} # {symbol: [recent_prices]}
self.max_price_history = 10 # Keep last 10 entry prices
# Add position reset tracking
self.position_reset_flags = {} # {symbol: bool}
# Add price update tracking
self.last_price_update = {} # {symbol: timestamp}
self.price_update_threshold = 5 # 5 seconds max since last price update
# Add P&L verification
self.trade_history = {} # {symbol: [trade_records]}
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
def apply_fixes(self):
"""Apply all fixes to the trading executor"""
logger.info("Applying TradingExecutor fixes...")
self._patch_execute_action()
self._patch_close_position()
self._patch_calculate_pnl()
self._patch_update_prices()
# Store original methods for patching
original_execute_action = trading_executor.execute_action
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
# Apply fixes
TradingExecutorFix._fix_price_caching(trading_executor)
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
TradingExecutorFix._add_trade_cooldown(trading_executor)
TradingExecutorFix._fix_position_tracking(trading_executor)
logger.info("TradingExecutor fixes applied successfully")
return trading_executor
logger.info("All trading executor fixes applied successfully")
@staticmethod
def _fix_price_caching(trading_executor):
"""Fix price caching to prevent duplicate entry prices"""
# Add a price cache timestamp to track when prices were last updated
trading_executor.price_cache_timestamp = {}
def _patch_execute_action(self):
"""Patch the execute_action method to add price validation and cooldown"""
original_execute_action = self.trading_executor.execute_action
# Store original get_current_price method
original_get_current_price = trading_executor.get_current_price
def get_current_price_fixed(self, symbol):
"""Fixed get_current_price method with cache invalidation"""
now = time.time()
# Force price refresh if cache is older than 5 seconds
if symbol in self.price_cache_timestamp:
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
if cache_age > 5: # 5 seconds max cache age
# Clear price cache for this symbol
if hasattr(self, 'current_prices') and symbol in self.current_prices:
del self.current_prices[symbol]
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
# Call original method to get fresh price
price = original_get_current_price(symbol)
# Update cache timestamp
self.price_cache_timestamp[symbol] = now
return price
# Apply the patch
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
logger.info("Price caching fix applied")
@staticmethod
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
"""Fix P&L calculation to ensure accuracy"""
def calculate_pnl_fixed(self, position, current_price=None):
"""Fixed P&L calculation with proper handling of position side"""
try:
# Get position details
entry_price = position.entry_price
size = position.size
side = position.side
# Use provided price or get current price
if current_price is None:
current_price = self.get_current_price(position.symbol)
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size
else: # SHORT
pnl = (entry_price - current_price) * size
# Calculate fees (if available)
fees = getattr(position, 'fees', 0.0)
# Return both gross and net P&L
return {
'gross_pnl': pnl,
'fees': fees,
'net_pnl': pnl - fees
}
except Exception as e:
logger.error(f"Error calculating P&L: {e}")
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
# Apply the patch if original method exists
if original_calculate_pnl:
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
logger.info("P&L calculation fix applied")
else:
# Add the method if it doesn't exist
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
logger.info("P&L calculation method added")
@staticmethod
def _fix_execute_action(trading_executor, original_execute_action):
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
def execute_action_fixed(self, decision):
"""Fixed execute_action with duplicate entry prevention"""
def execute_action_with_fixes(decision):
"""Enhanced execute_action with price validation and cooldown"""
try:
symbol = decision.symbol
action = decision.action
current_time = datetime.now()
# Check for duplicate entry (same price as recent entry)
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
recent_entry = self.recent_entries[symbol]
current_price = self.get_current_price(symbol)
# If price is within 0.1% of recent entry, consider it a duplicate
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
time_diff = time.time() - recent_entry['timestamp']
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
# Mark decision as blocked
decision.blocked = True
decision.blocked_reason = "Duplicate entry prevention"
# 1. Check cooldown period
if symbol in self.last_trade_time:
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
if time_since_last_trade < self.min_trade_cooldown:
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
return False
# Check trade cooldown
if hasattr(self, '_check_trade_cooldown'):
if not self._check_trade_cooldown(symbol, action):
# Mark decision as blocked
decision.blocked = True
decision.blocked_reason = "Trade cooldown active"
# 2. Validate price freshness
if symbol in self.last_price_update:
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
if time_since_update > self.price_update_threshold:
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
# Force price refresh
self._refresh_price(symbol)
return False
# Force price refresh before execution
fresh_price = self.get_current_price(symbol)
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
# 3. Validate entry price against recent history
current_price = self._get_current_price(symbol)
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
# Check if price is identical to any recent entry
if current_price in self.recent_entry_prices[symbol]:
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
return False
# Update decision price with fresh price
decision.price = fresh_price
# 4. Ensure position is properly reset before new entry
if not self._ensure_position_reset(symbol):
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
return False
# Call original execute_action
# Execute the original action
result = original_execute_action(decision)
# If execution was successful, record the entry
if result and not getattr(decision, 'blocked', False):
if not hasattr(self, 'recent_entries'):
self.recent_entries = {}
# If successful, update tracking
if result:
# Update cooldown timestamp
self.last_trade_time[symbol] = current_time
self.recent_entries[symbol] = {
'price': fresh_price,
'timestamp': time.time(),
'action': action
}
# Update price history
if symbol not in self.recent_entry_prices:
self.recent_entry_prices[symbol] = []
# Record last trade time for cooldown
if not hasattr(self, 'last_trade_time'):
self.last_trade_time = {}
self.recent_entry_prices[symbol].append(current_price)
# Keep only the most recent prices
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
self.last_trade_time[symbol] = time.time()
# Mark position as active
self.position_reset_flags[symbol] = False
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
return result
except Exception as e:
logger.error(f"Error in execute_action_fixed: {e}")
logger.error(f"Error in execute_action_with_fixes: {e}")
return original_execute_action(decision)
# Replace the original method
self.trading_executor.execute_action = execute_action_with_fixes
logger.info("Patched execute_action with price validation and cooldown")
def _patch_close_position(self):
"""Patch the close_position method to ensure proper position reset"""
original_close_position = self.trading_executor.close_position
def close_position_with_fixes(symbol, **kwargs):
"""Enhanced close_position with proper reset logic"""
try:
# Get current price for P&L verification
exit_price = self._get_current_price(symbol)
# Call original close position
result = original_close_position(symbol, **kwargs)
if result:
# Mark position as reset
self.position_reset_flags[symbol] = True
# Record trade for verification
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
position = self.trading_executor.positions[symbol]
# Create trade record
trade_record = {
'symbol': symbol,
'entry_time': getattr(position, 'entry_time', datetime.now()),
'exit_time': datetime.now(),
'entry_price': getattr(position, 'entry_price', 0),
'exit_price': exit_price,
'size': getattr(position, 'size', 0),
'side': getattr(position, 'side', 'UNKNOWN'),
'pnl': self._calculate_verified_pnl(position, exit_price),
'fees': getattr(position, 'fees', 0),
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
}
# Store trade record
if symbol not in self.trade_history:
self.trade_history[symbol] = []
self.trade_history[symbol].append(trade_record)
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
return result
except Exception as e:
logger.error(f"Error in close_position_with_fixes: {e}")
return original_close_position(symbol, **kwargs)
# Replace the original method
self.trading_executor.close_position = close_position_with_fixes
logger.info("Patched close_position with proper reset logic")
def _patch_calculate_pnl(self):
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
def calculate_pnl_with_fixes(position, current_price=None):
"""Enhanced calculate_pnl with verification"""
try:
# If no original method, implement our own
if original_calculate_pnl is None:
return self._calculate_verified_pnl(position, current_price)
# Call original method
original_pnl = original_calculate_pnl(position, current_price)
# Calculate our verified P&L
verified_pnl = self._calculate_verified_pnl(position, current_price)
# If there's a significant difference, log it
if abs(original_pnl - verified_pnl) > 0.01:
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
# Use the verified P&L
return verified_pnl
return original_pnl
except Exception as e:
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
if original_calculate_pnl:
return original_calculate_pnl(position, current_price)
return 0.0
# Replace the original method if it exists
if original_calculate_pnl:
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Patched calculate_pnl with verification")
else:
# Add the method if it doesn't exist
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
logger.info("Added calculate_pnl method with verification")
def _patch_update_prices(self):
"""Patch the update_prices method to track price updates"""
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
def update_prices_with_tracking(prices):
"""Enhanced update_prices with timestamp tracking"""
try:
# Call original method if it exists
if original_update_prices:
result = original_update_prices(prices)
else:
# If no original method, update prices directly
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices.update(prices)
result = True
# Track update timestamps
current_time = datetime.now()
for symbol in prices:
self.last_price_update[symbol] = current_time
return result
except Exception as e:
logger.error(f"Error in update_prices_with_tracking: {e}")
if original_update_prices:
return original_update_prices(prices)
return False
# Apply the patch
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
# Initialize recent entries dict if it doesn't exist
if not hasattr(trading_executor, 'recent_entries'):
trading_executor.recent_entries = {}
logger.info("Execute action fix applied")
# Replace the original method if it exists
if original_update_prices:
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Patched update_prices with timestamp tracking")
else:
# Add the method if it doesn't exist
self.trading_executor.update_prices = update_prices_with_tracking
logger.info("Added update_prices method with timestamp tracking")
@staticmethod
def _add_trade_cooldown(trading_executor):
"""Add trade cooldown to prevent rapid consecutive trades"""
# Add cooldown settings
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
if not hasattr(trading_executor, 'last_trade_time'):
trading_executor.last_trade_time = {}
def check_trade_cooldown(self, symbol, action):
"""Check if trade cooldown is active for a symbol"""
if not hasattr(self, 'last_trade_time'):
self.last_trade_time = {}
return True
def _calculate_verified_pnl(self, position, current_price=None):
"""Calculate verified P&L for a position"""
try:
# Get position details
entry_price = getattr(position, 'entry_price', 0)
size = getattr(position, 'size', 0)
side = getattr(position, 'side', 'UNKNOWN')
leverage = getattr(position, 'leverage', 1.0)
fees = getattr(position, 'fees', 0.0)
if symbol not in self.last_trade_time:
return True
# If current_price is not provided, try to get it
if current_price is None:
symbol = getattr(position, 'symbol', None)
if symbol:
current_price = self._get_current_price(symbol)
else:
return 0.0
# Get time since last trade
time_since_last = time.time() - self.last_trade_time[symbol]
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size * leverage
elif side == 'SHORT':
pnl = (entry_price - current_price) * size * leverage
else:
pnl = 0.0
# Check if cooldown is still active
if time_since_last < self.trade_cooldown_seconds:
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
f"need {self.trade_cooldown_seconds}s")
return False
# Subtract fees for net P&L
net_pnl = pnl - fees
return True
# Add the method
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
logger.info("Trade cooldown feature added")
return net_pnl
except Exception as e:
logger.error(f"Error calculating verified P&L: {e}")
return 0.0
@staticmethod
def _fix_position_tracking(trading_executor):
"""Fix position tracking to ensure proper reset between trades"""
# Store original close_position method
original_close_position = getattr(trading_executor, 'close_position', None)
if original_close_position:
def close_position_fixed(self, symbol, price=None):
"""Fixed close_position with proper position cleanup"""
try:
# Call original close_position
result = original_close_position(symbol, price)
# Ensure position is fully cleaned up
if symbol in self.positions:
del self.positions[symbol]
# Clear recent entry for this symbol
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
del self.recent_entries[symbol]
logger.info(f"Position for {symbol} fully cleaned up after close")
return result
except Exception as e:
logger.error(f"Error in close_position_fixed: {e}")
def _get_current_price(self, symbol):
"""Get current price for a symbol with fallbacks"""
try:
# Try to get from trading executor
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
return self.trading_executor.current_prices[symbol]
# Try to get from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'get_current_price'):
price = data_provider.get_current_price(symbol)
if price and price > 0:
return price
# Try to get from COB data
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
cob_data = self.trading_executor.latest_cob_data[symbol]
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
return cob_data.stats['mid_price']
# Default fallback
return 0.0
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return 0.0
def _refresh_price(self, symbol):
"""Force a price refresh for a symbol"""
try:
# Try to refresh from data provider
if hasattr(self.trading_executor, 'data_provider'):
data_provider = self.trading_executor.data_provider
if hasattr(data_provider, 'fetch_current_price'):
price = data_provider.fetch_current_price(symbol)
if price and price > 0:
# Update trading executor price
if hasattr(self.trading_executor, 'current_prices'):
self.trading_executor.current_prices[symbol] = price
# Update timestamp
self.last_price_update[symbol] = datetime.now()
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
return True
logger.warning(f"Failed to refresh price for {symbol}")
return False
except Exception as e:
logger.error(f"Error refreshing price for {symbol}: {e}")
return False
def _ensure_position_reset(self, symbol):
"""Ensure position is properly reset before new entry"""
try:
# Check if we have an active position
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
# Position exists, check if it's valid
position = self.trading_executor.positions[symbol]
if position and getattr(position, 'active', False):
logger.warning(f"Position already active for {symbol}, cannot enter new position")
return False
# Apply the patch
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
logger.info("Position tracking fix applied")
else:
logger.warning("close_position method not found, skipping position tracking fix")
# Check reset flag
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
# Force position cleanup
if hasattr(self.trading_executor, 'positions'):
self.trading_executor.positions.pop(symbol, None)
logger.info(f"Forced position reset for {symbol}")
self.position_reset_flags[symbol] = True
return True
except Exception as e:
logger.error(f"Error ensuring position reset for {symbol}: {e}")
return False
def get_trade_history(self, symbol=None):
"""Get verified trade history"""
if symbol:
return self.trade_history.get(symbol, [])
return self.trade_history
def get_price_update_status(self):
"""Get price update status for all symbols"""
status = {}
current_time = datetime.now()
for symbol, timestamp in self.last_price_update.items():
time_since_update = (current_time - timestamp).total_seconds()
status[symbol] = {
'last_update': timestamp,
'seconds_ago': time_since_update,
'is_fresh': time_since_update <= self.price_update_threshold
}
return status