From 439611cf88dcb2f0b5f773bdef0e5e5a9533c41d Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 15 Jul 2025 01:10:37 +0300 Subject: [PATCH] trading works! --- NN/models/dqn_agent.py | 12 ++- config.yaml | 4 +- core/orchestrator.py | 28 +++++-- core/trading_executor.py | 159 ++++++++++++++++++++++++++++++++++-- test_order_sync_and_fees.py | 122 +++++++++++++++++++++++++++ 5 files changed, 309 insertions(+), 16 deletions(-) create mode 100644 test_order_sync_and_fees.py diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 4afd369..47623f9 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -221,7 +221,7 @@ class DQNAgent: # Check if mixed precision training should be used if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: self.use_mixed_precision = True - self.scaler = torch.amp.GradScaler('cuda') + self.scaler = torch.cuda.amp.GradScaler() logger.info("Mixed precision training enabled") else: self.use_mixed_precision = False @@ -1083,6 +1083,11 @@ class DQNAgent: # Reset gradients self.optimizer.zero_grad() + # Ensure loss requires gradients before backward pass + if not total_loss.requires_grad: + logger.warning("Total loss tensor does not require gradients, skipping backward pass") + return 0.0 + # Backward pass total_loss.backward() @@ -1263,6 +1268,11 @@ class DQNAgent: # Just use Q-value loss loss = q_loss + # Ensure loss requires gradients before backward pass + if not loss.requires_grad: + logger.warning("Loss tensor does not require gradients, skipping backward pass") + return 0.0 + # Backward pass with scaled gradients self.scaler.scale(loss).backward() diff --git a/config.yaml b/config.yaml index 9df1102..8d6dac2 100644 --- a/config.yaml +++ b/config.yaml @@ -127,8 +127,8 @@ orchestrator: # Model weights for decision combination cnn_weight: 0.7 # Weight for CNN predictions rl_weight: 0.3 # Weight for RL decisions - confidence_threshold: 0.15 - confidence_threshold_close: 0.08 + confidence_threshold: 0.45 + confidence_threshold_close: 0.35 decision_frequency: 30 # Multi-symbol coordination diff --git a/core/orchestrator.py b/core/orchestrator.py index 70f1b8b..5f4e314 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -106,8 +106,8 @@ class TradingOrchestrator: # Configuration - AGGRESSIVE for more training data self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20 self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10 - # we do not cap the decision frequency in time - only in confidence - # self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30) + # Decision frequency limit to prevent excessive trading + self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30) self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols # NEW: Aggressiveness parameters @@ -612,7 +612,7 @@ class TradingOrchestrator: await self.make_trading_decision(symbol) await asyncio.sleep(1) # Small delay between symbols - # await asyncio.sleep(self.decision_frequency) + await asyncio.sleep(self.decision_frequency) except Exception as e: logger.error(f"Error in trading decision loop: {e}") await asyncio.sleep(5) # Wait before retrying @@ -930,8 +930,8 @@ class TradingOrchestrator: # Check if enough time has passed since last decision if symbol in self.last_decision_time: time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds() - # if time_since_last < self.decision_frequency: - # return None + if time_since_last < self.decision_frequency: + return None # Get current market data current_price = self.data_provider.get_current_price(symbol) @@ -1353,6 +1353,24 @@ class TradingOrchestrator: best_action = 'HOLD' reasoning['threshold_applied'] = True + # Signal accumulation check - require multiple confident signals + if best_action in ['BUY', 'SELL']: + required_signals = 3 # Require 3 confident signals + recent_decisions = self.get_recent_decisions(symbol, limit=5) + + # Count recent signals in the same direction + same_direction_count = sum(1 for d in recent_decisions + if d.action == best_action and d.confidence > entry_threshold) + + if same_direction_count < required_signals: + best_action = 'HOLD' + reasoning['signal_accumulation'] = True + reasoning['required_signals'] = required_signals + reasoning['current_signals'] = same_direction_count + logger.info(f"Signal accumulation: {same_direction_count}/{required_signals} signals for {best_action}") + else: + logger.info(f"Signal accumulation satisfied: {same_direction_count}/{required_signals} signals for {best_action}") + # Add P&L-based decision adjustment best_action, best_confidence = self._apply_pnl_feedback( best_action, best_confidence, current_position_pnl, symbol, reasoning diff --git a/core/trading_executor.py b/core/trading_executor.py index 8993ff5..c633ee5 100644 --- a/core/trading_executor.py +++ b/core/trading_executor.py @@ -139,6 +139,13 @@ class TradingExecutor: self.lock = RLock() self.lock_timeout = 10.0 # 10 second timeout for order execution + # Open order management + self.max_open_orders = 2 # Maximum number of open orders allowed + self.open_orders_count = 0 # Current count of open orders + + # Trading symbols + self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) + # Connect to exchange if self.trading_enabled: logger.info("TRADING EXECUTOR: Attempting to connect to exchange...") @@ -324,6 +331,109 @@ class TradingExecutor: self.lock.release() logger.debug(f"LOCK RELEASED: {action} for {symbol}") + def _get_open_orders_count(self) -> int: + """Get current count of open orders across all symbols""" + try: + if self.simulation_mode: + return 0 + + total_open_orders = 0 + for symbol in self.symbols: + open_orders = self.exchange.get_open_orders(symbol) + total_open_orders += len(open_orders) + + return total_open_orders + + except Exception as e: + logger.error(f"Error getting open orders count: {e}") + return 0 + + def _can_place_new_order(self) -> bool: + """Check if we can place a new order based on open order limit""" + current_count = self._get_open_orders_count() + can_place = current_count < self.max_open_orders + + if not can_place: + logger.warning(f"Cannot place new order: {current_count}/{self.max_open_orders} open orders") + + return can_place + + def sync_open_orders(self) -> Dict[str, Any]: + """Synchronize open orders with exchange and update internal state + + Returns: + dict: Sync result with status and order details + """ + try: + if self.simulation_mode: + return { + 'status': 'success', + 'message': 'Simulation mode - no real orders to sync', + 'orders': [], + 'count': 0 + } + + sync_result = { + 'status': 'started', + 'orders': [], + 'count': 0, + 'errors': [] + } + + total_orders = 0 + all_orders = [] + + # Sync orders for each symbol + for symbol in self.symbols: + try: + open_orders = self.exchange.get_open_orders(symbol) + if open_orders: + symbol_orders = [] + for order in open_orders: + order_info = { + 'symbol': symbol, + 'order_id': order.get('orderId'), + 'side': order.get('side'), + 'type': order.get('type'), + 'quantity': float(order.get('origQty', 0)), + 'price': float(order.get('price', 0)), + 'status': order.get('status'), + 'time': order.get('time') + } + symbol_orders.append(order_info) + all_orders.append(order_info) + + total_orders += len(symbol_orders) + logger.info(f"Synced {len(symbol_orders)} open orders for {symbol}") + + except Exception as e: + error_msg = f"Error syncing orders for {symbol}: {e}" + logger.error(error_msg) + sync_result['errors'].append(error_msg) + + # Update internal state + self.open_orders_count = total_orders + + sync_result.update({ + 'status': 'success', + 'orders': all_orders, + 'count': total_orders, + 'message': f"Synced {total_orders} open orders across {len(self.symbols)} symbols" + }) + + logger.info(f"Open order sync completed: {total_orders} orders") + return sync_result + + except Exception as e: + logger.error(f"Error in open order sync: {e}") + return { + 'status': 'error', + 'message': str(e), + 'orders': [], + 'count': 0, + 'errors': [str(e)] + } + def _cancel_open_orders(self, symbol: str) -> int: """Cancel all open orders for a symbol and return count of cancelled orders""" try: @@ -395,6 +505,11 @@ class TradingExecutor: logger.warning(f"Maximum concurrent positions reached: {len(self.positions)}") return False + # Check open order limit + if not self._can_place_new_order(): + logger.warning(f"Maximum open orders reached: {self._get_open_orders_count()}/{self.max_open_orders}") + return False + return True def _execute_buy(self, symbol: str, confidence: float, current_price: float) -> bool: @@ -848,13 +963,11 @@ class TradingExecutor: ) if order: - # Calculate simulated fees in simulation mode - taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + # Calculate fees using real API data when available + fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price) # Calculate P&L, fees, and hold time pnl = position.calculate_pnl(current_price) - fees = simulated_fees exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() @@ -989,13 +1102,11 @@ class TradingExecutor: ) if order: - # Calculate simulated fees in simulation mode - taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + # Calculate fees using real API data when available + fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price) # Calculate P&L, fees, and hold time pnl = position.calculate_pnl(current_price) - fees = simulated_fees exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() @@ -1213,6 +1324,38 @@ class TradingExecutor: logger.error(f"Error getting account balance: {e}") return {} + def _calculate_real_trading_fees(self, order_result: Dict[str, Any], symbol: str, + quantity: float, price: float) -> float: + """Calculate trading fees using real API data when available + + Args: + order_result: Order result from exchange API + symbol: Trading symbol + quantity: Order quantity + price: Execution price + + Returns: + float: Trading fee amount in quote currency + """ + try: + # Try to get actual fee from API response first + if order_result and 'fills' in order_result: + total_commission = 0.0 + for fill in order_result['fills']: + commission = float(fill.get('commission', 0)) + total_commission += commission + + if total_commission > 0: + logger.info(f"Using real API fee: {total_commission}") + return total_commission + + # Fall back to config-based calculation + return self._calculate_trading_fee(order_result, symbol, quantity, price) + + except Exception as e: + logger.warning(f"Error calculating real fees: {e}, falling back to config-based") + return self._calculate_trading_fee(order_result, symbol, quantity, price) + def _calculate_trading_fee(self, order_result: Dict[str, Any], symbol: str, quantity: float, price: float) -> float: """Calculate trading fee based on order execution details with enhanced MEXC API support diff --git a/test_order_sync_and_fees.py b/test_order_sync_and_fees.py new file mode 100644 index 0000000..0323b3c --- /dev/null +++ b/test_order_sync_and_fees.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +""" +Test Open Order Sync and Fee Calculation +Verify that open orders are properly synchronized and fees are correctly calculated in PnL +""" + +import os +import sys +import logging + +# Add the project root to the path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +# Load environment variables +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + if os.path.exists('.env'): + with open('.env', 'r') as f: + for line in f: + if line.strip() and not line.startswith('#'): + key, value = line.strip().split('=', 1) + os.environ[key] = value + +from core.trading_executor import TradingExecutor + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def test_open_order_sync_and_fees(): + """Test open order synchronization and fee calculation""" + print("๐Ÿงช Testing Open Order Sync and Fee Calculation...") + print("=" * 70) + + try: + # Create trading executor + executor = TradingExecutor() + + print(f"๐Ÿ“Š Current State Analysis:") + print(f" Open orders count: {executor._get_open_orders_count()}") + print(f" Max open orders: {executor.max_open_orders}") + print(f" Can place new order: {executor._can_place_new_order()}") + + # Test open order synchronization + print(f"\n๐Ÿ” Open Order Sync Analysis:") + print(f" - Current sync method: _get_open_orders_count()") + print(f" - Counts orders across all symbols") + print(f" - Real-time API queries") + print(f" - Handles API errors gracefully") + + # Check if there's a dedicated sync method + if hasattr(executor, 'sync_open_orders'): + print(f" โœ… Dedicated sync method exists") + else: + print(f" โš ๏ธ No dedicated sync method - using count method") + + # Test fee calculation in PnL + print(f"\n๐Ÿ’ฐ Fee Calculation Analysis:") + + # Check fee calculation methods + if hasattr(executor, '_calculate_trading_fee'): + print(f" โœ… Fee calculation method exists") + else: + print(f" โŒ No dedicated fee calculation method") + + # Check if fees are included in PnL + print(f"\n๐Ÿ“ˆ PnL Fee Integration:") + print(f" - TradeRecord includes fees field") + print(f" - PnL calculation: pnl = gross_pnl - fees") + print(f" - Fee rates from config: taker_fee, maker_fee") + + # Check fee sync + print(f"\n๐Ÿ”„ Fee Synchronization:") + if hasattr(executor, 'sync_fees_with_api'): + print(f" โœ… Fee sync method exists") + else: + print(f" โŒ No fee sync method") + + # Check config sync + if hasattr(executor, 'config_sync'): + print(f" โœ… Config synchronizer exists") + else: + print(f" โŒ No config synchronizer") + + print(f"\n๐Ÿ“‹ Issues Found:") + + # Issue 1: No dedicated open order sync method + if not hasattr(executor, 'sync_open_orders'): + print(f" โŒ Missing: Dedicated open order synchronization method") + print(f" Current: Only counts orders, doesn't sync state") + + # Issue 2: Fee calculation may not be comprehensive + print(f" โš ๏ธ Potential: Fee calculation uses simulated rates") + print(f" Should: Use actual API fees when available") + + # Issue 3: Check if fees are properly tracked + print(f" โœ… Good: Fees are tracked in TradeRecord") + print(f" โœ… Good: PnL includes fee deduction") + + print(f"\n๐Ÿ”ง Recommended Fixes:") + print(f" 1. Add dedicated open order sync method") + print(f" 2. Enhance fee calculation with real API data") + print(f" 3. Add periodic order state synchronization") + print(f" 4. Improve fee tracking accuracy") + + return True + + except Exception as e: + print(f"โŒ Error testing order sync and fees: {e}") + return False + +if __name__ == "__main__": + success = test_open_order_sync_and_fees() + if success: + print(f"\n๐ŸŽ‰ Order sync and fee test completed!") + else: + print(f"\n๐Ÿ’ฅ Order sync and fee test failed!") \ No newline at end of file