integratoin fixes - COB and CNN
This commit is contained in:
@ -1453,17 +1453,37 @@ class DataProvider:
|
||||
async def _on_enhanced_cob_data(self, symbol: str, cob_data: Dict):
|
||||
"""Handle COB data from Enhanced WebSocket"""
|
||||
try:
|
||||
# Ensure cob_websocket_data is initialized
|
||||
if not hasattr(self, 'cob_websocket_data'):
|
||||
self.cob_websocket_data = {}
|
||||
|
||||
# Store the latest COB data
|
||||
self.cob_websocket_data[symbol] = cob_data
|
||||
|
||||
# Ensure cob_data_cache is initialized
|
||||
if not hasattr(self, 'cob_data_cache'):
|
||||
self.cob_data_cache = {}
|
||||
|
||||
# Update COB data cache for distribution
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
if binance_symbol in self.cob_data_cache:
|
||||
self.cob_data_cache[binance_symbol].append({
|
||||
'timestamp': cob_data.get('timestamp', datetime.now()),
|
||||
'data': cob_data,
|
||||
'source': 'enhanced_websocket'
|
||||
})
|
||||
if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None:
|
||||
from collections import deque
|
||||
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
|
||||
|
||||
# Ensure the deque is properly initialized
|
||||
if not isinstance(self.cob_data_cache[binance_symbol], deque):
|
||||
from collections import deque
|
||||
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
|
||||
|
||||
self.cob_data_cache[binance_symbol].append({
|
||||
'timestamp': cob_data.get('timestamp', datetime.now()),
|
||||
'data': cob_data,
|
||||
'source': 'enhanced_websocket'
|
||||
})
|
||||
|
||||
# Ensure cob_data_callbacks is initialized
|
||||
if not hasattr(self, 'cob_data_callbacks'):
|
||||
self.cob_data_callbacks = []
|
||||
|
||||
# Distribute to COB data callbacks
|
||||
for callback in self.cob_data_callbacks:
|
||||
@ -1472,6 +1492,13 @@ class DataProvider:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB data callback: {e}")
|
||||
|
||||
# Ensure distribution_stats is initialized
|
||||
if not hasattr(self, 'distribution_stats'):
|
||||
self.distribution_stats = {
|
||||
'total_ticks_received': 0,
|
||||
'last_tick_time': {}
|
||||
}
|
||||
|
||||
# Update distribution stats
|
||||
self.distribution_stats['total_ticks_received'] += 1
|
||||
self.distribution_stats['last_tick_time'][symbol] = datetime.now()
|
||||
@ -1479,7 +1506,7 @@ class DataProvider:
|
||||
logger.debug(f"Enhanced COB data received for {symbol}: {len(cob_data.get('bids', []))} bids, {len(cob_data.get('asks', []))} asks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling enhanced COB data for {symbol}: {e}")
|
||||
logger.error(f"Error handling enhanced COB data for {symbol}: {e}", exc_info=True)
|
||||
|
||||
async def _on_websocket_status_update(self, status_data: Dict):
|
||||
"""Handle WebSocket status updates"""
|
||||
@ -1488,12 +1515,16 @@ class DataProvider:
|
||||
status = status_data.get('status')
|
||||
message = status_data.get('message', '')
|
||||
|
||||
# Ensure cob_websocket_status is initialized
|
||||
if not hasattr(self, 'cob_websocket_status'):
|
||||
self.cob_websocket_status = {}
|
||||
|
||||
if symbol:
|
||||
self.cob_websocket_status[symbol] = status
|
||||
logger.info(f"🔌 Enhanced WebSocket status for {symbol}: {status} - {message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling WebSocket status update: {e}")
|
||||
logger.error(f"Error handling WebSocket status update: {e}", exc_info=True)
|
||||
|
||||
async def _start_fallback_websocket_streaming(self):
|
||||
"""Fallback to old WebSocket method if Enhanced COB WebSocket fails"""
|
||||
@ -3499,9 +3530,14 @@ class DataProvider:
|
||||
# Convert datetime to timestamp
|
||||
cob_data['timestamp'] = cob_data['timestamp'].timestamp()
|
||||
|
||||
# Store raw tick
|
||||
# Store raw tick - ensure proper initialization
|
||||
if not hasattr(self, 'cob_raw_ticks'):
|
||||
self.cob_raw_ticks = {'ETH/USDT': [], 'BTC/USDT': []}
|
||||
self.cob_raw_ticks = {}
|
||||
|
||||
# Ensure symbol keys exist in the dictionary
|
||||
for sym in ['ETH/USDT', 'BTC/USDT']:
|
||||
if sym not in self.cob_raw_ticks:
|
||||
self.cob_raw_ticks[sym] = []
|
||||
|
||||
# Add to raw ticks with size limit (keep last 10 seconds of data)
|
||||
max_ticks = 1000 # ~10 seconds at 100 updates/sec
|
||||
@ -3514,6 +3550,7 @@ class DataProvider:
|
||||
if not hasattr(self, 'cob_data_cache'):
|
||||
self.cob_data_cache = {}
|
||||
|
||||
# Ensure symbol key exists in the cache
|
||||
if symbol not in self.cob_data_cache:
|
||||
self.cob_data_cache[symbol] = []
|
||||
|
||||
@ -3537,7 +3574,7 @@ class DataProvider:
|
||||
logger.debug(f"Processed WebSocket COB tick for {symbol}: {len(cob_data.get('bids', []))} bids, {len(cob_data.get('asks', []))} asks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket COB data for {symbol}: {e}")
|
||||
logger.error(f"Error processing WebSocket COB data for {symbol}: {e}", exc_info=True)
|
||||
|
||||
def _on_cob_websocket_status(self, status_data: dict):
|
||||
"""Handle WebSocket status updates"""
|
||||
|
@ -1,261 +1,401 @@
|
||||
"""
|
||||
Trading Executor Fix
|
||||
Trading Executor Fix - Addresses issues with entry/exit prices and P&L calculations
|
||||
|
||||
This module provides fixes for the trading executor to address:
|
||||
1. Duplicate entry prices
|
||||
2. P&L calculation issues
|
||||
3. Position tracking problems
|
||||
This module provides fixes for:
|
||||
1. Identical entry prices issue
|
||||
2. Price caching problems
|
||||
3. Position tracking reset logic
|
||||
4. Trade cooldown implementation
|
||||
5. P&L calculation verification
|
||||
|
||||
Apply these fixes by importing and applying the patch in main.py
|
||||
Apply these fixes to the TradingExecutor class to improve trade execution reliability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExecutorFix:
|
||||
"""Fixes for the TradingExecutor class"""
|
||||
"""
|
||||
Fixes for the TradingExecutor class to address entry/exit price issues
|
||||
and improve P&L calculation accuracy.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def apply_fixes(trading_executor):
|
||||
def __init__(self, trading_executor):
|
||||
"""
|
||||
Initialize the fix with a reference to the trading executor
|
||||
|
||||
Args:
|
||||
trading_executor: The TradingExecutor instance to fix
|
||||
"""
|
||||
self.trading_executor = trading_executor
|
||||
|
||||
# Add cooldown tracking
|
||||
self.last_trade_time = {} # {symbol: timestamp}
|
||||
self.min_trade_cooldown = 30 # 30 seconds minimum between trades
|
||||
|
||||
# Add price history for validation
|
||||
self.recent_entry_prices = {} # {symbol: [recent_prices]}
|
||||
self.max_price_history = 10 # Keep last 10 entry prices
|
||||
|
||||
# Add position reset tracking
|
||||
self.position_reset_flags = {} # {symbol: bool}
|
||||
|
||||
# Add price update tracking
|
||||
self.last_price_update = {} # {symbol: timestamp}
|
||||
self.price_update_threshold = 5 # 5 seconds max since last price update
|
||||
|
||||
# Add P&L verification
|
||||
self.trade_history = {} # {symbol: [trade_records]}
|
||||
|
||||
logger.info("TradingExecutorFix initialized - addressing entry/exit price issues")
|
||||
|
||||
def apply_fixes(self):
|
||||
"""Apply all fixes to the trading executor"""
|
||||
logger.info("Applying TradingExecutor fixes...")
|
||||
self._patch_execute_action()
|
||||
self._patch_close_position()
|
||||
self._patch_calculate_pnl()
|
||||
self._patch_update_prices()
|
||||
|
||||
# Store original methods for patching
|
||||
original_execute_action = trading_executor.execute_action
|
||||
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
|
||||
|
||||
# Apply fixes
|
||||
TradingExecutorFix._fix_price_caching(trading_executor)
|
||||
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
|
||||
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
|
||||
TradingExecutorFix._add_trade_cooldown(trading_executor)
|
||||
TradingExecutorFix._fix_position_tracking(trading_executor)
|
||||
|
||||
logger.info("TradingExecutor fixes applied successfully")
|
||||
return trading_executor
|
||||
logger.info("All trading executor fixes applied successfully")
|
||||
|
||||
@staticmethod
|
||||
def _fix_price_caching(trading_executor):
|
||||
"""Fix price caching to prevent duplicate entry prices"""
|
||||
# Add a price cache timestamp to track when prices were last updated
|
||||
trading_executor.price_cache_timestamp = {}
|
||||
def _patch_execute_action(self):
|
||||
"""Patch the execute_action method to add price validation and cooldown"""
|
||||
original_execute_action = self.trading_executor.execute_action
|
||||
|
||||
# Store original get_current_price method
|
||||
original_get_current_price = trading_executor.get_current_price
|
||||
|
||||
def get_current_price_fixed(self, symbol):
|
||||
"""Fixed get_current_price method with cache invalidation"""
|
||||
now = time.time()
|
||||
|
||||
# Force price refresh if cache is older than 5 seconds
|
||||
if symbol in self.price_cache_timestamp:
|
||||
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
|
||||
if cache_age > 5: # 5 seconds max cache age
|
||||
# Clear price cache for this symbol
|
||||
if hasattr(self, 'current_prices') and symbol in self.current_prices:
|
||||
del self.current_prices[symbol]
|
||||
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
|
||||
|
||||
# Call original method to get fresh price
|
||||
price = original_get_current_price(symbol)
|
||||
|
||||
# Update cache timestamp
|
||||
self.price_cache_timestamp[symbol] = now
|
||||
|
||||
return price
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
|
||||
logger.info("Price caching fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
|
||||
"""Fix P&L calculation to ensure accuracy"""
|
||||
def calculate_pnl_fixed(self, position, current_price=None):
|
||||
"""Fixed P&L calculation with proper handling of position side"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = position.entry_price
|
||||
size = position.size
|
||||
side = position.side
|
||||
|
||||
# Use provided price or get current price
|
||||
if current_price is None:
|
||||
current_price = self.get_current_price(position.symbol)
|
||||
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size
|
||||
|
||||
# Calculate fees (if available)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
# Return both gross and net P&L
|
||||
return {
|
||||
'gross_pnl': pnl,
|
||||
'fees': fees,
|
||||
'net_pnl': pnl - fees
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating P&L: {e}")
|
||||
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
|
||||
|
||||
# Apply the patch if original method exists
|
||||
if original_calculate_pnl:
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation fix applied")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation method added")
|
||||
|
||||
@staticmethod
|
||||
def _fix_execute_action(trading_executor, original_execute_action):
|
||||
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
|
||||
def execute_action_fixed(self, decision):
|
||||
"""Fixed execute_action with duplicate entry prevention"""
|
||||
def execute_action_with_fixes(decision):
|
||||
"""Enhanced execute_action with price validation and cooldown"""
|
||||
try:
|
||||
symbol = decision.symbol
|
||||
action = decision.action
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check for duplicate entry (same price as recent entry)
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
recent_entry = self.recent_entries[symbol]
|
||||
current_price = self.get_current_price(symbol)
|
||||
|
||||
# If price is within 0.1% of recent entry, consider it a duplicate
|
||||
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
|
||||
time_diff = time.time() - recent_entry['timestamp']
|
||||
|
||||
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
|
||||
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
|
||||
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
|
||||
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Duplicate entry prevention"
|
||||
# 1. Check cooldown period
|
||||
if symbol in self.last_trade_time:
|
||||
time_since_last_trade = (current_time - self.last_trade_time[symbol]).total_seconds()
|
||||
if time_since_last_trade < self.min_trade_cooldown:
|
||||
logger.warning(f"Trade rejected: Cooldown period ({time_since_last_trade:.1f}s < {self.min_trade_cooldown}s) for {symbol}")
|
||||
return False
|
||||
|
||||
# Check trade cooldown
|
||||
if hasattr(self, '_check_trade_cooldown'):
|
||||
if not self._check_trade_cooldown(symbol, action):
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Trade cooldown active"
|
||||
# 2. Validate price freshness
|
||||
if symbol in self.last_price_update:
|
||||
time_since_update = (current_time - self.last_price_update[symbol]).total_seconds()
|
||||
if time_since_update > self.price_update_threshold:
|
||||
logger.warning(f"Trade rejected: Price data stale ({time_since_update:.1f}s > {self.price_update_threshold}s) for {symbol}")
|
||||
# Force price refresh
|
||||
self._refresh_price(symbol)
|
||||
return False
|
||||
|
||||
# Force price refresh before execution
|
||||
fresh_price = self.get_current_price(symbol)
|
||||
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
|
||||
# 3. Validate entry price against recent history
|
||||
current_price = self._get_current_price(symbol)
|
||||
if symbol in self.recent_entry_prices and len(self.recent_entry_prices[symbol]) > 0:
|
||||
# Check if price is identical to any recent entry
|
||||
if current_price in self.recent_entry_prices[symbol]:
|
||||
logger.warning(f"Trade rejected: Duplicate entry price ${current_price} for {symbol}")
|
||||
return False
|
||||
|
||||
# Update decision price with fresh price
|
||||
decision.price = fresh_price
|
||||
# 4. Ensure position is properly reset before new entry
|
||||
if not self._ensure_position_reset(symbol):
|
||||
logger.warning(f"Trade rejected: Position not properly reset for {symbol}")
|
||||
return False
|
||||
|
||||
# Call original execute_action
|
||||
# Execute the original action
|
||||
result = original_execute_action(decision)
|
||||
|
||||
# If execution was successful, record the entry
|
||||
if result and not getattr(decision, 'blocked', False):
|
||||
if not hasattr(self, 'recent_entries'):
|
||||
self.recent_entries = {}
|
||||
# If successful, update tracking
|
||||
if result:
|
||||
# Update cooldown timestamp
|
||||
self.last_trade_time[symbol] = current_time
|
||||
|
||||
self.recent_entries[symbol] = {
|
||||
'price': fresh_price,
|
||||
'timestamp': time.time(),
|
||||
'action': action
|
||||
}
|
||||
# Update price history
|
||||
if symbol not in self.recent_entry_prices:
|
||||
self.recent_entry_prices[symbol] = []
|
||||
|
||||
# Record last trade time for cooldown
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
self.recent_entry_prices[symbol].append(current_price)
|
||||
# Keep only the most recent prices
|
||||
if len(self.recent_entry_prices[symbol]) > self.max_price_history:
|
||||
self.recent_entry_prices[symbol] = self.recent_entry_prices[symbol][-self.max_price_history:]
|
||||
|
||||
self.last_trade_time[symbol] = time.time()
|
||||
# Mark position as active
|
||||
self.position_reset_flags[symbol] = False
|
||||
|
||||
logger.info(f"Trade executed: {action} {symbol} at ${current_price} with validation")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action_fixed: {e}")
|
||||
logger.error(f"Error in execute_action_with_fixes: {e}")
|
||||
return original_execute_action(decision)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.execute_action = execute_action_with_fixes
|
||||
logger.info("Patched execute_action with price validation and cooldown")
|
||||
|
||||
def _patch_close_position(self):
|
||||
"""Patch the close_position method to ensure proper position reset"""
|
||||
original_close_position = self.trading_executor.close_position
|
||||
|
||||
def close_position_with_fixes(symbol, **kwargs):
|
||||
"""Enhanced close_position with proper reset logic"""
|
||||
try:
|
||||
# Get current price for P&L verification
|
||||
exit_price = self._get_current_price(symbol)
|
||||
|
||||
# Call original close position
|
||||
result = original_close_position(symbol, **kwargs)
|
||||
|
||||
if result:
|
||||
# Mark position as reset
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
# Record trade for verification
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
position = self.trading_executor.positions[symbol]
|
||||
|
||||
# Create trade record
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'entry_time': getattr(position, 'entry_time', datetime.now()),
|
||||
'exit_time': datetime.now(),
|
||||
'entry_price': getattr(position, 'entry_price', 0),
|
||||
'exit_price': exit_price,
|
||||
'size': getattr(position, 'size', 0),
|
||||
'side': getattr(position, 'side', 'UNKNOWN'),
|
||||
'pnl': self._calculate_verified_pnl(position, exit_price),
|
||||
'fees': getattr(position, 'fees', 0),
|
||||
'hold_time_seconds': (datetime.now() - getattr(position, 'entry_time', datetime.now())).total_seconds()
|
||||
}
|
||||
|
||||
# Store trade record
|
||||
if symbol not in self.trade_history:
|
||||
self.trade_history[symbol] = []
|
||||
self.trade_history[symbol].append(trade_record)
|
||||
|
||||
logger.info(f"Position closed: {symbol} at ${exit_price} with verified P&L: ${trade_record['pnl']:.2f}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_with_fixes: {e}")
|
||||
return original_close_position(symbol, **kwargs)
|
||||
|
||||
# Replace the original method
|
||||
self.trading_executor.close_position = close_position_with_fixes
|
||||
logger.info("Patched close_position with proper reset logic")
|
||||
|
||||
def _patch_calculate_pnl(self):
|
||||
"""Patch the calculate_pnl method to ensure accurate P&L calculation"""
|
||||
original_calculate_pnl = getattr(self.trading_executor, 'calculate_pnl', None)
|
||||
|
||||
def calculate_pnl_with_fixes(position, current_price=None):
|
||||
"""Enhanced calculate_pnl with verification"""
|
||||
try:
|
||||
# If no original method, implement our own
|
||||
if original_calculate_pnl is None:
|
||||
return self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# Call original method
|
||||
original_pnl = original_calculate_pnl(position, current_price)
|
||||
|
||||
# Calculate our verified P&L
|
||||
verified_pnl = self._calculate_verified_pnl(position, current_price)
|
||||
|
||||
# If there's a significant difference, log it
|
||||
if abs(original_pnl - verified_pnl) > 0.01:
|
||||
logger.warning(f"P&L calculation discrepancy: original=${original_pnl:.2f}, verified=${verified_pnl:.2f}")
|
||||
# Use the verified P&L
|
||||
return verified_pnl
|
||||
|
||||
return original_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in calculate_pnl_with_fixes: {e}")
|
||||
if original_calculate_pnl:
|
||||
return original_calculate_pnl(position, current_price)
|
||||
return 0.0
|
||||
|
||||
# Replace the original method if it exists
|
||||
if original_calculate_pnl:
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Patched calculate_pnl with verification")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.calculate_pnl = calculate_pnl_with_fixes
|
||||
logger.info("Added calculate_pnl method with verification")
|
||||
|
||||
def _patch_update_prices(self):
|
||||
"""Patch the update_prices method to track price updates"""
|
||||
original_update_prices = getattr(self.trading_executor, 'update_prices', None)
|
||||
|
||||
def update_prices_with_tracking(prices):
|
||||
"""Enhanced update_prices with timestamp tracking"""
|
||||
try:
|
||||
# Call original method if it exists
|
||||
if original_update_prices:
|
||||
result = original_update_prices(prices)
|
||||
else:
|
||||
# If no original method, update prices directly
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices.update(prices)
|
||||
result = True
|
||||
|
||||
# Track update timestamps
|
||||
current_time = datetime.now()
|
||||
for symbol in prices:
|
||||
self.last_price_update[symbol] = current_time
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in update_prices_with_tracking: {e}")
|
||||
if original_update_prices:
|
||||
return original_update_prices(prices)
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
|
||||
|
||||
# Initialize recent entries dict if it doesn't exist
|
||||
if not hasattr(trading_executor, 'recent_entries'):
|
||||
trading_executor.recent_entries = {}
|
||||
|
||||
logger.info("Execute action fix applied")
|
||||
# Replace the original method if it exists
|
||||
if original_update_prices:
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Patched update_prices with timestamp tracking")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
self.trading_executor.update_prices = update_prices_with_tracking
|
||||
logger.info("Added update_prices method with timestamp tracking")
|
||||
|
||||
@staticmethod
|
||||
def _add_trade_cooldown(trading_executor):
|
||||
"""Add trade cooldown to prevent rapid consecutive trades"""
|
||||
# Add cooldown settings
|
||||
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
|
||||
|
||||
if not hasattr(trading_executor, 'last_trade_time'):
|
||||
trading_executor.last_trade_time = {}
|
||||
|
||||
def check_trade_cooldown(self, symbol, action):
|
||||
"""Check if trade cooldown is active for a symbol"""
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
return True
|
||||
def _calculate_verified_pnl(self, position, current_price=None):
|
||||
"""Calculate verified P&L for a position"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = getattr(position, 'entry_price', 0)
|
||||
size = getattr(position, 'size', 0)
|
||||
side = getattr(position, 'side', 'UNKNOWN')
|
||||
leverage = getattr(position, 'leverage', 1.0)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
if symbol not in self.last_trade_time:
|
||||
return True
|
||||
# If current_price is not provided, try to get it
|
||||
if current_price is None:
|
||||
symbol = getattr(position, 'symbol', None)
|
||||
if symbol:
|
||||
current_price = self._get_current_price(symbol)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
# Get time since last trade
|
||||
time_since_last = time.time() - self.last_trade_time[symbol]
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size * leverage
|
||||
elif side == 'SHORT':
|
||||
pnl = (entry_price - current_price) * size * leverage
|
||||
else:
|
||||
pnl = 0.0
|
||||
|
||||
# Check if cooldown is still active
|
||||
if time_since_last < self.trade_cooldown_seconds:
|
||||
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
|
||||
f"need {self.trade_cooldown_seconds}s")
|
||||
return False
|
||||
# Subtract fees for net P&L
|
||||
net_pnl = pnl - fees
|
||||
|
||||
return True
|
||||
|
||||
# Add the method
|
||||
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
|
||||
logger.info("Trade cooldown feature added")
|
||||
return net_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating verified P&L: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _fix_position_tracking(trading_executor):
|
||||
"""Fix position tracking to ensure proper reset between trades"""
|
||||
# Store original close_position method
|
||||
original_close_position = getattr(trading_executor, 'close_position', None)
|
||||
|
||||
if original_close_position:
|
||||
def close_position_fixed(self, symbol, price=None):
|
||||
"""Fixed close_position with proper position cleanup"""
|
||||
try:
|
||||
# Call original close_position
|
||||
result = original_close_position(symbol, price)
|
||||
|
||||
# Ensure position is fully cleaned up
|
||||
if symbol in self.positions:
|
||||
del self.positions[symbol]
|
||||
|
||||
# Clear recent entry for this symbol
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
del self.recent_entries[symbol]
|
||||
|
||||
logger.info(f"Position for {symbol} fully cleaned up after close")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_fixed: {e}")
|
||||
def _get_current_price(self, symbol):
|
||||
"""Get current price for a symbol with fallbacks"""
|
||||
try:
|
||||
# Try to get from trading executor
|
||||
if hasattr(self.trading_executor, 'current_prices') and symbol in self.trading_executor.current_prices:
|
||||
return self.trading_executor.current_prices[symbol]
|
||||
|
||||
# Try to get from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'get_current_price'):
|
||||
price = data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Try to get from COB data
|
||||
if hasattr(self.trading_executor, 'latest_cob_data') and symbol in self.trading_executor.latest_cob_data:
|
||||
cob_data = self.trading_executor.latest_cob_data[symbol]
|
||||
if hasattr(cob_data, 'stats') and 'mid_price' in cob_data.stats:
|
||||
return cob_data.stats['mid_price']
|
||||
|
||||
# Default fallback
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _refresh_price(self, symbol):
|
||||
"""Force a price refresh for a symbol"""
|
||||
try:
|
||||
# Try to refresh from data provider
|
||||
if hasattr(self.trading_executor, 'data_provider'):
|
||||
data_provider = self.trading_executor.data_provider
|
||||
if hasattr(data_provider, 'fetch_current_price'):
|
||||
price = data_provider.fetch_current_price(symbol)
|
||||
if price and price > 0:
|
||||
# Update trading executor price
|
||||
if hasattr(self.trading_executor, 'current_prices'):
|
||||
self.trading_executor.current_prices[symbol] = price
|
||||
|
||||
# Update timestamp
|
||||
self.last_price_update[symbol] = datetime.now()
|
||||
|
||||
logger.info(f"Refreshed price for {symbol}: ${price:.2f}")
|
||||
return True
|
||||
|
||||
logger.warning(f"Failed to refresh price for {symbol}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing price for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def _ensure_position_reset(self, symbol):
|
||||
"""Ensure position is properly reset before new entry"""
|
||||
try:
|
||||
# Check if we have an active position
|
||||
if hasattr(self.trading_executor, 'positions') and symbol in self.trading_executor.positions:
|
||||
# Position exists, check if it's valid
|
||||
position = self.trading_executor.positions[symbol]
|
||||
if position and getattr(position, 'active', False):
|
||||
logger.warning(f"Position already active for {symbol}, cannot enter new position")
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
|
||||
logger.info("Position tracking fix applied")
|
||||
else:
|
||||
logger.warning("close_position method not found, skipping position tracking fix")
|
||||
# Check reset flag
|
||||
if symbol in self.position_reset_flags and not self.position_reset_flags[symbol]:
|
||||
# Force position cleanup
|
||||
if hasattr(self.trading_executor, 'positions'):
|
||||
self.trading_executor.positions.pop(symbol, None)
|
||||
|
||||
logger.info(f"Forced position reset for {symbol}")
|
||||
self.position_reset_flags[symbol] = True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring position reset for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def get_trade_history(self, symbol=None):
|
||||
"""Get verified trade history"""
|
||||
if symbol:
|
||||
return self.trade_history.get(symbol, [])
|
||||
return self.trade_history
|
||||
|
||||
def get_price_update_status(self):
|
||||
"""Get price update status for all symbols"""
|
||||
status = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol, timestamp in self.last_price_update.items():
|
||||
time_since_update = (current_time - timestamp).total_seconds()
|
||||
status[symbol] = {
|
||||
'last_update': timestamp,
|
||||
'seconds_ago': time_since_update,
|
||||
'is_fresh': time_since_update <= self.price_update_threshold
|
||||
}
|
||||
|
||||
return status
|
Reference in New Issue
Block a user