integrating new CNN model
This commit is contained in:
@ -289,11 +289,9 @@ class TradingOrchestrator:
|
||||
|
||||
# Initialize CNN Model
|
||||
try:
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
|
||||
cnn_input_shape = self.config.cnn.get('input_shape', 100)
|
||||
cnn_n_actions = self.config.cnn.get('n_actions', 3)
|
||||
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
@ -325,8 +323,8 @@ class TradingOrchestrator:
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
self.cnn_model = CNNModel()
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||
|
||||
|
261
core/trading_executor_fix.py
Normal file
261
core/trading_executor_fix.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""
|
||||
Trading Executor Fix
|
||||
|
||||
This module provides fixes for the trading executor to address:
|
||||
1. Duplicate entry prices
|
||||
2. P&L calculation issues
|
||||
3. Position tracking problems
|
||||
|
||||
Apply these fixes by importing and applying the patch in main.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExecutorFix:
|
||||
"""Fixes for the TradingExecutor class"""
|
||||
|
||||
@staticmethod
|
||||
def apply_fixes(trading_executor):
|
||||
"""Apply all fixes to the trading executor"""
|
||||
logger.info("Applying TradingExecutor fixes...")
|
||||
|
||||
# Store original methods for patching
|
||||
original_execute_action = trading_executor.execute_action
|
||||
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
|
||||
|
||||
# Apply fixes
|
||||
TradingExecutorFix._fix_price_caching(trading_executor)
|
||||
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
|
||||
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
|
||||
TradingExecutorFix._add_trade_cooldown(trading_executor)
|
||||
TradingExecutorFix._fix_position_tracking(trading_executor)
|
||||
|
||||
logger.info("TradingExecutor fixes applied successfully")
|
||||
return trading_executor
|
||||
|
||||
@staticmethod
|
||||
def _fix_price_caching(trading_executor):
|
||||
"""Fix price caching to prevent duplicate entry prices"""
|
||||
# Add a price cache timestamp to track when prices were last updated
|
||||
trading_executor.price_cache_timestamp = {}
|
||||
|
||||
# Store original get_current_price method
|
||||
original_get_current_price = trading_executor.get_current_price
|
||||
|
||||
def get_current_price_fixed(self, symbol):
|
||||
"""Fixed get_current_price method with cache invalidation"""
|
||||
now = time.time()
|
||||
|
||||
# Force price refresh if cache is older than 5 seconds
|
||||
if symbol in self.price_cache_timestamp:
|
||||
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
|
||||
if cache_age > 5: # 5 seconds max cache age
|
||||
# Clear price cache for this symbol
|
||||
if hasattr(self, 'current_prices') and symbol in self.current_prices:
|
||||
del self.current_prices[symbol]
|
||||
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
|
||||
|
||||
# Call original method to get fresh price
|
||||
price = original_get_current_price(symbol)
|
||||
|
||||
# Update cache timestamp
|
||||
self.price_cache_timestamp[symbol] = now
|
||||
|
||||
return price
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
|
||||
logger.info("Price caching fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
|
||||
"""Fix P&L calculation to ensure accuracy"""
|
||||
def calculate_pnl_fixed(self, position, current_price=None):
|
||||
"""Fixed P&L calculation with proper handling of position side"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = position.entry_price
|
||||
size = position.size
|
||||
side = position.side
|
||||
|
||||
# Use provided price or get current price
|
||||
if current_price is None:
|
||||
current_price = self.get_current_price(position.symbol)
|
||||
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size
|
||||
|
||||
# Calculate fees (if available)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
# Return both gross and net P&L
|
||||
return {
|
||||
'gross_pnl': pnl,
|
||||
'fees': fees,
|
||||
'net_pnl': pnl - fees
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating P&L: {e}")
|
||||
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
|
||||
|
||||
# Apply the patch if original method exists
|
||||
if original_calculate_pnl:
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation fix applied")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation method added")
|
||||
|
||||
@staticmethod
|
||||
def _fix_execute_action(trading_executor, original_execute_action):
|
||||
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
|
||||
def execute_action_fixed(self, decision):
|
||||
"""Fixed execute_action with duplicate entry prevention"""
|
||||
try:
|
||||
symbol = decision.symbol
|
||||
action = decision.action
|
||||
|
||||
# Check for duplicate entry (same price as recent entry)
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
recent_entry = self.recent_entries[symbol]
|
||||
current_price = self.get_current_price(symbol)
|
||||
|
||||
# If price is within 0.1% of recent entry, consider it a duplicate
|
||||
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
|
||||
time_diff = time.time() - recent_entry['timestamp']
|
||||
|
||||
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
|
||||
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
|
||||
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
|
||||
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Duplicate entry prevention"
|
||||
return False
|
||||
|
||||
# Check trade cooldown
|
||||
if hasattr(self, '_check_trade_cooldown'):
|
||||
if not self._check_trade_cooldown(symbol, action):
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Trade cooldown active"
|
||||
return False
|
||||
|
||||
# Force price refresh before execution
|
||||
fresh_price = self.get_current_price(symbol)
|
||||
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
|
||||
|
||||
# Update decision price with fresh price
|
||||
decision.price = fresh_price
|
||||
|
||||
# Call original execute_action
|
||||
result = original_execute_action(decision)
|
||||
|
||||
# If execution was successful, record the entry
|
||||
if result and not getattr(decision, 'blocked', False):
|
||||
if not hasattr(self, 'recent_entries'):
|
||||
self.recent_entries = {}
|
||||
|
||||
self.recent_entries[symbol] = {
|
||||
'price': fresh_price,
|
||||
'timestamp': time.time(),
|
||||
'action': action
|
||||
}
|
||||
|
||||
# Record last trade time for cooldown
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
|
||||
self.last_trade_time[symbol] = time.time()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action_fixed: {e}")
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
|
||||
|
||||
# Initialize recent entries dict if it doesn't exist
|
||||
if not hasattr(trading_executor, 'recent_entries'):
|
||||
trading_executor.recent_entries = {}
|
||||
|
||||
logger.info("Execute action fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _add_trade_cooldown(trading_executor):
|
||||
"""Add trade cooldown to prevent rapid consecutive trades"""
|
||||
# Add cooldown settings
|
||||
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
|
||||
|
||||
if not hasattr(trading_executor, 'last_trade_time'):
|
||||
trading_executor.last_trade_time = {}
|
||||
|
||||
def check_trade_cooldown(self, symbol, action):
|
||||
"""Check if trade cooldown is active for a symbol"""
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
return True
|
||||
|
||||
if symbol not in self.last_trade_time:
|
||||
return True
|
||||
|
||||
# Get time since last trade
|
||||
time_since_last = time.time() - self.last_trade_time[symbol]
|
||||
|
||||
# Check if cooldown is still active
|
||||
if time_since_last < self.trade_cooldown_seconds:
|
||||
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
|
||||
f"need {self.trade_cooldown_seconds}s")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Add the method
|
||||
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
|
||||
logger.info("Trade cooldown feature added")
|
||||
|
||||
@staticmethod
|
||||
def _fix_position_tracking(trading_executor):
|
||||
"""Fix position tracking to ensure proper reset between trades"""
|
||||
# Store original close_position method
|
||||
original_close_position = getattr(trading_executor, 'close_position', None)
|
||||
|
||||
if original_close_position:
|
||||
def close_position_fixed(self, symbol, price=None):
|
||||
"""Fixed close_position with proper position cleanup"""
|
||||
try:
|
||||
# Call original close_position
|
||||
result = original_close_position(symbol, price)
|
||||
|
||||
# Ensure position is fully cleaned up
|
||||
if symbol in self.positions:
|
||||
del self.positions[symbol]
|
||||
|
||||
# Clear recent entry for this symbol
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
del self.recent_entries[symbol]
|
||||
|
||||
logger.info(f"Position for {symbol} fully cleaned up after close")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_fixed: {e}")
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
|
||||
logger.info("Position tracking fix applied")
|
||||
else:
|
||||
logger.warning("close_position method not found, skipping position tracking fix")
|
Reference in New Issue
Block a user