trading works!
This commit is contained in:
@ -221,7 +221,7 @@ class DQNAgent:
|
||||
# Check if mixed precision training should be used
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.use_mixed_precision = False
|
||||
@ -1083,6 +1083,11 @@ class DQNAgent:
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
if not total_loss.requires_grad:
|
||||
logger.warning("Total loss tensor does not require gradients, skipping backward pass")
|
||||
return 0.0
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
@ -1263,6 +1268,11 @@ class DQNAgent:
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Ensure loss requires gradients before backward pass
|
||||
if not loss.requires_grad:
|
||||
logger.warning("Loss tensor does not require gradients, skipping backward pass")
|
||||
return 0.0
|
||||
|
||||
# Backward pass with scaled gradients
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
|
@ -127,8 +127,8 @@ orchestrator:
|
||||
# Model weights for decision combination
|
||||
cnn_weight: 0.7 # Weight for CNN predictions
|
||||
rl_weight: 0.3 # Weight for RL decisions
|
||||
confidence_threshold: 0.15
|
||||
confidence_threshold_close: 0.08
|
||||
confidence_threshold: 0.45
|
||||
confidence_threshold_close: 0.35
|
||||
decision_frequency: 30
|
||||
|
||||
# Multi-symbol coordination
|
||||
|
@ -106,8 +106,8 @@ class TradingOrchestrator:
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
||||
# we do not cap the decision frequency in time - only in confidence
|
||||
# self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
# Decision frequency limit to prevent excessive trading
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||||
|
||||
# NEW: Aggressiveness parameters
|
||||
@ -612,7 +612,7 @@ class TradingOrchestrator:
|
||||
await self.make_trading_decision(symbol)
|
||||
await asyncio.sleep(1) # Small delay between symbols
|
||||
|
||||
# await asyncio.sleep(self.decision_frequency)
|
||||
await asyncio.sleep(self.decision_frequency)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading decision loop: {e}")
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
@ -930,8 +930,8 @@ class TradingOrchestrator:
|
||||
# Check if enough time has passed since last decision
|
||||
if symbol in self.last_decision_time:
|
||||
time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds()
|
||||
# if time_since_last < self.decision_frequency:
|
||||
# return None
|
||||
if time_since_last < self.decision_frequency:
|
||||
return None
|
||||
|
||||
# Get current market data
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
@ -1353,6 +1353,24 @@ class TradingOrchestrator:
|
||||
best_action = 'HOLD'
|
||||
reasoning['threshold_applied'] = True
|
||||
|
||||
# Signal accumulation check - require multiple confident signals
|
||||
if best_action in ['BUY', 'SELL']:
|
||||
required_signals = 3 # Require 3 confident signals
|
||||
recent_decisions = self.get_recent_decisions(symbol, limit=5)
|
||||
|
||||
# Count recent signals in the same direction
|
||||
same_direction_count = sum(1 for d in recent_decisions
|
||||
if d.action == best_action and d.confidence > entry_threshold)
|
||||
|
||||
if same_direction_count < required_signals:
|
||||
best_action = 'HOLD'
|
||||
reasoning['signal_accumulation'] = True
|
||||
reasoning['required_signals'] = required_signals
|
||||
reasoning['current_signals'] = same_direction_count
|
||||
logger.info(f"Signal accumulation: {same_direction_count}/{required_signals} signals for {best_action}")
|
||||
else:
|
||||
logger.info(f"Signal accumulation satisfied: {same_direction_count}/{required_signals} signals for {best_action}")
|
||||
|
||||
# Add P&L-based decision adjustment
|
||||
best_action, best_confidence = self._apply_pnl_feedback(
|
||||
best_action, best_confidence, current_position_pnl, symbol, reasoning
|
||||
|
@ -139,6 +139,13 @@ class TradingExecutor:
|
||||
self.lock = RLock()
|
||||
self.lock_timeout = 10.0 # 10 second timeout for order execution
|
||||
|
||||
# Open order management
|
||||
self.max_open_orders = 2 # Maximum number of open orders allowed
|
||||
self.open_orders_count = 0 # Current count of open orders
|
||||
|
||||
# Trading symbols
|
||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
# Connect to exchange
|
||||
if self.trading_enabled:
|
||||
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
||||
@ -324,6 +331,109 @@ class TradingExecutor:
|
||||
self.lock.release()
|
||||
logger.debug(f"LOCK RELEASED: {action} for {symbol}")
|
||||
|
||||
def _get_open_orders_count(self) -> int:
|
||||
"""Get current count of open orders across all symbols"""
|
||||
try:
|
||||
if self.simulation_mode:
|
||||
return 0
|
||||
|
||||
total_open_orders = 0
|
||||
for symbol in self.symbols:
|
||||
open_orders = self.exchange.get_open_orders(symbol)
|
||||
total_open_orders += len(open_orders)
|
||||
|
||||
return total_open_orders
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting open orders count: {e}")
|
||||
return 0
|
||||
|
||||
def _can_place_new_order(self) -> bool:
|
||||
"""Check if we can place a new order based on open order limit"""
|
||||
current_count = self._get_open_orders_count()
|
||||
can_place = current_count < self.max_open_orders
|
||||
|
||||
if not can_place:
|
||||
logger.warning(f"Cannot place new order: {current_count}/{self.max_open_orders} open orders")
|
||||
|
||||
return can_place
|
||||
|
||||
def sync_open_orders(self) -> Dict[str, Any]:
|
||||
"""Synchronize open orders with exchange and update internal state
|
||||
|
||||
Returns:
|
||||
dict: Sync result with status and order details
|
||||
"""
|
||||
try:
|
||||
if self.simulation_mode:
|
||||
return {
|
||||
'status': 'success',
|
||||
'message': 'Simulation mode - no real orders to sync',
|
||||
'orders': [],
|
||||
'count': 0
|
||||
}
|
||||
|
||||
sync_result = {
|
||||
'status': 'started',
|
||||
'orders': [],
|
||||
'count': 0,
|
||||
'errors': []
|
||||
}
|
||||
|
||||
total_orders = 0
|
||||
all_orders = []
|
||||
|
||||
# Sync orders for each symbol
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
open_orders = self.exchange.get_open_orders(symbol)
|
||||
if open_orders:
|
||||
symbol_orders = []
|
||||
for order in open_orders:
|
||||
order_info = {
|
||||
'symbol': symbol,
|
||||
'order_id': order.get('orderId'),
|
||||
'side': order.get('side'),
|
||||
'type': order.get('type'),
|
||||
'quantity': float(order.get('origQty', 0)),
|
||||
'price': float(order.get('price', 0)),
|
||||
'status': order.get('status'),
|
||||
'time': order.get('time')
|
||||
}
|
||||
symbol_orders.append(order_info)
|
||||
all_orders.append(order_info)
|
||||
|
||||
total_orders += len(symbol_orders)
|
||||
logger.info(f"Synced {len(symbol_orders)} open orders for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error syncing orders for {symbol}: {e}"
|
||||
logger.error(error_msg)
|
||||
sync_result['errors'].append(error_msg)
|
||||
|
||||
# Update internal state
|
||||
self.open_orders_count = total_orders
|
||||
|
||||
sync_result.update({
|
||||
'status': 'success',
|
||||
'orders': all_orders,
|
||||
'count': total_orders,
|
||||
'message': f"Synced {total_orders} open orders across {len(self.symbols)} symbols"
|
||||
})
|
||||
|
||||
logger.info(f"Open order sync completed: {total_orders} orders")
|
||||
return sync_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in open order sync: {e}")
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'orders': [],
|
||||
'count': 0,
|
||||
'errors': [str(e)]
|
||||
}
|
||||
|
||||
def _cancel_open_orders(self, symbol: str) -> int:
|
||||
"""Cancel all open orders for a symbol and return count of cancelled orders"""
|
||||
try:
|
||||
@ -395,6 +505,11 @@ class TradingExecutor:
|
||||
logger.warning(f"Maximum concurrent positions reached: {len(self.positions)}")
|
||||
return False
|
||||
|
||||
# Check open order limit
|
||||
if not self._can_place_new_order():
|
||||
logger.warning(f"Maximum open orders reached: {self._get_open_orders_count()}/{self.max_open_orders}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _execute_buy(self, symbol: str, confidence: float, current_price: float) -> bool:
|
||||
@ -848,13 +963,11 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
# Calculate fees using real API data when available
|
||||
fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price)
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
fees = simulated_fees
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
@ -989,13 +1102,11 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
# Calculate fees using real API data when available
|
||||
fees = self._calculate_real_trading_fees(order, symbol, position.quantity, current_price)
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
fees = simulated_fees
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
@ -1213,6 +1324,38 @@ class TradingExecutor:
|
||||
logger.error(f"Error getting account balance: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_real_trading_fees(self, order_result: Dict[str, Any], symbol: str,
|
||||
quantity: float, price: float) -> float:
|
||||
"""Calculate trading fees using real API data when available
|
||||
|
||||
Args:
|
||||
order_result: Order result from exchange API
|
||||
symbol: Trading symbol
|
||||
quantity: Order quantity
|
||||
price: Execution price
|
||||
|
||||
Returns:
|
||||
float: Trading fee amount in quote currency
|
||||
"""
|
||||
try:
|
||||
# Try to get actual fee from API response first
|
||||
if order_result and 'fills' in order_result:
|
||||
total_commission = 0.0
|
||||
for fill in order_result['fills']:
|
||||
commission = float(fill.get('commission', 0))
|
||||
total_commission += commission
|
||||
|
||||
if total_commission > 0:
|
||||
logger.info(f"Using real API fee: {total_commission}")
|
||||
return total_commission
|
||||
|
||||
# Fall back to config-based calculation
|
||||
return self._calculate_trading_fee(order_result, symbol, quantity, price)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating real fees: {e}, falling back to config-based")
|
||||
return self._calculate_trading_fee(order_result, symbol, quantity, price)
|
||||
|
||||
def _calculate_trading_fee(self, order_result: Dict[str, Any], symbol: str,
|
||||
quantity: float, price: float) -> float:
|
||||
"""Calculate trading fee based on order execution details with enhanced MEXC API support
|
||||
|
122
test_order_sync_and_fees.py
Normal file
122
test_order_sync_and_fees.py
Normal file
@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Open Order Sync and Fee Calculation
|
||||
Verify that open orders are properly synchronized and fees are correctly calculated in PnL
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Load environment variables
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
if os.path.exists('.env'):
|
||||
with open('.env', 'r') as f:
|
||||
for line in f:
|
||||
if line.strip() and not line.startswith('#'):
|
||||
key, value = line.strip().split('=', 1)
|
||||
os.environ[key] = value
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_open_order_sync_and_fees():
|
||||
"""Test open order synchronization and fee calculation"""
|
||||
print("🧪 Testing Open Order Sync and Fee Calculation...")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
# Create trading executor
|
||||
executor = TradingExecutor()
|
||||
|
||||
print(f"📊 Current State Analysis:")
|
||||
print(f" Open orders count: {executor._get_open_orders_count()}")
|
||||
print(f" Max open orders: {executor.max_open_orders}")
|
||||
print(f" Can place new order: {executor._can_place_new_order()}")
|
||||
|
||||
# Test open order synchronization
|
||||
print(f"\n🔍 Open Order Sync Analysis:")
|
||||
print(f" - Current sync method: _get_open_orders_count()")
|
||||
print(f" - Counts orders across all symbols")
|
||||
print(f" - Real-time API queries")
|
||||
print(f" - Handles API errors gracefully")
|
||||
|
||||
# Check if there's a dedicated sync method
|
||||
if hasattr(executor, 'sync_open_orders'):
|
||||
print(f" ✅ Dedicated sync method exists")
|
||||
else:
|
||||
print(f" ⚠️ No dedicated sync method - using count method")
|
||||
|
||||
# Test fee calculation in PnL
|
||||
print(f"\n💰 Fee Calculation Analysis:")
|
||||
|
||||
# Check fee calculation methods
|
||||
if hasattr(executor, '_calculate_trading_fee'):
|
||||
print(f" ✅ Fee calculation method exists")
|
||||
else:
|
||||
print(f" ❌ No dedicated fee calculation method")
|
||||
|
||||
# Check if fees are included in PnL
|
||||
print(f"\n📈 PnL Fee Integration:")
|
||||
print(f" - TradeRecord includes fees field")
|
||||
print(f" - PnL calculation: pnl = gross_pnl - fees")
|
||||
print(f" - Fee rates from config: taker_fee, maker_fee")
|
||||
|
||||
# Check fee sync
|
||||
print(f"\n🔄 Fee Synchronization:")
|
||||
if hasattr(executor, 'sync_fees_with_api'):
|
||||
print(f" ✅ Fee sync method exists")
|
||||
else:
|
||||
print(f" ❌ No fee sync method")
|
||||
|
||||
# Check config sync
|
||||
if hasattr(executor, 'config_sync'):
|
||||
print(f" ✅ Config synchronizer exists")
|
||||
else:
|
||||
print(f" ❌ No config synchronizer")
|
||||
|
||||
print(f"\n📋 Issues Found:")
|
||||
|
||||
# Issue 1: No dedicated open order sync method
|
||||
if not hasattr(executor, 'sync_open_orders'):
|
||||
print(f" ❌ Missing: Dedicated open order synchronization method")
|
||||
print(f" Current: Only counts orders, doesn't sync state")
|
||||
|
||||
# Issue 2: Fee calculation may not be comprehensive
|
||||
print(f" ⚠️ Potential: Fee calculation uses simulated rates")
|
||||
print(f" Should: Use actual API fees when available")
|
||||
|
||||
# Issue 3: Check if fees are properly tracked
|
||||
print(f" ✅ Good: Fees are tracked in TradeRecord")
|
||||
print(f" ✅ Good: PnL includes fee deduction")
|
||||
|
||||
print(f"\n🔧 Recommended Fixes:")
|
||||
print(f" 1. Add dedicated open order sync method")
|
||||
print(f" 2. Enhance fee calculation with real API data")
|
||||
print(f" 3. Add periodic order state synchronization")
|
||||
print(f" 4. Improve fee tracking accuracy")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing order sync and fees: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_open_order_sync_and_fees()
|
||||
if success:
|
||||
print(f"\n🎉 Order sync and fee test completed!")
|
||||
else:
|
||||
print(f"\n💥 Order sync and fee test failed!")
|
Reference in New Issue
Block a user