integrating new CNN model

This commit is contained in:
Dobromir Popov
2025-07-23 16:59:35 +03:00
parent 1be270cc5c
commit f1d63f9da6
15 changed files with 1896 additions and 1003 deletions

View File

@ -289,11 +289,9 @@ class TradingOrchestrator:
# Initialize CNN Model
try:
from NN.models.enhanced_cnn import EnhancedCNN
from NN.models.standardized_cnn import StandardizedCNN
cnn_input_shape = self.config.cnn.get('input_shape', 100)
cnn_n_actions = self.config.cnn.get('n_actions', 3)
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
self.cnn_model = StandardizedCNN()
self.cnn_model.to(self.device) # Move CNN model to the determined device
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
@ -325,8 +323,8 @@ class TradingOrchestrator:
logger.info("Enhanced CNN model initialized")
except ImportError:
try:
from NN.models.cnn_model import CNNModel
self.cnn_model = CNNModel()
from NN.models.standardized_cnn import StandardizedCNN
self.cnn_model = StandardizedCNN()
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN

View File

@ -0,0 +1,261 @@
"""
Trading Executor Fix
This module provides fixes for the trading executor to address:
1. Duplicate entry prices
2. P&L calculation issues
3. Position tracking problems
Apply these fixes by importing and applying the patch in main.py
"""
import logging
import time
from datetime import datetime
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class TradingExecutorFix:
"""Fixes for the TradingExecutor class"""
@staticmethod
def apply_fixes(trading_executor):
"""Apply all fixes to the trading executor"""
logger.info("Applying TradingExecutor fixes...")
# Store original methods for patching
original_execute_action = trading_executor.execute_action
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
# Apply fixes
TradingExecutorFix._fix_price_caching(trading_executor)
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
TradingExecutorFix._add_trade_cooldown(trading_executor)
TradingExecutorFix._fix_position_tracking(trading_executor)
logger.info("TradingExecutor fixes applied successfully")
return trading_executor
@staticmethod
def _fix_price_caching(trading_executor):
"""Fix price caching to prevent duplicate entry prices"""
# Add a price cache timestamp to track when prices were last updated
trading_executor.price_cache_timestamp = {}
# Store original get_current_price method
original_get_current_price = trading_executor.get_current_price
def get_current_price_fixed(self, symbol):
"""Fixed get_current_price method with cache invalidation"""
now = time.time()
# Force price refresh if cache is older than 5 seconds
if symbol in self.price_cache_timestamp:
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
if cache_age > 5: # 5 seconds max cache age
# Clear price cache for this symbol
if hasattr(self, 'current_prices') and symbol in self.current_prices:
del self.current_prices[symbol]
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
# Call original method to get fresh price
price = original_get_current_price(symbol)
# Update cache timestamp
self.price_cache_timestamp[symbol] = now
return price
# Apply the patch
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
logger.info("Price caching fix applied")
@staticmethod
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
"""Fix P&L calculation to ensure accuracy"""
def calculate_pnl_fixed(self, position, current_price=None):
"""Fixed P&L calculation with proper handling of position side"""
try:
# Get position details
entry_price = position.entry_price
size = position.size
side = position.side
# Use provided price or get current price
if current_price is None:
current_price = self.get_current_price(position.symbol)
# Calculate P&L based on position side
if side == 'LONG':
pnl = (current_price - entry_price) * size
else: # SHORT
pnl = (entry_price - current_price) * size
# Calculate fees (if available)
fees = getattr(position, 'fees', 0.0)
# Return both gross and net P&L
return {
'gross_pnl': pnl,
'fees': fees,
'net_pnl': pnl - fees
}
except Exception as e:
logger.error(f"Error calculating P&L: {e}")
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
# Apply the patch if original method exists
if original_calculate_pnl:
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
logger.info("P&L calculation fix applied")
else:
# Add the method if it doesn't exist
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
logger.info("P&L calculation method added")
@staticmethod
def _fix_execute_action(trading_executor, original_execute_action):
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
def execute_action_fixed(self, decision):
"""Fixed execute_action with duplicate entry prevention"""
try:
symbol = decision.symbol
action = decision.action
# Check for duplicate entry (same price as recent entry)
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
recent_entry = self.recent_entries[symbol]
current_price = self.get_current_price(symbol)
# If price is within 0.1% of recent entry, consider it a duplicate
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
time_diff = time.time() - recent_entry['timestamp']
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
# Mark decision as blocked
decision.blocked = True
decision.blocked_reason = "Duplicate entry prevention"
return False
# Check trade cooldown
if hasattr(self, '_check_trade_cooldown'):
if not self._check_trade_cooldown(symbol, action):
# Mark decision as blocked
decision.blocked = True
decision.blocked_reason = "Trade cooldown active"
return False
# Force price refresh before execution
fresh_price = self.get_current_price(symbol)
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
# Update decision price with fresh price
decision.price = fresh_price
# Call original execute_action
result = original_execute_action(decision)
# If execution was successful, record the entry
if result and not getattr(decision, 'blocked', False):
if not hasattr(self, 'recent_entries'):
self.recent_entries = {}
self.recent_entries[symbol] = {
'price': fresh_price,
'timestamp': time.time(),
'action': action
}
# Record last trade time for cooldown
if not hasattr(self, 'last_trade_time'):
self.last_trade_time = {}
self.last_trade_time[symbol] = time.time()
return result
except Exception as e:
logger.error(f"Error in execute_action_fixed: {e}")
return False
# Apply the patch
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
# Initialize recent entries dict if it doesn't exist
if not hasattr(trading_executor, 'recent_entries'):
trading_executor.recent_entries = {}
logger.info("Execute action fix applied")
@staticmethod
def _add_trade_cooldown(trading_executor):
"""Add trade cooldown to prevent rapid consecutive trades"""
# Add cooldown settings
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
if not hasattr(trading_executor, 'last_trade_time'):
trading_executor.last_trade_time = {}
def check_trade_cooldown(self, symbol, action):
"""Check if trade cooldown is active for a symbol"""
if not hasattr(self, 'last_trade_time'):
self.last_trade_time = {}
return True
if symbol not in self.last_trade_time:
return True
# Get time since last trade
time_since_last = time.time() - self.last_trade_time[symbol]
# Check if cooldown is still active
if time_since_last < self.trade_cooldown_seconds:
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
f"need {self.trade_cooldown_seconds}s")
return False
return True
# Add the method
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
logger.info("Trade cooldown feature added")
@staticmethod
def _fix_position_tracking(trading_executor):
"""Fix position tracking to ensure proper reset between trades"""
# Store original close_position method
original_close_position = getattr(trading_executor, 'close_position', None)
if original_close_position:
def close_position_fixed(self, symbol, price=None):
"""Fixed close_position with proper position cleanup"""
try:
# Call original close_position
result = original_close_position(symbol, price)
# Ensure position is fully cleaned up
if symbol in self.positions:
del self.positions[symbol]
# Clear recent entry for this symbol
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
del self.recent_entries[symbol]
logger.info(f"Position for {symbol} fully cleaned up after close")
return result
except Exception as e:
logger.error(f"Error in close_position_fixed: {e}")
return False
# Apply the patch
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
logger.info("Position tracking fix applied")
else:
logger.warning("close_position method not found, skipping position tracking fix")